Commit 7b66fb19 authored by 徐畅's avatar 徐畅 Committed by LeiWang1999
Browse files

[CI] Add flash_decoding example to CI (#487)

* [CI] Add flash_decoding example to CI

* Add output of ref latency

* format example_gqa_decode.py
parent d4f096ef
...@@ -240,7 +240,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -240,7 +240,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
Output[bz, by, i] = o_accum_local[i] Output[bz, by, i] = o_accum_local[i]
@T.prim_func @T.prim_func
def main_split( def flashattn_gqa_decode_split(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
...@@ -253,7 +253,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -253,7 +253,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
combine(glse, Output_partial, Output) combine(glse, Output_partial, Output)
@T.prim_func @T.prim_func
def main_no_split( def flashattn_gqa_decode_no_split(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
...@@ -265,9 +265,9 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -265,9 +265,9 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
flash_attn(Q, K, V, mask, Output) flash_attn(Q, K, V, mask, Output)
if num_split > 1: if num_split > 1:
return main_split return flashattn_gqa_decode_split
else: else:
return main_no_split return flashattn_gqa_decode_no_split
if tune: if tune:
...@@ -410,22 +410,18 @@ def reduce_ref(Q, K, V, mask, glse, Output_partial): ...@@ -410,22 +410,18 @@ def reduce_ref(Q, K, V, mask, glse, Output_partial):
return o.to(torch.float16) return o.to(torch.float16)
if __name__ == "__main__": def main(batch: int = 1,
parser = argparse.ArgumentParser() heads: int = 32,
parser.add_argument('--batch', type=int, default=1, help='batch size') groups: int = 8,
parser.add_argument('--heads', type=int, default=32, help='heads') kv_seqlen: int = 8192,
parser.add_argument('--groups', type=int, default=8, help='groups') dim: int = 128,
parser.add_argument('--kv_seqlen', type=int, default=8192, help='kv sequence length') tune: bool = False):
parser.add_argument('--dim', type=int, default=128, help='dim') batch, heads, groups, kv_seqlen, dim = batch, heads, groups, kv_seqlen, dim
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
batch, heads, groups, kv_seqlen, dim = args.batch, args.heads, args.groups, args.kv_seqlen, args.dim
qk_flops = 2 * batch * heads * kv_seqlen * dim qk_flops = 2 * batch * heads * kv_seqlen * dim
pv_flops = 2 * batch * heads * kv_seqlen * dim pv_flops = 2 * batch * heads * kv_seqlen * dim
total_flops = qk_flops + pv_flops total_flops = qk_flops + pv_flops
if (not args.tune): if (not tune):
def get_heuristic_config() -> dict: def get_heuristic_config() -> dict:
# Get CUDA device properties # Get CUDA device properties
...@@ -453,7 +449,7 @@ if __name__ == "__main__": ...@@ -453,7 +449,7 @@ if __name__ == "__main__":
} }
config = get_heuristic_config() config = get_heuristic_config()
program = flashattn(batch, heads, groups, kv_seqlen, dim, tune=args.tune)(**config) program = flashattn(batch, heads, groups, kv_seqlen, dim, tune=tune)(**config)
kernel = tilelang.compile(program, out_idx=[6]) kernel = tilelang.compile(program, out_idx=[6])
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
...@@ -465,10 +461,23 @@ if __name__ == "__main__": ...@@ -465,10 +461,23 @@ if __name__ == "__main__":
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else: else:
best_result = flashattn(batch, heads, groups, kv_seqlen, dim, tune=args.tune) best_result = flashattn(batch, heads, groups, kv_seqlen, dim, tune=tune)
best_latency = best_result.latency best_latency = best_result.latency
best_config = best_result.config best_config = best_result.config
ref_latency = best_result.ref_latency ref_latency = best_result.ref_latency
print(f"Best latency: {best_latency}") print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}") print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}") print(f"Best config: {best_config}")
print(f"Ref latency: {ref_latency}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--groups', type=int, default=8, help='groups')
parser.add_argument('--kv_seqlen', type=int, default=8192, help='kv sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
main(args.batch, args.heads, args.groups, args.kv_seqlen, args.dim, args.tune)
...@@ -199,7 +199,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -199,7 +199,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
T.copy(o_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) T.copy(o_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
@T.prim_func @T.prim_func
def main( def flashattn_mha_inference(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_kv, dtype), K: T.Tensor(shape_kv, dtype),
V: T.Tensor(shape_kv, dtype), V: T.Tensor(shape_kv, dtype),
...@@ -210,7 +210,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -210,7 +210,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
flash_attn_split(Q, K, V, glse, Output_partial) flash_attn_split(Q, K, V, glse, Output_partial)
combine(glse, Output_partial, Output) combine(glse, Output_partial, Output)
return main return flashattn_mha_inference
def ref_program(Q, K, V, glse, Output_partial, causal): def ref_program(Q, K, V, glse, Output_partial, causal):
...@@ -293,7 +293,7 @@ def flash_split_ref(Q, K, V, causal): ...@@ -293,7 +293,7 @@ def flash_split_ref(Q, K, V, causal):
3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4) 3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4)
if __name__ == "__main__": def main():
BATCH, H, Q_CTX, KV_CTX, D_HEAD = 1, 32, 128, 8192, 128 BATCH, H, Q_CTX, KV_CTX, D_HEAD = 1, 32, 128, 8192, 128
causal = False causal = False
flops_per_matmul = 2.0 * BATCH * H * Q_CTX * KV_CTX * D_HEAD flops_per_matmul = 2.0 * BATCH * H * Q_CTX * KV_CTX * D_HEAD
...@@ -303,17 +303,21 @@ if __name__ == "__main__": ...@@ -303,17 +303,21 @@ if __name__ == "__main__":
BLOCK_M = 128 BLOCK_M = 128
BLOCK_N = 64 # if D_HEAD <= 128 else 32 BLOCK_N = 64 # if D_HEAD <= 128 else 32
program = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal, BLOCK_M, BLOCK_N) program = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal, BLOCK_M, BLOCK_N)
ref_program = partial(ref_program, causal=causal) ref_fn = partial(ref_program, causal=causal)
kernel = tilelang.compile( kernel = tilelang.compile(
program, out_idx=[5], pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}) program, out_idx=[5], pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True})
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_fn, rtol=0.01, atol=0.01)
print("All checks passed!") print("All checks passed!")
latency = profiler.do_bench(ref_program, warmup=500) latency = profiler.do_bench(ref_fn, warmup=500)
print("{:.2f} ms".format(latency)) print("{:.2f} ms".format(latency))
print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) print("{:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(n_warmup=10, n_repeat=10) latency = profiler.do_bench(n_warmup=10, n_repeat=10)
print("{:.4f} ms".format(latency)) print("{:.4f} ms".format(latency))
print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) print("{:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
main()
\ No newline at end of file
import tilelang.testing
import example_gqa_decode
import example_mha_inference
def test_example_example_gqa_decode():
example_gqa_decode.main()
def test_example_example_mha_inference():
example_mha_inference.main()
if __name__ == "__main__":
tilelang.testing.main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment