Commit 8dec14e0 authored by yyttt6's avatar yyttt6 Committed by LeiWang1999
Browse files

[CI] Add Analyzer and blocksparse_attention examples to CI (#472)



* yes

* [Bugfix] fix the unexpected keyword error of autotune

* format

* test

* [CI] Add Analyzer and blocksparse_attention examples to CI

* format

* try

* try

* try

* try

* t

* format

* d

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent 66dba763
...@@ -47,7 +47,7 @@ def kernel(N, ...@@ -47,7 +47,7 @@ def kernel(N,
is_hopper = check_hopper() is_hopper = check_hopper()
@T.prim_func @T.prim_func
def main( def conv(
data: T.Tensor((N, H, W, C), dtype), data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype), kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype), out: T.Tensor((N, OH, OW, F), dtype),
...@@ -89,11 +89,16 @@ def kernel(N, ...@@ -89,11 +89,16 @@ def kernel(N,
T.copy(out_local, out_shared) T.copy(out_local, out_shared)
T.copy(out_shared, out_flat[by * block_M, bx * block_N]) T.copy(out_shared, out_flat[by * block_M, bx * block_N])
return main return conv
my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256) def main():
cuda_device = CUDA("cuda") my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256)
result = Analyzer.analysis(my_func, cuda_device) cuda_device = CUDA("cuda")
print(result) result = Analyzer.analysis(my_func, cuda_device)
print(f"Analyzed FLOPs: {result.total_flops}") print(result)
print(f"Analyzed FLOPs: {result.total_flops}")
if __name__ == "__main__":
main()
...@@ -17,7 +17,7 @@ def kernel( ...@@ -17,7 +17,7 @@ def kernel(
accum_dtype = "float" accum_dtype = "float"
@T.prim_func @T.prim_func
def main( def matmul(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
...@@ -41,13 +41,18 @@ def kernel( ...@@ -41,13 +41,18 @@ def kernel(
T.copy(C_local, C_shared) T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N]) T.copy(C_shared, C[by * block_M, bx * block_N])
return main return matmul
my_func = kernel(128, 128, 32, 3, 128, True) def main():
my_func = kernel(128, 128, 32, 3, 128, True)
cuda_device = CUDA("cuda") cuda_device = CUDA("cuda")
result = Analyzer.analysis(my_func, cuda_device) result = Analyzer.analysis(my_func, cuda_device)
print(f"Analyzed FLOPs: {result.total_flops}") print(f"Analyzed FLOPs: {result.total_flops}")
print(f"Expected FLOPs: {2 * M * N * K}") print(f"Expected FLOPs: {2 * M * N * K}")
if __name__ == "__main__":
main()
import tilelang.testing
import example_gemm_analyze
import example_conv_analyze
def test_example_gemm_analyze():
example_gemm_analyze.main()
def test_example_conv_analyze():
example_conv_analyze.main()
if __name__ == "__main__":
tilelang.testing.main()
...@@ -377,6 +377,10 @@ def test_topk_sparse_attention_qlt_kl(): ...@@ -377,6 +377,10 @@ def test_topk_sparse_attention_qlt_kl():
print("Pass topk sparse attention test with qlen < klen") print("Pass topk sparse attention test with qlen < klen")
if __name__ == "__main__": def main():
test_topk_sparse_attention() test_topk_sparse_attention()
test_topk_sparse_attention_qlt_kl() test_topk_sparse_attention_qlt_kl()
if __name__ == "__main__":
main()
...@@ -116,7 +116,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -116,7 +116,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.prim_func @T.prim_func
def main( def blocksparse_flashattn(
Q: T.Tensor(shape, dtype), Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype), K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype), V: T.Tensor(shape, dtype),
...@@ -163,7 +163,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -163,7 +163,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
return main return blocksparse_flashattn
return kernel_func(block_M, block_N, num_stages, threads) return kernel_func(block_M, block_N, num_stages, threads)
...@@ -217,5 +217,9 @@ def test_topk_sparse_attention(): ...@@ -217,5 +217,9 @@ def test_topk_sparse_attention():
print("Pass topk sparse attention test with qlen == klen") print("Pass topk sparse attention test with qlen == klen")
if __name__ == "__main__": def main():
test_topk_sparse_attention() test_topk_sparse_attention()
if __name__ == "__main__":
main()
...@@ -220,6 +220,7 @@ class SparseFlashAttn(torch.nn.Module): ...@@ -220,6 +220,7 @@ class SparseFlashAttn(torch.nn.Module):
heads = self.heads heads = self.heads
heads_kv = self.heads_kv heads_kv = self.heads_kv
dim_v = self.dim_v dim_v = self.dim_v
dim = self.dim
block_size = self.block_size block_size = self.block_size
max_selected_blocks = block_indices.shape[-1] max_selected_blocks = block_indices.shape[-1]
...@@ -394,30 +395,20 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3): ...@@ -394,30 +395,20 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3):
print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}") print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}")
if __name__ == "__main__": def main(batch=8,
parser = argparse.ArgumentParser() heads=32,
parser.add_argument('--batch', type=int, default=8, help='batch size') heads_kv=8,
parser.add_argument('--heads', type=int, default=32, help='heads') max_cache_seqlen=8192,
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') dim=128,
parser.add_argument( dim_v=128,
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') sparse_ratio=0.8,
parser.add_argument('--dim', type=int, default=128, help='dim') block_size=32):
parser.add_argument('--dim_v', type=int, default=128, help='dim_v') batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') sparse_ratio = sparse_ratio
parser.add_argument('--block_size', type=int, default=32, help='block_size') block_size = block_size
args = parser.parse_args()
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v
sparse_ratio = args.sparse_ratio
block_size = args.block_size
qk_flops = 2 * batch * heads * max_cache_seqlen * dim
pv_flops = 2 * batch * heads * max_cache_seqlen * dim_v
total_flops = qk_flops + pv_flops
max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size))
print("max_selected_blocks: ", max_selected_blocks) print("max_selected_blocks: ", max_selected_blocks)
dtype = torch.float16 dtype = torch.float16
block_H = 64
Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda')
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda')
...@@ -491,3 +482,19 @@ if __name__ == "__main__": ...@@ -491,3 +482,19 @@ if __name__ == "__main__":
out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) out = sparse_kernel(Q, K, V, block_indices, cache_seqlens)
torch.cuda.synchronize() torch.cuda.synchronize()
print("sparse time: ", (time.time() - start) / 100 * 1000) print("sparse time: ", (time.time() - start) / 100 * 1000)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv')
parser.add_argument(
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--dim_v', type=int, default=128, help='dim_v')
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio')
parser.add_argument('--block_size', type=int, default=32, help='block_size')
args = parser.parse_args()
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v,
args.sparse_ratio, args.block_size)
...@@ -206,6 +206,7 @@ class SparseFlashAttn(torch.nn.Module): ...@@ -206,6 +206,7 @@ class SparseFlashAttn(torch.nn.Module):
heads = self.heads heads = self.heads
heads_kv = self.heads_kv heads_kv = self.heads_kv
dim_v = self.dim_v dim_v = self.dim_v
dim = self.dim
block_size = self.block_size block_size = self.block_size
block_H = self.block_H block_H = self.block_H
max_cache_seqlen = key.shape[1] max_cache_seqlen = key.shape[1]
...@@ -367,30 +368,20 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3): ...@@ -367,30 +368,20 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3):
print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}") print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}")
if __name__ == "__main__": def main(batch=8,
parser = argparse.ArgumentParser() heads=32,
parser.add_argument('--batch', type=int, default=8, help='batch size') heads_kv=8,
parser.add_argument('--heads', type=int, default=32, help='heads') max_cache_seqlen=8192,
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') dim=128,
parser.add_argument( dim_v=128,
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') sparse_ratio=0.8,
parser.add_argument('--dim', type=int, default=128, help='dim') block_size=32):
parser.add_argument('--dim_v', type=int, default=128, help='dim_v') batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') sparse_ratio = sparse_ratio
parser.add_argument('--block_size', type=int, default=32, help='block_size') block_size = block_size
args = parser.parse_args()
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v
sparse_ratio = args.sparse_ratio
block_size = args.block_size
qk_flops = 2 * batch * heads * max_cache_seqlen * dim
pv_flops = 2 * batch * heads * max_cache_seqlen * dim_v
total_flops = qk_flops + pv_flops
max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size))
print("max_selected_blocks: ", max_selected_blocks) print("max_selected_blocks: ", max_selected_blocks)
dtype = torch.float16 dtype = torch.float16
block_H = 64
Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda')
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda')
...@@ -454,3 +445,19 @@ if __name__ == "__main__": ...@@ -454,3 +445,19 @@ if __name__ == "__main__":
out = model(Q, K, V, block_mask, cache_seqlens) out = model(Q, K, V, block_mask, cache_seqlens)
torch.cuda.synchronize() torch.cuda.synchronize()
print("sparse time: ", (time.time() - start) / 100 * 1000) print("sparse time: ", (time.time() - start) / 100 * 1000)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv')
parser.add_argument(
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--dim_v', type=int, default=128, help='dim_v')
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio')
parser.add_argument('--block_size', type=int, default=32, help='block_size')
args = parser.parse_args()
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v,
args.sparse_ratio, args.block_size)
...@@ -348,22 +348,18 @@ def ref_program_fa(query, key, value, cache_seqlens): ...@@ -348,22 +348,18 @@ def ref_program_fa(query, key, value, cache_seqlens):
return output return output
if __name__ == "__main__": def main(batch=64,
parser = argparse.ArgumentParser() heads=32,
parser.add_argument('--batch', type=int, default=64, help='batch size') heads_kv=8,
parser.add_argument('--heads', type=int, default=32, help='heads') max_cache_seqlen=8192,
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') dim=128,
parser.add_argument( dim_v=128,
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') sparse_ratio=0.8,
parser.add_argument('--dim', type=int, default=128, help='dim') block_size=32):
parser.add_argument('--dim_v', type=int, default=128, help='dim_v')
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
parser.add_argument('--block_size', type=int, default=32, help='block_size') sparse_ratio = sparse_ratio
args = parser.parse_args() block_size = block_size
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v
sparse_ratio = args.sparse_ratio
block_size = args.block_size
qk_flops = 2 * batch * heads * max_cache_seqlen * dim qk_flops = 2 * batch * heads * max_cache_seqlen * dim
pv_flops = 2 * batch * heads * max_cache_seqlen * dim_v pv_flops = 2 * batch * heads * max_cache_seqlen * dim_v
total_flops = qk_flops + pv_flops total_flops = qk_flops + pv_flops
...@@ -464,3 +460,19 @@ if __name__ == "__main__": ...@@ -464,3 +460,19 @@ if __name__ == "__main__":
print(f"Average time of ref: {avg_time_ref:.6f} seconds") print(f"Average time of ref: {avg_time_ref:.6f} seconds")
print(f"Speedup: {avg_time_ref / avg_time:.2f}x") print(f"Speedup: {avg_time_ref / avg_time:.2f}x")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=64, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv')
parser.add_argument(
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--dim_v', type=int, default=128, help='dim_v')
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio')
parser.add_argument('--block_size', type=int, default=32, help='block_size')
args = parser.parse_args()
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v,
args.sparse_ratio, args.block_size)
...@@ -345,28 +345,23 @@ def ref_program_fa(query, key, value, cache_seqlens): ...@@ -345,28 +345,23 @@ def ref_program_fa(query, key, value, cache_seqlens):
return output return output
if __name__ == "__main__": def main(batch=64,
parser = argparse.ArgumentParser() heads=32,
parser.add_argument('--batch', type=int, default=64, help='batch size') heads_kv=8,
parser.add_argument('--heads', type=int, default=32, help='heads') max_cache_seqlen=8192,
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') dim=128,
parser.add_argument( dim_v=128,
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') sparse_ratio=0.8,
parser.add_argument('--dim', type=int, default=128, help='dim') block_size=32):
parser.add_argument('--dim_v', type=int, default=128, help='dim_v')
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
parser.add_argument('--block_size', type=int, default=32, help='block_size') block_size = block_size
args = parser.parse_args() sparse_ratio = sparse_ratio
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v
block_size = args.block_size
sparse_ratio = args.sparse_ratio
qk_flops = 2 * batch * heads * max_cache_seqlen * dim qk_flops = 2 * batch * heads * max_cache_seqlen * dim
pv_flops = 2 * batch * heads * max_cache_seqlen * dim_v pv_flops = 2 * batch * heads * max_cache_seqlen * dim_v
total_flops = qk_flops + pv_flops total_flops = qk_flops + pv_flops
dtype = torch.float16 dtype = torch.float16
block_H = 64
Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda')
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda')
...@@ -432,6 +427,7 @@ if __name__ == "__main__": ...@@ -432,6 +427,7 @@ if __name__ == "__main__":
avg_time = elapsed_time / 1000 avg_time = elapsed_time / 1000
avg_flops = total_flops / avg_time avg_flops = total_flops / avg_time
print(f"Average time: {avg_time:.6f} seconds") print(f"Average time: {avg_time:.6f} seconds")
print(f"Average flops: {avg_flops:.2f} GFLOPS")
# Measure performance of reference implementation # Measure performance of reference implementation
start = time.time() start = time.time()
...@@ -443,5 +439,22 @@ if __name__ == "__main__": ...@@ -443,5 +439,22 @@ if __name__ == "__main__":
avg_time_ref = elapsed_time_ref / 1000 avg_time_ref = elapsed_time_ref / 1000
avg_flops_ref = total_flops / avg_time_ref avg_flops_ref = total_flops / avg_time_ref
print(f"Average time of ref: {avg_time_ref:.6f} seconds") print(f"Average time of ref: {avg_time_ref:.6f} seconds")
print(f"Average flops of ref: {avg_flops_ref:.2f} GFLOPS")
print(f"Speedup: {avg_time_ref / avg_time:.2f}x") print(f"Speedup: {avg_time_ref / avg_time:.2f}x")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=64, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv')
parser.add_argument(
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--dim_v', type=int, default=128, help='dim_v')
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio')
parser.add_argument('--block_size', type=int, default=32, help='block_size')
args = parser.parse_args()
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v,
args.sparse_ratio, args.block_size)
import tilelang.testing
import block_sparse_attn_triton
import example_tilelang_block_sparse_attn
import example_tilelang_sparse_gqa_decode_varlen_indice
import example_tilelang_sparse_gqa_decode_varlen_mask
import example_triton_sparse_gqa_decode_varlen_indice
import example_triton_sparse_gqa_decode_varlen_mask
def test_block_sparse_attn_triton():
block_sparse_attn_triton.main()
def test_example_tilelang_block_sparse_attn():
example_tilelang_block_sparse_attn.main()
def test_example_tilelang_sparse_gqa_decode_varlen_indice():
example_tilelang_sparse_gqa_decode_varlen_indice.main()
def test_example_tilelang_sparse_gqa_decode_varlen_mask():
example_tilelang_sparse_gqa_decode_varlen_mask.main()
def test_example_triton_sparse_gqa_decode_varlen_indice():
example_triton_sparse_gqa_decode_varlen_indice.main()
def test_example_triton_sparse_gqa_decode_varlen_mask():
example_triton_sparse_gqa_decode_varlen_mask.main()
if __name__ == "__main__":
tilelang.testing.main()
...@@ -25,3 +25,8 @@ tabulate ...@@ -25,3 +25,8 @@ tabulate
wheel wheel
setuptools setuptools
einops einops
attrs
decorator
flash-attn
scipy
tornado
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment