Unverified Commit 667632cc authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents d6dd2ddf a874e4e8
...@@ -12,12 +12,8 @@ from heuristic import num_splits_heuristic ...@@ -12,12 +12,8 @@ from heuristic import num_splits_heuristic
@triton.autotune( @triton.autotune(
configs=[ configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]],
triton.Config({}, num_warps=num_warps, num_stages=num_stages) key=["BLOCK_H", "BLOCK_N", "BLOCK_D"],
for num_warps in [1, 2, 4]\
for num_stages in [1, 2, 3, 4, 7]
],
key=['BLOCK_H', 'BLOCK_N', 'BLOCK_D'],
) )
@triton.jit @triton.jit
def _split_kernel( def _split_kernel(
...@@ -79,16 +75,11 @@ def _split_kernel( ...@@ -79,16 +75,11 @@ def _split_kernel(
loop_range = blocks_per_split loop_range = blocks_per_split
q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h
k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[ k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[None, :] * stride_k_s + offs_d[:, None] * stride_k_d
None, :] * stride_k_s + offs_d[:, None] * stride_k_d v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, None] * stride_v_s + offs_d[None, :] * stride_v_d
v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:,
None] * stride_v_s + offs_d[
None, :] * stride_v_d
mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h
q = tl.load( q = tl.load(q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, mask=offs_h[:, None] < gqa_group_size)
q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d,
mask=offs_h[:, None] < gqa_group_size)
start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks) start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks)
for i in range(loop_range): for i in range(loop_range):
block_idx = tl.load(mask_ptr + (start + i) * stride_mask_s) block_idx = tl.load(mask_ptr + (start + i) * stride_mask_s)
...@@ -119,23 +110,18 @@ def _split_kernel( ...@@ -119,23 +110,18 @@ def _split_kernel(
acc = acc * l_recip acc = acc * l_recip
acc = acc.to(o_partial_ptr.dtype.element_ty) acc = acc.to(o_partial_ptr.dtype.element_ty)
lse_partial_ptr += batch_idx * stride_lse_b + ( lse_partial_ptr += batch_idx * stride_lse_b + (head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split
head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split
tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size) tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size)
o_partial_ptr += batch_idx * stride_o_b + ( o_partial_ptr += (
head_idx_q + batch_idx * stride_o_b + (head_idx_q + offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d
offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d )
tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size) tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size)
@triton.autotune( @triton.autotune(
configs=[ configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]],
triton.Config({}, num_warps=num_warps, num_stages=num_stages) key=["BLOCK_D"],
for num_warps in [1, 2, 4]\
for num_stages in [1, 2, 3, 4, 7]
],
key=['BLOCK_D'],
) )
@triton.jit @triton.jit
def _merge_kernel( def _merge_kernel(
...@@ -163,18 +149,15 @@ def _merge_kernel( ...@@ -163,18 +149,15 @@ def _merge_kernel(
offs_d = tl.arange(0, BLOCK_D) offs_d = tl.arange(0, BLOCK_D)
lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h
lse = tl.load( lse = tl.load(lse_offsets + offs_splits * lse_partial_stride_split, mask=offs_splits < num_splits, other=float("-inf"))
lse_offsets + offs_splits * lse_partial_stride_split,
mask=offs_splits < num_splits,
other=float("-inf"))
lse_max = tl.max(lse) lse_max = tl.max(lse)
o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h
o_partial = tl.load( o_partial = tl.load(
o_offsets + offs_splits[:, None] * o_partial_stride_split + o_offsets + offs_splits[:, None] * o_partial_stride_split + offs_d[None, :] * o_partial_stride_d,
offs_d[None, :] * o_partial_stride_d, mask=offs_splits[:, None] < num_splits,
mask=offs_splits[:, None] < num_splits) )
sumexp_normalized_splitk = tl.exp(lse - lse_max) sumexp_normalized_splitk = tl.exp(lse - lse_max)
sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0) sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0)
numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0) numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0)
...@@ -209,19 +192,13 @@ def block_sparse_flash_decode_gqa_indice_triton( ...@@ -209,19 +192,13 @@ def block_sparse_flash_decode_gqa_indice_triton(
num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H
num_n_blocks = max_selected_blocks num_n_blocks = max_selected_blocks
size_one_kv_head = max_selected_blocks * block_size * ( size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2
dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks total_mblocks = batch * heads_kv * num_m_blocks
num_sm = 64 num_sm = 64
# num_sm = self.num_sm # num_sm = self.num_sm
num_splits = num_splits_heuristic( num_splits = num_splits_heuristic(
total_mblocks, total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128
num_sm, )
num_n_blocks,
num_m_blocks,
size_one_kv_head,
is_causal_or_local=True,
max_splits=128)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks) # print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
...@@ -295,24 +272,18 @@ def block_sparse_flash_decode_gqa_indice_triton( ...@@ -295,24 +272,18 @@ def block_sparse_flash_decode_gqa_indice_triton(
return output return output
def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size):
block_size):
batch, heads, dim = query.shape batch, heads, dim = query.shape
heads_kv = key.shape[2] heads_kv = key.shape[2]
dim_v = value.shape[-1] dim_v = value.shape[-1]
num_head_groups = query.shape[1] // key.shape[2] num_head_groups = query.shape[1] // key.shape[2]
scale = dim**0.5 scale = dim**0.5
key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
query = rearrange( query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
query, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
scores = einsum( scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv]
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask = torch.zeros_like(scores) sparse_mask = torch.zeros_like(scores)
# Assign mask values based on block_indices # Assign mask values based on block_indices
...@@ -321,42 +292,33 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache ...@@ -321,42 +292,33 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache
valid_indices = block_indices[b, h] # Extract indices for this batch and head valid_indices = block_indices[b, h] # Extract indices for this batch and head
for idx in valid_indices: for idx in valid_indices:
if idx >= 0: if idx >= 0:
sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1 sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1
scores = scores.masked_fill(sparse_mask == 0, float('-inf')) scores = scores.masked_fill(sparse_mask == 0, float("-inf"))
range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0) range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0)
cache_seqlens_expanded = cache_seqlens.unsqueeze(1) cache_seqlens_expanded = cache_seqlens.unsqueeze(1)
pad_mask = range_len >= cache_seqlens_expanded pad_mask = range_len >= cache_seqlens_expanded
pad_mask = pad_mask[:, None, None, :] pad_mask = pad_mask[:, None, None, :]
scores = scores.masked_fill(pad_mask, float('-inf')) scores = scores.masked_fill(pad_mask, float("-inf"))
attention = F.softmax( attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
out = einsum(attention, value, out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim]
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out return out
def ref_program_fa(query, key, value, cache_seqlens): def ref_program_fa(query, key, value, cache_seqlens):
# latency reference # latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3 # from flash_attn_interface import flash_attn_with_kvcache # fa3
from flash_attn import flash_attn_with_kvcache #fa2 from flash_attn import flash_attn_with_kvcache # fa2
query = query.unsqueeze(1) query = query.unsqueeze(1)
output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens)
output = output.squeeze(1) output = output.squeeze(1)
return output return output
def main(batch=64, def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32):
heads=32,
heads_kv=8,
max_cache_seqlen=8192,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32):
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
sparse_ratio = sparse_ratio sparse_ratio = sparse_ratio
block_size = block_size block_size = block_size
...@@ -369,34 +331,29 @@ def main(batch=64, ...@@ -369,34 +331,29 @@ def main(batch=64,
dtype = torch.float16 dtype = torch.float16
block_H = 64 block_H = 64
Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda")
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda")
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda")
cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda') cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda")
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
# Ensure at least one element equals cache_seqlen # Ensure at least one element equals cache_seqlen
random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index
cache_seqlens[ cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence
random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence
print("cache_seqlens: ", cache_seqlens) print("cache_seqlens: ", cache_seqlens)
max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int()
print("max_valid_num_blocks: ", max_valid_num_blocks) print("max_valid_num_blocks: ", max_valid_num_blocks)
# Initialize block_indices with -1 (for padding blocks) # Initialize block_indices with -1 (for padding blocks)
block_indices = torch.full((batch, heads_kv, max_selected_blocks), block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda")
-1,
dtype=torch.int32,
device='cuda')
# Assign valid indices while ensuring no duplicates within each batch-group # Assign valid indices while ensuring no duplicates within each batch-group
for b in range(batch): for b in range(batch):
max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch
if max_valid_block > 0: # Ensure there's at least one valid block if max_valid_block > 0: # Ensure there's at least one valid block
for h in range(heads_kv): for h in range(heads_kv):
valid_indices = torch.randperm( valid_indices = torch.randperm(max_valid_block, device="cuda", dtype=torch.int32)[:max_selected_blocks]
max_valid_block, device='cuda', dtype=torch.int32)[:max_selected_blocks] block_indices[b, h, : len(valid_indices)] = valid_indices
block_indices[b, h, :len(valid_indices)] = valid_indices
# Sort indices within each batch-group for consistency # Sort indices within each batch-group for consistency
block_indices, _ = block_indices.sort(dim=-1, descending=True) block_indices, _ = block_indices.sort(dim=-1, descending=True)
...@@ -408,8 +365,7 @@ def main(batch=64, ...@@ -408,8 +365,7 @@ def main(batch=64,
max_num_blocks = torch.max(max_valid_num_blocks).item() max_num_blocks = torch.max(max_valid_num_blocks).item()
print("max_num_blocks: ", max_num_blocks) print("max_num_blocks: ", max_num_blocks)
ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size)
block_size)
triton_out = block_sparse_flash_decode_gqa_indice_triton( triton_out = block_sparse_flash_decode_gqa_indice_triton(
Q, Q,
...@@ -423,8 +379,7 @@ def main(batch=64, ...@@ -423,8 +379,7 @@ def main(batch=64,
) )
print("max difference: ", torch.max(torch.abs(ref - triton_out))) print("max difference: ", torch.max(torch.abs(ref - triton_out)))
assert torch.allclose( assert torch.allclose(ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation"
ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation"
print("Passed the ref test!") print("Passed the ref test!")
# Measure performance # Measure performance
...@@ -466,15 +421,13 @@ def main(batch=64, ...@@ -466,15 +421,13 @@ def main(batch=64,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=64, help='batch size') parser.add_argument("--batch", type=int, default=64, help="batch size")
parser.add_argument('--heads', type=int, default=32, help='heads') parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv")
parser.add_argument( parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length")
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument('--dim', type=int, default=128, help='dim') parser.add_argument("--dim_v", type=int, default=128, help="dim_v")
parser.add_argument('--dim_v', type=int, default=128, help='dim_v') parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio")
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') parser.add_argument("--block_size", type=int, default=32, help="block_size")
parser.add_argument('--block_size', type=int, default=32, help='block_size')
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size)
args.sparse_ratio, args.block_size)
...@@ -11,12 +11,8 @@ from heuristic import num_splits_heuristic ...@@ -11,12 +11,8 @@ from heuristic import num_splits_heuristic
@triton.autotune( @triton.autotune(
configs=[ configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]],
triton.Config({}, num_warps=num_warps, num_stages=num_stages) key=["BLOCK_H", "BLOCK_N", "BLOCK_D"],
for num_warps in [1, 2, 4]\
for num_stages in [1, 2, 3, 4, 7]
],
key=['BLOCK_H', 'BLOCK_N', 'BLOCK_D'],
) )
@triton.jit @triton.jit
def _split_kernel( def _split_kernel(
...@@ -77,16 +73,11 @@ def _split_kernel( ...@@ -77,16 +73,11 @@ def _split_kernel(
loop_range = blocks_per_split loop_range = blocks_per_split
q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h
k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[ k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[None, :] * stride_k_s + offs_d[:, None] * stride_k_d
None, :] * stride_k_s + offs_d[:, None] * stride_k_d v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, None] * stride_v_s + offs_d[None, :] * stride_v_d
v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:,
None] * stride_v_s + offs_d[
None, :] * stride_v_d
mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h
q = tl.load( q = tl.load(q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, mask=offs_h[:, None] < gqa_group_size)
q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d,
mask=offs_h[:, None] < gqa_group_size)
start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks) start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks)
for block_idx in range(loop_range): for block_idx in range(loop_range):
start_n = (start + block_idx) * BLOCK_N start_n = (start + block_idx) * BLOCK_N
...@@ -117,23 +108,18 @@ def _split_kernel( ...@@ -117,23 +108,18 @@ def _split_kernel(
acc = acc * l_recip acc = acc * l_recip
acc = acc.to(o_partial_ptr.dtype.element_ty) acc = acc.to(o_partial_ptr.dtype.element_ty)
lse_partial_ptr += batch_idx * stride_lse_b + ( lse_partial_ptr += batch_idx * stride_lse_b + (head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split
head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split
tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size) tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size)
o_partial_ptr += batch_idx * stride_o_b + ( o_partial_ptr += (
head_idx_q + batch_idx * stride_o_b + (head_idx_q + offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d
offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d )
tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size) tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size)
@triton.autotune( @triton.autotune(
configs=[ configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]],
triton.Config({}, num_warps=num_warps, num_stages=num_stages) key=["BLOCK_D"],
for num_warps in [1, 2, 4]\
for num_stages in [1, 2, 3, 4, 7]
],
key=['BLOCK_D'],
) )
@triton.jit @triton.jit
def _merge_kernel( def _merge_kernel(
...@@ -161,18 +147,15 @@ def _merge_kernel( ...@@ -161,18 +147,15 @@ def _merge_kernel(
offs_d = tl.arange(0, BLOCK_D) offs_d = tl.arange(0, BLOCK_D)
lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h
lse = tl.load( lse = tl.load(lse_offsets + offs_splits * lse_partial_stride_split, mask=offs_splits < num_splits, other=float("-inf"))
lse_offsets + offs_splits * lse_partial_stride_split,
mask=offs_splits < num_splits,
other=float("-inf"))
lse_max = tl.max(lse) lse_max = tl.max(lse)
o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h
o_partial = tl.load( o_partial = tl.load(
o_offsets + offs_splits[:, None] * o_partial_stride_split + o_offsets + offs_splits[:, None] * o_partial_stride_split + offs_d[None, :] * o_partial_stride_d,
offs_d[None, :] * o_partial_stride_d, mask=offs_splits[:, None] < num_splits,
mask=offs_splits[:, None] < num_splits) )
sumexp_normalized_splitk = tl.exp(lse - lse_max) sumexp_normalized_splitk = tl.exp(lse - lse_max)
sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0) sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0)
numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0) numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0)
...@@ -207,19 +190,13 @@ def block_sparse_flash_decode_gqa_mask_triton( ...@@ -207,19 +190,13 @@ def block_sparse_flash_decode_gqa_mask_triton(
num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H
num_n_blocks = max_selected_blocks num_n_blocks = max_selected_blocks
size_one_kv_head = max_selected_blocks * block_size * ( size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2
dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks total_mblocks = batch * heads_kv * num_m_blocks
num_sm = 64 num_sm = 64
# num_sm = self.num_sm # num_sm = self.num_sm
num_splits = num_splits_heuristic( num_splits = num_splits_heuristic(
total_mblocks, total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128
num_sm, )
num_n_blocks,
num_m_blocks,
size_one_kv_head,
is_causal_or_local=True,
max_splits=128)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks) # print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
...@@ -292,24 +269,18 @@ def block_sparse_flash_decode_gqa_mask_triton( ...@@ -292,24 +269,18 @@ def block_sparse_flash_decode_gqa_mask_triton(
return output return output
def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size):
block_size):
batch, heads, dim = query.shape batch, heads, dim = query.shape
heads_kv = key.shape[2] heads_kv = key.shape[2]
num_head_groups = query.shape[1] // key.shape[2] num_head_groups = query.shape[1] // key.shape[2]
scale = dim**0.5 scale = dim**0.5
key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
query = rearrange( query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
query, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
scores = einsum( scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv]
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask = torch.zeros_like(scores) sparse_mask = torch.zeros_like(scores)
# Assign mask values # Assign mask values
...@@ -317,43 +288,34 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se ...@@ -317,43 +288,34 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se
for h in range(heads_kv): for h in range(heads_kv):
for idx in range(num_blocks): for idx in range(num_blocks):
if block_mask[b, h, idx]: if block_mask[b, h, idx]:
sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1 sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1
scores = scores.masked_fill(sparse_mask == 0, float('-inf')) scores = scores.masked_fill(sparse_mask == 0, float("-inf"))
range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0) range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0)
cache_seqlens_expanded = cache_seqlens.unsqueeze(1) cache_seqlens_expanded = cache_seqlens.unsqueeze(1)
pad_mask = range_len >= cache_seqlens_expanded pad_mask = range_len >= cache_seqlens_expanded
pad_mask = pad_mask[:, None, None, :] pad_mask = pad_mask[:, None, None, :]
scores = scores.masked_fill(pad_mask, float('-inf')) scores = scores.masked_fill(pad_mask, float("-inf"))
attention = F.softmax( attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
out = einsum(attention, value, out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim]
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out return out
def ref_program_fa(query, key, value, cache_seqlens): def ref_program_fa(query, key, value, cache_seqlens):
# latency reference # latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3 # from flash_attn_interface import flash_attn_with_kvcache # fa3
from flash_attn import flash_attn_with_kvcache #fa2 from flash_attn import flash_attn_with_kvcache # fa2
query = query.unsqueeze(1) query = query.unsqueeze(1)
output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens)
output = output.squeeze(1) output = output.squeeze(1)
return output return output
def main(batch=64, def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32):
heads=32,
heads_kv=8,
max_cache_seqlen=8192,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32):
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
block_size = block_size block_size = block_size
sparse_ratio = sparse_ratio sparse_ratio = sparse_ratio
...@@ -363,14 +325,13 @@ def main(batch=64, ...@@ -363,14 +325,13 @@ def main(batch=64,
dtype = torch.float16 dtype = torch.float16
Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda")
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda")
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda")
cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda') cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda")
# Ensure at least one element equals cache_seqlen # Ensure at least one element equals cache_seqlen
random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index
cache_seqlens[ cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence
random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence
num_blocks = (max_cache_seqlen + block_size - 1) // block_size num_blocks = (max_cache_seqlen + block_size - 1) // block_size
...@@ -379,7 +340,7 @@ def main(batch=64, ...@@ -379,7 +340,7 @@ def main(batch=64,
max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int()
print("max_valid_num_blocks: ", max_valid_num_blocks) print("max_valid_num_blocks: ", max_valid_num_blocks)
# Initialize block_mask with false (for padding blocks) # Initialize block_mask with false (for padding blocks)
block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device='cuda') block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device="cuda")
# Assign valid indices while ensuring no duplicates within each batch-group # Assign valid indices while ensuring no duplicates within each batch-group
for b in range(batch): for b in range(batch):
...@@ -387,11 +348,10 @@ def main(batch=64, ...@@ -387,11 +348,10 @@ def main(batch=64,
valid_num_block = valid_num_blocks[b].item() # Valid blocks for this batch valid_num_block = valid_num_blocks[b].item() # Valid blocks for this batch
if valid_num_block > 0: # Ensure there's at least one valid block if valid_num_block > 0: # Ensure there's at least one valid block
for h in range(heads_kv): for h in range(heads_kv):
perm = torch.randperm(max_valid_block, device='cuda')[:valid_num_block] perm = torch.randperm(max_valid_block, device="cuda")[:valid_num_block]
block_mask[b, h, perm] = True block_mask[b, h, perm] = True
ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size)
block_size)
triton_out = block_sparse_flash_decode_gqa_mask_triton( triton_out = block_sparse_flash_decode_gqa_mask_triton(
Q, Q,
...@@ -404,8 +364,7 @@ def main(batch=64, ...@@ -404,8 +364,7 @@ def main(batch=64,
) )
# print("max difference: ", torch.max(torch.abs(ref - triton_out))) # print("max difference: ", torch.max(torch.abs(ref - triton_out)))
assert torch.allclose( assert torch.allclose(ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation"
ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation"
print("Passed the ref test!") print("Passed the ref test!")
# Measure performance # Measure performance
...@@ -448,15 +407,13 @@ def main(batch=64, ...@@ -448,15 +407,13 @@ def main(batch=64,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=64, help='batch size') parser.add_argument("--batch", type=int, default=64, help="batch size")
parser.add_argument('--heads', type=int, default=32, help='heads') parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv")
parser.add_argument( parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length")
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument('--dim', type=int, default=128, help='dim') parser.add_argument("--dim_v", type=int, default=128, help="dim_v")
parser.add_argument('--dim_v', type=int, default=128, help='dim_v') parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio")
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') parser.add_argument("--block_size", type=int, default=32, help="block_size")
parser.add_argument('--block_size', type=int, default=32, help='block_size')
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size)
args.sparse_ratio, args.block_size)
import math import math
def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, size_one_kv_head, def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local, max_splits):
is_causal_or_local, max_splits):
""" """
Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency. Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency.
......
...@@ -25,26 +25,14 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask(): ...@@ -25,26 +25,14 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask():
def test_example_triton_sparse_gqa_decode_varlen_indice(): def test_example_triton_sparse_gqa_decode_varlen_indice():
example_triton_sparse_gqa_decode_varlen_indice.main( example_triton_sparse_gqa_decode_varlen_indice.main(
batch=8, batch=8, heads=8, heads_kv=4, max_cache_seqlen=2048, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32
heads=8, )
heads_kv=4,
max_cache_seqlen=2048,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32)
def test_example_triton_sparse_gqa_decode_varlen_mask(): def test_example_triton_sparse_gqa_decode_varlen_mask():
example_triton_sparse_gqa_decode_varlen_mask.main( example_triton_sparse_gqa_decode_varlen_mask.main(
batch=16, batch=16, heads=16, heads_kv=8, max_cache_seqlen=1024, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32
heads=16, )
heads_kv=8,
max_cache_seqlen=1024,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -19,8 +19,7 @@ parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M") ...@@ -19,8 +19,7 @@ parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N") parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K") parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K")
parser.add_argument("--sparsity", type=float, default=0.5, help="Sparsity ratio (0-1)") parser.add_argument("--sparsity", type=float, default=0.5, help="Sparsity ratio (0-1)")
parser.add_argument( parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune")
"--use_autotune", action="store_true", default=False, help="Whether to use autotune")
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
M, N, K = args.m, args.n, args.k M, N, K = args.m, args.n, args.k
...@@ -41,17 +40,19 @@ def get_configs(): ...@@ -41,17 +40,19 @@ def get_configs():
thread_num = [128, 256] thread_num = [128, 256]
enable_rasterization = [True, False] enable_rasterization = [True, False]
_configs = list( _configs = list(itertools.product(block_M, block_N, block_K, num_stages, thread_num, enable_rasterization))
itertools.product(block_M, block_N, block_K, num_stages, thread_num, enable_rasterization))
return [{ return [
"block_M": c[0], {
"block_N": c[1], "block_M": c[0],
"block_K": c[2], "block_N": c[1],
"num_stages": c[3], "block_K": c[2],
"thread_num": c[4], "num_stages": c[3],
"enable_rasteration": c[5], "thread_num": c[4],
} for c in _configs] "enable_rasteration": c[5],
}
for c in _configs
]
def ref_program(A, B, BlockMask, block_M, block_N, block_K): def ref_program(A, B, BlockMask, block_M, block_N, block_K):
...@@ -61,12 +62,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K): ...@@ -61,12 +62,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K):
accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device) accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device)
for k in range(K // block_K): for k in range(K // block_K):
if BlockMask[i, j, k]: if BlockMask[i, j, k]:
accu += ( accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[
A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to( k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N
torch.float32) @ B[k * block_K:(k + 1) * block_K, ].to(torch.float32)
j * block_N:(j + 1) * block_N].to(torch.float32)) ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16)
ref_c[i * block_M:(i + 1) * block_M,
j * block_N:(j + 1) * block_N] = accu.to(torch.float16)
return ref_c return ref_c
...@@ -89,28 +88,21 @@ def supply_program(params: List[KernelParam]): ...@@ -89,28 +88,21 @@ def supply_program(params: List[KernelParam]):
return input_tensors return input_tensors
@tilelang.autotune(configs=get_configs(),) @tilelang.autotune(
configs=get_configs(),
)
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def blocksparse_matmul(M, def blocksparse_matmul(
N, M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32
K, ):
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float"):
block_mask_shape = (M // block_M, N // block_N, K // block_K) block_mask_shape = (M // block_M, N // block_N, K // block_K)
@T.prim_func @T.prim_func
def block_sparse_matmul( def block_sparse_matmul(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"), BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
...@@ -134,7 +126,6 @@ def blocksparse_matmul(M, ...@@ -134,7 +126,6 @@ def blocksparse_matmul(M,
def main(): def main():
# Initialize input matrices A and B on the GPU with half precision # Initialize input matrices A and B on the GPU with half precision
a = torch.randn(M, K).cuda().half() a = torch.randn(M, K).cuda().half()
b = torch.randn(K, N).cuda().half() b = torch.randn(K, N).cuda().half()
...@@ -147,8 +138,7 @@ def main(): ...@@ -147,8 +138,7 @@ def main():
best_config = kernel.config best_config = kernel.config
best_latency = kernel.latency best_latency = kernel.latency
block_M, block_N, block_K = best_config["block_M"], best_config["block_N"], best_config[ block_M, block_N, block_K = best_config["block_M"], best_config["block_N"], best_config["block_K"]
"block_K"]
print(f"Best Config: {best_config}") print(f"Best Config: {best_config}")
print(f"Sparsity Ratio: {sparsity}") print(f"Sparsity Ratio: {sparsity}")
...@@ -163,10 +153,10 @@ def main(): ...@@ -163,10 +153,10 @@ def main():
block_K=DEFAULT_BLOCK_K, block_K=DEFAULT_BLOCK_K,
num_stages=DEFAULT_NUM_STAGES, num_stages=DEFAULT_NUM_STAGES,
thread_num=DEFAULT_THREAD_NUM, thread_num=DEFAULT_THREAD_NUM,
enable_rasteration=DEFAULT_ENABLE_RASTERIZATION) enable_rasteration=DEFAULT_ENABLE_RASTERIZATION,
)
block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K
print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})") print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})")
# Create block mask with desired sparsity # Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K) mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > sparsity block_mask = torch.rand(mask_shape).cuda() > sparsity
......
...@@ -5,8 +5,8 @@ from typing import Tuple ...@@ -5,8 +5,8 @@ from typing import Tuple
from tilelang.utils.tensor import torch_assert_close from tilelang.utils.tensor import torch_assert_close
# support bfloat16, float, float16 # support bfloat16, float, float16
dtype = "bfloat16" dtype = T.bfloat16
accum_dtype = "float" accum_dtype = T.float32
@tilelang.jit(out_idx=[2, 3]) @tilelang.jit(out_idx=[2, 3])
...@@ -16,11 +16,13 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m): ...@@ -16,11 +16,13 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
fp8_max = 448.0 fp8_max = 448.0
@T.prim_func @T.prim_func
def group_per_split_token_cast(X: T.Tensor((M, N), dtype), batch_sizes: T.Tensor( def group_per_split_token_cast(
(BG,), "int32"), X_fp8: T.Tensor((BG, M_max, N), "float8_e4m3"), X_amax: T.Tensor( X: T.Tensor((M, N), dtype),
(BG, M_max, T.ceildiv(N, group_size)), accum_dtype)): batch_sizes: T.Tensor((BG,), T.int32),
with T.Kernel( X_fp8: T.Tensor((BG, M_max, N), T.float8_e4m3fn),
T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz): X_amax: T.Tensor((BG, M_max, T.ceildiv(N, group_size)), accum_dtype),
):
with T.Kernel(T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz):
row = bx row = bx
row_g_id = by row_g_id = by
bg = bz bg = bz
...@@ -28,39 +30,35 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m): ...@@ -28,39 +30,35 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
y_amax_local = T.alloc_fragment((blk_m,), accum_dtype) y_amax_local = T.alloc_fragment((blk_m,), accum_dtype)
y_s_local = T.alloc_fragment((blk_m,), accum_dtype) y_s_local = T.alloc_fragment((blk_m,), accum_dtype)
y_q_local = T.alloc_fragment((blk_m, group_size), accum_dtype) y_q_local = T.alloc_fragment((blk_m, group_size), accum_dtype)
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3") y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), T.float8_e4m3fn)
row_offset = T.alloc_fragment((1,), "int32") row_offset = T.alloc_fragment((1,), T.int32)
T.annotate_layout({ T.annotate_layout(
y_local: {
T.Fragment( y_local: T.Fragment(y_local.shape, forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32),
y_local.shape, }
forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32), )
})
row_offset[0] = 0 row_offset[0] = 0
for i in T.serial(bg): for i in T.serial(bg):
row_offset[0] += batch_sizes[i] row_offset[0] += batch_sizes[i]
T.copy( T.copy(
X[row_offset[0] + row * blk_m:row_offset[0] + (row + 1) * blk_m, X[row_offset[0] + row * blk_m : row_offset[0] + (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size],
row_g_id * group_size:(row_g_id + 1) * group_size], y_local) y_local,
)
T.reduce_absmax(y_local, y_amax_local, dim=1) T.reduce_absmax(y_local, y_amax_local, dim=1)
for i in T.Parallel(blk_m): for i in T.Parallel(blk_m):
y_amax_local[i] = T.max(y_amax_local[i], 1e-4) y_amax_local[i] = T.max(y_amax_local[i], 1e-4)
y_s_local[i] = T.if_then_else(row * blk_m + i < batch_sizes[bg], y_s_local[i] = T.if_then_else(row * blk_m + i < batch_sizes[bg], y_amax_local[i] / fp8_max, 0)
y_amax_local[i] / fp8_max, 0)
for i, j in T.Parallel(blk_m, group_size): for i, j in T.Parallel(blk_m, group_size):
y_q_local[i, j] = T.clamp(y_local[i, j] / y_s_local[i], fp8_min, fp8_max) y_q_local[i, j] = T.clamp(y_local[i, j] / y_s_local[i], fp8_min, fp8_max)
T.copy(y_q_local, y_q_local_fp8) T.copy(y_q_local, y_q_local_fp8)
for i, j in T.Parallel(blk_m, group_size): for i, j in T.Parallel(blk_m, group_size):
y_q_local_fp8[i, j] = T.if_then_else(row * blk_m + i < batch_sizes[bg], y_q_local_fp8[i, j] = T.if_then_else(row * blk_m + i < batch_sizes[bg], y_q_local[i, j], 0)
y_q_local[i, j], 0)
for i in T.Parallel(blk_m): for i in T.Parallel(blk_m):
X_amax[bg, row * blk_m + i, row_g_id] = y_s_local[i] X_amax[bg, row * blk_m + i, row_g_id] = y_s_local[i]
T.copy( T.copy(y_q_local_fp8, X_fp8[bg, row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size])
y_q_local_fp8, X_fp8[bg, row * blk_m:(row + 1) * blk_m,
row_g_id * group_size:(row_g_id + 1) * group_size])
return group_per_split_token_cast return group_per_split_token_cast
...@@ -127,8 +125,7 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: ...@@ -127,8 +125,7 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
return x.squeeze(0) if remove_dim else x return x.squeeze(0) if remove_dim else x
# Normal layout requires transposing # Normal layout requires transposing
aligned_x = torch.transpose( aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
aligned_x[:, :m, :] = x aligned_x[:, :m, :] = x
aligned_x = aligned_x[:, :m, :] aligned_x = aligned_x[:, :m, :]
return aligned_x.squeeze(0) if remove_dim else aligned_x return aligned_x.squeeze(0) if remove_dim else aligned_x
...@@ -146,15 +143,17 @@ def ref_per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens ...@@ -146,15 +143,17 @@ def ref_per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
x_fp8 = x_fp8.view(m, -1)[:, :n].contiguous() x_fp8 = x_fp8.view(m, -1)[:, :n].contiguous()
return x_fp8, (x_amax / 448.0).view(m, -1) return x_fp8, (x_amax / 448.0).view(m, -1)
def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \
Tuple[torch.Tensor, torch.Tensor]: def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# assert x.shape[0] == batch_sizes.sum() # assert x.shape[0] == batch_sizes.sum()
M_max = ceil_div(batch_sizes.max(), 128) * 128 M_max = ceil_div(batch_sizes.max(), 128) * 128
split_x = torch.split(x, batch_sizes.tolist(), dim=0) split_x = torch.split(x, batch_sizes.tolist(), dim=0)
padded_x = [torch.nn.functional.pad(t, (0, 0, 0, M_max - t.shape[0])) for t in split_x] padded_x = [torch.nn.functional.pad(t, (0, 0, 0, M_max - t.shape[0])) for t in split_x]
num_groups, m, n = batch_sizes.shape[0], M_max, x.shape[1] num_groups, m, n = batch_sizes.shape[0], M_max, x.shape[1]
x_fp8 = (torch.empty((num_groups, m, n), device='cuda', dtype=torch.float8_e4m3fn), x_fp8 = (
torch.empty((num_groups, m, n // 128), device='cuda', dtype=torch.float)) torch.empty((num_groups, m, n), device="cuda", dtype=torch.float8_e4m3fn),
torch.empty((num_groups, m, n // 128), device="cuda", dtype=torch.float),
)
for i in range(num_groups): for i in range(num_groups):
x_fp8[0][i], x_fp8[1][i] = ref_per_token_cast_to_fp8(padded_x[i]) x_fp8[0][i], x_fp8[1][i] = ref_per_token_cast_to_fp8(padded_x[i])
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
...@@ -164,11 +163,11 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \ ...@@ -164,11 +163,11 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \
def main(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=None): def main(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=None):
if batch_sizes is None: if batch_sizes is None:
batch_sizes = [2048, 6144] batch_sizes = [2048, 6144]
if dtype == "float": if dtype == T.float:
x = torch.randn(M, N, device="cuda", dtype=torch.float32) x = torch.randn(M, N, device="cuda", dtype=torch.float32)
elif dtype == "float16": elif dtype == T.float16:
x = torch.randn(M, N, device="cuda", dtype=torch.float16) x = torch.randn(M, N, device="cuda", dtype=torch.float16)
elif dtype == "bfloat16": elif dtype == T.bfloat16:
x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16) x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16)
else: else:
raise ValueError(f"Unsupported dtype: {dtype}") raise ValueError(f"Unsupported dtype: {dtype}")
......
...@@ -7,14 +7,15 @@ from tilelang.utils.tensor import torch_assert_close ...@@ -7,14 +7,15 @@ from tilelang.utils.tensor import torch_assert_close
@tilelang.jit(out_idx=[1, 2]) @tilelang.jit(out_idx=[1, 2])
def per_token_cast_to_fp8(M, N, blk_m): def per_token_cast_to_fp8(M, N, blk_m):
dtype = "float" dtype = T.float
group_size = 128 group_size = 128
fp8_min = -448.0 fp8_min = -448.0
fp8_max = 448.0 fp8_max = 448.0
@T.prim_func @T.prim_func
def per_token_cast(X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "float8_e4m3"), def per_token_cast(
X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype)): X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), T.float8_e4m3fn), X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype)
):
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (bx, by): with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (bx, by):
row = bx row = bx
row_g_id = by row_g_id = by
...@@ -22,18 +23,15 @@ def per_token_cast_to_fp8(M, N, blk_m): ...@@ -22,18 +23,15 @@ def per_token_cast_to_fp8(M, N, blk_m):
y_amax_local = T.alloc_fragment((blk_m,), dtype) y_amax_local = T.alloc_fragment((blk_m,), dtype)
y_s_local = T.alloc_fragment((blk_m,), dtype) y_s_local = T.alloc_fragment((blk_m,), dtype)
y_q_local = T.alloc_fragment((blk_m, group_size), dtype) y_q_local = T.alloc_fragment((blk_m, group_size), dtype)
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3") y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), T.float8_e4m3fn)
T.annotate_layout({ T.annotate_layout(
y_local: {
T.Fragment( y_local: T.Fragment(y_local.shape, forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32),
y_local.shape, }
forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32), )
})
T.copy(X[row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size], y_local)
T.copy(
X[row * blk_m:(row + 1) * blk_m, row_g_id * group_size:(row_g_id + 1) * group_size],
y_local)
T.reduce_absmax(y_local, y_amax_local, dim=1) T.reduce_absmax(y_local, y_amax_local, dim=1)
for i in T.Parallel(blk_m): for i in T.Parallel(blk_m):
y_amax_local[i] = T.max(y_amax_local[i], 1e-4) y_amax_local[i] = T.max(y_amax_local[i], 1e-4)
...@@ -43,9 +41,7 @@ def per_token_cast_to_fp8(M, N, blk_m): ...@@ -43,9 +41,7 @@ def per_token_cast_to_fp8(M, N, blk_m):
T.copy(y_q_local, y_q_local_fp8) T.copy(y_q_local, y_q_local_fp8)
for i in T.Parallel(blk_m): for i in T.Parallel(blk_m):
X_amax[row * blk_m + i, row_g_id] = y_s_local[i] X_amax[row * blk_m + i, row_g_id] = y_s_local[i]
T.copy( T.copy(y_q_local_fp8, X_fp8[row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size])
y_q_local_fp8, X_fp8[row * blk_m:(row + 1) * blk_m,
row_g_id * group_size:(row_g_id + 1) * group_size])
return per_token_cast return per_token_cast
...@@ -105,8 +101,7 @@ def main(M=8192, N=8192, blk_m=8): ...@@ -105,8 +101,7 @@ def main(M=8192, N=8192, blk_m=8):
from example_triton_cast_to_fp8 import per_token_group_quant_fp8 from example_triton_cast_to_fp8 import per_token_group_quant_fp8
def run_triton(): def run_triton():
x_fp8_triton_, x_amax_triton_ = per_token_group_quant_fp8( x_fp8_triton_, x_amax_triton_ = per_token_group_quant_fp8(x, 128, 1e-4, dtype=torch.float8_e4m3fn, column_major_scales=False)
x, 128, 1e-4, dtype=torch.float8_e4m3fn, column_major_scales=False)
return x_fp8_triton_, x_amax_triton_ return x_fp8_triton_, x_amax_triton_
x_fp8_triton, x_amax_triton = run_triton() x_fp8_triton, x_amax_triton = run_triton()
......
...@@ -128,9 +128,7 @@ def per_token_group_quant_fp8( ...@@ -128,9 +128,7 @@ def per_token_group_quant_fp8(
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization. scaling factor for quantization.
""" """
assert (x.shape[-1] % assert x.shape[-1] % group_size == 0, f"the last dimension of `x` {x.shape[-1]} must be divisible by `group_size` {group_size}"
group_size == 0), (f"the last dimension of `x` {x.shape[-1]} must be divisible "
f"by `group_size` {group_size}")
assert x.stride(-1) == 1, "`x` groups must be contiguous" assert x.stride(-1) == 1, "`x` groups must be contiguous"
finfo = torch.finfo(dtype) finfo = torch.finfo(dtype)
......
...@@ -4,8 +4,7 @@ import example_per_token_cast_to_fp8 ...@@ -4,8 +4,7 @@ import example_per_token_cast_to_fp8
def test_example_group_per_split_token_cast_to_fp8(): def test_example_group_per_split_token_cast_to_fp8():
example_group_per_split_token_cast_to_fp8.main( example_group_per_split_token_cast_to_fp8.main(M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896])
M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896])
def test_example_per_token_cast_to_fp8(): def test_example_per_token_cast_to_fp8():
......
...@@ -33,12 +33,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): ...@@ -33,12 +33,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
"warnings", "warnings",
"error", "error",
} }
if (sum( if sum(len(terminalreporter.stats.get(k, [])) for k in known_types.difference({"skipped", "deselected"})) == 0:
len(terminalreporter.stats.get(k, []))
for k in known_types.difference({"skipped", "deselected"})) == 0):
terminalreporter.write_sep( terminalreporter.write_sep(
"!", "!",
(f"Error: No tests were collected. " (f"Error: No tests were collected. {dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"),
f"{dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"),
) )
pytest.exit("No tests were collected.", returncode=5) pytest.exit("No tests were collected.", returncode=5)
...@@ -14,7 +14,6 @@ def check_hopper(): ...@@ -14,7 +14,6 @@ def check_hopper():
def ref_program(stride, padding, dilation): def ref_program(stride, padding, dilation):
def main(A, B): def main(A, B):
A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W
B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W
...@@ -26,38 +25,21 @@ def ref_program(stride, padding, dilation): ...@@ -26,38 +25,21 @@ def ref_program(stride, padding, dilation):
@tilelang.jit(out_idx=[2]) @tilelang.jit(out_idx=[2])
def convolution(N, def convolution(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32):
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages,
threads,
dtype="float16",
accum_dtype="float"):
KH, KW = K, K KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
is_hopper = check_hopper() is_hopper = check_hopper()
@T.prim_func @T.prim_func
def main( def main(
data: T.Tensor((N, H, W, C), dtype), data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype), kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype), out: T.Tensor((N, OH, OW, F), dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by):
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=threads) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype) data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype) kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype) out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
...@@ -66,11 +48,13 @@ def convolution(N, ...@@ -66,11 +48,13 @@ def convolution(N,
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
T.annotate_layout({ T.annotate_layout(
out_shared: tilelang.layout.make_swizzled_layout(out_shared), {
data_shared: tilelang.layout.make_swizzled_layout(data_shared), out_shared: tilelang.layout.make_swizzled_layout(out_shared),
kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), data_shared: tilelang.layout.make_swizzled_layout(data_shared),
}) kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
}
)
T.clear(out_local) T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
...@@ -82,10 +66,8 @@ def convolution(N, ...@@ -82,10 +66,8 @@ def convolution(N,
m = by * block_M + i m = by * block_M + i
access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
access_w = m % OW * S + k // C % KW * D - P access_w = m % OW * S + k // C % KW * D - P
in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W)
(access_w < W)) data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
data_shared[i, j] = T.if_then_else(
in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local) T.gemm(data_shared, kernel_shared, out_local)
...@@ -97,15 +79,15 @@ def convolution(N, ...@@ -97,15 +79,15 @@ def convolution(N,
def main(argv=None): def main(argv=None):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--n', type=int, default=128, help='n') parser.add_argument("--n", type=int, default=128, help="n")
parser.add_argument('--c', type=int, default=128, help='c') parser.add_argument("--c", type=int, default=128, help="c")
parser.add_argument('--h', type=int, default=64, help='h') parser.add_argument("--h", type=int, default=64, help="h")
parser.add_argument('--w', type=int, default=64, help='w') parser.add_argument("--w", type=int, default=64, help="w")
parser.add_argument('--f', type=int, default=128, help='f') parser.add_argument("--f", type=int, default=128, help="f")
parser.add_argument('--k', type=int, default=3, help='k') parser.add_argument("--k", type=int, default=3, help="k")
parser.add_argument('--s', type=int, default=1, help='s') parser.add_argument("--s", type=int, default=1, help="s")
parser.add_argument('--d', type=int, default=1, help='d') parser.add_argument("--d", type=int, default=1, help="d")
parser.add_argument('--p', type=int, default=1, help='p') parser.add_argument("--p", type=int, default=1, help="p")
args = parser.parse_args(argv) args = parser.parse_args(argv)
N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p
......
...@@ -14,7 +14,6 @@ def check_hopper(): ...@@ -14,7 +14,6 @@ def check_hopper():
def ref_program(stride, padding, dilation): def ref_program(stride, padding, dilation):
def main(A, B): def main(A, B):
A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W
B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W
...@@ -40,7 +39,8 @@ def get_configs(): ...@@ -40,7 +39,8 @@ def get_configs():
num_stages, num_stages,
thread_num, thread_num,
enable_rasterization, enable_rasterization,
)) )
)
configs = [ configs = [
{ {
...@@ -50,7 +50,8 @@ def get_configs(): ...@@ -50,7 +50,8 @@ def get_configs():
"num_stages": c[3], "num_stages": c[3],
"thread_num": c[4], "thread_num": c[4],
"enable_rasteration": c[5], # keep param name for backward-compat "enable_rasteration": c[5], # keep param name for backward-compat
} for c in _configs }
for c in _configs
] ]
return configs return configs
...@@ -64,69 +65,32 @@ def get_heuristic_config() -> dict: ...@@ -64,69 +65,32 @@ def get_heuristic_config() -> dict:
sm_version = sm_major * 10 + sm_minor sm_version = sm_major * 10 + sm_minor
print(f"CUDA device capability: {sm_version}") print(f"CUDA device capability: {sm_version}")
if sm_version in {80}: if sm_version in {80}:
return { return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 2, "thread_num": 128, "enable_rasteration": True}
"block_M": 128,
"block_N": 256,
"block_K": 32,
"num_stages": 2,
"thread_num": 128,
"enable_rasteration": True
}
elif sm_version in {90}: elif sm_version in {90}:
return { return {"block_M": 128, "block_N": 256, "block_K": 64, "num_stages": 3, "thread_num": 256, "enable_rasteration": True}
"block_M": 128,
"block_N": 256,
"block_K": 64,
"num_stages": 3,
"thread_num": 256,
"enable_rasteration": True
}
else: else:
return { return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 0, "thread_num": 128, "enable_rasteration": True}
"block_M": 128,
"block_N": 256,
"block_K": 32,
"num_stages": 0,
"thread_num": 128,
"enable_rasteration": True
}
@tilelang.autotune(configs=get_configs()) @tilelang.autotune(configs=get_configs())
@tilelang.jit(out_idx=[2]) @tilelang.jit(out_idx=[2])
def convolution(N, def convolution(
C, N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32
H, ):
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float"):
KH, KW = K, K KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
is_hopper = check_hopper() is_hopper = check_hopper()
@T.prim_func @T.prim_func
def main( def main(
data: T.Tensor((N, H, W, C), dtype), data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype), kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype), out: T.Tensor((N, OH, OW, F), dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=thread_num) as (bx, by):
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=thread_num) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype) data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype) kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype) out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
...@@ -136,9 +100,11 @@ def convolution(N, ...@@ -136,9 +100,11 @@ def convolution(N,
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
if is_hopper: if is_hopper:
T.annotate_layout({ T.annotate_layout(
out_shared: tilelang.layout.make_swizzled_layout(out_shared), {
}) out_shared: tilelang.layout.make_swizzled_layout(out_shared),
}
)
T.clear(out_local) T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
...@@ -150,10 +116,8 @@ def convolution(N, ...@@ -150,10 +116,8 @@ def convolution(N,
m = by * block_M + i m = by * block_M + i
access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
access_w = m % OW * S + k // C % KW * D - P access_w = m % OW * S + k // C % KW * D - P
in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W)
(access_w < W)) data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
data_shared[i, j] = T.if_then_else(
in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local) T.gemm(data_shared, kernel_shared, out_local)
...@@ -166,17 +130,19 @@ def convolution(N, ...@@ -166,17 +130,19 @@ def convolution(N,
return main return main
def main(n: int = 128, def main(
c: int = 128, n: int = 128,
h: int = 64, c: int = 128,
w: int = 64, h: int = 64,
f: int = 128, w: int = 64,
k: int = 3, f: int = 128,
s: int = 1, k: int = 3,
d: int = 1, s: int = 1,
p: int = 1, d: int = 1,
use_autotune: bool = False, p: int = 1,
with_roller: bool = True): use_autotune: bool = False,
with_roller: bool = True,
):
N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p
ref_prog = ref_program(S, P, D) ref_prog = ref_program(S, P, D)
...@@ -196,25 +162,16 @@ def main(n: int = 128, ...@@ -196,25 +162,16 @@ def main(n: int = 128,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument('--n', type=int, default=128, help='n') parser.add_argument("--n", type=int, default=128, help="n")
parser.add_argument('--c', type=int, default=128, help='c') parser.add_argument("--c", type=int, default=128, help="c")
parser.add_argument('--h', type=int, default=64, help='h') parser.add_argument("--h", type=int, default=64, help="h")
parser.add_argument('--w', type=int, default=64, help='w') parser.add_argument("--w", type=int, default=64, help="w")
parser.add_argument('--f', type=int, default=128, help='f') parser.add_argument("--f", type=int, default=128, help="f")
parser.add_argument('--k', type=int, default=3, help='k') parser.add_argument("--k", type=int, default=3, help="k")
parser.add_argument('--s', type=int, default=1, help='s') parser.add_argument("--s", type=int, default=1, help="s")
parser.add_argument('--d', type=int, default=1, help='d') parser.add_argument("--d", type=int, default=1, help="d")
parser.add_argument('--p', type=int, default=1, help='p') parser.add_argument("--p", type=int, default=1, help="p")
parser.add_argument( parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs")
"--use_autotune", parser.add_argument("--with_roller", action="store_true", default=True, help="Whether to enable BitBLAS roller for search space")
action="store_true",
default=False,
help="Whether to use autotune for matmul configs")
parser.add_argument(
"--with_roller",
action="store_true",
default=True,
help="Whether to enable BitBLAS roller for search space")
args = parser.parse_args() args = parser.parse_args()
main(args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p, args.use_autotune, main(args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p, args.use_autotune, args.with_roller)
args.with_roller)
...@@ -20,11 +20,11 @@ def tl_gemm( ...@@ -20,11 +20,11 @@ def tl_gemm(
accum_dtype, accum_dtype,
): ):
assert in_dtype in [ assert in_dtype in [
"float8_e4m3", T.float8_e4m3fn,
], "Currently only float8_e4m3 is supported" ], "Currently only float8_e4m3 is supported"
assert out_dtype in [ assert out_dtype in [
"bfloat16", T.bfloat16,
"float32", T.float32,
], "Currently only float16 and float32 are supported" ], "Currently only float16 and float32 are supported"
group_size = 128 group_size = 128
...@@ -41,18 +41,17 @@ def tl_gemm( ...@@ -41,18 +41,17 @@ def tl_gemm(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
scales_a: T.Tensor(Scales_A_shape, "float32"), scales_a: T.Tensor(Scales_A_shape, T.float32),
scales_b: T.Tensor(Scales_B_shape, "float32"), scales_b: T.Tensor(Scales_B_shape, T.float32),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_shared = T.alloc_shared(C_shared_shape, out_dtype) C_shared = T.alloc_shared(C_shared_shape, out_dtype)
Scale_C_shared = T.alloc_shared((block_M), "float32") Scale_C_shared = T.alloc_shared((block_M), T.float32)
C_local = T.alloc_fragment(C_shared_shape, accum_dtype) C_local = T.alloc_fragment(C_shared_shape, accum_dtype)
C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype) C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype)
...@@ -93,21 +92,18 @@ def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: ...@@ -93,21 +92,18 @@ def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
m, n = x.shape m, n = x.shape
x_view = x.view(m, -1, 128) x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view( return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
m, n), (x_amax / 448.0).view(m, -1)
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2 assert x.dim() == 2
m, n = x.shape m, n = x.shape
x_padded = torch.zeros( x_padded = torch.zeros(ceildiv(m, 128) * 128, ceildiv(n, 128) * 128, dtype=x.dtype, device=x.device)
ceildiv(m, 128) * 128, ceildiv(n, 128) * 128, dtype=x.dtype, device=x.device)
x_padded[:m, :n] = x x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view( return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
x_view.size(0), x_view.size(2))
def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype): def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype):
...@@ -127,13 +123,14 @@ def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype): ...@@ -127,13 +123,14 @@ def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype):
c_acc.zero_() c_acc.zero_()
for k in range(ceildiv(K, 128)): for k in range(ceildiv(K, 128)):
c = torch._scaled_mm( c = torch._scaled_mm(
A_fp8[i * 128:(i + 1) * 128, k * 128:(k + 1) * 128], A_fp8[i * 128 : (i + 1) * 128, k * 128 : (k + 1) * 128],
B_fp8[j * 128:(j + 1) * 128, k * 128:(k + 1) * 128].T, B_fp8[j * 128 : (j + 1) * 128, k * 128 : (k + 1) * 128].T,
scale_a=A_scales[i, k].view(128, 1).contiguous(), scale_a=A_scales[i, k].view(128, 1).contiguous(),
scale_b=B_scales[j, k].view(1, 128).contiguous(), scale_b=B_scales[j, k].view(1, 128).contiguous(),
out_dtype=torch.bfloat16) out_dtype=torch.bfloat16,
)
c_acc += c.to(torch.float32) c_acc += c.to(torch.float32)
C[i * 128:(i + 1) * 128, j * 128:(j + 1) * 128] = c_acc.to(out_dtype) C[i * 128 : (i + 1) * 128, j * 128 : (j + 1) * 128] = c_acc.to(out_dtype)
return C return C
...@@ -179,11 +176,11 @@ def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtyp ...@@ -179,11 +176,11 @@ def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtyp
def main(): def main():
assert_tl_gemm_correctness(1024, 1024, 8192, 128, "float8_e4m3", "bfloat16", "float32") assert_tl_gemm_correctness(1024, 1024, 8192, 128, T.float8_e4m3fn, T.bfloat16, T.float32)
if __name__ == "__main__": if __name__ == "__main__":
for dtype in ["float8_e4m3"]: for dtype in [T.float8_e4m3fn]:
for out_dtype in ["bfloat16", "float32"]: for out_dtype in [T.bfloat16, T.float32]:
for block_N in [16, 32, 64, 128]: for block_N in [16, 32, 64, 128]:
assert_tl_gemm_correctness(1024, 1024, 8192, block_N, dtype, out_dtype, "float32") assert_tl_gemm_correctness(1024, 1024, 8192, block_N, dtype, out_dtype, T.float32)
...@@ -8,6 +8,7 @@ import argparse ...@@ -8,6 +8,7 @@ import argparse
def get_configs(): def get_configs():
import itertools import itertools
BLOCK_N = [16, 32, 64, 128] BLOCK_N = [16, 32, 64, 128]
BLOCK_H = [16, 32, 64, 128] BLOCK_H = [16, 32, 64, 128]
num_split = [1, 2, 4, 8, 16, 32] num_split = [1, 2, 4, 8, 16, 32]
...@@ -15,43 +16,39 @@ def get_configs(): ...@@ -15,43 +16,39 @@ def get_configs():
_configs = list(itertools.product(BLOCK_N, BLOCK_H, num_split, threads)) _configs = list(itertools.product(BLOCK_N, BLOCK_H, num_split, threads))
return [{ return [
"block_N": c[0], {
"block_H": c[1], "block_N": c[0],
"num_split": c[2], "block_H": c[1],
"threads": c[3], "num_split": c[2],
} for c in _configs] "threads": c[3],
}
for c in _configs
]
@tilelang.autotune(configs=get_configs()) @tilelang.autotune(configs=get_configs())
@tilelang.jit( @tilelang.jit(
out_idx=[6], pass_configs={ out_idx=[6],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
def flashmla_decode(batch, )
heads, def flashmla_decode(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, threads=128):
kv_head_num, scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e)
seqlen_kv, dtype = T.float16
dim, accum_dtype = T.float32
pe_dim,
block_N,
block_H,
num_split,
threads=128):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
kv_group_num = heads // kv_head_num kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num) VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1" assert kv_head_num == 1, "kv_head_num must be 1"
@T.macro @T.macro
def flash_attn( def flash_attn(
Q: T.Tensor([batch, heads, dim], dtype), Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype), Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype), Output: T.Tensor([batch, heads, dim], dtype),
): ):
with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=threads) as (bx, by): with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=threads) as (bx, by):
Q_local = T.alloc_fragment([block_H, dim], dtype) Q_local = T.alloc_fragment([block_H, dim], dtype)
...@@ -70,27 +67,24 @@ def flashmla_decode(batch, ...@@ -70,27 +67,24 @@ def flashmla_decode(batch,
cur_kv_head = by // (kv_group_num // block_H) cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10) T.use_swizzle(10)
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_local) T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_local)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_local) T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_local)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(seqlen_kv, block_N) loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=0): for k in T.Pipelined(loop_range, num_stages=0):
T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared) T.copy(KV[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared) T.copy(K_pe[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.clear(acc_s) T.clear(acc_s)
T.gemm(Q_local, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_local, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.gemm( T.gemm(Q_pe_local, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
Q_pe_local,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
...@@ -105,20 +99,18 @@ def flashmla_decode(batch, ...@@ -105,20 +99,18 @@ def flashmla_decode(batch,
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :]) T.copy(acc_o, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :])
@T.macro @T.macro
def flash_attn_split( def flash_attn_split(
Q: T.Tensor([batch, heads, dim], dtype), Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype), Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype), glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
): ):
with T.Kernel( with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=threads) as (bx, by, bz):
batch, heads // min(block_H, kv_group_num), num_split,
threads=threads) as (bx, by, bz):
Q_local = T.alloc_fragment([block_H, dim], dtype) Q_local = T.alloc_fragment([block_H, dim], dtype)
Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype) Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype)
KV_shared = T.alloc_shared([block_N, dim], dtype) KV_shared = T.alloc_shared([block_N, dim], dtype)
...@@ -134,8 +126,8 @@ def flashmla_decode(batch, ...@@ -134,8 +126,8 @@ def flashmla_decode(batch,
cur_kv_head = by // (kv_group_num // block_H) cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10) T.use_swizzle(10)
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_local) T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_local)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_local) T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_local)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -148,15 +140,12 @@ def flashmla_decode(batch, ...@@ -148,15 +140,12 @@ def flashmla_decode(batch,
T.copy(K_pe[bx, kv_start:kv_end, cur_kv_head, :], K_pe_shared) T.copy(K_pe[bx, kv_start:kv_end, cur_kv_head, :], K_pe_shared)
T.clear(acc_s) T.clear(acc_s)
T.gemm(Q_local, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_local, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.gemm( T.gemm(Q_pe_local, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
Q_pe_local,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
...@@ -172,14 +161,14 @@ def flashmla_decode(batch, ...@@ -172,14 +161,14 @@ def flashmla_decode(batch,
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz]) T.copy(logsum, glse[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz])
T.copy(acc_o, Output_partial[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz, :]) T.copy(acc_o, Output_partial[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz, :])
@T.macro @T.macro
def combine( def combine(
glse: T.Tensor([batch, heads, num_split], dtype), glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype), Output: T.Tensor([batch, heads, dim], dtype),
): ):
with T.Kernel(heads, batch, threads=128) as (by, bz): with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim], dtype) po_local = T.alloc_fragment([dim], dtype)
...@@ -189,9 +178,11 @@ def flashmla_decode(batch, ...@@ -189,9 +178,11 @@ def flashmla_decode(batch,
lse_max_local = T.alloc_local([1], accum_dtype) lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({ T.annotate_layout(
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), {
}) lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}
)
T.clear(lse_logsum_local) T.clear(lse_logsum_local)
T.clear(o_accum_local) T.clear(o_accum_local)
...@@ -214,26 +205,26 @@ def flashmla_decode(batch, ...@@ -214,26 +205,26 @@ def flashmla_decode(batch,
@T.prim_func @T.prim_func
def main_split( def main_split(
Q: T.Tensor([batch, heads, dim], dtype), Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype), Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype), glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype), Output: T.Tensor([batch, heads, dim], dtype),
): ):
flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial) flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial)
combine(glse, Output_partial, Output) combine(glse, Output_partial, Output)
@T.prim_func @T.prim_func
def main_no_split( def main_no_split(
Q: T.Tensor([batch, heads, dim], dtype), Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype), Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype), glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype), Output: T.Tensor([batch, heads, dim], dtype),
): ):
flash_attn(Q, Q_pe, KV, K_pe, Output) flash_attn(Q, Q_pe, KV, K_pe, Output)
...@@ -258,43 +249,36 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): ...@@ -258,43 +249,36 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
dim = q.shape[-1] dim = q.shape[-1]
pe_dim = q_pe.shape[-1] pe_dim = q_pe.shape[-1]
num_head_groups = q.shape[1] // kv.shape[2] num_head_groups = q.shape[1] // kv.shape[2]
scale = (dim + pe_dim)**0.5 scale = (dim + pe_dim) ** 0.5
q = rearrange( q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q_pe = rearrange( q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
q_pe, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim]
k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim]
query = torch.concat([q, q_pe], dim=-1) query = torch.concat([q, q_pe], dim=-1)
key = torch.concat([kv, k_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1)
scores = einsum( scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv]
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax( attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, kv, out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim]
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out return out
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size') parser.add_argument("--batch", type=int, default=128, help="batch size")
parser.add_argument('--heads', type=int, default=128, help='q heads number') parser.add_argument("--heads", type=int, default=128, help="q heads number")
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number")
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length")
parser.add_argument('--dim', type=int, default=512, help='head dim') parser.add_argument("--dim", type=int, default=512, help="head dim")
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim")
parser.add_argument('--autotune', action='store_true', help='auto tune') parser.add_argument("--autotune", action="store_true", help="auto tune")
args = parser.parse_args() args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
enable_autotune = args.autotune enable_autotune = args.autotune
...@@ -310,17 +294,7 @@ if __name__ == "__main__": ...@@ -310,17 +294,7 @@ if __name__ == "__main__":
if enable_autotune: if enable_autotune:
kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim) kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim)
else: else:
kernel = flashmla_decode( kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, threads=threads)
batch,
heads,
kv_heads,
kv_ctx,
dim,
pe_dim,
BLOCK_N,
BLOCK_H,
num_split,
threads=threads)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
input_tensors = profiler._get_inputs() input_tensors = profiler._get_inputs()
tilelang_output = kernel(*input_tensors) tilelang_output = kernel(*input_tensors)
......
...@@ -32,8 +32,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): ...@@ -32,8 +32,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@torch.inference_mode() @torch.inference_mode()
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
h_kv, d, dv, causal, dtype):
blocked_v = blocked_k[..., :dv] blocked_v = blocked_k[..., :dv]
def ref_mla(): def ref_mla():
...@@ -94,8 +93,7 @@ def _mla_attn_kernel( ...@@ -94,8 +93,7 @@ def _mla_attn_kernel(
offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) offs_d_ckv = tl.arange(0, HEAD_DIM_CKV)
cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H)
offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[ offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :]
None, :]
q_nope = tl.load(Q_nope + offs_q_nope) q_nope = tl.load(Q_nope + offs_q_nope)
offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) offs_d_kpe = tl.arange(0, HEAD_DIM_KPE)
...@@ -141,9 +139,7 @@ def _mla_attn_kernel( ...@@ -141,9 +139,7 @@ def _mla_attn_kernel(
e_sum = e_sum * re_scale + tl.sum(p, 1) e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max e_max = n_e_max
offs_o = cur_batch * stride_o_b + cur_head[:, offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :]
None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[
None, :]
tl.store(O + offs_o, acc / e_sum[:, None]) tl.store(O + offs_o, acc / e_sum[:, None])
offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV
tl.store(O + offs_o_1, e_max + tl.log(e_sum)) tl.store(O + offs_o_1, e_max + tl.log(e_sum))
...@@ -309,24 +305,30 @@ def mla_decode_triton( ...@@ -309,24 +305,30 @@ def mla_decode_triton(
@torch.inference_mode() @torch.inference_mode()
def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
blocked_v = blocked_k[..., :dv] blocked_v = blocked_k[..., :dv]
assert d > dv, "mla with rope dim should be larger than no rope dim" assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
dv:].contiguous()
def flash_mla_triton(): def flash_mla_triton():
num_kv_splits = 32 num_kv_splits = 32
o = torch.empty([b * s_q, h_q, dv]) o = torch.empty([b * s_q, h_q, dv])
attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1])
mla_decode_triton( mla_decode_triton(
q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope.view(-1, dv), q_nope.view(-1, h_q, dv),
blocked_k_pe.view(-1, d - dv), o, block_table, cache_seqlens, attn_logits, q_pe.view(-1, h_q, d - dv),
num_kv_splits, 1 / math.sqrt(d), block_size) blocked_k_nope.view(-1, dv),
blocked_k_pe.view(-1, d - dv),
o,
block_table,
cache_seqlens,
attn_logits,
num_kv_splits,
1 / math.sqrt(d),
block_size,
)
return o.view([b, s_q, h_q, dv]) return o.view([b, s_q, h_q, dv])
out_flash = flash_mla_triton() out_flash = flash_mla_triton()
...@@ -362,14 +364,15 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal ...@@ -362,14 +364,15 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
q = torch.randn(b, s_q, h_q, d) q = torch.randn(b, s_q, h_q, d)
block_size = 64 block_size = 64
block_table = torch.arange( block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, out_a, lse_a, perf_a = baseline_func(
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, )
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) out_b, lse_b, perf_b = target_func(
q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out"
if target not in ["flash_mla_triton"]: if target not in ["flash_mla_triton"]:
...@@ -377,21 +380,14 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal ...@@ -377,21 +380,14 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse"
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
torch.finfo(dtype).bits // 8) print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s")
print( print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s")
f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s"
)
print(
f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s"
)
return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b
def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
print( print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}")
f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}"
)
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
device = torch.device("cuda:0") device = torch.device("cuda:0")
torch.set_default_device(device) torch.set_default_device(device)
...@@ -408,19 +404,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): ...@@ -408,19 +404,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
q = torch.randn(b, s_q, h_q, d) q = torch.randn(b, s_q, h_q, d)
block_size = 64 block_size = 64
block_table = torch.arange( block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, out_b, lse_b, perf_b = target_func(
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
torch.finfo(dtype).bits // 8) print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s")
print(
f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s"
)
return bytes / 10**6 / perf_b return bytes / 10**6 / perf_b
...@@ -429,26 +422,22 @@ available_targets = [ ...@@ -429,26 +422,22 @@ available_targets = [
"flash_mla_triton", "flash_mla_triton",
] ]
shape_configs = [{ shape_configs = [
"b": {
batch, "b": batch,
"s_q": "s_q": 1,
1, "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"),
"cache_seqlens": "h_q": head,
torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), "h_kv": 1,
"h_q": "d": 512 + 64,
head, "dv": 512,
"h_kv": "causal": True,
1, "dtype": torch.float16,
"d": }
512 + 64, for batch in [128]
"dv": for seqlen in [1024, 2048, 4096, 8192, 16384]
512, for head in [128]
"causal": ]
True,
"dtype":
torch.float16
} for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 16384] for head in [128]]
def get_args(): def get_args():
...@@ -470,26 +459,54 @@ if __name__ == "__main__": ...@@ -470,26 +459,54 @@ if __name__ == "__main__":
for shape in shape_configs: for shape in shape_configs:
if args.all: if args.all:
for target in available_targets: for target in available_targets:
perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], perf = compare_a(
shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], target,
shape["causal"], shape["dtype"]) shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write( fout.write(
f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n"
) )
elif args.compare: elif args.compare:
perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], perfa, prefb = compare_ab(
shape["cache_seqlens"], shape["h_q"], shape["h_kv"], args.baseline,
shape["d"], shape["dv"], shape["causal"], shape["dtype"]) args.target,
shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write( fout.write(
f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n' f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n"
) )
fout.write( fout.write(
f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n' f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n"
) )
elif args.one: elif args.one:
perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], perf = compare_a(
shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], args.target,
shape["causal"], shape["dtype"]) shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write( fout.write(
f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n"
) )
...@@ -29,8 +29,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): ...@@ -29,8 +29,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@torch.inference_mode() @torch.inference_mode()
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
h_kv, d, dv, causal, dtype):
blocked_v = blocked_k[..., :dv] blocked_v = blocked_k[..., :dv]
def ref_mla(): def ref_mla():
...@@ -91,8 +90,7 @@ def _mla_attn_kernel( ...@@ -91,8 +90,7 @@ def _mla_attn_kernel(
offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) offs_d_ckv = tl.arange(0, HEAD_DIM_CKV)
cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H)
offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[ offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :]
None, :]
q_nope = tl.load(Q_nope + offs_q_nope) q_nope = tl.load(Q_nope + offs_q_nope)
offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) offs_d_kpe = tl.arange(0, HEAD_DIM_KPE)
...@@ -138,9 +136,7 @@ def _mla_attn_kernel( ...@@ -138,9 +136,7 @@ def _mla_attn_kernel(
e_sum = e_sum * re_scale + tl.sum(p, 1) e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max e_max = n_e_max
offs_o = cur_batch * stride_o_b + cur_head[:, offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :]
None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[
None, :]
tl.store(O + offs_o, acc / e_sum[:, None]) tl.store(O + offs_o, acc / e_sum[:, None])
offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV
tl.store(O + offs_o_1, e_max + tl.log(e_sum)) tl.store(O + offs_o_1, e_max + tl.log(e_sum))
...@@ -306,24 +302,30 @@ def mla_decode_triton( ...@@ -306,24 +302,30 @@ def mla_decode_triton(
@torch.inference_mode() @torch.inference_mode()
def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
blocked_v = blocked_k[..., :dv] blocked_v = blocked_k[..., :dv]
assert d > dv, "mla with rope dim should be larger than no rope dim" assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
dv:].contiguous()
def flash_mla_triton(): def flash_mla_triton():
num_kv_splits = 32 num_kv_splits = 32
o = torch.empty([b * s_q, h_q, dv]) o = torch.empty([b * s_q, h_q, dv])
attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1])
mla_decode_triton( mla_decode_triton(
q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope.view(-1, dv), q_nope.view(-1, h_q, dv),
blocked_k_pe.view(-1, d - dv), o, block_table, cache_seqlens, attn_logits, q_pe.view(-1, h_q, d - dv),
num_kv_splits, 1 / math.sqrt(d), block_size) blocked_k_nope.view(-1, dv),
blocked_k_pe.view(-1, d - dv),
o,
block_table,
cache_seqlens,
attn_logits,
num_kv_splits,
1 / math.sqrt(d),
block_size,
)
return o.view([b, s_q, h_q, dv]) return o.view([b, s_q, h_q, dv])
out_flash = flash_mla_triton() out_flash = flash_mla_triton()
...@@ -359,14 +361,15 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal ...@@ -359,14 +361,15 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
q = torch.randn(b, s_q, h_q, d) q = torch.randn(b, s_q, h_q, d)
block_size = 64 block_size = 64
block_table = torch.arange( block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, out_a, lse_a, perf_a = baseline_func(
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, )
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) out_b, lse_b, perf_b = target_func(
q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out"
if target not in ["flash_mla_triton"]: if target not in ["flash_mla_triton"]:
...@@ -374,21 +377,14 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal ...@@ -374,21 +377,14 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse"
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
torch.finfo(dtype).bits // 8) print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s")
print( print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s")
f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s"
)
print(
f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s"
)
return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b
def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
print( print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}")
f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}"
)
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
device = torch.device("cuda:0") device = torch.device("cuda:0")
torch.set_default_device(device) torch.set_default_device(device)
...@@ -405,19 +401,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): ...@@ -405,19 +401,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
q = torch.randn(b, s_q, h_q, d) q = torch.randn(b, s_q, h_q, d)
block_size = 64 block_size = 64
block_table = torch.arange( block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, out_b, lse_b, perf_b = target_func(
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
torch.finfo(dtype).bits // 8) print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s")
print(
f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s"
)
return bytes / 10**6 / perf_b return bytes / 10**6 / perf_b
...@@ -426,26 +419,22 @@ available_targets = [ ...@@ -426,26 +419,22 @@ available_targets = [
"flash_mla_triton", "flash_mla_triton",
] ]
shape_configs = [{ shape_configs = [
"b": {
batch, "b": batch,
"s_q": "s_q": 1,
1, "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"),
"cache_seqlens": "h_q": head,
torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), "h_kv": 1,
"h_q": "d": 512 + 64,
head, "dv": 512,
"h_kv": "causal": True,
1, "dtype": torch.float16,
"d": }
512 + 64, for batch in [64, 128]
"dv": for seqlen in [1024, 2048, 4096, 8192, 16384]
512, for head in [128]
"causal": ]
True,
"dtype":
torch.float16
} for batch in [64, 128] for seqlen in [1024, 2048, 4096, 8192, 16384] for head in [128]]
def get_args(): def get_args():
...@@ -467,26 +456,54 @@ if __name__ == "__main__": ...@@ -467,26 +456,54 @@ if __name__ == "__main__":
for shape in shape_configs: for shape in shape_configs:
if args.all: if args.all:
for target in available_targets: for target in available_targets:
perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], perf = compare_a(
shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], target,
shape["causal"], shape["dtype"]) shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write( fout.write(
f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n"
) )
elif args.compare: elif args.compare:
perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], perfa, prefb = compare_ab(
shape["cache_seqlens"], shape["h_q"], shape["h_kv"], args.baseline,
shape["d"], shape["dv"], shape["causal"], shape["dtype"]) args.target,
shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write( fout.write(
f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n' f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n"
) )
fout.write( fout.write(
f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n' f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n"
) )
elif args.one: elif args.one:
perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], perf = compare_a(
shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], args.target,
shape["causal"], shape["dtype"]) shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write( fout.write(
f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n"
) )
...@@ -33,8 +33,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): ...@@ -33,8 +33,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@torch.inference_mode() @torch.inference_mode()
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
h_kv, d, dv, causal, dtype):
blocked_v = blocked_k[..., :dv] blocked_v = blocked_k[..., :dv]
def ref_mla(): def ref_mla():
...@@ -61,8 +60,7 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, ...@@ -61,8 +60,7 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
@torch.inference_mode() @torch.inference_mode()
def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
h_kv, d, dv, causal, dtype):
from flash_mla import flash_mla_with_kvcache, get_mla_metadata from flash_mla import flash_mla_with_kvcache, get_mla_metadata
blocked_v = blocked_k[..., :dv] blocked_v = blocked_k[..., :dv]
...@@ -87,14 +85,13 @@ def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, ...@@ -87,14 +85,13 @@ def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
@torch.inference_mode() @torch.inference_mode()
def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
h_q, h_kv, d, dv, causal, dtype):
# pip install flashinfer-python # pip install flashinfer-python
import flashinfer import flashinfer
assert d > dv, "mla with rope dim should be larger than no rope dim" assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
dv:].contiguous()
kv_indptr = [0] kv_indptr = [0]
kv_indices = [] kv_indices = []
...@@ -111,8 +108,7 @@ def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q ...@@ -111,8 +108,7 @@ def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(torch.empty(128 * 1024 * 1024, dtype=torch.int8), backend="fa3")
torch.empty(128 * 1024 * 1024, dtype=torch.int8), backend="fa3")
mla_wrapper.plan( mla_wrapper.plan(
q_indptr, q_indptr,
kv_indptr, kv_indptr,
...@@ -129,12 +125,7 @@ def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q ...@@ -129,12 +125,7 @@ def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q
) )
def flashinfer(): def flashinfer():
output, lse = mla_wrapper.run( output, lse = mla_wrapper.run(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope, blocked_k_pe, return_lse=True)
q_nope.view(-1, h_q, dv),
q_pe.view(-1, h_q, d - dv),
blocked_k_nope,
blocked_k_pe,
return_lse=True)
return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1) return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1)
out_flash, lse_flash = flashinfer() out_flash, lse_flash = flashinfer()
...@@ -177,8 +168,7 @@ def _mla_attn_kernel( ...@@ -177,8 +168,7 @@ def _mla_attn_kernel(
offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) offs_d_ckv = tl.arange(0, HEAD_DIM_CKV)
cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H)
offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[ offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :]
None, :]
q_nope = tl.load(Q_nope + offs_q_nope) q_nope = tl.load(Q_nope + offs_q_nope)
offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) offs_d_kpe = tl.arange(0, HEAD_DIM_KPE)
...@@ -224,9 +214,7 @@ def _mla_attn_kernel( ...@@ -224,9 +214,7 @@ def _mla_attn_kernel(
e_sum = e_sum * re_scale + tl.sum(p, 1) e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max e_max = n_e_max
offs_o = cur_batch * stride_o_b + cur_head[:, offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :]
None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[
None, :]
tl.store(O + offs_o, acc / e_sum[:, None]) tl.store(O + offs_o, acc / e_sum[:, None])
offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV
tl.store(O + offs_o_1, e_max + tl.log(e_sum)) tl.store(O + offs_o_1, e_max + tl.log(e_sum))
...@@ -393,24 +381,30 @@ def mla_decode_triton( ...@@ -393,24 +381,30 @@ def mla_decode_triton(
@torch.inference_mode() @torch.inference_mode()
def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
blocked_v = blocked_k[..., :dv] blocked_v = blocked_k[..., :dv]
assert d > dv, "mla with rope dim should be larger than no rope dim" assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
dv:].contiguous()
def flash_mla_triton(): def flash_mla_triton():
num_kv_splits = 32 num_kv_splits = 32
o = torch.empty([b * s_q, h_q, dv]) o = torch.empty([b * s_q, h_q, dv])
attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1])
mla_decode_triton( mla_decode_triton(
q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope.view(-1, dv), q_nope.view(-1, h_q, dv),
blocked_k_pe.view(-1, d - dv), o, block_table, cache_seqlens, attn_logits, q_pe.view(-1, h_q, d - dv),
num_kv_splits, 1 / math.sqrt(d), block_size) blocked_k_nope.view(-1, dv),
blocked_k_pe.view(-1, d - dv),
o,
block_table,
cache_seqlens,
attn_logits,
num_kv_splits,
1 / math.sqrt(d),
block_size,
)
return o.view([b, s_q, h_q, dv]) return o.view([b, s_q, h_q, dv])
out_flash = flash_mla_triton() out_flash = flash_mla_triton()
...@@ -419,13 +413,10 @@ def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, ...@@ -419,13 +413,10 @@ def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size,
@torch.inference_mode() @torch.inference_mode()
def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
assert d > dv, "mla with rope dim should be larger than no rope dim" assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
dv:].contiguous()
dpe = d - dv dpe = d - dv
num_kv_splits = 1 num_kv_splits = 1
...@@ -434,8 +425,7 @@ def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size ...@@ -434,8 +425,7 @@ def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size
out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device)
glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device)
kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size)
num_kv_splits, block_size)
def flash_mla_tilelang(): def flash_mla_tilelang():
out = kernel( out = kernel(
...@@ -486,38 +476,31 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal ...@@ -486,38 +476,31 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
q = torch.randn(b, s_q, h_q, d) q = torch.randn(b, s_q, h_q, d)
block_size = 64 block_size = 64
block_table = torch.arange( block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, out_a, lse_a, perf_a = baseline_func(
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, )
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) out_b, lse_b, perf_b = target_func(
q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out"
if target not in ["flashinfer", "flash_mla_triton", "tilelang" if target not in ["flashinfer", "flash_mla_triton", "tilelang"] and baseline not in ["flashinfer", "flash_mla_triton", "tilelang"]:
] and baseline not in ["flashinfer", "flash_mla_triton", "tilelang"]:
# flashinfer has a different lse return value # flashinfer has a different lse return value
# flash_mla_triton and flash_mla_tilelang doesn't return lse # flash_mla_triton and flash_mla_tilelang doesn't return lse
torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse"
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
torch.finfo(dtype).bits // 8) print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s")
print( print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s")
f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s"
)
print(
f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s"
)
return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b
def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
print( print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}")
f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}"
)
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
device = torch.device("cuda:0") device = torch.device("cuda:0")
torch.set_default_device(device) torch.set_default_device(device)
...@@ -534,19 +517,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): ...@@ -534,19 +517,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
q = torch.randn(b, s_q, h_q, d) q = torch.randn(b, s_q, h_q, d)
block_size = 64 block_size = 64
block_table = torch.arange( block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, out_b, lse_b, perf_b = target_func(
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
torch.finfo(dtype).bits // 8) print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s")
print(
f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s"
)
return bytes / 10**6 / perf_b return bytes / 10**6 / perf_b
...@@ -558,26 +538,22 @@ available_targets = [ ...@@ -558,26 +538,22 @@ available_targets = [
"flash_mla_triton", "flash_mla_triton",
] ]
shape_configs = [{ shape_configs = [
"b": {
batch, "b": batch,
"s_q": "s_q": 1,
1, "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"),
"cache_seqlens": "h_q": head,
torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), "h_kv": 1,
"h_q": "d": 512 + 64,
head, "dv": 512,
"h_kv": "causal": True,
1, "dtype": torch.float16,
"d": }
512 + 64, for batch in [128]
"dv": for seqlen in [1024, 2048, 4096, 8192, 16384, 32768]
512, for head in [128]
"causal": ]
True,
"dtype":
torch.float16
} for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 16384, 32768] for head in [128]]
def get_args(): def get_args():
...@@ -599,26 +575,54 @@ if __name__ == "__main__": ...@@ -599,26 +575,54 @@ if __name__ == "__main__":
for shape in shape_configs: for shape in shape_configs:
if args.all: if args.all:
for target in available_targets: for target in available_targets:
perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], perf = compare_a(
shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], target,
shape["causal"], shape["dtype"]) shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write( fout.write(
f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n"
) )
elif args.compare: elif args.compare:
perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], perfa, prefb = compare_ab(
shape["cache_seqlens"], shape["h_q"], shape["h_kv"], args.baseline,
shape["d"], shape["dv"], shape["causal"], shape["dtype"]) args.target,
shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write( fout.write(
f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n' f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n"
) )
fout.write( fout.write(
f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n' f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n"
) )
elif args.one: elif args.one:
perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], perf = compare_a(
shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], args.target,
shape["causal"], shape["dtype"]) shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write( fout.write(
f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n"
) )
...@@ -8,25 +8,26 @@ import argparse ...@@ -8,25 +8,26 @@ import argparse
@tilelang.jit( @tilelang.jit(
out_idx=[6], pass_configs={ out_idx=[6],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, )
softmax_scale): def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale):
scale = float(softmax_scale * 1.44269504) # log2(e) scale = float(softmax_scale * 1.44269504) # log2(e)
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
kv_group_num = heads // kv_head_num kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num) VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1" assert kv_head_num == 1, "kv_head_num must be 1"
@T.macro @T.macro
def flash_attn( def flash_attn(
Q: T.Tensor([batch, heads, dim], dtype), Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype), Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype), Output: T.Tensor([batch, heads, dim], dtype),
): ):
with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid): with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid):
Q_shared = T.alloc_shared([block_H, dim], dtype) Q_shared = T.alloc_shared([block_H, dim], dtype)
...@@ -44,36 +45,29 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -44,36 +45,29 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
logsum = T.alloc_fragment([block_H], accum_dtype) logsum = T.alloc_fragment([block_H], accum_dtype)
cur_kv_head = hid // (kv_group_num // block_H) cur_kv_head = hid // (kv_group_num // block_H)
T.annotate_layout({ T.annotate_layout(
O_shared: tilelang.layout.make_swizzled_layout(O_shared), {
}) O_shared: tilelang.layout.make_swizzled_layout(O_shared),
}
)
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared) T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(seqlen_kv, block_N) loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=2): for k in T.Pipelined(loop_range, num_stages=2):
T.copy(KV[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared) T.copy(KV[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared) T.copy(K_pe[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.gemm( T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True)
Q_shared, T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
KV_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
clear_accum=True)
T.gemm(
Q_pe_shared,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
...@@ -88,20 +82,18 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -88,20 +82,18 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :]) T.copy(O_shared, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :])
@T.macro @T.macro
def flash_attn_split( def flash_attn_split(
Q: T.Tensor([batch, heads, dim], dtype), Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype), Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype), glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
): ):
with T.Kernel( with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=256) as (bid, hid, bz):
batch, heads // min(block_H, kv_group_num), num_split,
threads=256) as (bid, hid, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype) Q_shared = T.alloc_shared([block_H, dim], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype) S_shared = T.alloc_shared([block_H, block_N], dtype)
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
...@@ -119,13 +111,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -119,13 +111,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
cur_kv_head = hid // (kv_group_num // block_H) cur_kv_head = hid // (kv_group_num // block_H)
T.use_swizzle(10) T.use_swizzle(10)
T.annotate_layout({ T.annotate_layout(
O_shared: tilelang.layout.make_swizzled_layout(O_shared), {
S_shared: tilelang.layout.make_swizzled_layout(S_shared), O_shared: tilelang.layout.make_swizzled_layout(O_shared),
}) S_shared: tilelang.layout.make_swizzled_layout(S_shared),
}
)
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared) T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -137,17 +131,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -137,17 +131,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared) T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared)
T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared) T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared)
T.clear(acc_s) T.clear(acc_s)
T.gemm( T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(
Q_pe_shared,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
...@@ -164,16 +154,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -164,16 +154,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, bz]) T.copy(logsum, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz])
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy(O_shared, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, T.copy(O_shared, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz, :])
bz, :])
@T.macro @T.macro
def combine( def combine(
glse: T.Tensor([batch, heads, num_split], dtype), glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype), Output: T.Tensor([batch, heads, dim], dtype),
): ):
with T.Kernel(heads, batch, threads=128) as (hid, bz): with T.Kernel(heads, batch, threads=128) as (hid, bz):
po_local = T.alloc_fragment([dim], dtype) po_local = T.alloc_fragment([dim], dtype)
...@@ -183,9 +172,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -183,9 +172,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
lse_max_local = T.alloc_local([1], accum_dtype) lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({ T.annotate_layout(
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), {
}) lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}
)
T.clear(lse_logsum_local) T.clear(lse_logsum_local)
T.clear(o_accum_local) T.clear(o_accum_local)
...@@ -208,26 +199,26 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -208,26 +199,26 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@T.prim_func @T.prim_func
def main_split( def main_split(
Q: T.Tensor([batch, heads, dim], dtype), Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype), Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype), glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype), Output: T.Tensor([batch, heads, dim], dtype),
): ):
flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial) flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial)
combine(glse, Output_partial, Output) combine(glse, Output_partial, Output)
@T.prim_func @T.prim_func
def main_no_split( def main_no_split(
Q: T.Tensor([batch, heads, dim], dtype), Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype), Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype), glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype), Output: T.Tensor([batch, heads, dim], dtype),
): ):
flash_attn(Q, Q_pe, KV, K_pe, Output) flash_attn(Q, Q_pe, KV, K_pe, Output)
...@@ -252,31 +243,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): ...@@ -252,31 +243,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
dim = q.shape[-1] dim = q.shape[-1]
pe_dim = q_pe.shape[-1] pe_dim = q_pe.shape[-1]
num_head_groups = q.shape[1] // kv.shape[2] num_head_groups = q.shape[1] // kv.shape[2]
scale = (dim + pe_dim)**0.5 scale = (dim + pe_dim) ** 0.5
q = rearrange( q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q_pe = rearrange( q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
q_pe, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim]
k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim]
query = torch.concat([q, q_pe], dim=-1) query = torch.concat([q, q_pe], dim=-1)
key = torch.concat([kv, k_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1)
scores = einsum( scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv]
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax( attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, kv, out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim]
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out return out
...@@ -294,10 +278,9 @@ def main( ...@@ -294,10 +278,9 @@ def main(
BLOCK_N = 64 BLOCK_N = 64
BLOCK_H = min(64, heads // kv_heads) BLOCK_H = min(64, heads // kv_heads)
num_split = 1 num_split = 1
softmax_scale = (dim + pe_dim)**-0.5 softmax_scale = (dim + pe_dim) ** -0.5
kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, softmax_scale)
softmax_scale)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4) profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4)
latency = profiler.do_bench(warmup=500) latency = profiler.do_bench(warmup=500)
...@@ -307,12 +290,12 @@ def main( ...@@ -307,12 +290,12 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=132, help='batch size') parser.add_argument("--batch", type=int, default=132, help="batch size")
parser.add_argument('--heads', type=int, default=128, help='q heads number') parser.add_argument("--heads", type=int, default=128, help="q heads number")
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number")
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length")
parser.add_argument('--dim', type=int, default=512, help='head dim') parser.add_argument("--dim", type=int, default=512, help="head dim")
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim")
args = parser.parse_args() args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
main(batch, heads, kv_heads, kv_ctx, dim, pe_dim) main(batch, heads, kv_heads, kv_ctx, dim, pe_dim)
...@@ -8,25 +8,17 @@ import math ...@@ -8,25 +8,17 @@ import math
@tilelang.jit( @tilelang.jit(
out_idx=[8], pass_configs={ out_idx=[8],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
def mla_decode_tilelang(batch, )
h_q, def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split, block_size, softmax_scale=None):
h_kv,
max_seqlen_pad,
dv,
dpe,
block_N,
block_H,
num_split,
block_size,
softmax_scale=None):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = (dv + dpe)**-0.5 softmax_scale = (dv + dpe) ** -0.5
scale = float(softmax_scale * 1.44269504) # log2(e) scale = float(softmax_scale * 1.44269504) # log2(e)
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
kv_group_num = h_q // h_kv kv_group_num = h_q // h_kv
VALID_BLOCK_H = min(block_H, kv_group_num) VALID_BLOCK_H = min(block_H, kv_group_num)
assert h_kv == 1, "h_kv must be 1" assert h_kv == 1, "h_kv must be 1"
...@@ -34,13 +26,13 @@ def mla_decode_tilelang(batch, ...@@ -34,13 +26,13 @@ def mla_decode_tilelang(batch,
@T.macro @T.macro
def flash_mla_kernel( def flash_mla_kernel(
Q: T.Tensor([batch, h_q, dv], dtype), Q: T.Tensor([batch, h_q, dv], dtype),
Q_pe: T.Tensor([batch, h_q, dpe], dtype), Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], T.int32),
CACHE_SEQLENS: T.Tensor([batch], "int32"), CACHE_SEQLENS: T.Tensor([batch], T.int32),
Output: T.Tensor([batch, h_q, dv], dtype), Output: T.Tensor([batch, h_q, dv], dtype),
): ):
with T.Kernel(batch, h_q // min(block_H, kv_group_num), threads=256) as (bx, by): with T.Kernel(batch, h_q // min(block_H, kv_group_num), threads=256) as (bx, by):
Q_shared = T.alloc_shared([block_H, dv], dtype) Q_shared = T.alloc_shared([block_H, dv], dtype)
...@@ -59,13 +51,15 @@ def mla_decode_tilelang(batch, ...@@ -59,13 +51,15 @@ def mla_decode_tilelang(batch,
cur_kv_head = by // (kv_group_num // block_H) cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10) T.use_swizzle(10)
T.annotate_layout({ T.annotate_layout(
O_shared: tilelang.layout.make_swizzled_layout(O_shared), {
S_shared: tilelang.layout.make_swizzled_layout(S_shared), O_shared: tilelang.layout.make_swizzled_layout(O_shared),
}) S_shared: tilelang.layout.make_swizzled_layout(S_shared),
}
)
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared) T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -73,26 +67,20 @@ def mla_decode_tilelang(batch, ...@@ -73,26 +67,20 @@ def mla_decode_tilelang(batch,
loop_range = T.ceildiv(CACHE_SEQLENS[bx], block_N) loop_range = T.ceildiv(CACHE_SEQLENS[bx], block_N)
for kr in T.Pipelined(loop_range, num_stages=2): for kr in T.Pipelined(loop_range, num_stages=2):
k = loop_range - 1 - kr k = loop_range - 1 - kr
kv_start = BLOCK_TABLE[bx, (k * block_N) // kv_start = BLOCK_TABLE[bx, (k * block_N) // block_size] * block_size + (k * block_N) % block_size
block_size] * block_size + (k * block_N) % block_size T.copy(KV[kv_start : kv_start + block_N, cur_kv_head, :], KV_shared)
T.copy(KV[kv_start:kv_start + block_N, cur_kv_head, :], KV_shared) T.copy(K_pe[kv_start : kv_start + block_N, cur_kv_head, :], K_pe_shared)
T.copy(K_pe[kv_start:kv_start + block_N, cur_kv_head, :], K_pe_shared)
T.clear(acc_s) T.clear(acc_s)
T.gemm( T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(
Q_pe_shared,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
if kr == 0: if kr == 0:
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx], acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j])
-T.infinity(accum_dtype), acc_s[i, j])
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
...@@ -107,21 +95,20 @@ def mla_decode_tilelang(batch, ...@@ -107,21 +95,20 @@ def mla_decode_tilelang(batch,
for i, j in T.Parallel(block_H, dv): for i, j in T.Parallel(block_H, dv):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :]) T.copy(O_shared, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :])
@T.macro @T.macro
def flash_mla_split_kv_kernel( def flash_mla_split_kv_kernel(
Q: T.Tensor([batch, h_q, dv], dtype), Q: T.Tensor([batch, h_q, dv], dtype),
Q_pe: T.Tensor([batch, h_q, dpe], dtype), Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], T.int32),
CACHE_SEQLENS: T.Tensor([batch], "int32"), CACHE_SEQLENS: T.Tensor([batch], T.int32),
glse: T.Tensor([batch, h_q, num_split], dtype), glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
): ):
with T.Kernel( with T.Kernel(batch, h_q // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz):
batch, h_q // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dv], dtype) Q_shared = T.alloc_shared([block_H, dv], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype) S_shared = T.alloc_shared([block_H, block_N], dtype)
Q_pe_shared = T.alloc_shared([block_H, dpe], dtype) Q_pe_shared = T.alloc_shared([block_H, dpe], dtype)
...@@ -139,13 +126,15 @@ def mla_decode_tilelang(batch, ...@@ -139,13 +126,15 @@ def mla_decode_tilelang(batch,
cur_kv_head = by // (kv_group_num // block_H) cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10) T.use_swizzle(10)
T.annotate_layout({ T.annotate_layout(
O_shared: tilelang.layout.make_swizzled_layout(O_shared), {
S_shared: tilelang.layout.make_swizzled_layout(S_shared), O_shared: tilelang.layout.make_swizzled_layout(O_shared),
}) S_shared: tilelang.layout.make_swizzled_layout(S_shared),
}
)
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared) T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -153,29 +142,23 @@ def mla_decode_tilelang(batch, ...@@ -153,29 +142,23 @@ def mla_decode_tilelang(batch,
total_blocks = T.ceildiv(CACHE_SEQLENS[bx], block_N) total_blocks = T.ceildiv(CACHE_SEQLENS[bx], block_N)
blocks_per_split = T.floordiv(total_blocks, num_split) blocks_per_split = T.floordiv(total_blocks, num_split)
remaining_blocks = T.floormod(total_blocks, num_split) remaining_blocks = T.floormod(total_blocks, num_split)
loop_range = (blocks_per_split + T.if_then_else(bz < remaining_blocks, 1, 0)) loop_range = blocks_per_split + T.if_then_else(bz < remaining_blocks, 1, 0)
start = (blocks_per_split * bz + T.min(bz, remaining_blocks)) * block_N start = (blocks_per_split * bz + T.min(bz, remaining_blocks)) * block_N
for k in T.Pipelined(loop_range, num_stages=2): for k in T.Pipelined(loop_range, num_stages=2):
kv_start = BLOCK_TABLE[bx, (start + k * block_N) // kv_start = BLOCK_TABLE[bx, (start + k * block_N) // block_size] * block_size + (k * block_N) % block_size
block_size] * block_size + (k * block_N) % block_size T.copy(KV[kv_start : kv_start + block_N, cur_kv_head, :], KV_shared)
T.copy(KV[kv_start:kv_start + block_N, cur_kv_head, :], KV_shared) T.copy(K_pe[kv_start : kv_start + block_N, cur_kv_head, :], K_pe_shared)
T.copy(K_pe[kv_start:kv_start + block_N, cur_kv_head, :], K_pe_shared)
T.clear(acc_s) T.clear(acc_s)
T.gemm( T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(
Q_pe_shared,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(start + k * block_N + j >= CACHE_SEQLENS[bx], acc_s[i, j] = T.if_then_else(start + k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j])
-T.infinity(accum_dtype), acc_s[i, j])
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
...@@ -192,15 +175,15 @@ def mla_decode_tilelang(batch, ...@@ -192,15 +175,15 @@ def mla_decode_tilelang(batch,
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz]) T.copy(logsum, glse[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz])
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy(O_shared, Output_partial[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz, :]) T.copy(O_shared, Output_partial[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz, :])
@T.macro @T.macro
def combine( def combine(
glse: T.Tensor([batch, h_q, num_split], dtype), glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
Output: T.Tensor([batch, h_q, dv], dtype), Output: T.Tensor([batch, h_q, dv], dtype),
): ):
with T.Kernel(h_q, batch, threads=128) as (by, bz): with T.Kernel(h_q, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dv], dtype) po_local = T.alloc_fragment([dv], dtype)
...@@ -210,9 +193,11 @@ def mla_decode_tilelang(batch, ...@@ -210,9 +193,11 @@ def mla_decode_tilelang(batch,
lse_max_local = T.alloc_local([1], accum_dtype) lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({ T.annotate_layout(
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), {
}) lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}
)
T.clear(lse_logsum_local) T.clear(lse_logsum_local)
T.clear(o_accum_local) T.clear(o_accum_local)
...@@ -235,31 +220,30 @@ def mla_decode_tilelang(batch, ...@@ -235,31 +220,30 @@ def mla_decode_tilelang(batch,
@T.prim_func @T.prim_func
def main_split( def main_split(
Q: T.Tensor([batch, h_q, dv], dtype), Q: T.Tensor([batch, h_q, dv], dtype),
Q_pe: T.Tensor([batch, h_q, dpe], dtype), Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), block_table: T.Tensor([batch, max_seqlen_pad // block_size], T.int32),
cache_seqlens: T.Tensor([batch], "int32"), cache_seqlens: T.Tensor([batch], T.int32),
glse: T.Tensor([batch, h_q, num_split], dtype), glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
Output: T.Tensor([batch, h_q, dv], dtype), Output: T.Tensor([batch, h_q, dv], dtype),
): ):
flash_mla_split_kv_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, glse, flash_mla_split_kv_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, glse, Output_partial)
Output_partial)
combine(glse, Output_partial, Output) combine(glse, Output_partial, Output)
@T.prim_func @T.prim_func
def main_no_split( def main_no_split(
Q: T.Tensor([batch, h_q, dv], dtype), Q: T.Tensor([batch, h_q, dv], dtype),
Q_pe: T.Tensor([batch, h_q, dpe], dtype), Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), block_table: T.Tensor([batch, max_seqlen_pad // block_size], T.int32),
cache_seqlens: T.Tensor([batch], "int32"), cache_seqlens: T.Tensor([batch], T.int32),
glse: T.Tensor([batch, h_q, num_split], dtype), glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
Output: T.Tensor([batch, h_q, dv], dtype), Output: T.Tensor([batch, h_q, dv], dtype),
): ):
flash_mla_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, Output) flash_mla_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, Output)
...@@ -280,8 +264,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): ...@@ -280,8 +264,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
s_q = query.shape[-2] s_q = query.shape[-2]
s_k = key.shape[-2] s_k = key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype, device=query.device) attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype, device=query.device)
temp_mask = torch.ones( temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q)
s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype) attn_bias.to(query.dtype)
attn_weight += attn_bias attn_weight += attn_bias
...@@ -291,8 +274,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): ...@@ -291,8 +274,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@torch.inference_mode() @torch.inference_mode()
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
h_kv, d, dv, causal, dtype):
# q: [b, s_q, h_q, d] # q: [b, s_q, h_q, d]
# block_table: [b, max_seqlen_pad // block_size] # block_table: [b, max_seqlen_pad // block_size]
# blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d] # blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d]
...@@ -321,13 +303,10 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, ...@@ -321,13 +303,10 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
return out_torch return out_torch
def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
h_q, h_kv, d, dv, causal, dtype):
assert d > dv, "mla with rope dim should be larger than no rope dim" assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
dv:].contiguous()
dpe = d - dv dpe = d - dv
num_kv_splits = 1 num_kv_splits = 1
...@@ -337,8 +316,7 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s ...@@ -337,8 +316,7 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device)
glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device)
kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size, softmax_scale)
num_kv_splits, block_size, softmax_scale)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
def flash_mla_tilelang(): def flash_mla_tilelang():
...@@ -356,8 +334,7 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s ...@@ -356,8 +334,7 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
out_flash = flash_mla_tilelang() out_flash = flash_mla_tilelang()
t = do_bench(flash_mla_tilelang) t = do_bench(flash_mla_tilelang)
out_ref = run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, out_ref = run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
torch.testing.assert_close(out_flash, out_ref, rtol=0.01, atol=0.01) torch.testing.assert_close(out_flash, out_ref, rtol=0.01, atol=0.01)
print("All close") print("All close")
return out_flash, t return out_flash, t
...@@ -365,12 +342,12 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s ...@@ -365,12 +342,12 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size') parser.add_argument("--batch", type=int, default=128, help="batch size")
parser.add_argument('--h_q', type=int, default=128, help='q heads number') parser.add_argument("--h_q", type=int, default=128, help="q heads number")
parser.add_argument('--h_kv', type=int, default=1, help='kv heads number') parser.add_argument("--h_kv", type=int, default=1, help="kv heads number")
parser.add_argument('--cache_seqlen', type=int, default=8192, help='kv cache context length') parser.add_argument("--cache_seqlen", type=int, default=8192, help="kv cache context length")
parser.add_argument('--d', type=int, default=576, help='query/key head dim, d = dv + dpe') parser.add_argument("--d", type=int, default=576, help="query/key head dim, d = dv + dpe")
parser.add_argument('--dv', type=int, default=512, help='value head dim') parser.add_argument("--dv", type=int, default=512, help="value head dim")
args = parser.parse_args() args = parser.parse_args()
b, h_q, h_kv, cache_seqlen, d, dv = args.batch, args.h_q, args.h_kv, args.cache_seqlen, args.d, args.dv b, h_q, h_kv, cache_seqlen, d, dv = args.batch, args.h_q, args.h_kv, args.cache_seqlen, args.d, args.dv
...@@ -379,9 +356,7 @@ if __name__ == "__main__": ...@@ -379,9 +356,7 @@ if __name__ == "__main__":
s_q = 1 # for decode, s_q = 1 s_q = 1 # for decode, s_q = 1
block_size = 64 block_size = 64
cache_seqlens = torch.tensor([cache_seqlen + 2 * i for i in range(b)], cache_seqlens = torch.tensor([cache_seqlen + 2 * i for i in range(b)], dtype=torch.int32, device=device)
dtype=torch.int32,
device=device)
dpe = d - dv dpe = d - dv
causal = True causal = True
...@@ -393,12 +368,11 @@ if __name__ == "__main__": ...@@ -393,12 +368,11 @@ if __name__ == "__main__":
total_flops = s_q * total_seqlens * h_q * d * 2 total_flops = s_q * total_seqlens * h_q * d * 2
q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device) q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device)
block_table = torch.arange( block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32, device=device).view(b, max_seqlen_pad // block_size)
b * max_seqlen_pad // block_size, dtype=torch.int32,
device=device).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d, dtype=dtype, device=device) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d, dtype=dtype, device=device)
out_flash, latency = run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, out_flash, latency = run_tilelang_mla(
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
...@@ -9,13 +9,15 @@ import argparse ...@@ -9,13 +9,15 @@ import argparse
@tilelang.jit( @tilelang.jit(
out_idx=[6], pass_configs={ out_idx=[6],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e)
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
kv_group_num = heads // kv_head_num kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num) VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1" assert kv_head_num == 1, "kv_head_num must be 1"
...@@ -23,13 +25,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -23,13 +25,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@T.prim_func @T.prim_func
def main_split_persistent( def main_split_persistent(
Q: T.Tensor([batch, heads, dim], dtype), Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype), Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype), glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype), Output: T.Tensor([batch, heads, dim], dtype),
): ):
with T.Kernel(sm_num, threads=256) as (block_id): with T.Kernel(sm_num, threads=256) as (block_id):
Q_shared = T.alloc_shared([block_H, dim], dtype) Q_shared = T.alloc_shared([block_H, dim], dtype)
...@@ -53,11 +55,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -53,11 +55,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
lse_max_local = T.alloc_local([1], accum_dtype) lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({ T.annotate_layout(
# O_shared: tilelang.layout.make_swizzled_layout(O_shared), {
S_shared: tilelang.layout.make_swizzled_layout(S_shared), # O_shared: tilelang.layout.make_swizzled_layout(O_shared),
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), S_shared: tilelang.layout.make_swizzled_layout(S_shared),
}) lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}
)
T.use_swizzle(10) T.use_swizzle(10)
total_tiles = batch * (heads // min(block_H, kv_group_num)) * num_split total_tiles = batch * (heads // min(block_H, kv_group_num)) * num_split
...@@ -70,8 +74,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -70,8 +74,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
cur_kv_head = hid // (kv_group_num // block_H) cur_kv_head = hid // (kv_group_num // block_H)
if bid < batch and hid * VALID_BLOCK_H < heads and sid < num_split: if bid < batch and hid * VALID_BLOCK_H < heads and sid < num_split:
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared) T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -83,24 +87,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -83,24 +87,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared) T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared)
T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared) T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared)
T.clear(acc_s) T.clear(acc_s)
T.gemm( T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
Q_shared, T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
KV_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.gemm(
Q_pe_shared,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
scores_max[i] * scale) for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1) T.reduce_sum(acc_s, scores_sum, dim=1)
...@@ -115,11 +110,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -115,11 +110,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, sid]) T.copy(logsum, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, sid])
# T.copy(acc_o, O_shared) # T.copy(acc_o, O_shared)
T.copy( T.copy(acc_o, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, sid, :])
acc_o, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H,
sid, :])
T.sync_grid() T.sync_grid()
waves = T.ceildiv(heads * batch, sm_num) waves = T.ceildiv(heads * batch, sm_num)
...@@ -165,42 +158,35 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): ...@@ -165,42 +158,35 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
dim = q.shape[-1] dim = q.shape[-1]
pe_dim = q_pe.shape[-1] pe_dim = q_pe.shape[-1]
num_head_groups = q.shape[1] // kv.shape[2] num_head_groups = q.shape[1] // kv.shape[2]
scale = (dim + pe_dim)**0.5 scale = (dim + pe_dim) ** 0.5
q = rearrange( q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q_pe = rearrange( q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
q_pe, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim]
k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim]
query = torch.concat([q, q_pe], dim=-1) query = torch.concat([q, q_pe], dim=-1)
key = torch.concat([kv, k_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1)
scores = einsum( scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv]
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax( attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, kv, out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim]
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out return out
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size') parser.add_argument("--batch", type=int, default=128, help="batch size")
parser.add_argument('--heads', type=int, default=128, help='q heads number') parser.add_argument("--heads", type=int, default=128, help="q heads number")
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number")
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length")
parser.add_argument('--dim', type=int, default=512, help='head dim') parser.add_argument("--dim", type=int, default=512, help="head dim")
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim")
args = parser.parse_args() args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment