Unverified Commit d7164abf authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Language][Reshape] Improve variable handling and ensure correctness during Layout Reshape (#1248)

* fix

* Refactor tensor reshaping in fp8_lighting_indexer.py

- Replaced the allocation of `s_reshaped` with a reshape operation to improve clarity and performance.
- Updated the logic in the computation of `s_reshaped` to utilize the reshaped tensor, enhancing the overall functionality of the attention mechanism.

* Refactor analyzer usage in Layout and Fragment reshaping

- Consolidated analyzer logic in the `Reshape` methods of `LayoutNode` and `FragmentNode` to utilize a fallback analyzer, improving code clarity and preventing potential null dereference issues.
- Updated variable binding and simplification calls to use the selected analyzer consistently, enhancing robustness in shape validation and index computation.
parent c1398550
...@@ -127,7 +127,7 @@ def mqa_attn_return_logits( ...@@ -127,7 +127,7 @@ def mqa_attn_return_logits(
index_k_shared = T.alloc_shared([block_N, index_dim], dtype) index_k_shared = T.alloc_shared([block_N, index_dim], dtype)
index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype) index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype)
s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype) s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype)
s_reshaped = T.alloc_fragment([block_N, block_Q, heads], accum_dtype) s_reshaped = T.reshape(s, (block_N, block_Q, heads))
logits = T.alloc_fragment([block_N, block_Q], accum_dtype) logits = T.alloc_fragment([block_N, block_Q], accum_dtype)
weights = T.alloc_fragment([block_Q, heads], accum_dtype) weights = T.alloc_fragment([block_Q, heads], accum_dtype)
...@@ -165,7 +165,7 @@ def mqa_attn_return_logits( ...@@ -165,7 +165,7 @@ def mqa_attn_return_logits(
for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads): for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads):
s_reshaped[bn_i, bq_i, s_reshaped[bn_i, bq_i,
h_i] = (T.max(s[bn_i, bq_i * heads + h_i], 0) * h_i] = (T.max(s_reshaped[bn_i, bq_i, h_i], 0) *
weights[bq_i, h_i]) * index_k_scale_fragment[bn_i] weights[bq_i, h_i]) * index_k_scale_fragment[bn_i]
T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) T.reduce_sum(s_reshaped, logits, dim=-1, clear=True)
......
...@@ -313,20 +313,21 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape, ...@@ -313,20 +313,21 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
shape_product *= dim; shape_product *= dim;
} }
if (analyzer) { // Use provided analyzer if present, otherwise a local fallback to avoid
ICHECK(analyzer->CanProveEqual(input_shape_product, shape_product)) // potential null dereference paths flagged by static analysis.
<< "InputShape() = " << InputShape() << " shape = " << shape; arith::Analyzer fallback_analyzer;
} else { arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer;
arith::Analyzer local_analyzer; ICHECK(az->CanProveEqual(input_shape_product, shape_product))
ICHECK(local_analyzer.CanProveEqual(input_shape_product, shape_product)) << "InputShape() = " << InputShape() << " shape = " << shape;
<< "InputShape() = " << InputShape() << " shape = " << shape;
}
// Step 2. Create new forward indices by reshaping // Step 2. Create new forward indices by reshaping
// For each dimension in the new shape, we create a placeholder variable // For each dimension in the new shape, we create a placeholder variable
Array<Var> new_vars; Array<Var> new_vars;
new_vars.reserve(shape.size());
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
new_vars.push_back(InputPlaceholder(i)); auto var = Var(std::string("n_") + std::to_string(i), shape[i].dtype());
az->Bind(var, Range(0, shape[i]));
new_vars.push_back(var);
} }
// Step 3. Compute the flat index from new shape indices // Step 3. Compute the flat index from new shape indices
// flat_index = k0 * (s1 * s2 * ...) + k1 * (s2 * s3 * ...) + ... + kn // flat_index = k0 * (s1 * s2 * ...) + k1 * (s2 * s3 * ...) + ... + kn
...@@ -362,7 +363,11 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape, ...@@ -362,7 +363,11 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
substituted = substituted =
Substitute(substituted, {{InputPlaceholder(i), original_indices[i]}}); Substitute(substituted, {{InputPlaceholder(i), original_indices[i]}});
} }
new_forward_index.push_back(substituted); new_forward_index.push_back(az->Simplify(substituted));
}
for (size_t i = 0; i < new_vars.size(); ++i) {
new_forward_index =
Substitute(new_forward_index, {{new_vars[i], InputPlaceholder(i)}});
} }
return Layout(shape, new_forward_index); return Layout(shape, new_forward_index);
} }
...@@ -382,21 +387,25 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape, ...@@ -382,21 +387,25 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
for (const auto &d : shape) for (const auto &d : shape)
shape_prod *= d; shape_prod *= d;
if (analyzer) { // Use provided analyzer if present, otherwise a local fallback.
ICHECK(analyzer->CanProveEqual(input_prod, shape_prod)) arith::Analyzer fallback_analyzer;
<< "InputShape() = " << InputShape() << " shape = " << shape arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer;
<< " input fragment layout is = " << DebugOutput(); ICHECK(az->CanProveEqual(input_prod, shape_prod))
} else { << "InputShape() = " << InputShape() << " shape = " << shape
arith::Analyzer local_analyzer; << " input fragment layout is = " << DebugOutput();
ICHECK(local_analyzer.CanProveEqual(input_prod, shape_prod))
<< "InputShape() = " << InputShape() << " shape = " << shape;
}
// 2) Build flat index from new-shape indices // 2) Build flat index from new-shape indices
Array<Var> new_vars; Array<Var> new_vars;
new_vars.reserve(shape.size()); new_vars.reserve(shape.size());
for (size_t i = 0; i < shape.size(); ++i) for (size_t i = 0; i < shape.size(); ++i) {
new_vars.push_back(InputPlaceholder(i)); // Cannot use InputPlaceholder(i) here, because it would cause name capture
// (variable capture) with InputPlaceholder(i) in upper scopes. Therefore,
// we must create a fresh variable here to avoid confusion when
// substituting.
auto var = Var(std::string("n_") + std::to_string(i), shape[i].dtype());
az->Bind(var, Range(0, shape[i]));
new_vars.push_back(var);
}
PrimExpr flat = Integer(0); PrimExpr flat = Integer(0);
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
...@@ -405,7 +414,6 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape, ...@@ -405,7 +414,6 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
stride = stride * shape[j]; stride = stride * shape[j];
flat = flat + new_vars[i] * stride; flat = flat + new_vars[i] * stride;
} }
// 3) Recover original indices from flat index // 3) Recover original indices from flat index
Array<PrimExpr> orig_indices; Array<PrimExpr> orig_indices;
PrimExpr remain = flat; PrimExpr remain = flat;
...@@ -416,7 +424,6 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape, ...@@ -416,7 +424,6 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
orig_indices.push_back(floordiv(remain, stride)); orig_indices.push_back(floordiv(remain, stride));
remain = floormod(remain, stride); remain = floormod(remain, stride);
} }
// 4) Substitute old placeholders with expressions of new indices // 4) Substitute old placeholders with expressions of new indices
Array<PrimExpr> new_forward_index; Array<PrimExpr> new_forward_index;
for (const auto &e : forward_index_) { for (const auto &e : forward_index_) {
...@@ -424,15 +431,22 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape, ...@@ -424,15 +431,22 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
for (size_t i = 0; i < InputShape().size(); ++i) { for (size_t i = 0; i < InputShape().size(); ++i) {
cur = Substitute(cur, {{InputPlaceholder(i), orig_indices[i]}}); cur = Substitute(cur, {{InputPlaceholder(i), orig_indices[i]}});
} }
cur = az->Simplify(cur);
new_forward_index.push_back(cur); new_forward_index.push_back(cur);
} }
PrimExpr new_forward_thread = forward_thread_; PrimExpr new_forward_thread = forward_thread_;
for (size_t i = 0; i < InputShape().size(); ++i) { for (size_t i = 0; i < InputShape().size(); ++i) {
new_forward_thread = Substitute(new_forward_thread, new_forward_thread = Substitute(new_forward_thread,
{{InputPlaceholder(i), orig_indices[i]}}); {{InputPlaceholder(i), orig_indices[i]}});
} }
new_forward_thread = az->Simplify(new_forward_thread);
for (size_t i = 0; i < new_vars.size(); ++i) {
auto var = new_vars[i];
new_forward_index =
Substitute(new_forward_index, {{var, InputPlaceholder(i)}});
new_forward_thread =
Substitute(new_forward_thread, {{var, InputPlaceholder(i)}});
}
Fragment reshaped(shape, new_forward_index, new_forward_thread, Fragment reshaped(shape, new_forward_index, new_forward_thread,
ReplicateExtent(), std::nullopt); ReplicateExtent(), std::nullopt);
if (thread_range_.defined()) { if (thread_range_.defined()) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment