Commit 0fd82ed5 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Fix layout conflict issue for gqa decoding examples (#314)

* Remove logging statement from LoopVectorizerDynamic Substitute method for cleaner output.

* Refactor flashattn example to improve CUDA configuration handling

- Updated the `flashattn` function in `example_gqa_decode.py` to utilize a heuristic configuration based on CUDA device capabilities, enhancing compatibility with different architectures.
- Replaced local variable allocations with more efficient constructs and removed unnecessary logging statements for cleaner output.
- Adjusted the `do_bench` method call to streamline performance profiling.

* lint fix
parent c30904ea
...@@ -203,34 +203,29 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -203,34 +203,29 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
po_local = T.alloc_fragment([dim], dtype) po_local = T.alloc_fragment([dim], dtype)
o_accum_local = T.alloc_fragment([dim], accum_dtype) o_accum_local = T.alloc_fragment([dim], accum_dtype)
lse_local = T.alloc_fragment([num_split, 128], dtype) lse_local = T.alloc_fragment([num_split, 128], dtype)
lse_local_split = T.alloc_local([1], accum_dtype) lse_local_split = T.alloc_var(accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype) lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_fragment([128], accum_dtype) lse_max_local = T.alloc_fragment([128], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({ T.annotate_layout({
lse_logsum_local: lse_max_local: T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i),
T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
lse_max_local:
T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i),
lse_local:
T.Fragment(lse_local.shape, forward_thread_fn=lambda i, j: j),
}) })
T.clear(lse_logsum_local) # T.clear(lse_logsum_local)
T.clear(o_accum_local) T.clear(o_accum_local)
for k in T.Parallel(num_split): for k in T.Parallel(num_split):
lse_local[k, 0] = glse[bz, by, k] lse_local[k, 0] = glse[bz, by, k]
T.reduce_max(lse_local, lse_max_local, dim=0, clear=True) T.reduce_max(lse_local, lse_max_local, dim=0, clear=True)
for k in T.Pipelined(num_split, num_stages=1): for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, by, k] lse_local_split = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) lse_logsum_local[0] += T.exp2(lse_local_split - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
for k in T.serial(num_split): for k in T.serial(num_split):
for i in T.Parallel(dim): for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, by, k, i] po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k] lse_local_split = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) scale_local[0] = T.exp2(lse_local_split - lse_logsum_local[0])
for i in T.Parallel(dim): for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[0] o_accum_local[i] += po_local[i] * scale_local[0]
for i in T.Parallel(dim): for i in T.Parallel(dim):
...@@ -427,9 +422,34 @@ if __name__ == "__main__": ...@@ -427,9 +422,34 @@ if __name__ == "__main__":
total_flops = qk_flops + pv_flops total_flops = qk_flops + pv_flops
if (not args.tune): if (not args.tune):
program = flashattn(
batch, heads, groups, kv_seqlen, dim, tune=args.tune)( def get_heuristic_config() -> dict:
block_N=128, block_H=64, num_split=8, num_stages=2, threads=128) # Get CUDA device properties
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
device = torch.cuda.current_device()
sm_major, sm_minor = torch.cuda.get_device_capability(device)
sm_version = sm_major * 10 + sm_minor
print(f"CUDA device capability: {sm_version}")
if sm_version == 89:
return {
"block_N": 128,
"block_H": 64,
"num_split": 8,
"num_stages": 0,
"threads": 128
}
else:
return {
"block_N": 128,
"block_H": 64,
"num_split": 8,
"num_stages": 2,
"threads": 128
}
config = get_heuristic_config()
program = flashattn(batch, heads, groups, kv_seqlen, dim, tune=args.tune)(**config)
kernel = tilelang.compile(program, out_idx=[6]) kernel = tilelang.compile(program, out_idx=[6])
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
...@@ -437,7 +457,7 @@ if __name__ == "__main__": ...@@ -437,7 +457,7 @@ if __name__ == "__main__":
latency = profiler.do_bench(ref_program, warmup=500) latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency)) print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(kernel.rt_module, warmup=500, profiler="auto") latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else: else:
......
...@@ -340,7 +340,6 @@ class LoopVectorizerDynamic : public IRMutatorWithAnalyzer { ...@@ -340,7 +340,6 @@ class LoopVectorizerDynamic : public IRMutatorWithAnalyzer {
public: public:
static Stmt Substitute(Stmt stmt) { static Stmt Substitute(Stmt stmt) {
arith::Analyzer analyzer; arith::Analyzer analyzer;
LOG(INFO) << "LoopVectorizerDynamic Substitute";
LoopVectorizerDynamic substituter(&analyzer); LoopVectorizerDynamic substituter(&analyzer);
stmt = substituter.VisitStmt(stmt); stmt = substituter.VisitStmt(stmt);
return stmt; return stmt;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment