Commit 804735bf authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Improve tensor shape compatibility checks in AutoTuner (#590)

- Simplified the shape comparison logic in the AutoTuner class to enhance readability and maintainability.
- Ensured that the shape compatibility checks are more concise while preserving functionality, contributing to overall code clarity.
parent cce6aed8
...@@ -336,11 +336,10 @@ class AutoTuner: ...@@ -336,11 +336,10 @@ class AutoTuner:
continue continue
# Check tensor compatibility using generator expression # Check tensor compatibility using generator expression
if len(params) == len(self.jit_input_tensors):
def shape_equal(a, b): def shape_equal(a, b):
if len(a.shape) != len(b.shape): return all(
return False a_dim == b_dim or isinstance(a_dim, Var) or isinstance(b_dim, Var)
return all(a_dim == b_dim or isinstance(a_dim, Var) or isinstance(b_dim, Var) for a_dim, b_dim in zip(a.shape, b.shape)) for a_dim, b_dim in zip(a.shape, b.shape))
if p.dtype != c.dtype or not shape_equal(p, c): if p.dtype != c.dtype or not shape_equal(p, c):
logger.warning( logger.warning(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment