Unverified Commit d7164abf authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Language][Reshape] Improve variable handling and ensure correctness during Layout Reshape (#1248)

* fix

* Refactor tensor reshaping in fp8_lighting_indexer.py

- Replaced the allocation of `s_reshaped` with a reshape operation to improve clarity and performance.
- Updated the logic in the computation of `s_reshaped` to utilize the reshaped tensor, enhancing the overall functionality of the attention mechanism.

* Refactor analyzer usage in Layout and Fragment reshaping

- Consolidated analyzer logic in the `Reshape` methods of `LayoutNode` and `FragmentNode` to utilize a fallback analyzer, improving code clarity and preventing potential null dereference issues.
- Updated variable binding and simplification calls to use the selected analyzer consistently, enhancing robustness in shape validation and index computation.
parent c1398550
......@@ -127,7 +127,7 @@ def mqa_attn_return_logits(
index_k_shared = T.alloc_shared([block_N, index_dim], dtype)
index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype)
s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype)
s_reshaped = T.alloc_fragment([block_N, block_Q, heads], accum_dtype)
s_reshaped = T.reshape(s, (block_N, block_Q, heads))
logits = T.alloc_fragment([block_N, block_Q], accum_dtype)
weights = T.alloc_fragment([block_Q, heads], accum_dtype)
......@@ -165,7 +165,7 @@ def mqa_attn_return_logits(
for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads):
s_reshaped[bn_i, bq_i,
h_i] = (T.max(s[bn_i, bq_i * heads + h_i], 0) *
h_i] = (T.max(s_reshaped[bn_i, bq_i, h_i], 0) *
weights[bq_i, h_i]) * index_k_scale_fragment[bn_i]
T.reduce_sum(s_reshaped, logits, dim=-1, clear=True)
......
......@@ -313,20 +313,21 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
shape_product *= dim;
}
if (analyzer) {
ICHECK(analyzer->CanProveEqual(input_shape_product, shape_product))
<< "InputShape() = " << InputShape() << " shape = " << shape;
} else {
arith::Analyzer local_analyzer;
ICHECK(local_analyzer.CanProveEqual(input_shape_product, shape_product))
<< "InputShape() = " << InputShape() << " shape = " << shape;
}
// Use provided analyzer if present, otherwise a local fallback to avoid
// potential null dereference paths flagged by static analysis.
arith::Analyzer fallback_analyzer;
arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer;
ICHECK(az->CanProveEqual(input_shape_product, shape_product))
<< "InputShape() = " << InputShape() << " shape = " << shape;
// Step 2. Create new forward indices by reshaping
// For each dimension in the new shape, we create a placeholder variable
Array<Var> new_vars;
new_vars.reserve(shape.size());
for (size_t i = 0; i < shape.size(); ++i) {
new_vars.push_back(InputPlaceholder(i));
auto var = Var(std::string("n_") + std::to_string(i), shape[i].dtype());
az->Bind(var, Range(0, shape[i]));
new_vars.push_back(var);
}
// Step 3. Compute the flat index from new shape indices
// flat_index = k0 * (s1 * s2 * ...) + k1 * (s2 * s3 * ...) + ... + kn
......@@ -362,7 +363,11 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
substituted =
Substitute(substituted, {{InputPlaceholder(i), original_indices[i]}});
}
new_forward_index.push_back(substituted);
new_forward_index.push_back(az->Simplify(substituted));
}
for (size_t i = 0; i < new_vars.size(); ++i) {
new_forward_index =
Substitute(new_forward_index, {{new_vars[i], InputPlaceholder(i)}});
}
return Layout(shape, new_forward_index);
}
......@@ -382,21 +387,25 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
for (const auto &d : shape)
shape_prod *= d;
if (analyzer) {
ICHECK(analyzer->CanProveEqual(input_prod, shape_prod))
<< "InputShape() = " << InputShape() << " shape = " << shape
<< " input fragment layout is = " << DebugOutput();
} else {
arith::Analyzer local_analyzer;
ICHECK(local_analyzer.CanProveEqual(input_prod, shape_prod))
<< "InputShape() = " << InputShape() << " shape = " << shape;
}
// Use provided analyzer if present, otherwise a local fallback.
arith::Analyzer fallback_analyzer;
arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer;
ICHECK(az->CanProveEqual(input_prod, shape_prod))
<< "InputShape() = " << InputShape() << " shape = " << shape
<< " input fragment layout is = " << DebugOutput();
// 2) Build flat index from new-shape indices
Array<Var> new_vars;
new_vars.reserve(shape.size());
for (size_t i = 0; i < shape.size(); ++i)
new_vars.push_back(InputPlaceholder(i));
for (size_t i = 0; i < shape.size(); ++i) {
// Cannot use InputPlaceholder(i) here, because it would cause name capture
// (variable capture) with InputPlaceholder(i) in upper scopes. Therefore,
// we must create a fresh variable here to avoid confusion when
// substituting.
auto var = Var(std::string("n_") + std::to_string(i), shape[i].dtype());
az->Bind(var, Range(0, shape[i]));
new_vars.push_back(var);
}
PrimExpr flat = Integer(0);
for (size_t i = 0; i < shape.size(); ++i) {
......@@ -405,7 +414,6 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
stride = stride * shape[j];
flat = flat + new_vars[i] * stride;
}
// 3) Recover original indices from flat index
Array<PrimExpr> orig_indices;
PrimExpr remain = flat;
......@@ -416,7 +424,6 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
orig_indices.push_back(floordiv(remain, stride));
remain = floormod(remain, stride);
}
// 4) Substitute old placeholders with expressions of new indices
Array<PrimExpr> new_forward_index;
for (const auto &e : forward_index_) {
......@@ -424,15 +431,22 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
for (size_t i = 0; i < InputShape().size(); ++i) {
cur = Substitute(cur, {{InputPlaceholder(i), orig_indices[i]}});
}
cur = az->Simplify(cur);
new_forward_index.push_back(cur);
}
PrimExpr new_forward_thread = forward_thread_;
for (size_t i = 0; i < InputShape().size(); ++i) {
new_forward_thread = Substitute(new_forward_thread,
{{InputPlaceholder(i), orig_indices[i]}});
}
new_forward_thread = az->Simplify(new_forward_thread);
for (size_t i = 0; i < new_vars.size(); ++i) {
auto var = new_vars[i];
new_forward_index =
Substitute(new_forward_index, {{var, InputPlaceholder(i)}});
new_forward_thread =
Substitute(new_forward_thread, {{var, InputPlaceholder(i)}});
}
Fragment reshaped(shape, new_forward_index, new_forward_thread,
ReplicateExtent(), std::nullopt);
if (thread_range_.defined()) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment