Unverified Commit 242b43bb authored by Zhengju Tang's avatar Zhengju Tang Committed by GitHub
Browse files

[BugFix] Fix split kernel layout bug of GQA decode (#1386)

* [BugFix] Fix split kernel layout bug of GQA decode

* [BugFix] Avoid local with Parallel; use robust fragment instead
parent d933d65b
...@@ -111,3 +111,9 @@ cmake-build-*/ ...@@ -111,3 +111,9 @@ cmake-build-*/
# host checks logs # host checks logs
maint/host_checks/logs/* maint/host_checks/logs/*
# ncu
*.ncu-rep
# csv
*.csv
\ No newline at end of file
...@@ -15,7 +15,7 @@ torch.random.manual_seed(0) ...@@ -15,7 +15,7 @@ torch.random.manual_seed(0)
def get_configs(): def get_configs():
block_N = [64, 128] block_N = [64, 128]
block_H = [64] block_H = [64]
num_split = [2, 4, 8] num_split = [1, 2, 4, 8]
num_stages = [1, 2, 3] num_stages = [1, 2, 3]
threads = [128] threads = [128]
_configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads))
...@@ -42,7 +42,7 @@ def get_heuristic_config() -> Tuple[Dict, int]: ...@@ -42,7 +42,7 @@ def get_heuristic_config() -> Tuple[Dict, int]:
if sm_version == 89: if sm_version == 89:
cfg = dict(block_N=128, block_H=64, num_split=1, num_stages=0, threads=128) cfg = dict(block_N=128, block_H=64, num_split=1, num_stages=0, threads=128)
else: else:
cfg = dict(block_N=128, block_H=64, num_split=1, num_stages=2, threads=128) cfg = dict(block_N=128, block_H=64, num_split=8, num_stages=2, threads=128)
return cfg, sm_version return cfg, sm_version
...@@ -229,10 +229,9 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, ...@@ -229,10 +229,9 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
po_local = T.alloc_fragment([dim], dtype) po_local = T.alloc_fragment([dim], dtype)
o_accum_local = T.alloc_fragment([dim], accum_dtype) o_accum_local = T.alloc_fragment([dim], accum_dtype)
lse_local = T.alloc_fragment([num_split, 128], dtype) lse_local = T.alloc_fragment([num_split, 128], dtype)
lse_local_split = T.alloc_local([1], accum_dtype) lse_logsum_local = T.alloc_fragment([128], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_fragment([128], accum_dtype) lse_max_local = T.alloc_fragment([128], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_fragment([128], accum_dtype)
T.annotate_layout({ T.annotate_layout({
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
...@@ -246,17 +245,19 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, ...@@ -246,17 +245,19 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
for k, j in T.Parallel(num_split, 128): for k, j in T.Parallel(num_split, 128):
lse_local[k, j] = glse[bz, by, k] lse_local[k, j] = glse[bz, by, k]
T.reduce_max(lse_local, lse_max_local, dim=0, clear=True) T.reduce_max(lse_local, lse_max_local, dim=0, clear=True)
for k in T.Pipelined(num_split, num_stages=1): for k in T.serial(num_split):
lse_local_split[0] = glse[bz, by, k] for j in T.Parallel(128):
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) lse_logsum_local[j] += T.exp2(lse_local[k, j] - lse_max_local[j])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] for j in T.Parallel(128):
lse_logsum_local[j] = T.log2(lse_logsum_local[j]) + lse_max_local[j]
for k in T.serial(num_split): for k in T.serial(num_split):
for i in T.Parallel(dim): for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, by, k, i] po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k] for j in T.Parallel(128):
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) scale_local[j] = T.exp2(lse_local[k, j] - lse_logsum_local[j])
# Note: Pay attention to dim and the number of threads in Parallel
for i in T.Parallel(dim): for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[0] o_accum_local[i] += po_local[i] * scale_local[i]
for i in T.Parallel(dim): for i in T.Parallel(dim):
Output[bz, by, i] = o_accum_local[i] Output[bz, by, i] = o_accum_local[i]
...@@ -474,7 +475,7 @@ def main(batch: int = 1, ...@@ -474,7 +475,7 @@ def main(batch: int = 1,
print(o_ref) print(o_ref)
assert_similar(o, o_ref, name="o_ref") assert_similar(o, o_ref, name="o_ref")
assert_similar(o_ref_split, o_ref, name="o_ref_split") assert_similar(o, o_ref_split, name="o_ref_split")
print("All checks pass.") print("All checks pass.")
latency = profiler.do_bench(ref_program, warmup=500) latency = profiler.do_bench(ref_program, warmup=500)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment