Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
......@@ -3,8 +3,7 @@ from typing import Dict, List, Tuple
TokensText = Tuple[List[int], str]
def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[TokensText],
name_0: str, name_1: str):
def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[TokensText], name_0: str, name_1: str):
"""
Compare the two sequences generated by different models,
which should be equal.
......@@ -15,19 +14,14 @@ def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[Tok
output_ids_0, output_str_0 = outputs_0
output_ids_1, output_str_1 = outputs_1
assert output_str_0 == output_str_1, (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
assert output_ids_0 == output_ids_1, (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
assert output_str_0 == output_str_1, f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}"
assert output_ids_0 == output_ids_1, f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}"
TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]]
def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs],
outputs_1_lst: List[TokensTextLogprobs], name_0: str, name_1: str):
def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], outputs_1_lst: List[TokensTextLogprobs], name_0: str, name_1: str):
"""
Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal.
......@@ -41,16 +35,11 @@ def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs],
# Loop through generated tokens.
for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
# If generated tokens don't match, then
if output_id_0 != output_id_1:
# Each predicted token must be in top N logprobs of the other
assert output_id_0 in logprobs_1[idx], (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
assert output_id_1 in logprobs_0[idx], (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
assert output_id_0 in logprobs_1[idx], f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}"
assert output_id_1 in logprobs_0[idx], f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}"
# Break out since sequences will now diverge.
break
......@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
False,
dtype=torch.bool,
device=x.device)
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True
......@@ -56,7 +53,6 @@ def _fwd_kernel_inner(
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n)
# print
......@@ -73,8 +69,7 @@ def _fwd_kernel_inner(
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
if LAST_K_BLOCK:
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0,
float('-inf'))
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf"))
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
......@@ -154,7 +149,7 @@ def _fwd_kernel(
v_ptrs = V + off_v
mask_ptrs = block_mask_ptr + start_m * stride_bmm
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
......@@ -192,24 +187,12 @@ def _fwd_kernel(
acc = acc * l_recip
acc = acc.to(Out.dtype.element_ty)
off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[
None, :] * stride_od
off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX)
def _forward(ctx,
q,
k,
v,
block_sparse_mask,
sm_scale,
BLOCK_M=64,
BLOCK_N=64,
num_warps=None,
num_stages=1,
out=None):
def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None):
assert q.shape[-1] == k.shape[-1] == v.shape[-1]
assert k.shape[2] == v.shape[2]
o = out if out is not None else torch.empty_like(q).contiguous()
......@@ -254,7 +237,6 @@ def _forward(ctx,
class _sparse_attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, block_sparse_dense, sm_scale):
# shape constraints
......@@ -278,9 +260,9 @@ def test_topk_sparse_attention():
torch.manual_seed(0)
# Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level)
......@@ -288,9 +270,7 @@ def test_topk_sparse_attention():
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
print("downsample_len", downsample_len)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
device='cuda',
dtype=torch.bfloat16)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100
print("x_ds.shape", x_ds.shape)
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
......@@ -302,22 +282,21 @@ def test_topk_sparse_attention():
# Compute reference
# Expand block mask to full attention matrix
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda'))
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda"))
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal
# PyTorch reference implementation
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float('-inf'))
attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float("-inf"))
attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v)
ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
# print("ref_output", ref_output)
# print("triton_output", triton_output)
# Verify accuracy
assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \
"Triton output doesn't match reference"
assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference"
print("Pass topk sparse attention test with qlen == klen")
......@@ -329,9 +308,9 @@ def test_topk_sparse_attention_qlt_kl():
torch.manual_seed(0)
# Create inputs.
q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
# softmax scale
sm_scale = 1.0 / (D_HEAD**0.5)
......@@ -339,8 +318,7 @@ def test_topk_sparse_attention_qlt_kl():
print("downsample_factor", downsample_factor)
downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension
print("downsample_len", downsample_len)
x_ds = torch.randn(
BATCH, N_HEADS, downsample_len, downsample_len, device='cuda', dtype=torch.bfloat16)
x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.bfloat16)
# Force the first column to be high so that the first block is always selected.
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
......@@ -351,26 +329,25 @@ def test_topk_sparse_attention_qlt_kl():
past_len = K_LEN - Q_LEN
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale
attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool()
full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool()
full_mask_full = full_mask_full[..., :K_LEN, :K_LEN]
effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN)
i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1)
j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN)
causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN)
causal_mask = j_global <= i_global # shape: (Q_LEN, K_LEN)
final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN)
attn = attn.masked_fill(~final_mask, float('-inf'))
attn = attn.masked_fill(~final_mask, float("-inf"))
attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v)
ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
# Verify accuracy.
assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \
"Triton output doesn't match reference when qlen < klen"
assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference when qlen < klen"
print("Pass topk sparse attention test with qlen < klen")
......
......@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
False,
dtype=torch.bool,
device=x.device)
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True
......@@ -30,15 +27,17 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
@tilelang.jit(
out_idx=[4], pass_configs={
out_idx=[4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
},
)
def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal):
block_M = 64
block_N = 64
num_stages = 1
threads = 128
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim]
block_mask_shape = [batch, heads, downsample_len, downsample_len]
......@@ -47,7 +46,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask_dtype = "bool"
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def MMA0(
K: T.Tensor(shape, dtype),
......@@ -59,11 +57,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
......@@ -78,18 +75,18 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -113,22 +110,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def blocksparse_flashattn(
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype),
Output: T.Tensor(shape, dtype),
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype),
Output: T.Tensor(shape, dtype),
):
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
......@@ -143,7 +139,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
logsum = T.alloc_fragment([block_M], accum_dtype)
block_mask = T.alloc_local([downsample_len], block_mask_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -152,20 +148,19 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask[vj] = BlockSparseMask[bz, by, bx, vj]
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)
)
for k in T.Pipelined(loop_range, num_stages=num_stages):
if block_mask[k] != 0:
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return blocksparse_flashattn
......@@ -180,18 +175,16 @@ def test_topk_sparse_attention():
torch.manual_seed(0)
# Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level)
downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
device='cuda',
dtype=torch.bfloat16)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
......@@ -202,15 +195,15 @@ def test_topk_sparse_attention():
# Compute reference
# Expand block mask to full attention matrix
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda'))
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda"))
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal
# PyTorch reference implementation
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float('-inf'))
attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float("-inf"))
attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v)
ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
print("ref_output", ref_output)
print("tilelang_output", tilelang_output)
......
......@@ -13,17 +13,20 @@ from heuristic import num_splits_heuristic
def flashattn(batch, heads, heads_kv, dim, dim_v):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
kv_group_num = heads // heads_kv
@tilelang.jit(
out_idx=[-1], pass_configs={
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def kernel_func(block_N, block_H, page_block_size, num_split, num_stages, threads, num_pages,
max_num_blocks_per_seq, max_selected_blocks):
},
)
def kernel_func(
block_N, block_H, page_block_size, num_split, num_stages, threads, num_pages, max_num_blocks_per_seq, max_selected_blocks
):
shape_q = [batch, heads, dim]
shape_k = [num_pages, page_block_size, heads_kv, dim]
shape_v = [num_pages, page_block_size, heads_kv, dim_v]
......@@ -37,17 +40,16 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@T.macro
def flash_attn_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, "int32"),
cache_seqlens: T.Tensor([batch], "int32"),
block_table: T.Tensor(shape_block_table, "int32"),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, "int32"),
cache_seqlens: T.Tensor([batch], "int32"),
block_table: T.Tensor(shape_block_table, "int32"),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
):
with T.Kernel(
batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim_v], dtype)
......@@ -67,7 +69,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
sid = bz
cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared)
T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -75,7 +77,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
num_blocks = max_selected_blocks
blocks_per_split = T.floordiv(num_blocks, num_split)
remaining_blocks = T.floormod(num_blocks, num_split)
loop_range = (blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0))
loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0)
start = blocks_per_split * sid + T.min(sid, remaining_blocks)
has_valid_block = False
for k in T.Pipelined(loop_range, num_stages=num_stages):
......@@ -85,29 +87,20 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
block_table_idx = T.floordiv(logical_block_idx, block_ratio)
block_tile_idx = T.floormod(logical_block_idx, block_ratio)
physical_block_idx = block_table[bid, block_table_idx]
T.copy(
K[physical_block_idx,
block_tile_idx * block_N:(block_tile_idx + 1) * block_N,
cur_kv_head, :], K_shared)
T.copy(K[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], K_shared)
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(
logical_block_idx * block_N + j >= cache_seqlens[bid],
-T.infinity(accum_dtype), acc_s[i, j])
logical_block_idx * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j]
)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale -
scores_max[i] * scale)
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
......@@ -116,10 +109,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim_v):
acc_o[i, j] *= scores_scale[i]
T.copy(
V[physical_block_idx,
block_tile_idx * block_N:(block_tile_idx + 1) * block_N,
cur_kv_head, :], V_shared)
T.copy(V[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if has_valid_block:
for i, j in T.Parallel(block_H, dim_v):
......@@ -138,9 +128,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@T.macro
def combine(
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
):
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim_v], accum_dtype)
......@@ -151,17 +141,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
scale_local = T.alloc_local([1], accum_dtype)
max_split = T.alloc_local([1], "int32")
T.annotate_layout({
lse_logsum_local:
T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
})
T.annotate_layout(
{
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}
)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_local_split[0] = glse[bz, by, k]
if (lse_local_split[0] != 0):
if lse_local_split[0] != 0:
max_split[0] = k
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
......@@ -183,18 +174,17 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@T.prim_func
def main(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, "int32"),
cache_seqlens: T.Tensor([batch], "int32"),
block_table: T.Tensor(shape_block_table, "int32"),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, "int32"),
cache_seqlens: T.Tensor([batch], "int32"),
block_table: T.Tensor(shape_block_table, "int32"),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
):
flash_attn_split(Q, K, V, block_indices, cache_seqlens, block_table, glse,
Output_partial)
flash_attn_split(Q, K, V, block_indices, cache_seqlens, block_table, glse, Output_partial)
combine(glse, Output_partial, Output)
return main
......@@ -203,7 +193,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
class SparseFlashAttn(torch.nn.Module):
def __init__(self, batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_pages):
super(SparseFlashAttn, self).__init__()
self.batch = batch
......@@ -249,18 +238,11 @@ class SparseFlashAttn(torch.nn.Module):
num_sm = self.num_sm
num_split = num_splits_heuristic(
total_mblocks,
num_sm,
num_n_blocks,
num_m_blocks,
size_one_kv_head,
is_causal_or_local=True,
max_splits=128)
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
output_partial = torch.empty((batch, heads, num_split, dim_v),
dtype=torch.float32,
device='cuda')
total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128
)
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda")
output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda")
output = self.kernel(
query,
......@@ -275,14 +257,13 @@ class SparseFlashAttn(torch.nn.Module):
return output
def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_seqlens,
block_table, page_block_size, block_size):
def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_seqlens, block_table, page_block_size, block_size):
"""
Paged version of sparse attention reference implementation.
Args:
query: [batch, heads, dim]
key_cache: [num_pages, page_block_size, heads_kv, dim]
key_cache: [num_pages, page_block_size, heads_kv, dim]
value_cache: [num_pages, page_block_size, heads_kv, dim]
block_indices: [batch, heads_kv, max_selected_blocks] - logical block indices
cache_seqlens: [batch] - actual sequence lengths
......@@ -298,12 +279,8 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_
# Reconstruct the full key and value tensors from paged cache
max_cache_seqlen = max(cache_seqlens).item()
key_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim),
dtype=key_cache.dtype,
device=key_cache.device)
value_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim_v),
dtype=value_cache.dtype,
device=value_cache.device)
key_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim), dtype=key_cache.dtype, device=key_cache.device)
value_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim_v), dtype=value_cache.dtype, device=value_cache.device)
# Reconstruct full tensors from paged cache using block_table
for b in range(batch):
......@@ -319,20 +296,14 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_
actual_block_size = end_token - start_token
# Copy from paged cache to full tensors
key_full[b, :, start_token:end_token, :] = key_cache[
physical_block_idx, :actual_block_size, :, :].transpose(0, 1)
value_full[b, :, start_token:end_token, :] = value_cache[
physical_block_idx, :actual_block_size, :, :].transpose(0, 1)
key_full[b, :, start_token:end_token, :] = key_cache[physical_block_idx, :actual_block_size, :, :].transpose(0, 1)
value_full[b, :, start_token:end_token, :] = value_cache[physical_block_idx, :actual_block_size, :, :].transpose(0, 1)
# Reshape query for grouped attention
query = rearrange(
query, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
# Compute attention scores
scores = einsum(
query, key_full,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores = einsum(query, key_full, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv]
# Create sparse mask based on block_indices
sparse_mask = torch.zeros_like(scores)
......@@ -348,24 +319,23 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_
sparse_mask[b, :, h, start_pos:end_pos] = 1
# Apply sparse mask
scores = scores.masked_fill(sparse_mask == 0, float('-inf'))
scores = scores.masked_fill(sparse_mask == 0, float("-inf"))
# Apply causal mask based on actual sequence lengths
range_len = torch.arange(scores.shape[-1], device=scores.device).unsqueeze(0)
cache_seqlens_expanded = cache_seqlens.unsqueeze(1)
pad_mask = range_len >= cache_seqlens_expanded
pad_mask = pad_mask[:, None, None, :]
scores = scores.masked_fill(pad_mask, float('-inf'))
scores = scores.masked_fill(pad_mask, float("-inf"))
# Compute attention weights
attention = F.softmax(scores / scale, dim=-1)
# Apply attention to values
out = einsum(attention, value_full,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim]
out = einsum(attention, value_full, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim]
# Reshape output back to original format
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
return out
......@@ -373,17 +343,23 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_
def ref_program_fa(query, kcache, vcache, cache_seqlens, block_table):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from flash_attn import flash_attn_with_kvcache #fa2
from flash_attn import flash_attn_with_kvcache # fa2
query = query.unsqueeze(1)
output = flash_attn_with_kvcache(
query, kcache, vcache, cache_seqlens=cache_seqlens, block_table=block_table)
output = flash_attn_with_kvcache(query, kcache, vcache, cache_seqlens=cache_seqlens, block_table=block_table)
output = output.squeeze(1)
return output
def main(args):
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = (
args.batch,
args.heads,
args.heads_kv,
args.max_cache_seqlen,
args.dim,
args.dim_v,
)
sparse_ratio = args.sparse_ratio
block_N = args.block_N
page_block_size = args.page_block_size
......@@ -395,35 +371,30 @@ def main(args):
dtype = torch.float16
# Generate random inputs
Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda')
cache_seqlens = torch.randint(
max_cache_seqlen // 2, max_cache_seqlen + 1, (batch,), dtype=torch.int32, device='cuda')
Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda")
cache_seqlens = torch.randint(max_cache_seqlen // 2, max_cache_seqlen + 1, (batch,), dtype=torch.int32, device="cuda")
print("cache_seqlens: ", cache_seqlens)
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda')
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda')
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda")
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda")
# Create paged KV cache
K_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim), dtype=dtype, device='cuda')
V_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim_v),
dtype=dtype,
device='cuda')
K_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim), dtype=dtype, device="cuda")
V_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim_v), dtype=dtype, device="cuda")
# Create block table and block indices for dense case (all blocks selected)
max_num_blocks_per_seq = int(math.ceil(max_cache_seqlen / page_block_size))
print("max_num_blocks_per_seq: ", max_num_blocks_per_seq)
block_table = torch.zeros((batch, max_num_blocks_per_seq), dtype=torch.int32, device='cuda')
block_indices = torch.zeros((batch, heads_kv, max_selected_blocks),
dtype=torch.int32,
device='cuda')
block_table = torch.zeros((batch, max_num_blocks_per_seq), dtype=torch.int32, device="cuda")
block_indices = torch.zeros((batch, heads_kv, max_selected_blocks), dtype=torch.int32, device="cuda")
# Fill block table and block indices and cache
# Create a pool of available physical blocks
total_blocks_needed = sum(
int(math.ceil(cache_seqlens[seq_idx].item() / page_block_size)) for seq_idx in range(batch))
total_blocks_needed = sum(int(math.ceil(cache_seqlens[seq_idx].item() / page_block_size)) for seq_idx in range(batch))
available_blocks = list(range(total_blocks_needed))
import random
random.seed(42) # For reproducibility
random.shuffle(available_blocks)
......@@ -458,10 +429,8 @@ def main(args):
actual_block_size = end_token - start_token
# Copy K and V data to the paged cache
K_cache[physical_block_idx, :actual_block_size, :, :] = K[seq_idx,
start_token:end_token, :, :]
V_cache[physical_block_idx, :actual_block_size, :, :] = V[seq_idx,
start_token:end_token, :, :]
K_cache[physical_block_idx, :actual_block_size, :, :] = K[seq_idx, start_token:end_token, :, :]
V_cache[physical_block_idx, :actual_block_size, :, :] = V[seq_idx, start_token:end_token, :, :]
# Fill block_indices for sparse attention
# For dense case (verification), we select all blocks in reverse order
......@@ -496,10 +465,9 @@ def main(args):
remaining_blocks = [b for b in all_blocks if b not in selected_blocks]
if remaining_blocks:
import random
random.seed(42) # For reproducibility
additional_blocks = random.sample(
remaining_blocks,
min(num_selected - recent_blocks, len(remaining_blocks)))
additional_blocks = random.sample(remaining_blocks, min(num_selected - recent_blocks, len(remaining_blocks)))
selected_blocks.extend(additional_blocks)
# Sort selected blocks in reverse order (most recent first)
......@@ -512,25 +480,20 @@ def main(args):
block_indices[seq_idx, head_idx, i] = -1
# Initialize sparse attention module
sparse_attn = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N,
num_blocks)
output_sparse = sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens,
block_table)
sparse_attn = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_blocks)
output_sparse = sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table)
import flash_attn # noqa: F401
output_ref_torch = ref_program_torch_paged(Q, K_cache, V_cache, block_indices, cache_seqlens,
block_table, page_block_size, block_N)
output_ref_torch = ref_program_torch_paged(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table, page_block_size, block_N)
output_ref_fa = ref_program_fa(Q, K_cache, V_cache, cache_seqlens, block_table)
# Check correctness
if sparse_ratio == 0.0:
max_diff = torch.max(torch.abs(output_sparse - output_ref_fa)).item()
mean_diff = torch.mean(torch.abs(output_sparse - output_ref_fa)).item()
assert torch.allclose(
output_ref_fa, output_ref_torch, atol=1e-2), "Reference outputs do not match!"
assert torch.allclose(output_ref_fa, output_ref_torch, atol=1e-2), "Reference outputs do not match!"
else:
max_diff = torch.max(torch.abs(output_sparse - output_ref_torch)).item()
mean_diff = torch.mean(torch.abs(output_sparse - output_ref_torch)).item()
......@@ -574,16 +537,15 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv')
parser.add_argument(
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--dim_v', type=int, default=128, help='dim_v')
parser.add_argument('--sparse_ratio', type=float, default=0.0, help='sparse ratio')
parser.add_argument('--block_N', type=int, default=64, help='block_N')
parser.add_argument('--page_block_size', type=int, default=256, help='block size of pages')
parser.add_argument('--num_pages', type=int, default=1024, help='total number of pages')
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv")
parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--dim_v", type=int, default=128, help="dim_v")
parser.add_argument("--sparse_ratio", type=float, default=0.0, help="sparse ratio")
parser.add_argument("--block_N", type=int, default=64, help="block_N")
parser.add_argument("--page_block_size", type=int, default=256, help="block size of pages")
parser.add_argument("--num_pages", type=int, default=1024, help="total number of pages")
args = parser.parse_args()
main(args)
......@@ -10,17 +10,18 @@ from heuristic import num_splits_heuristic
def flashattn(batch, heads, heads_kv, dim, dim_v):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
kv_group_num = heads // heads_kv
@tilelang.jit(
out_idx=[-1], pass_configs={
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen,
max_selected_blocks):
},
)
def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, max_selected_blocks):
shape_q = [batch, heads, dim]
shape_k = [batch, max_cache_seqlen, heads_kv, dim]
shape_v = [batch, max_cache_seqlen, heads_kv, dim_v]
......@@ -31,17 +32,16 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@T.macro
def flash_attn_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, "int32"),
cache_seqlens: T.Tensor([batch], "int32"),
# actual_num_blocks: T.Tensor([batch], "int32"),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, "int32"),
cache_seqlens: T.Tensor([batch], "int32"),
# actual_num_blocks: T.Tensor([batch], "int32"),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
):
with T.Kernel(
batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim_v], dtype)
......@@ -62,7 +62,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
sid = bz
cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared)
T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -70,7 +70,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
num_blocks = max_selected_blocks
blocks_per_split = T.floordiv(num_blocks, num_split)
remaining_blocks = T.floormod(num_blocks, num_split)
loop_range = (blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0))
loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0)
start = blocks_per_split * sid + T.min(sid, remaining_blocks)
has_valid_block = False
......@@ -78,26 +78,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
i_s = block_indices[bid, cur_kv_head, start + k]
if i_s >= 0:
has_valid_block = True
T.copy(K[bid, i_s * block_N:(i_s + 1) * block_N, cur_kv_head, :], K_shared)
T.copy(K[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], K_shared)
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition
for i, j in T.Parallel(block_H, block_N):
acc_s[i,
j] = T.if_then_else(i_s * block_N + j >= cache_seqlens[bid],
-T.infinity(accum_dtype), acc_s[i, j])
acc_s[i, j] = T.if_then_else(i_s * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j])
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale -
scores_max[i] * scale)
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
......@@ -106,7 +98,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim_v):
acc_o[i, j] *= scores_scale[i]
T.copy(V[bid, i_s * block_N:(i_s + 1) * block_N, cur_kv_head, :], V_shared)
T.copy(V[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if has_valid_block:
for i, j in T.Parallel(block_H, dim_v):
......@@ -125,9 +117,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@T.macro
def combine(
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
):
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim_v], accum_dtype)
......@@ -138,17 +130,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
scale_local = T.alloc_local([1], accum_dtype)
max_split = T.alloc_local([1], "int32")
T.annotate_layout({
lse_logsum_local:
T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
})
T.annotate_layout(
{
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}
)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_local_split[0] = glse[bz, by, k]
if (lse_local_split[0] != 0):
if lse_local_split[0] != 0:
max_split[0] = k
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
......@@ -170,15 +163,15 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@T.prim_func
def main(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, "int32"),
cache_seqlens: T.Tensor([batch], "int32"),
# actual_num_blocks: T.Tensor([batch], "int32"),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, "int32"),
cache_seqlens: T.Tensor([batch], "int32"),
# actual_num_blocks: T.Tensor([batch], "int32"),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
):
# flash_attn_split(Q, K, V, block_indices, cache_seqlens, actual_num_blocks, glse, Output_partial)
flash_attn_split(Q, K, V, block_indices, cache_seqlens, glse, Output_partial)
......@@ -190,7 +183,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
class SparseFlashAttn(torch.nn.Module):
def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size):
super(SparseFlashAttn, self).__init__()
self.batch = batch
......@@ -209,7 +201,8 @@ class SparseFlashAttn(torch.nn.Module):
num_stages=2,
threads=128,
max_cache_seqlen=T.dynamic("max_cache_seqlen"),
max_selected_blocks=T.dynamic("max_selected_blocks"))
max_selected_blocks=T.dynamic("max_selected_blocks"),
)
props = torch.cuda.get_device_properties(torch.device("cuda:0"))
self.num_sm = props.multi_processor_count
......@@ -232,25 +225,17 @@ class SparseFlashAttn(torch.nn.Module):
num_sm = self.num_sm
num_split = num_splits_heuristic(
total_mblocks,
num_sm,
num_n_blocks,
num_m_blocks,
size_one_kv_head,
is_causal_or_local=True,
max_splits=128)
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
output_partial = torch.empty((batch, heads, num_split, dim_v),
dtype=torch.float32,
device='cuda')
total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128
)
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda")
output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda")
output = self.kernel(query, key, value, block_indices, cache_seqlens, glse, output_partial)
return output
def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seqlens,
max_cache_seqlen, block_size):
def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, block_size):
"""
Args:
query: [batch, heads, dim]
......@@ -272,31 +257,24 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql
block_H = 64
actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32)
actual_num_blocks = actual_num_blocks[:,
0] #[batch], number of valid blocks, assume all groups in the same batch have the same number of blocks
actual_num_blocks = actual_num_blocks[
:, 0
] # [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks
# get num_split
num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H
num_n_blocks = max_selected_blocks #(kv_seqlen + block_size - 1 ) // block_size
num_n_blocks = max_selected_blocks # (kv_seqlen + block_size - 1 ) // block_size
# num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks
size_one_kv_head = max_selected_blocks * block_size * (
dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2
size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks
num_sm = 132
num_split = num_splits_heuristic(
total_mblocks,
num_sm,
num_n_blocks,
num_m_blocks,
size_one_kv_head,
is_causal_or_local=True,
max_splits=128)
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
Output_partial = torch.empty((batch, heads, num_split, dim_v),
dtype=torch.float32,
device='cuda')
total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128
)
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda")
Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda")
kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
block_H=block_H,
......@@ -304,29 +282,24 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql
num_stages=2,
threads=128,
max_cache_seqlen=T.dynamic("max_cache_seqlen"),
max_selected_blocks=T.dynamic("max_selected_blocks"))
max_selected_blocks=T.dynamic("max_selected_blocks"),
)
output = kernel(query, key, value, block_indices, cache_seqlens, glse, Output_partial)
return output
def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks,
block_size):
def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size):
batch, heads, dim = query.shape
heads_kv = key.shape[2]
num_head_groups = query.shape[1] // key.shape[2]
scale = dim**0.5
key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim]
value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim]
key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
query = rearrange(
query, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
scores = einsum(
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask = torch.zeros_like(scores)
# Assign mask values based on block_indices
......@@ -335,28 +308,26 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache
valid_indices = block_indices[b, h] # Extract indices for this batch and head
for idx in valid_indices:
if idx >= 0:
sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1
scores = scores.masked_fill(sparse_mask == 0, float('-inf'))
sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1
scores = scores.masked_fill(sparse_mask == 0, float("-inf"))
range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0)
range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0)
cache_seqlens_expanded = cache_seqlens.unsqueeze(1)
pad_mask = range_len >= cache_seqlens_expanded
pad_mask = pad_mask[:, None, None, :]
scores = scores.masked_fill(pad_mask, float('-inf'))
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores = scores.masked_fill(pad_mask, float("-inf"))
attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
out = einsum(attention, value,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim]
out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
return out
def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks,
block_size):
def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from flash_attn import flash_attn_with_kvcache #fa2
from flash_attn import flash_attn_with_kvcache # fa2
query = query.unsqueeze(1)
output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens)
output = output.squeeze(1)
......@@ -368,23 +339,13 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3):
print(name + " all_close={}".format(all_close))
if not all_close:
diff = (expect - actual).abs()
print("all_close={}, max={}, min={}, mean={}".format(all_close,
diff.max().item(),
diff.min().item(),
diff.mean().item()))
print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item()))
max_indices = torch.nonzero(diff == diff.max().item())
first_index = tuple(max_indices[0].tolist())
print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}")
def main(batch=8,
heads=32,
heads_kv=8,
max_cache_seqlen=8192,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32):
def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32):
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
sparse_ratio = sparse_ratio
block_size = block_size
......@@ -392,10 +353,10 @@ def main(batch=8,
print("max_selected_blocks: ", max_selected_blocks)
dtype = torch.float16
Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda')
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda')
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda')
cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda')
Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda")
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda")
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda")
cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda")
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
# # Ensure at least one element equals cache_seqlen
# random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index
......@@ -406,10 +367,7 @@ def main(batch=8,
max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int()
print("max_valid_num_blocks: ", max_valid_num_blocks)
# Initialize block_indices with -1 (for padding blocks)
block_indices = torch.full((batch, heads_kv, max_selected_blocks),
-1,
dtype=torch.int32,
device='cuda')
block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda")
# max_num_blocks = int((max_cache_seqlen + block_size - 1)/ block_size)
# block_indices = torch.full((batch, heads_kv, max_num_blocks), -1, dtype=torch.int32, device='cuda')
......@@ -418,10 +376,9 @@ def main(batch=8,
max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch
if max_valid_block > 0: # Ensure there's at least one valid block
for h in range(heads_kv):
valid_indices = torch.randperm(
max_valid_block, device='cuda', dtype=torch.int32)[:max_selected_blocks]
valid_indices = torch.randperm(max_valid_block, device="cuda", dtype=torch.int32)[:max_selected_blocks]
# valid_indices = torch.randperm(max_valid_block, device='cuda', dtype=torch.int32)[:max_num_blocks]
block_indices[b, h, :len(valid_indices)] = valid_indices
block_indices[b, h, : len(valid_indices)] = valid_indices
# Sort indices within each batch-group for consistency
block_indices, _ = block_indices.sort(dim=-1, descending=True)
......@@ -434,8 +391,7 @@ def main(batch=8,
print("max_num_blocks: ", max_num_blocks)
# parity reference
ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks,
block_size)
ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size)
sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size)
out = sparse_kernel(Q, K, V, block_indices, cache_seqlens)
......@@ -445,13 +401,11 @@ def main(batch=8,
## latency reference
for _ in range(10):
ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen,
max_num_blocks, block_size)
ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size)
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen,
max_num_blocks, block_size)
ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size)
torch.cuda.synchronize()
print("dense time: ", (time.time() - start) / 100 * 1000)
......@@ -469,15 +423,13 @@ def main(batch=8,
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv')
parser.add_argument(
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--dim_v', type=int, default=128, help='dim_v')
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio')
parser.add_argument('--block_size', type=int, default=32, help='block_size')
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv")
parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--dim_v", type=int, default=128, help="dim_v")
parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio")
parser.add_argument("--block_size", type=int, default=32, help="block_size")
args = parser.parse_args()
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v,
args.sparse_ratio, args.block_size)
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size)
......@@ -12,15 +12,17 @@ from heuristic import num_splits_heuristic
def flashattn(batch, heads, heads_kv, dim, dim_v):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
kv_group_num = heads // heads_kv
@tilelang.jit(
out_idx=[-1], pass_configs={
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
},
)
def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, num_blocks):
shape_q = [batch, heads, dim]
shape_k = [batch, max_cache_seqlen, heads_kv, dim]
......@@ -32,16 +34,15 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@T.macro
def flash_attn_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_mask: T.Tensor(shape_mask, "bool"),
cache_seqlens: T.Tensor([batch], "int32"),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_mask: T.Tensor(shape_mask, "bool"),
cache_seqlens: T.Tensor([batch], "int32"),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
):
with T.Kernel(
batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim_v], dtype)
......@@ -62,39 +63,31 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
sid = bz
cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared)
T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
blocks_per_split = T.floordiv(num_blocks, num_split)
remaining_blocks = T.floormod(num_blocks, num_split)
loop_range = (blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0))
loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0)
start = blocks_per_split * sid + T.min(sid, remaining_blocks)
has_valid_block = False
for k in T.Pipelined(loop_range, num_stages=num_stages):
if block_mask[bid, hid, start + k]:
has_valid_block = True
T.copy(
K[bid, (start + k) * block_N:(start + k + 1) * block_N, cur_kv_head, :],
K_shared)
T.copy(K[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], K_shared)
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else((start + k) * block_N + j
>= cache_seqlens[bx],
-T.infinity(accum_dtype), acc_s[i, j])
acc_s[i, j] = T.if_then_else(
(start + k) * block_N + j >= cache_seqlens[bx], -T.infinity(accum_dtype), acc_s[i, j]
)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale -
scores_max[i] * scale)
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
......@@ -103,9 +96,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim_v):
acc_o[i, j] *= scores_scale[i]
T.copy(
V[bid, (start + k) * block_N:(start + k + 1) * block_N, cur_kv_head, :],
V_shared)
T.copy(V[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if has_valid_block:
for i, j in T.Parallel(block_H, dim_v):
......@@ -123,9 +114,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@T.macro
def combine(
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
):
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim_v], accum_dtype)
......@@ -135,10 +126,11 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({
lse_logsum_local:
T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
})
T.annotate_layout(
{
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}
)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
......@@ -161,14 +153,14 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@T.prim_func
def main(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_mask: T.Tensor(shape_mask, "bool"),
cache_seqlens: T.Tensor([batch], "int32"),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_mask: T.Tensor(shape_mask, "bool"),
cache_seqlens: T.Tensor([batch], "int32"),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
):
flash_attn_split(Q, K, V, block_mask, cache_seqlens, glse, Output_partial)
combine(glse, Output_partial, Output)
......@@ -179,7 +171,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
class SparseFlashAttn(torch.nn.Module):
def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size):
super(SparseFlashAttn, self).__init__()
self.batch = batch
......@@ -198,7 +189,8 @@ class SparseFlashAttn(torch.nn.Module):
num_stages=2,
threads=128,
max_cache_seqlen=T.dynamic("max_cache_seqlen"),
num_blocks=T.dynamic("num_blocks"))
num_blocks=T.dynamic("num_blocks"),
)
props = torch.cuda.get_device_properties(torch.device("cuda:0"))
self.num_sm = props.multi_processor_count
......@@ -217,24 +209,16 @@ class SparseFlashAttn(torch.nn.Module):
num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H
num_n_blocks = max_selected_blocks
size_one_kv_head = max_selected_blocks * block_size * (
dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2
size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks
# num_sm = 132
num_sm = self.num_sm
num_split = num_splits_heuristic(
total_mblocks,
num_sm,
num_n_blocks,
num_m_blocks,
size_one_kv_head,
is_causal_or_local=True,
max_splits=128)
total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128
)
# print("num_split: ", num_split)
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
Output_partial = torch.empty((batch, heads, num_split, dim_v),
dtype=torch.float32,
device='cuda')
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda")
Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda")
output = self.kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial)
return output
......@@ -259,26 +243,21 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
block_H = 64
actual_num_blocks = torch.sum(block_mask, dim=-1).to(torch.int32)
actual_num_blocks = actual_num_blocks[:,
0] #[batch], number of valid blocks, assume all groups in the same batch have the same number of blocks
actual_num_blocks = actual_num_blocks[
:, 0
] # [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks
max_selected_blocks = actual_num_blocks.max().item()
# get num_split
num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H
num_n_blocks = max_selected_blocks #(kv_seqlen + block_size - 1 ) // block_size
num_n_blocks = max_selected_blocks # (kv_seqlen + block_size - 1 ) // block_size
# num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks
size_one_kv_head = max_selected_blocks * block_size * (
dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2
size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks
num_sm = 132
num_split = num_splits_heuristic(
total_mblocks,
num_sm,
num_n_blocks,
num_m_blocks,
size_one_kv_head,
is_causal_or_local=True,
max_splits=128)
total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128
)
kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
......@@ -287,11 +266,10 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
num_stages=2,
threads=128,
max_cache_seqlen=T.dynamic("max_cache_seqlen"),
num_blocks=T.dynamic("num_blocks"))
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
Output_partial = torch.empty((batch, heads, num_split, dim_v),
dtype=torch.float32,
device='cuda')
num_blocks=T.dynamic("num_blocks"),
)
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda")
Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda")
# print(kernel.get_kernel_source())
output = kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial)
......@@ -299,24 +277,18 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
return output
def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks,
block_size):
def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size):
batch, heads, dim = query.shape
heads_kv = key.shape[2]
num_head_groups = query.shape[1] // key.shape[2]
scale = dim**0.5
key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim]
value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim]
key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
query = rearrange(
query, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
scores = einsum(
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask = torch.zeros_like(scores)
# Assign mask values
......@@ -324,29 +296,27 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se
for h in range(heads_kv):
for idx in range(num_blocks):
if block_mask[b, h, idx]:
sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1
sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1
scores = scores.masked_fill(sparse_mask == 0, float('-inf'))
scores = scores.masked_fill(sparse_mask == 0, float("-inf"))
range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0)
range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0)
cache_seqlens_expanded = cache_seqlens.unsqueeze(1)
pad_mask = range_len >= cache_seqlens_expanded
pad_mask = pad_mask[:, None, None, :]
scores = scores.masked_fill(pad_mask, float('-inf'))
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores = scores.masked_fill(pad_mask, float("-inf"))
attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
out = einsum(attention, value,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim]
out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
return out
def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks,
block_size):
def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from flash_attn import flash_attn_with_kvcache #fa2
from flash_attn import flash_attn_with_kvcache # fa2
query = query.unsqueeze(1)
output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens)
output = output.squeeze(1)
......@@ -360,23 +330,13 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3):
# print(expect[3, 28])
# print(actual[3, 28])
diff = (expect - actual).abs()
print("all_close={}, max={}, min={}, mean={}".format(all_close,
diff.max().item(),
diff.min().item(),
diff.mean().item()))
print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item()))
max_indices = torch.nonzero(diff == diff.max().item())
first_index = tuple(max_indices[0].tolist())
print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}")
def main(batch=8,
heads=32,
heads_kv=8,
max_cache_seqlen=8192,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32):
def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32):
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
sparse_ratio = sparse_ratio
block_size = block_size
......@@ -384,14 +344,13 @@ def main(batch=8,
print("max_selected_blocks: ", max_selected_blocks)
dtype = torch.float16
Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda')
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda')
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda')
cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda')
Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda")
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda")
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda")
cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda")
# Ensure at least one element equals cache_seqlen
random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index
cache_seqlens[
random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence
random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index
cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
print("cache_seqlens: ", cache_seqlens)
......@@ -403,7 +362,7 @@ def main(batch=8,
max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int()
print("max_valid_num_blocks: ", max_valid_num_blocks)
# Initialize block_mask with false (for padding blocks)
block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device='cuda')
block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device="cuda")
# Assign valid indices while ensuring no duplicates within each batch-group
for b in range(batch):
......@@ -411,13 +370,12 @@ def main(batch=8,
valid_num_block = valid_num_blocks[b].item() # Valid blocks for this batch
if valid_num_block > 0: # Ensure there's at least one valid block
for h in range(heads_kv):
perm = torch.randperm(max_valid_block, device='cuda')[:valid_num_block]
perm = torch.randperm(max_valid_block, device="cuda")[:valid_num_block]
block_mask[b, h, perm] = True
# print("block_mask: ", block_mask)
# parity reference
ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks,
block_size)
ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size)
# out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size)
model = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size)
out = model(Q, K, V, block_mask, cache_seqlens)
......@@ -427,13 +385,11 @@ def main(batch=8,
## latency reference
for _ in range(10):
ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks,
block_size)
ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size)
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks,
block_size)
ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size)
torch.cuda.synchronize()
print("dense time: ", (time.time() - start) / 100 * 1000)
......@@ -452,15 +408,13 @@ def main(batch=8,
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv')
parser.add_argument(
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--dim_v', type=int, default=128, help='dim_v')
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio')
parser.add_argument('--block_size', type=int, default=32, help='block_size')
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv")
parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--dim_v", type=int, default=128, help="dim_v")
parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio")
parser.add_argument("--block_size", type=int, default=32, help="block_size")
args = parser.parse_args()
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v,
args.sparse_ratio, args.block_size)
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size)
......@@ -12,12 +12,8 @@ from heuristic import num_splits_heuristic
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [1, 2, 4]\
for num_stages in [1, 2, 3, 4, 7]
],
key=['BLOCK_H', 'BLOCK_N', 'BLOCK_D'],
configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]],
key=["BLOCK_H", "BLOCK_N", "BLOCK_D"],
)
@triton.jit
def _split_kernel(
......@@ -79,16 +75,11 @@ def _split_kernel(
loop_range = blocks_per_split
q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h
k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[
None, :] * stride_k_s + offs_d[:, None] * stride_k_d
v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:,
None] * stride_v_s + offs_d[
None, :] * stride_v_d
k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[None, :] * stride_k_s + offs_d[:, None] * stride_k_d
v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, None] * stride_v_s + offs_d[None, :] * stride_v_d
mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h
q = tl.load(
q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d,
mask=offs_h[:, None] < gqa_group_size)
q = tl.load(q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, mask=offs_h[:, None] < gqa_group_size)
start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks)
for i in range(loop_range):
block_idx = tl.load(mask_ptr + (start + i) * stride_mask_s)
......@@ -119,23 +110,18 @@ def _split_kernel(
acc = acc * l_recip
acc = acc.to(o_partial_ptr.dtype.element_ty)
lse_partial_ptr += batch_idx * stride_lse_b + (
head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split
lse_partial_ptr += batch_idx * stride_lse_b + (head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split
tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size)
o_partial_ptr += batch_idx * stride_o_b + (
head_idx_q +
offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d
o_partial_ptr += (
batch_idx * stride_o_b + (head_idx_q + offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d
)
tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size)
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [1, 2, 4]\
for num_stages in [1, 2, 3, 4, 7]
],
key=['BLOCK_D'],
configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]],
key=["BLOCK_D"],
)
@triton.jit
def _merge_kernel(
......@@ -163,18 +149,15 @@ def _merge_kernel(
offs_d = tl.arange(0, BLOCK_D)
lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h
lse = tl.load(
lse_offsets + offs_splits * lse_partial_stride_split,
mask=offs_splits < num_splits,
other=float("-inf"))
lse = tl.load(lse_offsets + offs_splits * lse_partial_stride_split, mask=offs_splits < num_splits, other=float("-inf"))
lse_max = tl.max(lse)
o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h
o_partial = tl.load(
o_offsets + offs_splits[:, None] * o_partial_stride_split +
offs_d[None, :] * o_partial_stride_d,
mask=offs_splits[:, None] < num_splits)
o_offsets + offs_splits[:, None] * o_partial_stride_split + offs_d[None, :] * o_partial_stride_d,
mask=offs_splits[:, None] < num_splits,
)
sumexp_normalized_splitk = tl.exp(lse - lse_max)
sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0)
numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0)
......@@ -209,19 +192,13 @@ def block_sparse_flash_decode_gqa_indice_triton(
num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H
num_n_blocks = max_selected_blocks
size_one_kv_head = max_selected_blocks * block_size * (
dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2
size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks
num_sm = 64
# num_sm = self.num_sm
num_splits = num_splits_heuristic(
total_mblocks,
num_sm,
num_n_blocks,
num_m_blocks,
size_one_kv_head,
is_causal_or_local=True,
max_splits=128)
total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128
)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
......@@ -295,24 +272,18 @@ def block_sparse_flash_decode_gqa_indice_triton(
return output
def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks,
block_size):
def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size):
batch, heads, dim = query.shape
heads_kv = key.shape[2]
dim_v = value.shape[-1]
num_head_groups = query.shape[1] // key.shape[2]
scale = dim**0.5
key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim]
value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim]
key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
query = rearrange(
query, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
scores = einsum(
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask = torch.zeros_like(scores)
# Assign mask values based on block_indices
......@@ -321,42 +292,33 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache
valid_indices = block_indices[b, h] # Extract indices for this batch and head
for idx in valid_indices:
if idx >= 0:
sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1
scores = scores.masked_fill(sparse_mask == 0, float('-inf'))
sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1
scores = scores.masked_fill(sparse_mask == 0, float("-inf"))
range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0)
range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0)
cache_seqlens_expanded = cache_seqlens.unsqueeze(1)
pad_mask = range_len >= cache_seqlens_expanded
pad_mask = pad_mask[:, None, None, :]
scores = scores.masked_fill(pad_mask, float('-inf'))
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores = scores.masked_fill(pad_mask, float("-inf"))
attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
out = einsum(attention, value,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim]
out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
return out
def ref_program_fa(query, key, value, cache_seqlens):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from flash_attn import flash_attn_with_kvcache #fa2
from flash_attn import flash_attn_with_kvcache # fa2
query = query.unsqueeze(1)
output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens)
output = output.squeeze(1)
return output
def main(batch=64,
heads=32,
heads_kv=8,
max_cache_seqlen=8192,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32):
def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32):
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
sparse_ratio = sparse_ratio
block_size = block_size
......@@ -369,34 +331,29 @@ def main(batch=64,
dtype = torch.float16
block_H = 64
Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda')
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda')
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda')
cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda')
Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda")
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda")
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda")
cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda")
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
# Ensure at least one element equals cache_seqlen
random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index
cache_seqlens[
random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence
random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index
cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence
print("cache_seqlens: ", cache_seqlens)
max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int()
print("max_valid_num_blocks: ", max_valid_num_blocks)
# Initialize block_indices with -1 (for padding blocks)
block_indices = torch.full((batch, heads_kv, max_selected_blocks),
-1,
dtype=torch.int32,
device='cuda')
block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda")
# Assign valid indices while ensuring no duplicates within each batch-group
for b in range(batch):
max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch
if max_valid_block > 0: # Ensure there's at least one valid block
for h in range(heads_kv):
valid_indices = torch.randperm(
max_valid_block, device='cuda', dtype=torch.int32)[:max_selected_blocks]
block_indices[b, h, :len(valid_indices)] = valid_indices
valid_indices = torch.randperm(max_valid_block, device="cuda", dtype=torch.int32)[:max_selected_blocks]
block_indices[b, h, : len(valid_indices)] = valid_indices
# Sort indices within each batch-group for consistency
block_indices, _ = block_indices.sort(dim=-1, descending=True)
......@@ -408,8 +365,7 @@ def main(batch=64,
max_num_blocks = torch.max(max_valid_num_blocks).item()
print("max_num_blocks: ", max_num_blocks)
ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks,
block_size)
ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size)
triton_out = block_sparse_flash_decode_gqa_indice_triton(
Q,
......@@ -423,8 +379,7 @@ def main(batch=64,
)
print("max difference: ", torch.max(torch.abs(ref - triton_out)))
assert torch.allclose(
ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation"
assert torch.allclose(ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation"
print("Passed the ref test!")
# Measure performance
......@@ -466,15 +421,13 @@ def main(batch=64,
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=64, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv')
parser.add_argument(
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--dim_v', type=int, default=128, help='dim_v')
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio')
parser.add_argument('--block_size', type=int, default=32, help='block_size')
parser.add_argument("--batch", type=int, default=64, help="batch size")
parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv")
parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--dim_v", type=int, default=128, help="dim_v")
parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio")
parser.add_argument("--block_size", type=int, default=32, help="block_size")
args = parser.parse_args()
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v,
args.sparse_ratio, args.block_size)
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size)
......@@ -11,12 +11,8 @@ from heuristic import num_splits_heuristic
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [1, 2, 4]\
for num_stages in [1, 2, 3, 4, 7]
],
key=['BLOCK_H', 'BLOCK_N', 'BLOCK_D'],
configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]],
key=["BLOCK_H", "BLOCK_N", "BLOCK_D"],
)
@triton.jit
def _split_kernel(
......@@ -77,16 +73,11 @@ def _split_kernel(
loop_range = blocks_per_split
q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h
k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[
None, :] * stride_k_s + offs_d[:, None] * stride_k_d
v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:,
None] * stride_v_s + offs_d[
None, :] * stride_v_d
k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[None, :] * stride_k_s + offs_d[:, None] * stride_k_d
v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, None] * stride_v_s + offs_d[None, :] * stride_v_d
mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h
q = tl.load(
q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d,
mask=offs_h[:, None] < gqa_group_size)
q = tl.load(q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, mask=offs_h[:, None] < gqa_group_size)
start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks)
for block_idx in range(loop_range):
start_n = (start + block_idx) * BLOCK_N
......@@ -117,23 +108,18 @@ def _split_kernel(
acc = acc * l_recip
acc = acc.to(o_partial_ptr.dtype.element_ty)
lse_partial_ptr += batch_idx * stride_lse_b + (
head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split
lse_partial_ptr += batch_idx * stride_lse_b + (head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split
tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size)
o_partial_ptr += batch_idx * stride_o_b + (
head_idx_q +
offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d
o_partial_ptr += (
batch_idx * stride_o_b + (head_idx_q + offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d
)
tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size)
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [1, 2, 4]\
for num_stages in [1, 2, 3, 4, 7]
],
key=['BLOCK_D'],
configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]],
key=["BLOCK_D"],
)
@triton.jit
def _merge_kernel(
......@@ -161,18 +147,15 @@ def _merge_kernel(
offs_d = tl.arange(0, BLOCK_D)
lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h
lse = tl.load(
lse_offsets + offs_splits * lse_partial_stride_split,
mask=offs_splits < num_splits,
other=float("-inf"))
lse = tl.load(lse_offsets + offs_splits * lse_partial_stride_split, mask=offs_splits < num_splits, other=float("-inf"))
lse_max = tl.max(lse)
o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h
o_partial = tl.load(
o_offsets + offs_splits[:, None] * o_partial_stride_split +
offs_d[None, :] * o_partial_stride_d,
mask=offs_splits[:, None] < num_splits)
o_offsets + offs_splits[:, None] * o_partial_stride_split + offs_d[None, :] * o_partial_stride_d,
mask=offs_splits[:, None] < num_splits,
)
sumexp_normalized_splitk = tl.exp(lse - lse_max)
sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0)
numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0)
......@@ -207,19 +190,13 @@ def block_sparse_flash_decode_gqa_mask_triton(
num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H
num_n_blocks = max_selected_blocks
size_one_kv_head = max_selected_blocks * block_size * (
dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2
size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks
num_sm = 64
# num_sm = self.num_sm
num_splits = num_splits_heuristic(
total_mblocks,
num_sm,
num_n_blocks,
num_m_blocks,
size_one_kv_head,
is_causal_or_local=True,
max_splits=128)
total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128
)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
......@@ -292,24 +269,18 @@ def block_sparse_flash_decode_gqa_mask_triton(
return output
def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks,
block_size):
def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size):
batch, heads, dim = query.shape
heads_kv = key.shape[2]
num_head_groups = query.shape[1] // key.shape[2]
scale = dim**0.5
key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim]
value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim]
key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
query = rearrange(
query, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
scores = einsum(
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask = torch.zeros_like(scores)
# Assign mask values
......@@ -317,43 +288,34 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se
for h in range(heads_kv):
for idx in range(num_blocks):
if block_mask[b, h, idx]:
sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1
sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1
scores = scores.masked_fill(sparse_mask == 0, float('-inf'))
scores = scores.masked_fill(sparse_mask == 0, float("-inf"))
range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0)
range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0)
cache_seqlens_expanded = cache_seqlens.unsqueeze(1)
pad_mask = range_len >= cache_seqlens_expanded
pad_mask = pad_mask[:, None, None, :]
scores = scores.masked_fill(pad_mask, float('-inf'))
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores = scores.masked_fill(pad_mask, float("-inf"))
attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
out = einsum(attention, value,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim]
out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
return out
def ref_program_fa(query, key, value, cache_seqlens):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from flash_attn import flash_attn_with_kvcache #fa2
from flash_attn import flash_attn_with_kvcache # fa2
query = query.unsqueeze(1)
output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens)
output = output.squeeze(1)
return output
def main(batch=64,
heads=32,
heads_kv=8,
max_cache_seqlen=8192,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32):
def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32):
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
block_size = block_size
sparse_ratio = sparse_ratio
......@@ -363,14 +325,13 @@ def main(batch=64,
dtype = torch.float16
Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda')
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda')
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda')
cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda')
Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda")
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda")
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda")
cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda")
# Ensure at least one element equals cache_seqlen
random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index
cache_seqlens[
random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence
random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index
cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence
num_blocks = (max_cache_seqlen + block_size - 1) // block_size
......@@ -379,7 +340,7 @@ def main(batch=64,
max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int()
print("max_valid_num_blocks: ", max_valid_num_blocks)
# Initialize block_mask with false (for padding blocks)
block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device='cuda')
block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device="cuda")
# Assign valid indices while ensuring no duplicates within each batch-group
for b in range(batch):
......@@ -387,11 +348,10 @@ def main(batch=64,
valid_num_block = valid_num_blocks[b].item() # Valid blocks for this batch
if valid_num_block > 0: # Ensure there's at least one valid block
for h in range(heads_kv):
perm = torch.randperm(max_valid_block, device='cuda')[:valid_num_block]
perm = torch.randperm(max_valid_block, device="cuda")[:valid_num_block]
block_mask[b, h, perm] = True
ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks,
block_size)
ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size)
triton_out = block_sparse_flash_decode_gqa_mask_triton(
Q,
......@@ -404,8 +364,7 @@ def main(batch=64,
)
# print("max difference: ", torch.max(torch.abs(ref - triton_out)))
assert torch.allclose(
ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation"
assert torch.allclose(ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation"
print("Passed the ref test!")
# Measure performance
......@@ -448,15 +407,13 @@ def main(batch=64,
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=64, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv')
parser.add_argument(
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--dim_v', type=int, default=128, help='dim_v')
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio')
parser.add_argument('--block_size', type=int, default=32, help='block_size')
parser.add_argument("--batch", type=int, default=64, help="batch size")
parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv")
parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--dim_v", type=int, default=128, help="dim_v")
parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio")
parser.add_argument("--block_size", type=int, default=32, help="block_size")
args = parser.parse_args()
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v,
args.sparse_ratio, args.block_size)
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size)
import math
def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, size_one_kv_head,
is_causal_or_local, max_splits):
def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local, max_splits):
"""
Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency.
......
......@@ -25,26 +25,14 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask():
def test_example_triton_sparse_gqa_decode_varlen_indice():
example_triton_sparse_gqa_decode_varlen_indice.main(
batch=8,
heads=8,
heads_kv=4,
max_cache_seqlen=2048,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32)
batch=8, heads=8, heads_kv=4, max_cache_seqlen=2048, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32
)
def test_example_triton_sparse_gqa_decode_varlen_mask():
example_triton_sparse_gqa_decode_varlen_mask.main(
batch=16,
heads=16,
heads_kv=8,
max_cache_seqlen=1024,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32)
batch=16, heads=16, heads_kv=8, max_cache_seqlen=1024, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32
)
if __name__ == "__main__":
......
......@@ -19,8 +19,7 @@ parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K")
parser.add_argument("--sparsity", type=float, default=0.5, help="Sparsity ratio (0-1)")
parser.add_argument(
"--use_autotune", action="store_true", default=False, help="Whether to use autotune")
parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune")
args, _ = parser.parse_known_args()
M, N, K = args.m, args.n, args.k
......@@ -41,17 +40,19 @@ def get_configs():
thread_num = [128, 256]
enable_rasterization = [True, False]
_configs = list(
itertools.product(block_M, block_N, block_K, num_stages, thread_num, enable_rasterization))
_configs = list(itertools.product(block_M, block_N, block_K, num_stages, thread_num, enable_rasterization))
return [{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"enable_rasteration": c[5],
} for c in _configs]
return [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"enable_rasteration": c[5],
}
for c in _configs
]
def ref_program(A, B, BlockMask, block_M, block_N, block_K):
......@@ -61,12 +62,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K):
accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device)
for k in range(K // block_K):
if BlockMask[i, j, k]:
accu += (
A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to(
torch.float32) @ B[k * block_K:(k + 1) * block_K,
j * block_N:(j + 1) * block_N].to(torch.float32))
ref_c[i * block_M:(i + 1) * block_M,
j * block_N:(j + 1) * block_N] = accu.to(torch.float16)
accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[
k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N
].to(torch.float32)
ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16)
return ref_c
......@@ -89,28 +88,21 @@ def supply_program(params: List[KernelParam]):
return input_tensors
@tilelang.autotune(configs=get_configs(),)
@tilelang.autotune(
configs=get_configs(),
)
@tilelang.jit(out_idx=[-1])
def blocksparse_matmul(M,
N,
K,
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float"):
def blocksparse_matmul(
M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype="float16", accum_dtype="float"
):
block_mask_shape = (M // block_M, N // block_N, K // block_K)
@T.prim_func
def block_sparse_matmul(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
......@@ -134,7 +126,6 @@ def blocksparse_matmul(M,
def main():
# Initialize input matrices A and B on the GPU with half precision
a = torch.randn(M, K).cuda().half()
b = torch.randn(K, N).cuda().half()
......@@ -147,8 +138,7 @@ def main():
best_config = kernel.config
best_latency = kernel.latency
block_M, block_N, block_K = best_config["block_M"], best_config["block_N"], best_config[
"block_K"]
block_M, block_N, block_K = best_config["block_M"], best_config["block_N"], best_config["block_K"]
print(f"Best Config: {best_config}")
print(f"Sparsity Ratio: {sparsity}")
......@@ -163,7 +153,8 @@ def main():
block_K=DEFAULT_BLOCK_K,
num_stages=DEFAULT_NUM_STAGES,
thread_num=DEFAULT_THREAD_NUM,
enable_rasteration=DEFAULT_ENABLE_RASTERIZATION)
enable_rasteration=DEFAULT_ENABLE_RASTERIZATION,
)
block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K
print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})")
# Create block mask with desired sparsity
......
......@@ -16,11 +16,13 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
fp8_max = 448.0
@T.prim_func
def group_per_split_token_cast(X: T.Tensor((M, N), dtype), batch_sizes: T.Tensor(
(BG,), "int32"), X_fp8: T.Tensor((BG, M_max, N), "float8_e4m3"), X_amax: T.Tensor(
(BG, M_max, T.ceildiv(N, group_size)), accum_dtype)):
with T.Kernel(
T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz):
def group_per_split_token_cast(
X: T.Tensor((M, N), dtype),
batch_sizes: T.Tensor((BG,), "int32"),
X_fp8: T.Tensor((BG, M_max, N), "float8_e4m3"),
X_amax: T.Tensor((BG, M_max, T.ceildiv(N, group_size)), accum_dtype),
):
with T.Kernel(T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz):
row = bx
row_g_id = by
bg = bz
......@@ -31,36 +33,32 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3")
row_offset = T.alloc_fragment((1,), "int32")
T.annotate_layout({
y_local:
T.Fragment(
y_local.shape,
forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32),
})
T.annotate_layout(
{
y_local: T.Fragment(y_local.shape, forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32),
}
)
row_offset[0] = 0
for i in T.serial(bg):
row_offset[0] += batch_sizes[i]
T.copy(
X[row_offset[0] + row * blk_m:row_offset[0] + (row + 1) * blk_m,
row_g_id * group_size:(row_g_id + 1) * group_size], y_local)
X[row_offset[0] + row * blk_m : row_offset[0] + (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size],
y_local,
)
T.reduce_absmax(y_local, y_amax_local, dim=1)
for i in T.Parallel(blk_m):
y_amax_local[i] = T.max(y_amax_local[i], 1e-4)
y_s_local[i] = T.if_then_else(row * blk_m + i < batch_sizes[bg],
y_amax_local[i] / fp8_max, 0)
y_s_local[i] = T.if_then_else(row * blk_m + i < batch_sizes[bg], y_amax_local[i] / fp8_max, 0)
for i, j in T.Parallel(blk_m, group_size):
y_q_local[i, j] = T.clamp(y_local[i, j] / y_s_local[i], fp8_min, fp8_max)
T.copy(y_q_local, y_q_local_fp8)
for i, j in T.Parallel(blk_m, group_size):
y_q_local_fp8[i, j] = T.if_then_else(row * blk_m + i < batch_sizes[bg],
y_q_local[i, j], 0)
y_q_local_fp8[i, j] = T.if_then_else(row * blk_m + i < batch_sizes[bg], y_q_local[i, j], 0)
for i in T.Parallel(blk_m):
X_amax[bg, row * blk_m + i, row_g_id] = y_s_local[i]
T.copy(
y_q_local_fp8, X_fp8[bg, row * blk_m:(row + 1) * blk_m,
row_g_id * group_size:(row_g_id + 1) * group_size])
T.copy(y_q_local_fp8, X_fp8[bg, row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size])
return group_per_split_token_cast
......@@ -127,8 +125,7 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
return x.squeeze(0) if remove_dim else x
# Normal layout requires transposing
aligned_x = torch.transpose(
torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
aligned_x[:, :m, :] = x
aligned_x = aligned_x[:, :m, :]
return aligned_x.squeeze(0) if remove_dim else aligned_x
......@@ -146,15 +143,17 @@ def ref_per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
x_fp8 = x_fp8.view(m, -1)[:, :n].contiguous()
return x_fp8, (x_amax / 448.0).view(m, -1)
def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \
Tuple[torch.Tensor, torch.Tensor]:
def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# assert x.shape[0] == batch_sizes.sum()
M_max = ceil_div(batch_sizes.max(), 128) * 128
split_x = torch.split(x, batch_sizes.tolist(), dim=0)
padded_x = [torch.nn.functional.pad(t, (0, 0, 0, M_max - t.shape[0])) for t in split_x]
num_groups, m, n = batch_sizes.shape[0], M_max, x.shape[1]
x_fp8 = (torch.empty((num_groups, m, n), device='cuda', dtype=torch.float8_e4m3fn),
torch.empty((num_groups, m, n // 128), device='cuda', dtype=torch.float))
x_fp8 = (
torch.empty((num_groups, m, n), device="cuda", dtype=torch.float8_e4m3fn),
torch.empty((num_groups, m, n // 128), device="cuda", dtype=torch.float),
)
for i in range(num_groups):
x_fp8[0][i], x_fp8[1][i] = ref_per_token_cast_to_fp8(padded_x[i])
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
......
......@@ -13,8 +13,9 @@ def per_token_cast_to_fp8(M, N, blk_m):
fp8_max = 448.0
@T.prim_func
def per_token_cast(X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "float8_e4m3"),
X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype)):
def per_token_cast(
X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "float8_e4m3"), X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype)
):
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (bx, by):
row = bx
row_g_id = by
......@@ -24,16 +25,13 @@ def per_token_cast_to_fp8(M, N, blk_m):
y_q_local = T.alloc_fragment((blk_m, group_size), dtype)
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3")
T.annotate_layout({
y_local:
T.Fragment(
y_local.shape,
forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32),
})
T.annotate_layout(
{
y_local: T.Fragment(y_local.shape, forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32),
}
)
T.copy(
X[row * blk_m:(row + 1) * blk_m, row_g_id * group_size:(row_g_id + 1) * group_size],
y_local)
T.copy(X[row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size], y_local)
T.reduce_absmax(y_local, y_amax_local, dim=1)
for i in T.Parallel(blk_m):
y_amax_local[i] = T.max(y_amax_local[i], 1e-4)
......@@ -43,9 +41,7 @@ def per_token_cast_to_fp8(M, N, blk_m):
T.copy(y_q_local, y_q_local_fp8)
for i in T.Parallel(blk_m):
X_amax[row * blk_m + i, row_g_id] = y_s_local[i]
T.copy(
y_q_local_fp8, X_fp8[row * blk_m:(row + 1) * blk_m,
row_g_id * group_size:(row_g_id + 1) * group_size])
T.copy(y_q_local_fp8, X_fp8[row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size])
return per_token_cast
......@@ -105,8 +101,7 @@ def main(M=8192, N=8192, blk_m=8):
from example_triton_cast_to_fp8 import per_token_group_quant_fp8
def run_triton():
x_fp8_triton_, x_amax_triton_ = per_token_group_quant_fp8(
x, 128, 1e-4, dtype=torch.float8_e4m3fn, column_major_scales=False)
x_fp8_triton_, x_amax_triton_ = per_token_group_quant_fp8(x, 128, 1e-4, dtype=torch.float8_e4m3fn, column_major_scales=False)
return x_fp8_triton_, x_amax_triton_
x_fp8_triton, x_amax_triton = run_triton()
......
......@@ -128,9 +128,7 @@ def per_token_group_quant_fp8(
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
assert (x.shape[-1] %
group_size == 0), (f"the last dimension of `x` {x.shape[-1]} must be divisible "
f"by `group_size` {group_size}")
assert x.shape[-1] % group_size == 0, f"the last dimension of `x` {x.shape[-1]} must be divisible by `group_size` {group_size}"
assert x.stride(-1) == 1, "`x` groups must be contiguous"
finfo = torch.finfo(dtype)
......
......@@ -4,8 +4,7 @@ import example_per_token_cast_to_fp8
def test_example_group_per_split_token_cast_to_fp8():
example_group_per_split_token_cast_to_fp8.main(
M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896])
example_group_per_split_token_cast_to_fp8.main(M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896])
def test_example_per_token_cast_to_fp8():
......
......@@ -4,12 +4,11 @@ import tilelang.language as T
# @tilelang.jit(compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
......@@ -36,8 +35,7 @@ block_K = 32
func = matmul(M, N, K, block_M, block_N, block_K)
jit_kernel = tilelang.compile(
func, out_idx=[2], target="cuda", compile_flags="-O3 --use_fast_math --expt-relaxed-constexpr")
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags="-O3 --use_fast_math --expt-relaxed-constexpr")
# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"])
# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3 --use_fast_math --expt-relaxed-constexpr"])
......
......@@ -33,12 +33,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
"warnings",
"error",
}
if (sum(
len(terminalreporter.stats.get(k, []))
for k in known_types.difference({"skipped", "deselected"})) == 0):
if sum(len(terminalreporter.stats.get(k, [])) for k in known_types.difference({"skipped", "deselected"})) == 0:
terminalreporter.write_sep(
"!",
(f"Error: No tests were collected. "
f"{dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"),
(f"Error: No tests were collected. {dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"),
)
pytest.exit("No tests were collected.", returncode=5)
......@@ -14,7 +14,6 @@ def check_hopper():
def ref_program(stride, padding, dilation):
def main(A, B):
A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W
B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W
......@@ -26,22 +25,7 @@ def ref_program(stride, padding, dilation):
@tilelang.jit(out_idx=[2])
def convolution(N,
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages,
threads,
dtype="float16",
accum_dtype="float"):
def convolution(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
......@@ -51,13 +35,11 @@ def convolution(N,
@T.prim_func
def main(
data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
):
with T.Kernel(
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=threads) as (bx, by):
with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......@@ -66,11 +48,13 @@ def convolution(N,
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
T.annotate_layout({
out_shared: tilelang.layout.make_swizzled_layout(out_shared),
data_shared: tilelang.layout.make_swizzled_layout(data_shared),
kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
})
T.annotate_layout(
{
out_shared: tilelang.layout.make_swizzled_layout(out_shared),
data_shared: tilelang.layout.make_swizzled_layout(data_shared),
kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
}
)
T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
......@@ -82,10 +66,8 @@ def convolution(N,
m = by * block_M + i
access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
access_w = m % OW * S + k // C % KW * D - P
in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and
(access_w < W))
data_shared[i, j] = T.if_then_else(
in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W)
data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local)
......@@ -97,15 +79,15 @@ def convolution(N,
def main(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument('--n', type=int, default=128, help='n')
parser.add_argument('--c', type=int, default=128, help='c')
parser.add_argument('--h', type=int, default=64, help='h')
parser.add_argument('--w', type=int, default=64, help='w')
parser.add_argument('--f', type=int, default=128, help='f')
parser.add_argument('--k', type=int, default=3, help='k')
parser.add_argument('--s', type=int, default=1, help='s')
parser.add_argument('--d', type=int, default=1, help='d')
parser.add_argument('--p', type=int, default=1, help='p')
parser.add_argument("--n", type=int, default=128, help="n")
parser.add_argument("--c", type=int, default=128, help="c")
parser.add_argument("--h", type=int, default=64, help="h")
parser.add_argument("--w", type=int, default=64, help="w")
parser.add_argument("--f", type=int, default=128, help="f")
parser.add_argument("--k", type=int, default=3, help="k")
parser.add_argument("--s", type=int, default=1, help="s")
parser.add_argument("--d", type=int, default=1, help="d")
parser.add_argument("--p", type=int, default=1, help="p")
args = parser.parse_args(argv)
N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p
......
......@@ -14,7 +14,6 @@ def check_hopper():
def ref_program(stride, padding, dilation):
def main(A, B):
A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W
B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W
......@@ -40,7 +39,8 @@ def get_configs():
num_stages,
thread_num,
enable_rasterization,
))
)
)
configs = [
{
......@@ -50,7 +50,8 @@ def get_configs():
"num_stages": c[3],
"thread_num": c[4],
"enable_rasteration": c[5], # keep param name for backward-compat
} for c in _configs
}
for c in _configs
]
return configs
......@@ -64,53 +65,18 @@ def get_heuristic_config() -> dict:
sm_version = sm_major * 10 + sm_minor
print(f"CUDA device capability: {sm_version}")
if sm_version in {80}:
return {
"block_M": 128,
"block_N": 256,
"block_K": 32,
"num_stages": 2,
"thread_num": 128,
"enable_rasteration": True
}
return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 2, "thread_num": 128, "enable_rasteration": True}
elif sm_version in {90}:
return {
"block_M": 128,
"block_N": 256,
"block_K": 64,
"num_stages": 3,
"thread_num": 256,
"enable_rasteration": True
}
return {"block_M": 128, "block_N": 256, "block_K": 64, "num_stages": 3, "thread_num": 256, "enable_rasteration": True}
else:
return {
"block_M": 128,
"block_N": 256,
"block_K": 32,
"num_stages": 0,
"thread_num": 128,
"enable_rasteration": True
}
return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 0, "thread_num": 128, "enable_rasteration": True}
@tilelang.autotune(configs=get_configs())
@tilelang.jit(out_idx=[2])
def convolution(N,
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float"):
def convolution(
N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype="float16", accum_dtype="float"
):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
......@@ -120,13 +86,11 @@ def convolution(N,
@T.prim_func
def main(
data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
):
with T.Kernel(
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=thread_num) as (bx, by):
with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=thread_num) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......@@ -136,9 +100,11 @@ def convolution(N,
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
if is_hopper:
T.annotate_layout({
out_shared: tilelang.layout.make_swizzled_layout(out_shared),
})
T.annotate_layout(
{
out_shared: tilelang.layout.make_swizzled_layout(out_shared),
}
)
T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
......@@ -150,10 +116,8 @@ def convolution(N,
m = by * block_M + i
access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
access_w = m % OW * S + k // C % KW * D - P
in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and
(access_w < W))
data_shared[i, j] = T.if_then_else(
in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W)
data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local)
......@@ -166,17 +130,19 @@ def convolution(N,
return main
def main(n: int = 128,
c: int = 128,
h: int = 64,
w: int = 64,
f: int = 128,
k: int = 3,
s: int = 1,
d: int = 1,
p: int = 1,
use_autotune: bool = False,
with_roller: bool = True):
def main(
n: int = 128,
c: int = 128,
h: int = 64,
w: int = 64,
f: int = 128,
k: int = 3,
s: int = 1,
d: int = 1,
p: int = 1,
use_autotune: bool = False,
with_roller: bool = True,
):
N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p
ref_prog = ref_program(S, P, D)
......@@ -196,25 +162,16 @@ def main(n: int = 128,
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument('--n', type=int, default=128, help='n')
parser.add_argument('--c', type=int, default=128, help='c')
parser.add_argument('--h', type=int, default=64, help='h')
parser.add_argument('--w', type=int, default=64, help='w')
parser.add_argument('--f', type=int, default=128, help='f')
parser.add_argument('--k', type=int, default=3, help='k')
parser.add_argument('--s', type=int, default=1, help='s')
parser.add_argument('--d', type=int, default=1, help='d')
parser.add_argument('--p', type=int, default=1, help='p')
parser.add_argument(
"--use_autotune",
action="store_true",
default=False,
help="Whether to use autotune for matmul configs")
parser.add_argument(
"--with_roller",
action="store_true",
default=True,
help="Whether to enable BitBLAS roller for search space")
parser.add_argument("--n", type=int, default=128, help="n")
parser.add_argument("--c", type=int, default=128, help="c")
parser.add_argument("--h", type=int, default=64, help="h")
parser.add_argument("--w", type=int, default=64, help="w")
parser.add_argument("--f", type=int, default=128, help="f")
parser.add_argument("--k", type=int, default=3, help="k")
parser.add_argument("--s", type=int, default=1, help="s")
parser.add_argument("--d", type=int, default=1, help="d")
parser.add_argument("--p", type=int, default=1, help="p")
parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs")
parser.add_argument("--with_roller", action="store_true", default=True, help="Whether to enable BitBLAS roller for search space")
args = parser.parse_args()
main(args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p, args.use_autotune,
args.with_roller)
main(args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p, args.use_autotune, args.with_roller)
......@@ -41,14 +41,13 @@ def tl_gemm(
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
scales_a: T.Tensor(Scales_A_shape, "float32"),
scales_b: T.Tensor(Scales_B_shape, "float32"),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
scales_a: T.Tensor(Scales_A_shape, "float32"),
scales_b: T.Tensor(Scales_B_shape, "float32"),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_shared = T.alloc_shared(C_shared_shape, out_dtype)
......@@ -93,21 +92,18 @@ def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(
m, n), (x_amax / 448.0).view(m, -1)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
ceildiv(m, 128) * 128, ceildiv(n, 128) * 128, dtype=x.dtype, device=x.device)
x_padded = torch.zeros(ceildiv(m, 128) * 128, ceildiv(n, 128) * 128, dtype=x.dtype, device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(
x_view.size(0), x_view.size(2))
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype):
......@@ -127,13 +123,14 @@ def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype):
c_acc.zero_()
for k in range(ceildiv(K, 128)):
c = torch._scaled_mm(
A_fp8[i * 128:(i + 1) * 128, k * 128:(k + 1) * 128],
B_fp8[j * 128:(j + 1) * 128, k * 128:(k + 1) * 128].T,
A_fp8[i * 128 : (i + 1) * 128, k * 128 : (k + 1) * 128],
B_fp8[j * 128 : (j + 1) * 128, k * 128 : (k + 1) * 128].T,
scale_a=A_scales[i, k].view(128, 1).contiguous(),
scale_b=B_scales[j, k].view(1, 128).contiguous(),
out_dtype=torch.bfloat16)
out_dtype=torch.bfloat16,
)
c_acc += c.to(torch.float32)
C[i * 128:(i + 1) * 128, j * 128:(j + 1) * 128] = c_acc.to(out_dtype)
C[i * 128 : (i + 1) * 128, j * 128 : (j + 1) * 128] = c_acc.to(out_dtype)
return C
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment