Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
...@@ -3,8 +3,7 @@ from typing import Dict, List, Tuple ...@@ -3,8 +3,7 @@ from typing import Dict, List, Tuple
TokensText = Tuple[List[int], str] TokensText = Tuple[List[int], str]
def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[TokensText], def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[TokensText], name_0: str, name_1: str):
name_0: str, name_1: str):
""" """
Compare the two sequences generated by different models, Compare the two sequences generated by different models,
which should be equal. which should be equal.
...@@ -15,19 +14,14 @@ def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[Tok ...@@ -15,19 +14,14 @@ def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[Tok
output_ids_0, output_str_0 = outputs_0 output_ids_0, output_str_0 = outputs_0
output_ids_1, output_str_1 = outputs_1 output_ids_1, output_str_1 = outputs_1
assert output_str_0 == output_str_1, (f"Test{prompt_idx}:" assert output_str_0 == output_str_1, f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}"
f"\n{name_0}:\t{output_str_0!r}" assert output_ids_0 == output_ids_1, f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}"
f"\n{name_1}:\t{output_str_1!r}")
assert output_ids_0 == output_ids_1, (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]] TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]]
def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], outputs_1_lst: List[TokensTextLogprobs], name_0: str, name_1: str):
outputs_1_lst: List[TokensTextLogprobs], name_0: str, name_1: str):
""" """
Compare the logprobs of two sequences generated by different models, Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal. which should be similar but not necessarily equal.
...@@ -41,16 +35,11 @@ def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], ...@@ -41,16 +35,11 @@ def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs],
# Loop through generated tokens. # Loop through generated tokens.
for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
# If generated tokens don't match, then # If generated tokens don't match, then
if output_id_0 != output_id_1: if output_id_0 != output_id_1:
# Each predicted token must be in top N logprobs of the other # Each predicted token must be in top N logprobs of the other
assert output_id_0 in logprobs_1[idx], (f"Test{prompt_idx}:" assert output_id_0 in logprobs_1[idx], f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}"
f"\n{name_0}:\t{output_str_0!r}" assert output_id_1 in logprobs_0[idx], f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}"
f"\n{name_1}:\t{output_str_1!r}")
assert output_id_1 in logprobs_0[idx], (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
# Break out since sequences will now diverge. # Break out since sequences will now diverge.
break break
...@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): ...@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK # N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
False,
dtype=torch.bool,
device=x.device)
dense_mask.scatter_(-1, sparse_index, True) dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block: if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True dense_mask[:, :, -2:, :] = True
...@@ -56,7 +53,6 @@ def _fwd_kernel_inner( ...@@ -56,7 +53,6 @@ def _fwd_kernel_inner(
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
): ):
mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n)
# print # print
...@@ -73,8 +69,7 @@ def _fwd_kernel_inner( ...@@ -73,8 +69,7 @@ def _fwd_kernel_inner(
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
if LAST_K_BLOCK: if LAST_K_BLOCK:
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf"))
float('-inf'))
m_ij = tl.maximum(m_i, tl.max(qk, 1)) m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None] qk -= m_ij[:, None]
...@@ -154,7 +149,7 @@ def _fwd_kernel( ...@@ -154,7 +149,7 @@ def _fwd_kernel(
v_ptrs = V + off_v v_ptrs = V + off_v
mask_ptrs = block_mask_ptr + start_m * stride_bmm mask_ptrs = block_mask_ptr + start_m * stride_bmm
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
...@@ -192,24 +187,12 @@ def _fwd_kernel( ...@@ -192,24 +187,12 @@ def _fwd_kernel(
acc = acc * l_recip acc = acc * l_recip
acc = acc.to(Out.dtype.element_ty) acc = acc.to(Out.dtype.element_ty)
off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[ off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od
None, :] * stride_od
out_ptrs = Out + off_o out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX)
def _forward(ctx, def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None):
q,
k,
v,
block_sparse_mask,
sm_scale,
BLOCK_M=64,
BLOCK_N=64,
num_warps=None,
num_stages=1,
out=None):
assert q.shape[-1] == k.shape[-1] == v.shape[-1] assert q.shape[-1] == k.shape[-1] == v.shape[-1]
assert k.shape[2] == v.shape[2] assert k.shape[2] == v.shape[2]
o = out if out is not None else torch.empty_like(q).contiguous() o = out if out is not None else torch.empty_like(q).contiguous()
...@@ -254,7 +237,6 @@ def _forward(ctx, ...@@ -254,7 +237,6 @@ def _forward(ctx,
class _sparse_attention(torch.autograd.Function): class _sparse_attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, block_sparse_dense, sm_scale): def forward(ctx, q, k, v, block_sparse_dense, sm_scale):
# shape constraints # shape constraints
...@@ -278,9 +260,9 @@ def test_topk_sparse_attention(): ...@@ -278,9 +260,9 @@ def test_topk_sparse_attention():
torch.manual_seed(0) torch.manual_seed(0)
# Create inputs # Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
sm_scale = 1.0 / (D_HEAD**0.5) sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level) # Create sparse mask (downsampled to block level)
...@@ -288,9 +270,7 @@ def test_topk_sparse_attention(): ...@@ -288,9 +270,7 @@ def test_topk_sparse_attention():
downsample_len = math.ceil(SEQ_LEN / downsample_factor) downsample_len = math.ceil(SEQ_LEN / downsample_factor)
print("downsample_len", downsample_len) print("downsample_len", downsample_len)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16)
device='cuda',
dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100 x_ds[:, :, :, 0] = 100
print("x_ds.shape", x_ds.shape) print("x_ds.shape", x_ds.shape)
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
...@@ -302,22 +282,21 @@ def test_topk_sparse_attention(): ...@@ -302,22 +282,21 @@ def test_topk_sparse_attention():
# Compute reference # Compute reference
# Expand block mask to full attention matrix # Expand block mask to full attention matrix
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda"))
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal
# PyTorch reference implementation # PyTorch reference implementation
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float('-inf')) attn = attn.masked_fill(~full_mask, float("-inf"))
attn = F.softmax(attn, dim=-1) attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
# print("ref_output", ref_output) # print("ref_output", ref_output)
# print("triton_output", triton_output) # print("triton_output", triton_output)
# Verify accuracy # Verify accuracy
assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference"
"Triton output doesn't match reference"
print("Pass topk sparse attention test with qlen == klen") print("Pass topk sparse attention test with qlen == klen")
...@@ -329,9 +308,9 @@ def test_topk_sparse_attention_qlt_kl(): ...@@ -329,9 +308,9 @@ def test_topk_sparse_attention_qlt_kl():
torch.manual_seed(0) torch.manual_seed(0)
# Create inputs. # Create inputs.
q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
# softmax scale # softmax scale
sm_scale = 1.0 / (D_HEAD**0.5) sm_scale = 1.0 / (D_HEAD**0.5)
...@@ -339,8 +318,7 @@ def test_topk_sparse_attention_qlt_kl(): ...@@ -339,8 +318,7 @@ def test_topk_sparse_attention_qlt_kl():
print("downsample_factor", downsample_factor) print("downsample_factor", downsample_factor)
downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension
print("downsample_len", downsample_len) print("downsample_len", downsample_len)
x_ds = torch.randn( x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.bfloat16)
BATCH, N_HEADS, downsample_len, downsample_len, device='cuda', dtype=torch.bfloat16)
# Force the first column to be high so that the first block is always selected. # Force the first column to be high so that the first block is always selected.
x_ds[:, :, :, 0] = 100 x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
...@@ -351,26 +329,25 @@ def test_topk_sparse_attention_qlt_kl(): ...@@ -351,26 +329,25 @@ def test_topk_sparse_attention_qlt_kl():
past_len = K_LEN - Q_LEN past_len = K_LEN - Q_LEN
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool() full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool()
full_mask_full = full_mask_full[..., :K_LEN, :K_LEN] full_mask_full = full_mask_full[..., :K_LEN, :K_LEN]
effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN)
i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1)
j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN)
causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN) causal_mask = j_global <= i_global # shape: (Q_LEN, K_LEN)
final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN) final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN)
attn = attn.masked_fill(~final_mask, float('-inf')) attn = attn.masked_fill(~final_mask, float("-inf"))
attn = F.softmax(attn, dim=-1) attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
# Verify accuracy. # Verify accuracy.
assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference when qlen < klen"
"Triton output doesn't match reference when qlen < klen"
print("Pass topk sparse attention test with qlen < klen") print("Pass topk sparse attention test with qlen < klen")
......
...@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): ...@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK # N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
False,
dtype=torch.bool,
device=x.device)
dense_mask.scatter_(-1, sparse_index, True) dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block: if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True dense_mask[:, :, -2:, :] = True
...@@ -30,15 +27,17 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F ...@@ -30,15 +27,17 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
@tilelang.jit( @tilelang.jit(
out_idx=[4], pass_configs={ out_idx=[4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal): def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal):
block_M = 64 block_M = 64
block_N = 64 block_N = 64
num_stages = 1 num_stages = 1
threads = 128 threads = 128
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
block_mask_shape = [batch, heads, downsample_len, downsample_len] block_mask_shape = [batch, heads, downsample_len, downsample_len]
...@@ -47,7 +46,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -47,7 +46,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask_dtype = "bool" block_mask_dtype = "bool"
def kernel_func(block_M, block_N, num_stages, threads): def kernel_func(block_M, block_N, num_stages, threads):
@T.macro @T.macro
def MMA0( def MMA0(
K: T.Tensor(shape, dtype), K: T.Tensor(shape, dtype),
...@@ -59,11 +57,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -59,11 +57,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
-T.infinity(acc_s.dtype))
else: else:
T.clear(acc_s) T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
...@@ -78,18 +75,18 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -78,18 +75,18 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -113,22 +110,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -113,22 +110,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.prim_func @T.prim_func
def blocksparse_flashattn( def blocksparse_flashattn(
Q: T.Tensor(shape, dtype), Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype), K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype), V: T.Tensor(shape, dtype),
BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype),
Output: T.Tensor(shape, dtype), Output: T.Tensor(shape, dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype)
...@@ -143,7 +139,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -143,7 +139,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
block_mask = T.alloc_local([downsample_len], block_mask_dtype) block_mask = T.alloc_local([downsample_len], block_mask_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -152,20 +148,19 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -152,20 +148,19 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask[vj] = BlockSparseMask[bz, by, bx, vj] block_mask[vj] = BlockSparseMask[bz, by, bx, vj]
loop_range = ( loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv( T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) )
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
if block_mask[k] != 0: if block_mask[k] != 0:
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
scores_sum, logsum)
Rescale(acc_o, scores_scale) Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return blocksparse_flashattn return blocksparse_flashattn
...@@ -180,18 +175,16 @@ def test_topk_sparse_attention(): ...@@ -180,18 +175,16 @@ def test_topk_sparse_attention():
torch.manual_seed(0) torch.manual_seed(0)
# Create inputs # Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
sm_scale = 1.0 / (D_HEAD**0.5) sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level) # Create sparse mask (downsampled to block level)
downsample_factor = BLOCK downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor) downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16)
device='cuda',
dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100 x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
...@@ -202,15 +195,15 @@ def test_topk_sparse_attention(): ...@@ -202,15 +195,15 @@ def test_topk_sparse_attention():
# Compute reference # Compute reference
# Expand block mask to full attention matrix # Expand block mask to full attention matrix
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda"))
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal
# PyTorch reference implementation # PyTorch reference implementation
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float('-inf')) attn = attn.masked_fill(~full_mask, float("-inf"))
attn = F.softmax(attn, dim=-1) attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
print("ref_output", ref_output) print("ref_output", ref_output)
print("tilelang_output", tilelang_output) print("tilelang_output", tilelang_output)
......
...@@ -13,17 +13,20 @@ from heuristic import num_splits_heuristic ...@@ -13,17 +13,20 @@ from heuristic import num_splits_heuristic
def flashattn(batch, heads, heads_kv, dim, dim_v): def flashattn(batch, heads, heads_kv, dim, dim_v):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
kv_group_num = heads // heads_kv kv_group_num = heads // heads_kv
@tilelang.jit( @tilelang.jit(
out_idx=[-1], pass_configs={ out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
def kernel_func(block_N, block_H, page_block_size, num_split, num_stages, threads, num_pages, )
max_num_blocks_per_seq, max_selected_blocks): def kernel_func(
block_N, block_H, page_block_size, num_split, num_stages, threads, num_pages, max_num_blocks_per_seq, max_selected_blocks
):
shape_q = [batch, heads, dim] shape_q = [batch, heads, dim]
shape_k = [num_pages, page_block_size, heads_kv, dim] shape_k = [num_pages, page_block_size, heads_kv, dim]
shape_v = [num_pages, page_block_size, heads_kv, dim_v] shape_v = [num_pages, page_block_size, heads_kv, dim_v]
...@@ -37,17 +40,16 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -37,17 +40,16 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@T.macro @T.macro
def flash_attn_split( def flash_attn_split(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, "int32"), block_indices: T.Tensor(shape_indices, "int32"),
cache_seqlens: T.Tensor([batch], "int32"), cache_seqlens: T.Tensor([batch], "int32"),
block_table: T.Tensor(shape_block_table, "int32"), block_table: T.Tensor(shape_block_table, "int32"),
glse: T.Tensor([batch, heads, num_split], accum_dtype), glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype), Output_partial: T.Tensor(part_shape, accum_dtype),
): ):
with T.Kernel( with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype) Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim_v], dtype) V_shared = T.alloc_shared([block_N, dim_v], dtype)
...@@ -67,7 +69,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -67,7 +69,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
sid = bz sid = bz
cur_kv_head = hid // (kv_group_num // valid_block_H) cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -75,7 +77,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -75,7 +77,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
num_blocks = max_selected_blocks num_blocks = max_selected_blocks
blocks_per_split = T.floordiv(num_blocks, num_split) blocks_per_split = T.floordiv(num_blocks, num_split)
remaining_blocks = T.floormod(num_blocks, num_split) remaining_blocks = T.floormod(num_blocks, num_split)
loop_range = (blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0)) loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0)
start = blocks_per_split * sid + T.min(sid, remaining_blocks) start = blocks_per_split * sid + T.min(sid, remaining_blocks)
has_valid_block = False has_valid_block = False
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
...@@ -85,29 +87,20 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -85,29 +87,20 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
block_table_idx = T.floordiv(logical_block_idx, block_ratio) block_table_idx = T.floordiv(logical_block_idx, block_ratio)
block_tile_idx = T.floormod(logical_block_idx, block_ratio) block_tile_idx = T.floormod(logical_block_idx, block_ratio)
physical_block_idx = block_table[bid, block_table_idx] physical_block_idx = block_table[bid, block_table_idx]
T.copy( T.copy(K[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], K_shared)
K[physical_block_idx,
block_tile_idx * block_N:(block_tile_idx + 1) * block_N,
cur_kv_head, :], K_shared)
T.clear(acc_s) T.clear(acc_s)
T.gemm( T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else( acc_s[i, j] = T.if_then_else(
logical_block_idx * block_N + j >= cache_seqlens[bid], logical_block_idx * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j]
-T.infinity(accum_dtype), acc_s[i, j]) )
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1) T.reduce_sum(acc_s, scores_sum, dim=1)
...@@ -116,10 +109,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -116,10 +109,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
T.copy(acc_s, acc_s_cast) T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim_v): for i, j in T.Parallel(block_H, dim_v):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
T.copy( T.copy(V[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], V_shared)
V[physical_block_idx,
block_tile_idx * block_N:(block_tile_idx + 1) * block_N,
cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if has_valid_block: if has_valid_block:
for i, j in T.Parallel(block_H, dim_v): for i, j in T.Parallel(block_H, dim_v):
...@@ -138,9 +128,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -138,9 +128,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@T.macro @T.macro
def combine( def combine(
glse: T.Tensor([batch, heads, num_split], accum_dtype), glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype), Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype), Output: T.Tensor(shape_o, dtype),
): ):
with T.Kernel(heads, batch, threads=128) as (by, bz): with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim_v], accum_dtype) po_local = T.alloc_fragment([dim_v], accum_dtype)
...@@ -151,17 +141,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -151,17 +141,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
scale_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype)
max_split = T.alloc_local([1], "int32") max_split = T.alloc_local([1], "int32")
T.annotate_layout({ T.annotate_layout(
lse_logsum_local: {
T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}) }
)
T.clear(lse_logsum_local) T.clear(lse_logsum_local)
T.clear(o_accum_local) T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype) lse_max_local[0] = -T.infinity(accum_dtype)
for k in T.serial(num_split): for k in T.serial(num_split):
lse_local_split[0] = glse[bz, by, k] lse_local_split[0] = glse[bz, by, k]
if (lse_local_split[0] != 0): if lse_local_split[0] != 0:
max_split[0] = k max_split[0] = k
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
...@@ -183,18 +174,17 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -183,18 +174,17 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@T.prim_func @T.prim_func
def main( def main(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, "int32"), block_indices: T.Tensor(shape_indices, "int32"),
cache_seqlens: T.Tensor([batch], "int32"), cache_seqlens: T.Tensor([batch], "int32"),
block_table: T.Tensor(shape_block_table, "int32"), block_table: T.Tensor(shape_block_table, "int32"),
glse: T.Tensor([batch, heads, num_split], accum_dtype), glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype), Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype), Output: T.Tensor(shape_o, dtype),
): ):
flash_attn_split(Q, K, V, block_indices, cache_seqlens, block_table, glse, flash_attn_split(Q, K, V, block_indices, cache_seqlens, block_table, glse, Output_partial)
Output_partial)
combine(glse, Output_partial, Output) combine(glse, Output_partial, Output)
return main return main
...@@ -203,7 +193,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -203,7 +193,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
class SparseFlashAttn(torch.nn.Module): class SparseFlashAttn(torch.nn.Module):
def __init__(self, batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_pages): def __init__(self, batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_pages):
super(SparseFlashAttn, self).__init__() super(SparseFlashAttn, self).__init__()
self.batch = batch self.batch = batch
...@@ -249,18 +238,11 @@ class SparseFlashAttn(torch.nn.Module): ...@@ -249,18 +238,11 @@ class SparseFlashAttn(torch.nn.Module):
num_sm = self.num_sm num_sm = self.num_sm
num_split = num_splits_heuristic( num_split = num_splits_heuristic(
total_mblocks, total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128
num_sm, )
num_n_blocks,
num_m_blocks, glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda")
size_one_kv_head, output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda")
is_causal_or_local=True,
max_splits=128)
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
output_partial = torch.empty((batch, heads, num_split, dim_v),
dtype=torch.float32,
device='cuda')
output = self.kernel( output = self.kernel(
query, query,
...@@ -275,14 +257,13 @@ class SparseFlashAttn(torch.nn.Module): ...@@ -275,14 +257,13 @@ class SparseFlashAttn(torch.nn.Module):
return output return output
def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_seqlens, def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_seqlens, block_table, page_block_size, block_size):
block_table, page_block_size, block_size):
""" """
Paged version of sparse attention reference implementation. Paged version of sparse attention reference implementation.
Args: Args:
query: [batch, heads, dim] query: [batch, heads, dim]
key_cache: [num_pages, page_block_size, heads_kv, dim] key_cache: [num_pages, page_block_size, heads_kv, dim]
value_cache: [num_pages, page_block_size, heads_kv, dim] value_cache: [num_pages, page_block_size, heads_kv, dim]
block_indices: [batch, heads_kv, max_selected_blocks] - logical block indices block_indices: [batch, heads_kv, max_selected_blocks] - logical block indices
cache_seqlens: [batch] - actual sequence lengths cache_seqlens: [batch] - actual sequence lengths
...@@ -298,12 +279,8 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_ ...@@ -298,12 +279,8 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_
# Reconstruct the full key and value tensors from paged cache # Reconstruct the full key and value tensors from paged cache
max_cache_seqlen = max(cache_seqlens).item() max_cache_seqlen = max(cache_seqlens).item()
key_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim), key_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim), dtype=key_cache.dtype, device=key_cache.device)
dtype=key_cache.dtype, value_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim_v), dtype=value_cache.dtype, device=value_cache.device)
device=key_cache.device)
value_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim_v),
dtype=value_cache.dtype,
device=value_cache.device)
# Reconstruct full tensors from paged cache using block_table # Reconstruct full tensors from paged cache using block_table
for b in range(batch): for b in range(batch):
...@@ -319,20 +296,14 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_ ...@@ -319,20 +296,14 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_
actual_block_size = end_token - start_token actual_block_size = end_token - start_token
# Copy from paged cache to full tensors # Copy from paged cache to full tensors
key_full[b, :, start_token:end_token, :] = key_cache[ key_full[b, :, start_token:end_token, :] = key_cache[physical_block_idx, :actual_block_size, :, :].transpose(0, 1)
physical_block_idx, :actual_block_size, :, :].transpose(0, 1) value_full[b, :, start_token:end_token, :] = value_cache[physical_block_idx, :actual_block_size, :, :].transpose(0, 1)
value_full[b, :, start_token:end_token, :] = value_cache[
physical_block_idx, :actual_block_size, :, :].transpose(0, 1)
# Reshape query for grouped attention # Reshape query for grouped attention
query = rearrange( query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
query, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
# Compute attention scores # Compute attention scores
scores = einsum( scores = einsum(query, key_full, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv]
query, key_full,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv]
# Create sparse mask based on block_indices # Create sparse mask based on block_indices
sparse_mask = torch.zeros_like(scores) sparse_mask = torch.zeros_like(scores)
...@@ -348,24 +319,23 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_ ...@@ -348,24 +319,23 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_
sparse_mask[b, :, h, start_pos:end_pos] = 1 sparse_mask[b, :, h, start_pos:end_pos] = 1
# Apply sparse mask # Apply sparse mask
scores = scores.masked_fill(sparse_mask == 0, float('-inf')) scores = scores.masked_fill(sparse_mask == 0, float("-inf"))
# Apply causal mask based on actual sequence lengths # Apply causal mask based on actual sequence lengths
range_len = torch.arange(scores.shape[-1], device=scores.device).unsqueeze(0) range_len = torch.arange(scores.shape[-1], device=scores.device).unsqueeze(0)
cache_seqlens_expanded = cache_seqlens.unsqueeze(1) cache_seqlens_expanded = cache_seqlens.unsqueeze(1)
pad_mask = range_len >= cache_seqlens_expanded pad_mask = range_len >= cache_seqlens_expanded
pad_mask = pad_mask[:, None, None, :] pad_mask = pad_mask[:, None, None, :]
scores = scores.masked_fill(pad_mask, float('-inf')) scores = scores.masked_fill(pad_mask, float("-inf"))
# Compute attention weights # Compute attention weights
attention = F.softmax(scores / scale, dim=-1) attention = F.softmax(scores / scale, dim=-1)
# Apply attention to values # Apply attention to values
out = einsum(attention, value_full, out = einsum(attention, value_full, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim]
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim]
# Reshape output back to original format # Reshape output back to original format
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
return out return out
...@@ -373,17 +343,23 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_ ...@@ -373,17 +343,23 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_
def ref_program_fa(query, kcache, vcache, cache_seqlens, block_table): def ref_program_fa(query, kcache, vcache, cache_seqlens, block_table):
# latency reference # latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3 # from flash_attn_interface import flash_attn_with_kvcache # fa3
from flash_attn import flash_attn_with_kvcache #fa2 from flash_attn import flash_attn_with_kvcache # fa2
query = query.unsqueeze(1) query = query.unsqueeze(1)
output = flash_attn_with_kvcache( output = flash_attn_with_kvcache(query, kcache, vcache, cache_seqlens=cache_seqlens, block_table=block_table)
query, kcache, vcache, cache_seqlens=cache_seqlens, block_table=block_table)
output = output.squeeze(1) output = output.squeeze(1)
return output return output
def main(args): def main(args):
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = (
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v args.batch,
args.heads,
args.heads_kv,
args.max_cache_seqlen,
args.dim,
args.dim_v,
)
sparse_ratio = args.sparse_ratio sparse_ratio = args.sparse_ratio
block_N = args.block_N block_N = args.block_N
page_block_size = args.page_block_size page_block_size = args.page_block_size
...@@ -395,35 +371,30 @@ def main(args): ...@@ -395,35 +371,30 @@ def main(args):
dtype = torch.float16 dtype = torch.float16
# Generate random inputs # Generate random inputs
Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda")
cache_seqlens = torch.randint( cache_seqlens = torch.randint(max_cache_seqlen // 2, max_cache_seqlen + 1, (batch,), dtype=torch.int32, device="cuda")
max_cache_seqlen // 2, max_cache_seqlen + 1, (batch,), dtype=torch.int32, device='cuda')
print("cache_seqlens: ", cache_seqlens) print("cache_seqlens: ", cache_seqlens)
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda")
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda")
# Create paged KV cache # Create paged KV cache
K_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim), dtype=dtype, device='cuda') K_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim), dtype=dtype, device="cuda")
V_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim_v), V_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim_v), dtype=dtype, device="cuda")
dtype=dtype,
device='cuda')
# Create block table and block indices for dense case (all blocks selected) # Create block table and block indices for dense case (all blocks selected)
max_num_blocks_per_seq = int(math.ceil(max_cache_seqlen / page_block_size)) max_num_blocks_per_seq = int(math.ceil(max_cache_seqlen / page_block_size))
print("max_num_blocks_per_seq: ", max_num_blocks_per_seq) print("max_num_blocks_per_seq: ", max_num_blocks_per_seq)
block_table = torch.zeros((batch, max_num_blocks_per_seq), dtype=torch.int32, device='cuda') block_table = torch.zeros((batch, max_num_blocks_per_seq), dtype=torch.int32, device="cuda")
block_indices = torch.zeros((batch, heads_kv, max_selected_blocks), block_indices = torch.zeros((batch, heads_kv, max_selected_blocks), dtype=torch.int32, device="cuda")
dtype=torch.int32,
device='cuda')
# Fill block table and block indices and cache # Fill block table and block indices and cache
# Create a pool of available physical blocks # Create a pool of available physical blocks
total_blocks_needed = sum( total_blocks_needed = sum(int(math.ceil(cache_seqlens[seq_idx].item() / page_block_size)) for seq_idx in range(batch))
int(math.ceil(cache_seqlens[seq_idx].item() / page_block_size)) for seq_idx in range(batch))
available_blocks = list(range(total_blocks_needed)) available_blocks = list(range(total_blocks_needed))
import random import random
random.seed(42) # For reproducibility random.seed(42) # For reproducibility
random.shuffle(available_blocks) random.shuffle(available_blocks)
...@@ -458,10 +429,8 @@ def main(args): ...@@ -458,10 +429,8 @@ def main(args):
actual_block_size = end_token - start_token actual_block_size = end_token - start_token
# Copy K and V data to the paged cache # Copy K and V data to the paged cache
K_cache[physical_block_idx, :actual_block_size, :, :] = K[seq_idx, K_cache[physical_block_idx, :actual_block_size, :, :] = K[seq_idx, start_token:end_token, :, :]
start_token:end_token, :, :] V_cache[physical_block_idx, :actual_block_size, :, :] = V[seq_idx, start_token:end_token, :, :]
V_cache[physical_block_idx, :actual_block_size, :, :] = V[seq_idx,
start_token:end_token, :, :]
# Fill block_indices for sparse attention # Fill block_indices for sparse attention
# For dense case (verification), we select all blocks in reverse order # For dense case (verification), we select all blocks in reverse order
...@@ -496,10 +465,9 @@ def main(args): ...@@ -496,10 +465,9 @@ def main(args):
remaining_blocks = [b for b in all_blocks if b not in selected_blocks] remaining_blocks = [b for b in all_blocks if b not in selected_blocks]
if remaining_blocks: if remaining_blocks:
import random import random
random.seed(42) # For reproducibility random.seed(42) # For reproducibility
additional_blocks = random.sample( additional_blocks = random.sample(remaining_blocks, min(num_selected - recent_blocks, len(remaining_blocks)))
remaining_blocks,
min(num_selected - recent_blocks, len(remaining_blocks)))
selected_blocks.extend(additional_blocks) selected_blocks.extend(additional_blocks)
# Sort selected blocks in reverse order (most recent first) # Sort selected blocks in reverse order (most recent first)
...@@ -512,25 +480,20 @@ def main(args): ...@@ -512,25 +480,20 @@ def main(args):
block_indices[seq_idx, head_idx, i] = -1 block_indices[seq_idx, head_idx, i] = -1
# Initialize sparse attention module # Initialize sparse attention module
sparse_attn = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, sparse_attn = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_blocks)
num_blocks) output_sparse = sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table)
output_sparse = sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens,
block_table)
import flash_attn # noqa: F401 import flash_attn # noqa: F401
output_ref_torch = ref_program_torch_paged(Q, K_cache, V_cache, block_indices, cache_seqlens, output_ref_torch = ref_program_torch_paged(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table, page_block_size, block_N)
block_table, page_block_size, block_N)
output_ref_fa = ref_program_fa(Q, K_cache, V_cache, cache_seqlens, block_table) output_ref_fa = ref_program_fa(Q, K_cache, V_cache, cache_seqlens, block_table)
# Check correctness # Check correctness
if sparse_ratio == 0.0: if sparse_ratio == 0.0:
max_diff = torch.max(torch.abs(output_sparse - output_ref_fa)).item() max_diff = torch.max(torch.abs(output_sparse - output_ref_fa)).item()
mean_diff = torch.mean(torch.abs(output_sparse - output_ref_fa)).item() mean_diff = torch.mean(torch.abs(output_sparse - output_ref_fa)).item()
assert torch.allclose( assert torch.allclose(output_ref_fa, output_ref_torch, atol=1e-2), "Reference outputs do not match!"
output_ref_fa, output_ref_torch, atol=1e-2), "Reference outputs do not match!"
else: else:
max_diff = torch.max(torch.abs(output_sparse - output_ref_torch)).item() max_diff = torch.max(torch.abs(output_sparse - output_ref_torch)).item()
mean_diff = torch.mean(torch.abs(output_sparse - output_ref_torch)).item() mean_diff = torch.mean(torch.abs(output_sparse - output_ref_torch)).item()
...@@ -574,16 +537,15 @@ def main(args): ...@@ -574,16 +537,15 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size') parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument('--heads', type=int, default=32, help='heads') parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv")
parser.add_argument( parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length")
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument('--dim', type=int, default=128, help='dim') parser.add_argument("--dim_v", type=int, default=128, help="dim_v")
parser.add_argument('--dim_v', type=int, default=128, help='dim_v') parser.add_argument("--sparse_ratio", type=float, default=0.0, help="sparse ratio")
parser.add_argument('--sparse_ratio', type=float, default=0.0, help='sparse ratio') parser.add_argument("--block_N", type=int, default=64, help="block_N")
parser.add_argument('--block_N', type=int, default=64, help='block_N') parser.add_argument("--page_block_size", type=int, default=256, help="block size of pages")
parser.add_argument('--page_block_size', type=int, default=256, help='block size of pages') parser.add_argument("--num_pages", type=int, default=1024, help="total number of pages")
parser.add_argument('--num_pages', type=int, default=1024, help='total number of pages')
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
...@@ -10,17 +10,18 @@ from heuristic import num_splits_heuristic ...@@ -10,17 +10,18 @@ from heuristic import num_splits_heuristic
def flashattn(batch, heads, heads_kv, dim, dim_v): def flashattn(batch, heads, heads_kv, dim, dim_v):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
kv_group_num = heads // heads_kv kv_group_num = heads // heads_kv
@tilelang.jit( @tilelang.jit(
out_idx=[-1], pass_configs={ out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, )
max_selected_blocks): def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, max_selected_blocks):
shape_q = [batch, heads, dim] shape_q = [batch, heads, dim]
shape_k = [batch, max_cache_seqlen, heads_kv, dim] shape_k = [batch, max_cache_seqlen, heads_kv, dim]
shape_v = [batch, max_cache_seqlen, heads_kv, dim_v] shape_v = [batch, max_cache_seqlen, heads_kv, dim_v]
...@@ -31,17 +32,16 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -31,17 +32,16 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@T.macro @T.macro
def flash_attn_split( def flash_attn_split(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, "int32"), block_indices: T.Tensor(shape_indices, "int32"),
cache_seqlens: T.Tensor([batch], "int32"), cache_seqlens: T.Tensor([batch], "int32"),
# actual_num_blocks: T.Tensor([batch], "int32"), # actual_num_blocks: T.Tensor([batch], "int32"),
glse: T.Tensor([batch, heads, num_split], accum_dtype), glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype), Output_partial: T.Tensor(part_shape, accum_dtype),
): ):
with T.Kernel( with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype) Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim_v], dtype) V_shared = T.alloc_shared([block_N, dim_v], dtype)
...@@ -62,7 +62,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -62,7 +62,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
sid = bz sid = bz
cur_kv_head = hid // (kv_group_num // valid_block_H) cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -70,7 +70,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -70,7 +70,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
num_blocks = max_selected_blocks num_blocks = max_selected_blocks
blocks_per_split = T.floordiv(num_blocks, num_split) blocks_per_split = T.floordiv(num_blocks, num_split)
remaining_blocks = T.floormod(num_blocks, num_split) remaining_blocks = T.floormod(num_blocks, num_split)
loop_range = (blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0)) loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0)
start = blocks_per_split * sid + T.min(sid, remaining_blocks) start = blocks_per_split * sid + T.min(sid, remaining_blocks)
has_valid_block = False has_valid_block = False
...@@ -78,26 +78,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -78,26 +78,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
i_s = block_indices[bid, cur_kv_head, start + k] i_s = block_indices[bid, cur_kv_head, start + k]
if i_s >= 0: if i_s >= 0:
has_valid_block = True has_valid_block = True
T.copy(K[bid, i_s * block_N:(i_s + 1) * block_N, cur_kv_head, :], K_shared) T.copy(K[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], K_shared)
T.clear(acc_s) T.clear(acc_s)
T.gemm( T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
acc_s[i, acc_s[i, j] = T.if_then_else(i_s * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j])
j] = T.if_then_else(i_s * block_N + j >= cache_seqlens[bid],
-T.infinity(accum_dtype), acc_s[i, j])
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1) T.reduce_sum(acc_s, scores_sum, dim=1)
...@@ -106,7 +98,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -106,7 +98,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
T.copy(acc_s, acc_s_cast) T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim_v): for i, j in T.Parallel(block_H, dim_v):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
T.copy(V[bid, i_s * block_N:(i_s + 1) * block_N, cur_kv_head, :], V_shared) T.copy(V[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if has_valid_block: if has_valid_block:
for i, j in T.Parallel(block_H, dim_v): for i, j in T.Parallel(block_H, dim_v):
...@@ -125,9 +117,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -125,9 +117,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@T.macro @T.macro
def combine( def combine(
glse: T.Tensor([batch, heads, num_split], accum_dtype), glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype), Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype), Output: T.Tensor(shape_o, dtype),
): ):
with T.Kernel(heads, batch, threads=128) as (by, bz): with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim_v], accum_dtype) po_local = T.alloc_fragment([dim_v], accum_dtype)
...@@ -138,17 +130,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -138,17 +130,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
scale_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype)
max_split = T.alloc_local([1], "int32") max_split = T.alloc_local([1], "int32")
T.annotate_layout({ T.annotate_layout(
lse_logsum_local: {
T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}) }
)
T.clear(lse_logsum_local) T.clear(lse_logsum_local)
T.clear(o_accum_local) T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype) lse_max_local[0] = -T.infinity(accum_dtype)
for k in T.serial(num_split): for k in T.serial(num_split):
lse_local_split[0] = glse[bz, by, k] lse_local_split[0] = glse[bz, by, k]
if (lse_local_split[0] != 0): if lse_local_split[0] != 0:
max_split[0] = k max_split[0] = k
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
...@@ -170,15 +163,15 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -170,15 +163,15 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@T.prim_func @T.prim_func
def main( def main(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, "int32"), block_indices: T.Tensor(shape_indices, "int32"),
cache_seqlens: T.Tensor([batch], "int32"), cache_seqlens: T.Tensor([batch], "int32"),
# actual_num_blocks: T.Tensor([batch], "int32"), # actual_num_blocks: T.Tensor([batch], "int32"),
glse: T.Tensor([batch, heads, num_split], accum_dtype), glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype), Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype), Output: T.Tensor(shape_o, dtype),
): ):
# flash_attn_split(Q, K, V, block_indices, cache_seqlens, actual_num_blocks, glse, Output_partial) # flash_attn_split(Q, K, V, block_indices, cache_seqlens, actual_num_blocks, glse, Output_partial)
flash_attn_split(Q, K, V, block_indices, cache_seqlens, glse, Output_partial) flash_attn_split(Q, K, V, block_indices, cache_seqlens, glse, Output_partial)
...@@ -190,7 +183,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -190,7 +183,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
class SparseFlashAttn(torch.nn.Module): class SparseFlashAttn(torch.nn.Module):
def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size):
super(SparseFlashAttn, self).__init__() super(SparseFlashAttn, self).__init__()
self.batch = batch self.batch = batch
...@@ -209,7 +201,8 @@ class SparseFlashAttn(torch.nn.Module): ...@@ -209,7 +201,8 @@ class SparseFlashAttn(torch.nn.Module):
num_stages=2, num_stages=2,
threads=128, threads=128,
max_cache_seqlen=T.dynamic("max_cache_seqlen"), max_cache_seqlen=T.dynamic("max_cache_seqlen"),
max_selected_blocks=T.dynamic("max_selected_blocks")) max_selected_blocks=T.dynamic("max_selected_blocks"),
)
props = torch.cuda.get_device_properties(torch.device("cuda:0")) props = torch.cuda.get_device_properties(torch.device("cuda:0"))
self.num_sm = props.multi_processor_count self.num_sm = props.multi_processor_count
...@@ -232,25 +225,17 @@ class SparseFlashAttn(torch.nn.Module): ...@@ -232,25 +225,17 @@ class SparseFlashAttn(torch.nn.Module):
num_sm = self.num_sm num_sm = self.num_sm
num_split = num_splits_heuristic( num_split = num_splits_heuristic(
total_mblocks, total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128
num_sm, )
num_n_blocks,
num_m_blocks, glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda")
size_one_kv_head, output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda")
is_causal_or_local=True,
max_splits=128)
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
output_partial = torch.empty((batch, heads, num_split, dim_v),
dtype=torch.float32,
device='cuda')
output = self.kernel(query, key, value, block_indices, cache_seqlens, glse, output_partial) output = self.kernel(query, key, value, block_indices, cache_seqlens, glse, output_partial)
return output return output
def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seqlens, def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, block_size):
max_cache_seqlen, block_size):
""" """
Args: Args:
query: [batch, heads, dim] query: [batch, heads, dim]
...@@ -272,31 +257,24 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql ...@@ -272,31 +257,24 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql
block_H = 64 block_H = 64
actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32) actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32)
actual_num_blocks = actual_num_blocks[:, actual_num_blocks = actual_num_blocks[
0] #[batch], number of valid blocks, assume all groups in the same batch have the same number of blocks :, 0
] # [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks
# get num_split # get num_split
num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H
num_n_blocks = max_selected_blocks #(kv_seqlen + block_size - 1 ) // block_size num_n_blocks = max_selected_blocks # (kv_seqlen + block_size - 1 ) // block_size
# num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks # num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks
size_one_kv_head = max_selected_blocks * block_size * ( size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2
dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks total_mblocks = batch * heads_kv * num_m_blocks
num_sm = 132 num_sm = 132
num_split = num_splits_heuristic( num_split = num_splits_heuristic(
total_mblocks, total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128
num_sm, )
num_n_blocks,
num_m_blocks, glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda")
size_one_kv_head, Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda")
is_causal_or_local=True,
max_splits=128)
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
Output_partial = torch.empty((batch, heads, num_split, dim_v),
dtype=torch.float32,
device='cuda')
kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size, block_N=block_size,
block_H=block_H, block_H=block_H,
...@@ -304,29 +282,24 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql ...@@ -304,29 +282,24 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql
num_stages=2, num_stages=2,
threads=128, threads=128,
max_cache_seqlen=T.dynamic("max_cache_seqlen"), max_cache_seqlen=T.dynamic("max_cache_seqlen"),
max_selected_blocks=T.dynamic("max_selected_blocks")) max_selected_blocks=T.dynamic("max_selected_blocks"),
)
output = kernel(query, key, value, block_indices, cache_seqlens, glse, Output_partial) output = kernel(query, key, value, block_indices, cache_seqlens, glse, Output_partial)
return output return output
def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size):
block_size):
batch, heads, dim = query.shape batch, heads, dim = query.shape
heads_kv = key.shape[2] heads_kv = key.shape[2]
num_head_groups = query.shape[1] // key.shape[2] num_head_groups = query.shape[1] // key.shape[2]
scale = dim**0.5 scale = dim**0.5
key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
query = rearrange( query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
query, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
scores = einsum( scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv]
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask = torch.zeros_like(scores) sparse_mask = torch.zeros_like(scores)
# Assign mask values based on block_indices # Assign mask values based on block_indices
...@@ -335,28 +308,26 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache ...@@ -335,28 +308,26 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache
valid_indices = block_indices[b, h] # Extract indices for this batch and head valid_indices = block_indices[b, h] # Extract indices for this batch and head
for idx in valid_indices: for idx in valid_indices:
if idx >= 0: if idx >= 0:
sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1 sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1
scores = scores.masked_fill(sparse_mask == 0, float('-inf')) scores = scores.masked_fill(sparse_mask == 0, float("-inf"))
range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0) range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0)
cache_seqlens_expanded = cache_seqlens.unsqueeze(1) cache_seqlens_expanded = cache_seqlens.unsqueeze(1)
pad_mask = range_len >= cache_seqlens_expanded pad_mask = range_len >= cache_seqlens_expanded
pad_mask = pad_mask[:, None, None, :] pad_mask = pad_mask[:, None, None, :]
scores = scores.masked_fill(pad_mask, float('-inf')) scores = scores.masked_fill(pad_mask, float("-inf"))
attention = F.softmax( attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
out = einsum(attention, value, out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim]
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out return out
def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size):
block_size):
# latency reference # latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3 # from flash_attn_interface import flash_attn_with_kvcache # fa3
from flash_attn import flash_attn_with_kvcache #fa2 from flash_attn import flash_attn_with_kvcache # fa2
query = query.unsqueeze(1) query = query.unsqueeze(1)
output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens)
output = output.squeeze(1) output = output.squeeze(1)
...@@ -368,23 +339,13 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3): ...@@ -368,23 +339,13 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3):
print(name + " all_close={}".format(all_close)) print(name + " all_close={}".format(all_close))
if not all_close: if not all_close:
diff = (expect - actual).abs() diff = (expect - actual).abs()
print("all_close={}, max={}, min={}, mean={}".format(all_close, print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item()))
diff.max().item(),
diff.min().item(),
diff.mean().item()))
max_indices = torch.nonzero(diff == diff.max().item()) max_indices = torch.nonzero(diff == diff.max().item())
first_index = tuple(max_indices[0].tolist()) first_index = tuple(max_indices[0].tolist())
print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}") print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}")
def main(batch=8, def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32):
heads=32,
heads_kv=8,
max_cache_seqlen=8192,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32):
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
sparse_ratio = sparse_ratio sparse_ratio = sparse_ratio
block_size = block_size block_size = block_size
...@@ -392,10 +353,10 @@ def main(batch=8, ...@@ -392,10 +353,10 @@ def main(batch=8,
print("max_selected_blocks: ", max_selected_blocks) print("max_selected_blocks: ", max_selected_blocks)
dtype = torch.float16 dtype = torch.float16
Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda")
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda")
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda")
cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda') cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda")
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
# # Ensure at least one element equals cache_seqlen # # Ensure at least one element equals cache_seqlen
# random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index # random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index
...@@ -406,10 +367,7 @@ def main(batch=8, ...@@ -406,10 +367,7 @@ def main(batch=8,
max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int()
print("max_valid_num_blocks: ", max_valid_num_blocks) print("max_valid_num_blocks: ", max_valid_num_blocks)
# Initialize block_indices with -1 (for padding blocks) # Initialize block_indices with -1 (for padding blocks)
block_indices = torch.full((batch, heads_kv, max_selected_blocks), block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda")
-1,
dtype=torch.int32,
device='cuda')
# max_num_blocks = int((max_cache_seqlen + block_size - 1)/ block_size) # max_num_blocks = int((max_cache_seqlen + block_size - 1)/ block_size)
# block_indices = torch.full((batch, heads_kv, max_num_blocks), -1, dtype=torch.int32, device='cuda') # block_indices = torch.full((batch, heads_kv, max_num_blocks), -1, dtype=torch.int32, device='cuda')
...@@ -418,10 +376,9 @@ def main(batch=8, ...@@ -418,10 +376,9 @@ def main(batch=8,
max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch
if max_valid_block > 0: # Ensure there's at least one valid block if max_valid_block > 0: # Ensure there's at least one valid block
for h in range(heads_kv): for h in range(heads_kv):
valid_indices = torch.randperm( valid_indices = torch.randperm(max_valid_block, device="cuda", dtype=torch.int32)[:max_selected_blocks]
max_valid_block, device='cuda', dtype=torch.int32)[:max_selected_blocks]
# valid_indices = torch.randperm(max_valid_block, device='cuda', dtype=torch.int32)[:max_num_blocks] # valid_indices = torch.randperm(max_valid_block, device='cuda', dtype=torch.int32)[:max_num_blocks]
block_indices[b, h, :len(valid_indices)] = valid_indices block_indices[b, h, : len(valid_indices)] = valid_indices
# Sort indices within each batch-group for consistency # Sort indices within each batch-group for consistency
block_indices, _ = block_indices.sort(dim=-1, descending=True) block_indices, _ = block_indices.sort(dim=-1, descending=True)
...@@ -434,8 +391,7 @@ def main(batch=8, ...@@ -434,8 +391,7 @@ def main(batch=8,
print("max_num_blocks: ", max_num_blocks) print("max_num_blocks: ", max_num_blocks)
# parity reference # parity reference
ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size)
block_size)
sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size)
out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) out = sparse_kernel(Q, K, V, block_indices, cache_seqlens)
...@@ -445,13 +401,11 @@ def main(batch=8, ...@@ -445,13 +401,11 @@ def main(batch=8,
## latency reference ## latency reference
for _ in range(10): for _ in range(10):
ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size)
max_num_blocks, block_size)
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
for _ in range(100): for _ in range(100):
ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size)
max_num_blocks, block_size)
torch.cuda.synchronize() torch.cuda.synchronize()
print("dense time: ", (time.time() - start) / 100 * 1000) print("dense time: ", (time.time() - start) / 100 * 1000)
...@@ -469,15 +423,13 @@ def main(batch=8, ...@@ -469,15 +423,13 @@ def main(batch=8,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size') parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument('--heads', type=int, default=32, help='heads') parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv")
parser.add_argument( parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length")
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument('--dim', type=int, default=128, help='dim') parser.add_argument("--dim_v", type=int, default=128, help="dim_v")
parser.add_argument('--dim_v', type=int, default=128, help='dim_v') parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio")
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') parser.add_argument("--block_size", type=int, default=32, help="block_size")
parser.add_argument('--block_size', type=int, default=32, help='block_size')
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size)
args.sparse_ratio, args.block_size)
...@@ -12,15 +12,17 @@ from heuristic import num_splits_heuristic ...@@ -12,15 +12,17 @@ from heuristic import num_splits_heuristic
def flashattn(batch, heads, heads_kv, dim, dim_v): def flashattn(batch, heads, heads_kv, dim, dim_v):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
kv_group_num = heads // heads_kv kv_group_num = heads // heads_kv
@tilelang.jit( @tilelang.jit(
out_idx=[-1], pass_configs={ out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, num_blocks): def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, num_blocks):
shape_q = [batch, heads, dim] shape_q = [batch, heads, dim]
shape_k = [batch, max_cache_seqlen, heads_kv, dim] shape_k = [batch, max_cache_seqlen, heads_kv, dim]
...@@ -32,16 +34,15 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -32,16 +34,15 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@T.macro @T.macro
def flash_attn_split( def flash_attn_split(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
block_mask: T.Tensor(shape_mask, "bool"), block_mask: T.Tensor(shape_mask, "bool"),
cache_seqlens: T.Tensor([batch], "int32"), cache_seqlens: T.Tensor([batch], "int32"),
glse: T.Tensor([batch, heads, num_split], accum_dtype), glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype), Output_partial: T.Tensor(part_shape, accum_dtype),
): ):
with T.Kernel( with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype) Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim_v], dtype) V_shared = T.alloc_shared([block_N, dim_v], dtype)
...@@ -62,39 +63,31 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -62,39 +63,31 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
sid = bz sid = bz
cur_kv_head = hid // (kv_group_num // valid_block_H) cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
blocks_per_split = T.floordiv(num_blocks, num_split) blocks_per_split = T.floordiv(num_blocks, num_split)
remaining_blocks = T.floormod(num_blocks, num_split) remaining_blocks = T.floormod(num_blocks, num_split)
loop_range = (blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0)) loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0)
start = blocks_per_split * sid + T.min(sid, remaining_blocks) start = blocks_per_split * sid + T.min(sid, remaining_blocks)
has_valid_block = False has_valid_block = False
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
if block_mask[bid, hid, start + k]: if block_mask[bid, hid, start + k]:
has_valid_block = True has_valid_block = True
T.copy( T.copy(K[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], K_shared)
K[bid, (start + k) * block_N:(start + k + 1) * block_N, cur_kv_head, :],
K_shared)
T.clear(acc_s) T.clear(acc_s)
T.gemm( T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else((start + k) * block_N + j acc_s[i, j] = T.if_then_else(
>= cache_seqlens[bx], (start + k) * block_N + j >= cache_seqlens[bx], -T.infinity(accum_dtype), acc_s[i, j]
-T.infinity(accum_dtype), acc_s[i, j]) )
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1) T.reduce_sum(acc_s, scores_sum, dim=1)
...@@ -103,9 +96,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -103,9 +96,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
T.copy(acc_s, acc_s_cast) T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim_v): for i, j in T.Parallel(block_H, dim_v):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
T.copy( T.copy(V[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], V_shared)
V[bid, (start + k) * block_N:(start + k + 1) * block_N, cur_kv_head, :],
V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if has_valid_block: if has_valid_block:
for i, j in T.Parallel(block_H, dim_v): for i, j in T.Parallel(block_H, dim_v):
...@@ -123,9 +114,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -123,9 +114,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@T.macro @T.macro
def combine( def combine(
glse: T.Tensor([batch, heads, num_split], accum_dtype), glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype), Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype), Output: T.Tensor(shape_o, dtype),
): ):
with T.Kernel(heads, batch, threads=128) as (by, bz): with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim_v], accum_dtype) po_local = T.alloc_fragment([dim_v], accum_dtype)
...@@ -135,10 +126,11 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -135,10 +126,11 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
lse_max_local = T.alloc_local([1], accum_dtype) lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({ T.annotate_layout(
lse_logsum_local: {
T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}) }
)
T.clear(lse_logsum_local) T.clear(lse_logsum_local)
T.clear(o_accum_local) T.clear(o_accum_local)
...@@ -161,14 +153,14 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -161,14 +153,14 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@T.prim_func @T.prim_func
def main( def main(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
block_mask: T.Tensor(shape_mask, "bool"), block_mask: T.Tensor(shape_mask, "bool"),
cache_seqlens: T.Tensor([batch], "int32"), cache_seqlens: T.Tensor([batch], "int32"),
glse: T.Tensor([batch, heads, num_split], accum_dtype), glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype), Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype), Output: T.Tensor(shape_o, dtype),
): ):
flash_attn_split(Q, K, V, block_mask, cache_seqlens, glse, Output_partial) flash_attn_split(Q, K, V, block_mask, cache_seqlens, glse, Output_partial)
combine(glse, Output_partial, Output) combine(glse, Output_partial, Output)
...@@ -179,7 +171,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -179,7 +171,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
class SparseFlashAttn(torch.nn.Module): class SparseFlashAttn(torch.nn.Module):
def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size):
super(SparseFlashAttn, self).__init__() super(SparseFlashAttn, self).__init__()
self.batch = batch self.batch = batch
...@@ -198,7 +189,8 @@ class SparseFlashAttn(torch.nn.Module): ...@@ -198,7 +189,8 @@ class SparseFlashAttn(torch.nn.Module):
num_stages=2, num_stages=2,
threads=128, threads=128,
max_cache_seqlen=T.dynamic("max_cache_seqlen"), max_cache_seqlen=T.dynamic("max_cache_seqlen"),
num_blocks=T.dynamic("num_blocks")) num_blocks=T.dynamic("num_blocks"),
)
props = torch.cuda.get_device_properties(torch.device("cuda:0")) props = torch.cuda.get_device_properties(torch.device("cuda:0"))
self.num_sm = props.multi_processor_count self.num_sm = props.multi_processor_count
...@@ -217,24 +209,16 @@ class SparseFlashAttn(torch.nn.Module): ...@@ -217,24 +209,16 @@ class SparseFlashAttn(torch.nn.Module):
num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H
num_n_blocks = max_selected_blocks num_n_blocks = max_selected_blocks
size_one_kv_head = max_selected_blocks * block_size * ( size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2
dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks total_mblocks = batch * heads_kv * num_m_blocks
# num_sm = 132 # num_sm = 132
num_sm = self.num_sm num_sm = self.num_sm
num_split = num_splits_heuristic( num_split = num_splits_heuristic(
total_mblocks, total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128
num_sm, )
num_n_blocks,
num_m_blocks,
size_one_kv_head,
is_causal_or_local=True,
max_splits=128)
# print("num_split: ", num_split) # print("num_split: ", num_split)
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda")
Output_partial = torch.empty((batch, heads, num_split, dim_v), Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda")
dtype=torch.float32,
device='cuda')
output = self.kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial) output = self.kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial)
return output return output
...@@ -259,26 +243,21 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, ...@@ -259,26 +243,21 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
block_H = 64 block_H = 64
actual_num_blocks = torch.sum(block_mask, dim=-1).to(torch.int32) actual_num_blocks = torch.sum(block_mask, dim=-1).to(torch.int32)
actual_num_blocks = actual_num_blocks[:, actual_num_blocks = actual_num_blocks[
0] #[batch], number of valid blocks, assume all groups in the same batch have the same number of blocks :, 0
] # [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks
max_selected_blocks = actual_num_blocks.max().item() max_selected_blocks = actual_num_blocks.max().item()
# get num_split # get num_split
num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H
num_n_blocks = max_selected_blocks #(kv_seqlen + block_size - 1 ) // block_size num_n_blocks = max_selected_blocks # (kv_seqlen + block_size - 1 ) // block_size
# num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks # num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks
size_one_kv_head = max_selected_blocks * block_size * ( size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2
dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks total_mblocks = batch * heads_kv * num_m_blocks
num_sm = 132 num_sm = 132
num_split = num_splits_heuristic( num_split = num_splits_heuristic(
total_mblocks, total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128
num_sm, )
num_n_blocks,
num_m_blocks,
size_one_kv_head,
is_causal_or_local=True,
max_splits=128)
kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size, block_N=block_size,
...@@ -287,11 +266,10 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, ...@@ -287,11 +266,10 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
num_stages=2, num_stages=2,
threads=128, threads=128,
max_cache_seqlen=T.dynamic("max_cache_seqlen"), max_cache_seqlen=T.dynamic("max_cache_seqlen"),
num_blocks=T.dynamic("num_blocks")) num_blocks=T.dynamic("num_blocks"),
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') )
Output_partial = torch.empty((batch, heads, num_split, dim_v), glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda")
dtype=torch.float32, Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda")
device='cuda')
# print(kernel.get_kernel_source()) # print(kernel.get_kernel_source())
output = kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial) output = kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial)
...@@ -299,24 +277,18 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, ...@@ -299,24 +277,18 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
return output return output
def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size):
block_size):
batch, heads, dim = query.shape batch, heads, dim = query.shape
heads_kv = key.shape[2] heads_kv = key.shape[2]
num_head_groups = query.shape[1] // key.shape[2] num_head_groups = query.shape[1] // key.shape[2]
scale = dim**0.5 scale = dim**0.5
key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
query = rearrange( query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
query, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
scores = einsum( scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv]
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask = torch.zeros_like(scores) sparse_mask = torch.zeros_like(scores)
# Assign mask values # Assign mask values
...@@ -324,29 +296,27 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se ...@@ -324,29 +296,27 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se
for h in range(heads_kv): for h in range(heads_kv):
for idx in range(num_blocks): for idx in range(num_blocks):
if block_mask[b, h, idx]: if block_mask[b, h, idx]:
sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1 sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1
scores = scores.masked_fill(sparse_mask == 0, float('-inf')) scores = scores.masked_fill(sparse_mask == 0, float("-inf"))
range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0) range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0)
cache_seqlens_expanded = cache_seqlens.unsqueeze(1) cache_seqlens_expanded = cache_seqlens.unsqueeze(1)
pad_mask = range_len >= cache_seqlens_expanded pad_mask = range_len >= cache_seqlens_expanded
pad_mask = pad_mask[:, None, None, :] pad_mask = pad_mask[:, None, None, :]
scores = scores.masked_fill(pad_mask, float('-inf')) scores = scores.masked_fill(pad_mask, float("-inf"))
attention = F.softmax( attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
out = einsum(attention, value, out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim]
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out return out
def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size):
block_size):
# latency reference # latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3 # from flash_attn_interface import flash_attn_with_kvcache # fa3
from flash_attn import flash_attn_with_kvcache #fa2 from flash_attn import flash_attn_with_kvcache # fa2
query = query.unsqueeze(1) query = query.unsqueeze(1)
output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens)
output = output.squeeze(1) output = output.squeeze(1)
...@@ -360,23 +330,13 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3): ...@@ -360,23 +330,13 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3):
# print(expect[3, 28]) # print(expect[3, 28])
# print(actual[3, 28]) # print(actual[3, 28])
diff = (expect - actual).abs() diff = (expect - actual).abs()
print("all_close={}, max={}, min={}, mean={}".format(all_close, print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item()))
diff.max().item(),
diff.min().item(),
diff.mean().item()))
max_indices = torch.nonzero(diff == diff.max().item()) max_indices = torch.nonzero(diff == diff.max().item())
first_index = tuple(max_indices[0].tolist()) first_index = tuple(max_indices[0].tolist())
print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}") print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}")
def main(batch=8, def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32):
heads=32,
heads_kv=8,
max_cache_seqlen=8192,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32):
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
sparse_ratio = sparse_ratio sparse_ratio = sparse_ratio
block_size = block_size block_size = block_size
...@@ -384,14 +344,13 @@ def main(batch=8, ...@@ -384,14 +344,13 @@ def main(batch=8,
print("max_selected_blocks: ", max_selected_blocks) print("max_selected_blocks: ", max_selected_blocks)
dtype = torch.float16 dtype = torch.float16
Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda")
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda")
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda")
cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda') cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda")
# Ensure at least one element equals cache_seqlen # Ensure at least one element equals cache_seqlen
random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index
cache_seqlens[ cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence
random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
print("cache_seqlens: ", cache_seqlens) print("cache_seqlens: ", cache_seqlens)
...@@ -403,7 +362,7 @@ def main(batch=8, ...@@ -403,7 +362,7 @@ def main(batch=8,
max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int()
print("max_valid_num_blocks: ", max_valid_num_blocks) print("max_valid_num_blocks: ", max_valid_num_blocks)
# Initialize block_mask with false (for padding blocks) # Initialize block_mask with false (for padding blocks)
block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device='cuda') block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device="cuda")
# Assign valid indices while ensuring no duplicates within each batch-group # Assign valid indices while ensuring no duplicates within each batch-group
for b in range(batch): for b in range(batch):
...@@ -411,13 +370,12 @@ def main(batch=8, ...@@ -411,13 +370,12 @@ def main(batch=8,
valid_num_block = valid_num_blocks[b].item() # Valid blocks for this batch valid_num_block = valid_num_blocks[b].item() # Valid blocks for this batch
if valid_num_block > 0: # Ensure there's at least one valid block if valid_num_block > 0: # Ensure there's at least one valid block
for h in range(heads_kv): for h in range(heads_kv):
perm = torch.randperm(max_valid_block, device='cuda')[:valid_num_block] perm = torch.randperm(max_valid_block, device="cuda")[:valid_num_block]
block_mask[b, h, perm] = True block_mask[b, h, perm] = True
# print("block_mask: ", block_mask) # print("block_mask: ", block_mask)
# parity reference # parity reference
ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size)
block_size)
# out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size) # out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size)
model = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) model = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size)
out = model(Q, K, V, block_mask, cache_seqlens) out = model(Q, K, V, block_mask, cache_seqlens)
...@@ -427,13 +385,11 @@ def main(batch=8, ...@@ -427,13 +385,11 @@ def main(batch=8,
## latency reference ## latency reference
for _ in range(10): for _ in range(10):
ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size)
block_size)
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
for _ in range(100): for _ in range(100):
ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size)
block_size)
torch.cuda.synchronize() torch.cuda.synchronize()
print("dense time: ", (time.time() - start) / 100 * 1000) print("dense time: ", (time.time() - start) / 100 * 1000)
...@@ -452,15 +408,13 @@ def main(batch=8, ...@@ -452,15 +408,13 @@ def main(batch=8,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size') parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument('--heads', type=int, default=32, help='heads') parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv")
parser.add_argument( parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length")
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument('--dim', type=int, default=128, help='dim') parser.add_argument("--dim_v", type=int, default=128, help="dim_v")
parser.add_argument('--dim_v', type=int, default=128, help='dim_v') parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio")
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') parser.add_argument("--block_size", type=int, default=32, help="block_size")
parser.add_argument('--block_size', type=int, default=32, help='block_size')
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size)
args.sparse_ratio, args.block_size)
...@@ -12,12 +12,8 @@ from heuristic import num_splits_heuristic ...@@ -12,12 +12,8 @@ from heuristic import num_splits_heuristic
@triton.autotune( @triton.autotune(
configs=[ configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]],
triton.Config({}, num_warps=num_warps, num_stages=num_stages) key=["BLOCK_H", "BLOCK_N", "BLOCK_D"],
for num_warps in [1, 2, 4]\
for num_stages in [1, 2, 3, 4, 7]
],
key=['BLOCK_H', 'BLOCK_N', 'BLOCK_D'],
) )
@triton.jit @triton.jit
def _split_kernel( def _split_kernel(
...@@ -79,16 +75,11 @@ def _split_kernel( ...@@ -79,16 +75,11 @@ def _split_kernel(
loop_range = blocks_per_split loop_range = blocks_per_split
q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h
k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[ k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[None, :] * stride_k_s + offs_d[:, None] * stride_k_d
None, :] * stride_k_s + offs_d[:, None] * stride_k_d v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, None] * stride_v_s + offs_d[None, :] * stride_v_d
v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:,
None] * stride_v_s + offs_d[
None, :] * stride_v_d
mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h
q = tl.load( q = tl.load(q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, mask=offs_h[:, None] < gqa_group_size)
q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d,
mask=offs_h[:, None] < gqa_group_size)
start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks) start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks)
for i in range(loop_range): for i in range(loop_range):
block_idx = tl.load(mask_ptr + (start + i) * stride_mask_s) block_idx = tl.load(mask_ptr + (start + i) * stride_mask_s)
...@@ -119,23 +110,18 @@ def _split_kernel( ...@@ -119,23 +110,18 @@ def _split_kernel(
acc = acc * l_recip acc = acc * l_recip
acc = acc.to(o_partial_ptr.dtype.element_ty) acc = acc.to(o_partial_ptr.dtype.element_ty)
lse_partial_ptr += batch_idx * stride_lse_b + ( lse_partial_ptr += batch_idx * stride_lse_b + (head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split
head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split
tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size) tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size)
o_partial_ptr += batch_idx * stride_o_b + ( o_partial_ptr += (
head_idx_q + batch_idx * stride_o_b + (head_idx_q + offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d
offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d )
tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size) tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size)
@triton.autotune( @triton.autotune(
configs=[ configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]],
triton.Config({}, num_warps=num_warps, num_stages=num_stages) key=["BLOCK_D"],
for num_warps in [1, 2, 4]\
for num_stages in [1, 2, 3, 4, 7]
],
key=['BLOCK_D'],
) )
@triton.jit @triton.jit
def _merge_kernel( def _merge_kernel(
...@@ -163,18 +149,15 @@ def _merge_kernel( ...@@ -163,18 +149,15 @@ def _merge_kernel(
offs_d = tl.arange(0, BLOCK_D) offs_d = tl.arange(0, BLOCK_D)
lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h
lse = tl.load( lse = tl.load(lse_offsets + offs_splits * lse_partial_stride_split, mask=offs_splits < num_splits, other=float("-inf"))
lse_offsets + offs_splits * lse_partial_stride_split,
mask=offs_splits < num_splits,
other=float("-inf"))
lse_max = tl.max(lse) lse_max = tl.max(lse)
o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h
o_partial = tl.load( o_partial = tl.load(
o_offsets + offs_splits[:, None] * o_partial_stride_split + o_offsets + offs_splits[:, None] * o_partial_stride_split + offs_d[None, :] * o_partial_stride_d,
offs_d[None, :] * o_partial_stride_d, mask=offs_splits[:, None] < num_splits,
mask=offs_splits[:, None] < num_splits) )
sumexp_normalized_splitk = tl.exp(lse - lse_max) sumexp_normalized_splitk = tl.exp(lse - lse_max)
sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0) sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0)
numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0) numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0)
...@@ -209,19 +192,13 @@ def block_sparse_flash_decode_gqa_indice_triton( ...@@ -209,19 +192,13 @@ def block_sparse_flash_decode_gqa_indice_triton(
num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H
num_n_blocks = max_selected_blocks num_n_blocks = max_selected_blocks
size_one_kv_head = max_selected_blocks * block_size * ( size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2
dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks total_mblocks = batch * heads_kv * num_m_blocks
num_sm = 64 num_sm = 64
# num_sm = self.num_sm # num_sm = self.num_sm
num_splits = num_splits_heuristic( num_splits = num_splits_heuristic(
total_mblocks, total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128
num_sm, )
num_n_blocks,
num_m_blocks,
size_one_kv_head,
is_causal_or_local=True,
max_splits=128)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks) # print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
...@@ -295,24 +272,18 @@ def block_sparse_flash_decode_gqa_indice_triton( ...@@ -295,24 +272,18 @@ def block_sparse_flash_decode_gqa_indice_triton(
return output return output
def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size):
block_size):
batch, heads, dim = query.shape batch, heads, dim = query.shape
heads_kv = key.shape[2] heads_kv = key.shape[2]
dim_v = value.shape[-1] dim_v = value.shape[-1]
num_head_groups = query.shape[1] // key.shape[2] num_head_groups = query.shape[1] // key.shape[2]
scale = dim**0.5 scale = dim**0.5
key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
query = rearrange( query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
query, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
scores = einsum( scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv]
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask = torch.zeros_like(scores) sparse_mask = torch.zeros_like(scores)
# Assign mask values based on block_indices # Assign mask values based on block_indices
...@@ -321,42 +292,33 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache ...@@ -321,42 +292,33 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache
valid_indices = block_indices[b, h] # Extract indices for this batch and head valid_indices = block_indices[b, h] # Extract indices for this batch and head
for idx in valid_indices: for idx in valid_indices:
if idx >= 0: if idx >= 0:
sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1 sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1
scores = scores.masked_fill(sparse_mask == 0, float('-inf')) scores = scores.masked_fill(sparse_mask == 0, float("-inf"))
range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0) range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0)
cache_seqlens_expanded = cache_seqlens.unsqueeze(1) cache_seqlens_expanded = cache_seqlens.unsqueeze(1)
pad_mask = range_len >= cache_seqlens_expanded pad_mask = range_len >= cache_seqlens_expanded
pad_mask = pad_mask[:, None, None, :] pad_mask = pad_mask[:, None, None, :]
scores = scores.masked_fill(pad_mask, float('-inf')) scores = scores.masked_fill(pad_mask, float("-inf"))
attention = F.softmax( attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
out = einsum(attention, value, out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim]
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out return out
def ref_program_fa(query, key, value, cache_seqlens): def ref_program_fa(query, key, value, cache_seqlens):
# latency reference # latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3 # from flash_attn_interface import flash_attn_with_kvcache # fa3
from flash_attn import flash_attn_with_kvcache #fa2 from flash_attn import flash_attn_with_kvcache # fa2
query = query.unsqueeze(1) query = query.unsqueeze(1)
output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens)
output = output.squeeze(1) output = output.squeeze(1)
return output return output
def main(batch=64, def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32):
heads=32,
heads_kv=8,
max_cache_seqlen=8192,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32):
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
sparse_ratio = sparse_ratio sparse_ratio = sparse_ratio
block_size = block_size block_size = block_size
...@@ -369,34 +331,29 @@ def main(batch=64, ...@@ -369,34 +331,29 @@ def main(batch=64,
dtype = torch.float16 dtype = torch.float16
block_H = 64 block_H = 64
Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda")
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda")
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda")
cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda') cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda")
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
# Ensure at least one element equals cache_seqlen # Ensure at least one element equals cache_seqlen
random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index
cache_seqlens[ cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence
random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence
print("cache_seqlens: ", cache_seqlens) print("cache_seqlens: ", cache_seqlens)
max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int()
print("max_valid_num_blocks: ", max_valid_num_blocks) print("max_valid_num_blocks: ", max_valid_num_blocks)
# Initialize block_indices with -1 (for padding blocks) # Initialize block_indices with -1 (for padding blocks)
block_indices = torch.full((batch, heads_kv, max_selected_blocks), block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda")
-1,
dtype=torch.int32,
device='cuda')
# Assign valid indices while ensuring no duplicates within each batch-group # Assign valid indices while ensuring no duplicates within each batch-group
for b in range(batch): for b in range(batch):
max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch
if max_valid_block > 0: # Ensure there's at least one valid block if max_valid_block > 0: # Ensure there's at least one valid block
for h in range(heads_kv): for h in range(heads_kv):
valid_indices = torch.randperm( valid_indices = torch.randperm(max_valid_block, device="cuda", dtype=torch.int32)[:max_selected_blocks]
max_valid_block, device='cuda', dtype=torch.int32)[:max_selected_blocks] block_indices[b, h, : len(valid_indices)] = valid_indices
block_indices[b, h, :len(valid_indices)] = valid_indices
# Sort indices within each batch-group for consistency # Sort indices within each batch-group for consistency
block_indices, _ = block_indices.sort(dim=-1, descending=True) block_indices, _ = block_indices.sort(dim=-1, descending=True)
...@@ -408,8 +365,7 @@ def main(batch=64, ...@@ -408,8 +365,7 @@ def main(batch=64,
max_num_blocks = torch.max(max_valid_num_blocks).item() max_num_blocks = torch.max(max_valid_num_blocks).item()
print("max_num_blocks: ", max_num_blocks) print("max_num_blocks: ", max_num_blocks)
ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size)
block_size)
triton_out = block_sparse_flash_decode_gqa_indice_triton( triton_out = block_sparse_flash_decode_gqa_indice_triton(
Q, Q,
...@@ -423,8 +379,7 @@ def main(batch=64, ...@@ -423,8 +379,7 @@ def main(batch=64,
) )
print("max difference: ", torch.max(torch.abs(ref - triton_out))) print("max difference: ", torch.max(torch.abs(ref - triton_out)))
assert torch.allclose( assert torch.allclose(ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation"
ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation"
print("Passed the ref test!") print("Passed the ref test!")
# Measure performance # Measure performance
...@@ -466,15 +421,13 @@ def main(batch=64, ...@@ -466,15 +421,13 @@ def main(batch=64,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=64, help='batch size') parser.add_argument("--batch", type=int, default=64, help="batch size")
parser.add_argument('--heads', type=int, default=32, help='heads') parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv")
parser.add_argument( parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length")
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument('--dim', type=int, default=128, help='dim') parser.add_argument("--dim_v", type=int, default=128, help="dim_v")
parser.add_argument('--dim_v', type=int, default=128, help='dim_v') parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio")
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') parser.add_argument("--block_size", type=int, default=32, help="block_size")
parser.add_argument('--block_size', type=int, default=32, help='block_size')
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size)
args.sparse_ratio, args.block_size)
...@@ -11,12 +11,8 @@ from heuristic import num_splits_heuristic ...@@ -11,12 +11,8 @@ from heuristic import num_splits_heuristic
@triton.autotune( @triton.autotune(
configs=[ configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]],
triton.Config({}, num_warps=num_warps, num_stages=num_stages) key=["BLOCK_H", "BLOCK_N", "BLOCK_D"],
for num_warps in [1, 2, 4]\
for num_stages in [1, 2, 3, 4, 7]
],
key=['BLOCK_H', 'BLOCK_N', 'BLOCK_D'],
) )
@triton.jit @triton.jit
def _split_kernel( def _split_kernel(
...@@ -77,16 +73,11 @@ def _split_kernel( ...@@ -77,16 +73,11 @@ def _split_kernel(
loop_range = blocks_per_split loop_range = blocks_per_split
q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h
k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[ k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[None, :] * stride_k_s + offs_d[:, None] * stride_k_d
None, :] * stride_k_s + offs_d[:, None] * stride_k_d v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, None] * stride_v_s + offs_d[None, :] * stride_v_d
v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:,
None] * stride_v_s + offs_d[
None, :] * stride_v_d
mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h
q = tl.load( q = tl.load(q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, mask=offs_h[:, None] < gqa_group_size)
q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d,
mask=offs_h[:, None] < gqa_group_size)
start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks) start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks)
for block_idx in range(loop_range): for block_idx in range(loop_range):
start_n = (start + block_idx) * BLOCK_N start_n = (start + block_idx) * BLOCK_N
...@@ -117,23 +108,18 @@ def _split_kernel( ...@@ -117,23 +108,18 @@ def _split_kernel(
acc = acc * l_recip acc = acc * l_recip
acc = acc.to(o_partial_ptr.dtype.element_ty) acc = acc.to(o_partial_ptr.dtype.element_ty)
lse_partial_ptr += batch_idx * stride_lse_b + ( lse_partial_ptr += batch_idx * stride_lse_b + (head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split
head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split
tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size) tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size)
o_partial_ptr += batch_idx * stride_o_b + ( o_partial_ptr += (
head_idx_q + batch_idx * stride_o_b + (head_idx_q + offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d
offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d )
tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size) tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size)
@triton.autotune( @triton.autotune(
configs=[ configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]],
triton.Config({}, num_warps=num_warps, num_stages=num_stages) key=["BLOCK_D"],
for num_warps in [1, 2, 4]\
for num_stages in [1, 2, 3, 4, 7]
],
key=['BLOCK_D'],
) )
@triton.jit @triton.jit
def _merge_kernel( def _merge_kernel(
...@@ -161,18 +147,15 @@ def _merge_kernel( ...@@ -161,18 +147,15 @@ def _merge_kernel(
offs_d = tl.arange(0, BLOCK_D) offs_d = tl.arange(0, BLOCK_D)
lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h
lse = tl.load( lse = tl.load(lse_offsets + offs_splits * lse_partial_stride_split, mask=offs_splits < num_splits, other=float("-inf"))
lse_offsets + offs_splits * lse_partial_stride_split,
mask=offs_splits < num_splits,
other=float("-inf"))
lse_max = tl.max(lse) lse_max = tl.max(lse)
o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h
o_partial = tl.load( o_partial = tl.load(
o_offsets + offs_splits[:, None] * o_partial_stride_split + o_offsets + offs_splits[:, None] * o_partial_stride_split + offs_d[None, :] * o_partial_stride_d,
offs_d[None, :] * o_partial_stride_d, mask=offs_splits[:, None] < num_splits,
mask=offs_splits[:, None] < num_splits) )
sumexp_normalized_splitk = tl.exp(lse - lse_max) sumexp_normalized_splitk = tl.exp(lse - lse_max)
sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0) sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0)
numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0) numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0)
...@@ -207,19 +190,13 @@ def block_sparse_flash_decode_gqa_mask_triton( ...@@ -207,19 +190,13 @@ def block_sparse_flash_decode_gqa_mask_triton(
num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H
num_n_blocks = max_selected_blocks num_n_blocks = max_selected_blocks
size_one_kv_head = max_selected_blocks * block_size * ( size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2
dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks total_mblocks = batch * heads_kv * num_m_blocks
num_sm = 64 num_sm = 64
# num_sm = self.num_sm # num_sm = self.num_sm
num_splits = num_splits_heuristic( num_splits = num_splits_heuristic(
total_mblocks, total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128
num_sm, )
num_n_blocks,
num_m_blocks,
size_one_kv_head,
is_causal_or_local=True,
max_splits=128)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks) # print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
...@@ -292,24 +269,18 @@ def block_sparse_flash_decode_gqa_mask_triton( ...@@ -292,24 +269,18 @@ def block_sparse_flash_decode_gqa_mask_triton(
return output return output
def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size):
block_size):
batch, heads, dim = query.shape batch, heads, dim = query.shape
heads_kv = key.shape[2] heads_kv = key.shape[2]
num_head_groups = query.shape[1] // key.shape[2] num_head_groups = query.shape[1] // key.shape[2]
scale = dim**0.5 scale = dim**0.5
key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
query = rearrange( query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
query, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
scores = einsum( scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv]
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask = torch.zeros_like(scores) sparse_mask = torch.zeros_like(scores)
# Assign mask values # Assign mask values
...@@ -317,43 +288,34 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se ...@@ -317,43 +288,34 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se
for h in range(heads_kv): for h in range(heads_kv):
for idx in range(num_blocks): for idx in range(num_blocks):
if block_mask[b, h, idx]: if block_mask[b, h, idx]:
sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1 sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1
scores = scores.masked_fill(sparse_mask == 0, float('-inf')) scores = scores.masked_fill(sparse_mask == 0, float("-inf"))
range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0) range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0)
cache_seqlens_expanded = cache_seqlens.unsqueeze(1) cache_seqlens_expanded = cache_seqlens.unsqueeze(1)
pad_mask = range_len >= cache_seqlens_expanded pad_mask = range_len >= cache_seqlens_expanded
pad_mask = pad_mask[:, None, None, :] pad_mask = pad_mask[:, None, None, :]
scores = scores.masked_fill(pad_mask, float('-inf')) scores = scores.masked_fill(pad_mask, float("-inf"))
attention = F.softmax( attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
out = einsum(attention, value, out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim]
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out return out
def ref_program_fa(query, key, value, cache_seqlens): def ref_program_fa(query, key, value, cache_seqlens):
# latency reference # latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3 # from flash_attn_interface import flash_attn_with_kvcache # fa3
from flash_attn import flash_attn_with_kvcache #fa2 from flash_attn import flash_attn_with_kvcache # fa2
query = query.unsqueeze(1) query = query.unsqueeze(1)
output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens)
output = output.squeeze(1) output = output.squeeze(1)
return output return output
def main(batch=64, def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32):
heads=32,
heads_kv=8,
max_cache_seqlen=8192,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32):
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
block_size = block_size block_size = block_size
sparse_ratio = sparse_ratio sparse_ratio = sparse_ratio
...@@ -363,14 +325,13 @@ def main(batch=64, ...@@ -363,14 +325,13 @@ def main(batch=64,
dtype = torch.float16 dtype = torch.float16
Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda")
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda")
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda")
cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda') cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda")
# Ensure at least one element equals cache_seqlen # Ensure at least one element equals cache_seqlen
random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index
cache_seqlens[ cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence
random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence
num_blocks = (max_cache_seqlen + block_size - 1) // block_size num_blocks = (max_cache_seqlen + block_size - 1) // block_size
...@@ -379,7 +340,7 @@ def main(batch=64, ...@@ -379,7 +340,7 @@ def main(batch=64,
max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int()
print("max_valid_num_blocks: ", max_valid_num_blocks) print("max_valid_num_blocks: ", max_valid_num_blocks)
# Initialize block_mask with false (for padding blocks) # Initialize block_mask with false (for padding blocks)
block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device='cuda') block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device="cuda")
# Assign valid indices while ensuring no duplicates within each batch-group # Assign valid indices while ensuring no duplicates within each batch-group
for b in range(batch): for b in range(batch):
...@@ -387,11 +348,10 @@ def main(batch=64, ...@@ -387,11 +348,10 @@ def main(batch=64,
valid_num_block = valid_num_blocks[b].item() # Valid blocks for this batch valid_num_block = valid_num_blocks[b].item() # Valid blocks for this batch
if valid_num_block > 0: # Ensure there's at least one valid block if valid_num_block > 0: # Ensure there's at least one valid block
for h in range(heads_kv): for h in range(heads_kv):
perm = torch.randperm(max_valid_block, device='cuda')[:valid_num_block] perm = torch.randperm(max_valid_block, device="cuda")[:valid_num_block]
block_mask[b, h, perm] = True block_mask[b, h, perm] = True
ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size)
block_size)
triton_out = block_sparse_flash_decode_gqa_mask_triton( triton_out = block_sparse_flash_decode_gqa_mask_triton(
Q, Q,
...@@ -404,8 +364,7 @@ def main(batch=64, ...@@ -404,8 +364,7 @@ def main(batch=64,
) )
# print("max difference: ", torch.max(torch.abs(ref - triton_out))) # print("max difference: ", torch.max(torch.abs(ref - triton_out)))
assert torch.allclose( assert torch.allclose(ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation"
ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation"
print("Passed the ref test!") print("Passed the ref test!")
# Measure performance # Measure performance
...@@ -448,15 +407,13 @@ def main(batch=64, ...@@ -448,15 +407,13 @@ def main(batch=64,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=64, help='batch size') parser.add_argument("--batch", type=int, default=64, help="batch size")
parser.add_argument('--heads', type=int, default=32, help='heads') parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv")
parser.add_argument( parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length")
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument('--dim', type=int, default=128, help='dim') parser.add_argument("--dim_v", type=int, default=128, help="dim_v")
parser.add_argument('--dim_v', type=int, default=128, help='dim_v') parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio")
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') parser.add_argument("--block_size", type=int, default=32, help="block_size")
parser.add_argument('--block_size', type=int, default=32, help='block_size')
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size)
args.sparse_ratio, args.block_size)
import math import math
def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, size_one_kv_head, def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local, max_splits):
is_causal_or_local, max_splits):
""" """
Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency. Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency.
......
...@@ -25,26 +25,14 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask(): ...@@ -25,26 +25,14 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask():
def test_example_triton_sparse_gqa_decode_varlen_indice(): def test_example_triton_sparse_gqa_decode_varlen_indice():
example_triton_sparse_gqa_decode_varlen_indice.main( example_triton_sparse_gqa_decode_varlen_indice.main(
batch=8, batch=8, heads=8, heads_kv=4, max_cache_seqlen=2048, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32
heads=8, )
heads_kv=4,
max_cache_seqlen=2048,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32)
def test_example_triton_sparse_gqa_decode_varlen_mask(): def test_example_triton_sparse_gqa_decode_varlen_mask():
example_triton_sparse_gqa_decode_varlen_mask.main( example_triton_sparse_gqa_decode_varlen_mask.main(
batch=16, batch=16, heads=16, heads_kv=8, max_cache_seqlen=1024, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32
heads=16, )
heads_kv=8,
max_cache_seqlen=1024,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -19,8 +19,7 @@ parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M") ...@@ -19,8 +19,7 @@ parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N") parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K") parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K")
parser.add_argument("--sparsity", type=float, default=0.5, help="Sparsity ratio (0-1)") parser.add_argument("--sparsity", type=float, default=0.5, help="Sparsity ratio (0-1)")
parser.add_argument( parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune")
"--use_autotune", action="store_true", default=False, help="Whether to use autotune")
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
M, N, K = args.m, args.n, args.k M, N, K = args.m, args.n, args.k
...@@ -41,17 +40,19 @@ def get_configs(): ...@@ -41,17 +40,19 @@ def get_configs():
thread_num = [128, 256] thread_num = [128, 256]
enable_rasterization = [True, False] enable_rasterization = [True, False]
_configs = list( _configs = list(itertools.product(block_M, block_N, block_K, num_stages, thread_num, enable_rasterization))
itertools.product(block_M, block_N, block_K, num_stages, thread_num, enable_rasterization))
return [{ return [
"block_M": c[0], {
"block_N": c[1], "block_M": c[0],
"block_K": c[2], "block_N": c[1],
"num_stages": c[3], "block_K": c[2],
"thread_num": c[4], "num_stages": c[3],
"enable_rasteration": c[5], "thread_num": c[4],
} for c in _configs] "enable_rasteration": c[5],
}
for c in _configs
]
def ref_program(A, B, BlockMask, block_M, block_N, block_K): def ref_program(A, B, BlockMask, block_M, block_N, block_K):
...@@ -61,12 +62,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K): ...@@ -61,12 +62,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K):
accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device) accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device)
for k in range(K // block_K): for k in range(K // block_K):
if BlockMask[i, j, k]: if BlockMask[i, j, k]:
accu += ( accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[
A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to( k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N
torch.float32) @ B[k * block_K:(k + 1) * block_K, ].to(torch.float32)
j * block_N:(j + 1) * block_N].to(torch.float32)) ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16)
ref_c[i * block_M:(i + 1) * block_M,
j * block_N:(j + 1) * block_N] = accu.to(torch.float16)
return ref_c return ref_c
...@@ -89,28 +88,21 @@ def supply_program(params: List[KernelParam]): ...@@ -89,28 +88,21 @@ def supply_program(params: List[KernelParam]):
return input_tensors return input_tensors
@tilelang.autotune(configs=get_configs(),) @tilelang.autotune(
configs=get_configs(),
)
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def blocksparse_matmul(M, def blocksparse_matmul(
N, M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype="float16", accum_dtype="float"
K, ):
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float"):
block_mask_shape = (M // block_M, N // block_N, K // block_K) block_mask_shape = (M // block_M, N // block_N, K // block_K)
@T.prim_func @T.prim_func
def block_sparse_matmul( def block_sparse_matmul(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"), BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
...@@ -134,7 +126,6 @@ def blocksparse_matmul(M, ...@@ -134,7 +126,6 @@ def blocksparse_matmul(M,
def main(): def main():
# Initialize input matrices A and B on the GPU with half precision # Initialize input matrices A and B on the GPU with half precision
a = torch.randn(M, K).cuda().half() a = torch.randn(M, K).cuda().half()
b = torch.randn(K, N).cuda().half() b = torch.randn(K, N).cuda().half()
...@@ -147,8 +138,7 @@ def main(): ...@@ -147,8 +138,7 @@ def main():
best_config = kernel.config best_config = kernel.config
best_latency = kernel.latency best_latency = kernel.latency
block_M, block_N, block_K = best_config["block_M"], best_config["block_N"], best_config[ block_M, block_N, block_K = best_config["block_M"], best_config["block_N"], best_config["block_K"]
"block_K"]
print(f"Best Config: {best_config}") print(f"Best Config: {best_config}")
print(f"Sparsity Ratio: {sparsity}") print(f"Sparsity Ratio: {sparsity}")
...@@ -163,7 +153,8 @@ def main(): ...@@ -163,7 +153,8 @@ def main():
block_K=DEFAULT_BLOCK_K, block_K=DEFAULT_BLOCK_K,
num_stages=DEFAULT_NUM_STAGES, num_stages=DEFAULT_NUM_STAGES,
thread_num=DEFAULT_THREAD_NUM, thread_num=DEFAULT_THREAD_NUM,
enable_rasteration=DEFAULT_ENABLE_RASTERIZATION) enable_rasteration=DEFAULT_ENABLE_RASTERIZATION,
)
block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K
print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})") print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})")
# Create block mask with desired sparsity # Create block mask with desired sparsity
......
...@@ -16,11 +16,13 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m): ...@@ -16,11 +16,13 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
fp8_max = 448.0 fp8_max = 448.0
@T.prim_func @T.prim_func
def group_per_split_token_cast(X: T.Tensor((M, N), dtype), batch_sizes: T.Tensor( def group_per_split_token_cast(
(BG,), "int32"), X_fp8: T.Tensor((BG, M_max, N), "float8_e4m3"), X_amax: T.Tensor( X: T.Tensor((M, N), dtype),
(BG, M_max, T.ceildiv(N, group_size)), accum_dtype)): batch_sizes: T.Tensor((BG,), "int32"),
with T.Kernel( X_fp8: T.Tensor((BG, M_max, N), "float8_e4m3"),
T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz): X_amax: T.Tensor((BG, M_max, T.ceildiv(N, group_size)), accum_dtype),
):
with T.Kernel(T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz):
row = bx row = bx
row_g_id = by row_g_id = by
bg = bz bg = bz
...@@ -31,36 +33,32 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m): ...@@ -31,36 +33,32 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3") y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3")
row_offset = T.alloc_fragment((1,), "int32") row_offset = T.alloc_fragment((1,), "int32")
T.annotate_layout({ T.annotate_layout(
y_local: {
T.Fragment( y_local: T.Fragment(y_local.shape, forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32),
y_local.shape, }
forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32), )
})
row_offset[0] = 0 row_offset[0] = 0
for i in T.serial(bg): for i in T.serial(bg):
row_offset[0] += batch_sizes[i] row_offset[0] += batch_sizes[i]
T.copy( T.copy(
X[row_offset[0] + row * blk_m:row_offset[0] + (row + 1) * blk_m, X[row_offset[0] + row * blk_m : row_offset[0] + (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size],
row_g_id * group_size:(row_g_id + 1) * group_size], y_local) y_local,
)
T.reduce_absmax(y_local, y_amax_local, dim=1) T.reduce_absmax(y_local, y_amax_local, dim=1)
for i in T.Parallel(blk_m): for i in T.Parallel(blk_m):
y_amax_local[i] = T.max(y_amax_local[i], 1e-4) y_amax_local[i] = T.max(y_amax_local[i], 1e-4)
y_s_local[i] = T.if_then_else(row * blk_m + i < batch_sizes[bg], y_s_local[i] = T.if_then_else(row * blk_m + i < batch_sizes[bg], y_amax_local[i] / fp8_max, 0)
y_amax_local[i] / fp8_max, 0)
for i, j in T.Parallel(blk_m, group_size): for i, j in T.Parallel(blk_m, group_size):
y_q_local[i, j] = T.clamp(y_local[i, j] / y_s_local[i], fp8_min, fp8_max) y_q_local[i, j] = T.clamp(y_local[i, j] / y_s_local[i], fp8_min, fp8_max)
T.copy(y_q_local, y_q_local_fp8) T.copy(y_q_local, y_q_local_fp8)
for i, j in T.Parallel(blk_m, group_size): for i, j in T.Parallel(blk_m, group_size):
y_q_local_fp8[i, j] = T.if_then_else(row * blk_m + i < batch_sizes[bg], y_q_local_fp8[i, j] = T.if_then_else(row * blk_m + i < batch_sizes[bg], y_q_local[i, j], 0)
y_q_local[i, j], 0)
for i in T.Parallel(blk_m): for i in T.Parallel(blk_m):
X_amax[bg, row * blk_m + i, row_g_id] = y_s_local[i] X_amax[bg, row * blk_m + i, row_g_id] = y_s_local[i]
T.copy( T.copy(y_q_local_fp8, X_fp8[bg, row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size])
y_q_local_fp8, X_fp8[bg, row * blk_m:(row + 1) * blk_m,
row_g_id * group_size:(row_g_id + 1) * group_size])
return group_per_split_token_cast return group_per_split_token_cast
...@@ -127,8 +125,7 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: ...@@ -127,8 +125,7 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
return x.squeeze(0) if remove_dim else x return x.squeeze(0) if remove_dim else x
# Normal layout requires transposing # Normal layout requires transposing
aligned_x = torch.transpose( aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
aligned_x[:, :m, :] = x aligned_x[:, :m, :] = x
aligned_x = aligned_x[:, :m, :] aligned_x = aligned_x[:, :m, :]
return aligned_x.squeeze(0) if remove_dim else aligned_x return aligned_x.squeeze(0) if remove_dim else aligned_x
...@@ -146,15 +143,17 @@ def ref_per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens ...@@ -146,15 +143,17 @@ def ref_per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
x_fp8 = x_fp8.view(m, -1)[:, :n].contiguous() x_fp8 = x_fp8.view(m, -1)[:, :n].contiguous()
return x_fp8, (x_amax / 448.0).view(m, -1) return x_fp8, (x_amax / 448.0).view(m, -1)
def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \
Tuple[torch.Tensor, torch.Tensor]: def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# assert x.shape[0] == batch_sizes.sum() # assert x.shape[0] == batch_sizes.sum()
M_max = ceil_div(batch_sizes.max(), 128) * 128 M_max = ceil_div(batch_sizes.max(), 128) * 128
split_x = torch.split(x, batch_sizes.tolist(), dim=0) split_x = torch.split(x, batch_sizes.tolist(), dim=0)
padded_x = [torch.nn.functional.pad(t, (0, 0, 0, M_max - t.shape[0])) for t in split_x] padded_x = [torch.nn.functional.pad(t, (0, 0, 0, M_max - t.shape[0])) for t in split_x]
num_groups, m, n = batch_sizes.shape[0], M_max, x.shape[1] num_groups, m, n = batch_sizes.shape[0], M_max, x.shape[1]
x_fp8 = (torch.empty((num_groups, m, n), device='cuda', dtype=torch.float8_e4m3fn), x_fp8 = (
torch.empty((num_groups, m, n // 128), device='cuda', dtype=torch.float)) torch.empty((num_groups, m, n), device="cuda", dtype=torch.float8_e4m3fn),
torch.empty((num_groups, m, n // 128), device="cuda", dtype=torch.float),
)
for i in range(num_groups): for i in range(num_groups):
x_fp8[0][i], x_fp8[1][i] = ref_per_token_cast_to_fp8(padded_x[i]) x_fp8[0][i], x_fp8[1][i] = ref_per_token_cast_to_fp8(padded_x[i])
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
......
...@@ -13,8 +13,9 @@ def per_token_cast_to_fp8(M, N, blk_m): ...@@ -13,8 +13,9 @@ def per_token_cast_to_fp8(M, N, blk_m):
fp8_max = 448.0 fp8_max = 448.0
@T.prim_func @T.prim_func
def per_token_cast(X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "float8_e4m3"), def per_token_cast(
X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype)): X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "float8_e4m3"), X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype)
):
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (bx, by): with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (bx, by):
row = bx row = bx
row_g_id = by row_g_id = by
...@@ -24,16 +25,13 @@ def per_token_cast_to_fp8(M, N, blk_m): ...@@ -24,16 +25,13 @@ def per_token_cast_to_fp8(M, N, blk_m):
y_q_local = T.alloc_fragment((blk_m, group_size), dtype) y_q_local = T.alloc_fragment((blk_m, group_size), dtype)
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3") y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3")
T.annotate_layout({ T.annotate_layout(
y_local: {
T.Fragment( y_local: T.Fragment(y_local.shape, forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32),
y_local.shape, }
forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32), )
})
T.copy( T.copy(X[row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size], y_local)
X[row * blk_m:(row + 1) * blk_m, row_g_id * group_size:(row_g_id + 1) * group_size],
y_local)
T.reduce_absmax(y_local, y_amax_local, dim=1) T.reduce_absmax(y_local, y_amax_local, dim=1)
for i in T.Parallel(blk_m): for i in T.Parallel(blk_m):
y_amax_local[i] = T.max(y_amax_local[i], 1e-4) y_amax_local[i] = T.max(y_amax_local[i], 1e-4)
...@@ -43,9 +41,7 @@ def per_token_cast_to_fp8(M, N, blk_m): ...@@ -43,9 +41,7 @@ def per_token_cast_to_fp8(M, N, blk_m):
T.copy(y_q_local, y_q_local_fp8) T.copy(y_q_local, y_q_local_fp8)
for i in T.Parallel(blk_m): for i in T.Parallel(blk_m):
X_amax[row * blk_m + i, row_g_id] = y_s_local[i] X_amax[row * blk_m + i, row_g_id] = y_s_local[i]
T.copy( T.copy(y_q_local_fp8, X_fp8[row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size])
y_q_local_fp8, X_fp8[row * blk_m:(row + 1) * blk_m,
row_g_id * group_size:(row_g_id + 1) * group_size])
return per_token_cast return per_token_cast
...@@ -105,8 +101,7 @@ def main(M=8192, N=8192, blk_m=8): ...@@ -105,8 +101,7 @@ def main(M=8192, N=8192, blk_m=8):
from example_triton_cast_to_fp8 import per_token_group_quant_fp8 from example_triton_cast_to_fp8 import per_token_group_quant_fp8
def run_triton(): def run_triton():
x_fp8_triton_, x_amax_triton_ = per_token_group_quant_fp8( x_fp8_triton_, x_amax_triton_ = per_token_group_quant_fp8(x, 128, 1e-4, dtype=torch.float8_e4m3fn, column_major_scales=False)
x, 128, 1e-4, dtype=torch.float8_e4m3fn, column_major_scales=False)
return x_fp8_triton_, x_amax_triton_ return x_fp8_triton_, x_amax_triton_
x_fp8_triton, x_amax_triton = run_triton() x_fp8_triton, x_amax_triton = run_triton()
......
...@@ -128,9 +128,7 @@ def per_token_group_quant_fp8( ...@@ -128,9 +128,7 @@ def per_token_group_quant_fp8(
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization. scaling factor for quantization.
""" """
assert (x.shape[-1] % assert x.shape[-1] % group_size == 0, f"the last dimension of `x` {x.shape[-1]} must be divisible by `group_size` {group_size}"
group_size == 0), (f"the last dimension of `x` {x.shape[-1]} must be divisible "
f"by `group_size` {group_size}")
assert x.stride(-1) == 1, "`x` groups must be contiguous" assert x.stride(-1) == 1, "`x` groups must be contiguous"
finfo = torch.finfo(dtype) finfo = torch.finfo(dtype)
......
...@@ -4,8 +4,7 @@ import example_per_token_cast_to_fp8 ...@@ -4,8 +4,7 @@ import example_per_token_cast_to_fp8
def test_example_group_per_split_token_cast_to_fp8(): def test_example_group_per_split_token_cast_to_fp8():
example_group_per_split_token_cast_to_fp8.main( example_group_per_split_token_cast_to_fp8.main(M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896])
M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896])
def test_example_per_token_cast_to_fp8(): def test_example_per_token_cast_to_fp8():
......
...@@ -4,12 +4,11 @@ import tilelang.language as T ...@@ -4,12 +4,11 @@ import tilelang.language as T
# @tilelang.jit(compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"]) # @tilelang.jit(compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
...@@ -36,8 +35,7 @@ block_K = 32 ...@@ -36,8 +35,7 @@ block_K = 32
func = matmul(M, N, K, block_M, block_N, block_K) func = matmul(M, N, K, block_M, block_N, block_K)
jit_kernel = tilelang.compile( jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags="-O3 --use_fast_math --expt-relaxed-constexpr")
func, out_idx=[2], target="cuda", compile_flags="-O3 --use_fast_math --expt-relaxed-constexpr")
# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"]) # or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"])
# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3 --use_fast_math --expt-relaxed-constexpr"]) # or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3 --use_fast_math --expt-relaxed-constexpr"])
......
...@@ -33,12 +33,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): ...@@ -33,12 +33,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
"warnings", "warnings",
"error", "error",
} }
if (sum( if sum(len(terminalreporter.stats.get(k, [])) for k in known_types.difference({"skipped", "deselected"})) == 0:
len(terminalreporter.stats.get(k, []))
for k in known_types.difference({"skipped", "deselected"})) == 0):
terminalreporter.write_sep( terminalreporter.write_sep(
"!", "!",
(f"Error: No tests were collected. " (f"Error: No tests were collected. {dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"),
f"{dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"),
) )
pytest.exit("No tests were collected.", returncode=5) pytest.exit("No tests were collected.", returncode=5)
...@@ -14,7 +14,6 @@ def check_hopper(): ...@@ -14,7 +14,6 @@ def check_hopper():
def ref_program(stride, padding, dilation): def ref_program(stride, padding, dilation):
def main(A, B): def main(A, B):
A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W
B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W
...@@ -26,22 +25,7 @@ def ref_program(stride, padding, dilation): ...@@ -26,22 +25,7 @@ def ref_program(stride, padding, dilation):
@tilelang.jit(out_idx=[2]) @tilelang.jit(out_idx=[2])
def convolution(N, def convolution(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"):
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages,
threads,
dtype="float16",
accum_dtype="float"):
KH, KW = K, K KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
...@@ -51,13 +35,11 @@ def convolution(N, ...@@ -51,13 +35,11 @@ def convolution(N,
@T.prim_func @T.prim_func
def main( def main(
data: T.Tensor((N, H, W, C), dtype), data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype), kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype), out: T.Tensor((N, OH, OW, F), dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by):
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=threads) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype) data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype) kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype) out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
...@@ -66,11 +48,13 @@ def convolution(N, ...@@ -66,11 +48,13 @@ def convolution(N,
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
T.annotate_layout({ T.annotate_layout(
out_shared: tilelang.layout.make_swizzled_layout(out_shared), {
data_shared: tilelang.layout.make_swizzled_layout(data_shared), out_shared: tilelang.layout.make_swizzled_layout(out_shared),
kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), data_shared: tilelang.layout.make_swizzled_layout(data_shared),
}) kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
}
)
T.clear(out_local) T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
...@@ -82,10 +66,8 @@ def convolution(N, ...@@ -82,10 +66,8 @@ def convolution(N,
m = by * block_M + i m = by * block_M + i
access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
access_w = m % OW * S + k // C % KW * D - P access_w = m % OW * S + k // C % KW * D - P
in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W)
(access_w < W)) data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
data_shared[i, j] = T.if_then_else(
in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local) T.gemm(data_shared, kernel_shared, out_local)
...@@ -97,15 +79,15 @@ def convolution(N, ...@@ -97,15 +79,15 @@ def convolution(N,
def main(argv=None): def main(argv=None):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--n', type=int, default=128, help='n') parser.add_argument("--n", type=int, default=128, help="n")
parser.add_argument('--c', type=int, default=128, help='c') parser.add_argument("--c", type=int, default=128, help="c")
parser.add_argument('--h', type=int, default=64, help='h') parser.add_argument("--h", type=int, default=64, help="h")
parser.add_argument('--w', type=int, default=64, help='w') parser.add_argument("--w", type=int, default=64, help="w")
parser.add_argument('--f', type=int, default=128, help='f') parser.add_argument("--f", type=int, default=128, help="f")
parser.add_argument('--k', type=int, default=3, help='k') parser.add_argument("--k", type=int, default=3, help="k")
parser.add_argument('--s', type=int, default=1, help='s') parser.add_argument("--s", type=int, default=1, help="s")
parser.add_argument('--d', type=int, default=1, help='d') parser.add_argument("--d", type=int, default=1, help="d")
parser.add_argument('--p', type=int, default=1, help='p') parser.add_argument("--p", type=int, default=1, help="p")
args = parser.parse_args(argv) args = parser.parse_args(argv)
N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p
......
...@@ -14,7 +14,6 @@ def check_hopper(): ...@@ -14,7 +14,6 @@ def check_hopper():
def ref_program(stride, padding, dilation): def ref_program(stride, padding, dilation):
def main(A, B): def main(A, B):
A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W
B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W
...@@ -40,7 +39,8 @@ def get_configs(): ...@@ -40,7 +39,8 @@ def get_configs():
num_stages, num_stages,
thread_num, thread_num,
enable_rasterization, enable_rasterization,
)) )
)
configs = [ configs = [
{ {
...@@ -50,7 +50,8 @@ def get_configs(): ...@@ -50,7 +50,8 @@ def get_configs():
"num_stages": c[3], "num_stages": c[3],
"thread_num": c[4], "thread_num": c[4],
"enable_rasteration": c[5], # keep param name for backward-compat "enable_rasteration": c[5], # keep param name for backward-compat
} for c in _configs }
for c in _configs
] ]
return configs return configs
...@@ -64,53 +65,18 @@ def get_heuristic_config() -> dict: ...@@ -64,53 +65,18 @@ def get_heuristic_config() -> dict:
sm_version = sm_major * 10 + sm_minor sm_version = sm_major * 10 + sm_minor
print(f"CUDA device capability: {sm_version}") print(f"CUDA device capability: {sm_version}")
if sm_version in {80}: if sm_version in {80}:
return { return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 2, "thread_num": 128, "enable_rasteration": True}
"block_M": 128,
"block_N": 256,
"block_K": 32,
"num_stages": 2,
"thread_num": 128,
"enable_rasteration": True
}
elif sm_version in {90}: elif sm_version in {90}:
return { return {"block_M": 128, "block_N": 256, "block_K": 64, "num_stages": 3, "thread_num": 256, "enable_rasteration": True}
"block_M": 128,
"block_N": 256,
"block_K": 64,
"num_stages": 3,
"thread_num": 256,
"enable_rasteration": True
}
else: else:
return { return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 0, "thread_num": 128, "enable_rasteration": True}
"block_M": 128,
"block_N": 256,
"block_K": 32,
"num_stages": 0,
"thread_num": 128,
"enable_rasteration": True
}
@tilelang.autotune(configs=get_configs()) @tilelang.autotune(configs=get_configs())
@tilelang.jit(out_idx=[2]) @tilelang.jit(out_idx=[2])
def convolution(N, def convolution(
C, N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype="float16", accum_dtype="float"
H, ):
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float"):
KH, KW = K, K KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
...@@ -120,13 +86,11 @@ def convolution(N, ...@@ -120,13 +86,11 @@ def convolution(N,
@T.prim_func @T.prim_func
def main( def main(
data: T.Tensor((N, H, W, C), dtype), data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype), kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype), out: T.Tensor((N, OH, OW, F), dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=thread_num) as (bx, by):
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=thread_num) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype) data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype) kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype) out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
...@@ -136,9 +100,11 @@ def convolution(N, ...@@ -136,9 +100,11 @@ def convolution(N,
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
if is_hopper: if is_hopper:
T.annotate_layout({ T.annotate_layout(
out_shared: tilelang.layout.make_swizzled_layout(out_shared), {
}) out_shared: tilelang.layout.make_swizzled_layout(out_shared),
}
)
T.clear(out_local) T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
...@@ -150,10 +116,8 @@ def convolution(N, ...@@ -150,10 +116,8 @@ def convolution(N,
m = by * block_M + i m = by * block_M + i
access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
access_w = m % OW * S + k // C % KW * D - P access_w = m % OW * S + k // C % KW * D - P
in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W)
(access_w < W)) data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
data_shared[i, j] = T.if_then_else(
in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local) T.gemm(data_shared, kernel_shared, out_local)
...@@ -166,17 +130,19 @@ def convolution(N, ...@@ -166,17 +130,19 @@ def convolution(N,
return main return main
def main(n: int = 128, def main(
c: int = 128, n: int = 128,
h: int = 64, c: int = 128,
w: int = 64, h: int = 64,
f: int = 128, w: int = 64,
k: int = 3, f: int = 128,
s: int = 1, k: int = 3,
d: int = 1, s: int = 1,
p: int = 1, d: int = 1,
use_autotune: bool = False, p: int = 1,
with_roller: bool = True): use_autotune: bool = False,
with_roller: bool = True,
):
N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p
ref_prog = ref_program(S, P, D) ref_prog = ref_program(S, P, D)
...@@ -196,25 +162,16 @@ def main(n: int = 128, ...@@ -196,25 +162,16 @@ def main(n: int = 128,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument('--n', type=int, default=128, help='n') parser.add_argument("--n", type=int, default=128, help="n")
parser.add_argument('--c', type=int, default=128, help='c') parser.add_argument("--c", type=int, default=128, help="c")
parser.add_argument('--h', type=int, default=64, help='h') parser.add_argument("--h", type=int, default=64, help="h")
parser.add_argument('--w', type=int, default=64, help='w') parser.add_argument("--w", type=int, default=64, help="w")
parser.add_argument('--f', type=int, default=128, help='f') parser.add_argument("--f", type=int, default=128, help="f")
parser.add_argument('--k', type=int, default=3, help='k') parser.add_argument("--k", type=int, default=3, help="k")
parser.add_argument('--s', type=int, default=1, help='s') parser.add_argument("--s", type=int, default=1, help="s")
parser.add_argument('--d', type=int, default=1, help='d') parser.add_argument("--d", type=int, default=1, help="d")
parser.add_argument('--p', type=int, default=1, help='p') parser.add_argument("--p", type=int, default=1, help="p")
parser.add_argument( parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs")
"--use_autotune", parser.add_argument("--with_roller", action="store_true", default=True, help="Whether to enable BitBLAS roller for search space")
action="store_true",
default=False,
help="Whether to use autotune for matmul configs")
parser.add_argument(
"--with_roller",
action="store_true",
default=True,
help="Whether to enable BitBLAS roller for search space")
args = parser.parse_args() args = parser.parse_args()
main(args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p, args.use_autotune, main(args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p, args.use_autotune, args.with_roller)
args.with_roller)
...@@ -41,14 +41,13 @@ def tl_gemm( ...@@ -41,14 +41,13 @@ def tl_gemm(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
scales_a: T.Tensor(Scales_A_shape, "float32"), scales_a: T.Tensor(Scales_A_shape, "float32"),
scales_b: T.Tensor(Scales_B_shape, "float32"), scales_b: T.Tensor(Scales_B_shape, "float32"),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_shared = T.alloc_shared(C_shared_shape, out_dtype) C_shared = T.alloc_shared(C_shared_shape, out_dtype)
...@@ -93,21 +92,18 @@ def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: ...@@ -93,21 +92,18 @@ def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
m, n = x.shape m, n = x.shape
x_view = x.view(m, -1, 128) x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view( return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
m, n), (x_amax / 448.0).view(m, -1)
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2 assert x.dim() == 2
m, n = x.shape m, n = x.shape
x_padded = torch.zeros( x_padded = torch.zeros(ceildiv(m, 128) * 128, ceildiv(n, 128) * 128, dtype=x.dtype, device=x.device)
ceildiv(m, 128) * 128, ceildiv(n, 128) * 128, dtype=x.dtype, device=x.device)
x_padded[:m, :n] = x x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view( return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
x_view.size(0), x_view.size(2))
def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype): def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype):
...@@ -127,13 +123,14 @@ def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype): ...@@ -127,13 +123,14 @@ def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype):
c_acc.zero_() c_acc.zero_()
for k in range(ceildiv(K, 128)): for k in range(ceildiv(K, 128)):
c = torch._scaled_mm( c = torch._scaled_mm(
A_fp8[i * 128:(i + 1) * 128, k * 128:(k + 1) * 128], A_fp8[i * 128 : (i + 1) * 128, k * 128 : (k + 1) * 128],
B_fp8[j * 128:(j + 1) * 128, k * 128:(k + 1) * 128].T, B_fp8[j * 128 : (j + 1) * 128, k * 128 : (k + 1) * 128].T,
scale_a=A_scales[i, k].view(128, 1).contiguous(), scale_a=A_scales[i, k].view(128, 1).contiguous(),
scale_b=B_scales[j, k].view(1, 128).contiguous(), scale_b=B_scales[j, k].view(1, 128).contiguous(),
out_dtype=torch.bfloat16) out_dtype=torch.bfloat16,
)
c_acc += c.to(torch.float32) c_acc += c.to(torch.float32)
C[i * 128:(i + 1) * 128, j * 128:(j + 1) * 128] = c_acc.to(out_dtype) C[i * 128 : (i + 1) * 128, j * 128 : (j + 1) * 128] = c_acc.to(out_dtype)
return C return C
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment