Commit 5f5bf53c authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Improve flash attention example and layout comparison logic (#270)

* [Refactor] Improve flash attention example and layout comparison logic

- Removed unnecessary annotation for `lse_local_split` in the flash attention example to streamline the code.
- Updated the handling of `lse_local_split` to utilize parallel processing for better performance.
- Refactored kernel compilation and profiling logic to enhance clarity and maintainability in the flash attention example.
- Added a condition in `FragmentNode::IsEqual` to handle broadcast cases, improving the robustness of layout comparisons.

* lint fix

* [Enhancement] Add support for shared memory scope in Fill operation

- Introduced handling for `shared.dyn` and `shared` memory scopes in the Fill operation.
- Implemented parallel operation and layout inference for improved performance in shared memory scenarios.
- Updated thread loop partitioning and vectorization logic to accommodate new memory scope handling.
parent 2abd6ab7
...@@ -167,7 +167,6 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -167,7 +167,6 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
T.annotate_layout({ T.annotate_layout({
o_accum_local: T.Fragment(o_accum_local.shape, forward_thread_fn=lambda i, j: i), o_accum_local: T.Fragment(o_accum_local.shape, forward_thread_fn=lambda i, j: i),
lse_local_split: T.Fragment(lse_local_split.shape, forward_thread_fn=lambda i: i),
o_shared: tilelang.layout.make_swizzled_layout(o_shared), o_shared: tilelang.layout.make_swizzled_layout(o_shared),
po_shared: tilelang.layout.make_swizzled_layout(po_shared), po_shared: tilelang.layout.make_swizzled_layout(po_shared),
}) })
...@@ -190,7 +189,9 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -190,7 +189,9 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
for k in T.Pipelined(num_split, num_stages=2): for k in T.Pipelined(num_split, num_stages=2):
T.copy(Output_partial[bz, bx * block_M:(bx + 1) * block_M, by, k, :], po_shared) T.copy(Output_partial[bz, bx * block_M:(bx + 1) * block_M, by, k, :], po_shared)
T.copy(po_shared, po_local) T.copy(po_shared, po_local)
T.copy(lse_local[k, :], lse_local_split) for i in T.Parallel(block_M):
lse_local_split[i] = lse_local[k, i]
# T.copy(lse_local[k, :], lse_local_split)
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
scale_local[i] = T.exp2(lse_local_split[i] - lse_logsum_local[i]) scale_local[i] = T.exp2(lse_local_split[i] - lse_logsum_local[i])
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
...@@ -304,14 +305,15 @@ if __name__ == "__main__": ...@@ -304,14 +305,15 @@ if __name__ == "__main__":
BLOCK_N = 64 # if D_HEAD <= 128 else 32 BLOCK_N = 64 # if D_HEAD <= 128 else 32
program = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal, BLOCK_M, BLOCK_N) program = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal, BLOCK_M, BLOCK_N)
ref_program = partial(ref_program, causal=causal) ref_program = partial(ref_program, causal=causal)
mod = tilelang.compile(program, out_idx=[5], target="cuda", execution_backend="dlpack") kernel = tilelang.compile(program, out_idx=[5], target="cuda", execution_backend="dlpack")
profiler = mod.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks passed!") print("All checks passed!")
latency = profiler.do_bench(ref_program, warmup=500) latency = profiler.do_bench(ref_program, warmup=500)
print("{:.2f} ms".format(latency)) print("{:.2f} ms".format(latency))
print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) print("{:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(profiler.mod, n_warmup=10, n_repeat=10, profiler="tvm") latency = profiler.do_bench(n_warmup=10, n_repeat=10)
print("{:.4f} ms".format(latency)) print("{:.4f} ms".format(latency))
print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) print("{:.2f} TFlops".format(total_flops / latency * 1e-9))
...@@ -408,6 +408,10 @@ bool FragmentNode::IsEqual(const FragmentNode *other, bool skip_index) const { ...@@ -408,6 +408,10 @@ bool FragmentNode::IsEqual(const FragmentNode *other, bool skip_index) const {
// a[i, j] = b[j, i] in register level. // a[i, j] = b[j, i] in register level.
bool ret = StructuralEqual()(this->InputShape(), other->InputShape()); bool ret = StructuralEqual()(this->InputShape(), other->InputShape());
if (!ret) {
// may be broadcast case
return true;
}
ret &= StructuralEqual()(this->OutputShape(), other->OutputShape()); ret &= StructuralEqual()(this->OutputShape(), other->OutputShape());
ret &= StructuralEqual()(this->ReplicateExtent(), other->ReplicateExtent()); ret &= StructuralEqual()(this->ReplicateExtent(), other->ReplicateExtent());
ret &= StructuralEqual()(this->ThreadExtent(), other->ThreadExtent()); ret &= StructuralEqual()(this->ThreadExtent(), other->ThreadExtent());
......
...@@ -437,6 +437,14 @@ Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -437,6 +437,14 @@ Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto init_loop = MakeSIMTLoop(analyzer); auto init_loop = MakeSIMTLoop(analyzer);
auto vectorized_thread_loop = VectorizeLoop(init_loop); auto vectorized_thread_loop = VectorizeLoop(init_loop);
return vectorized_thread_loop; return vectorized_thread_loop;
} else if (dst.scope() == "shared.dyn" || dst.scope() == "shared") {
auto par_op = std::make_unique<ParallelOp>(MakeSIMTLoop(analyzer));
par_op->InferLayout({T.target, T.block_size, T.layout_map},
InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout());
auto vectorized_thread_loop = VectorizeLoop(thread_loop);
return vectorized_thread_loop;
} else { } else {
LOG(FATAL) << "Unsupported scope " << dst.scope(); LOG(FATAL) << "Unsupported scope " << dst.scope();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment