"examples/model_compress/vscode:/vscode.git/clone" did not exist on "b122c63df2d03da2c4f6a6947ed0afdacad713a3"
Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
...@@ -3,8 +3,7 @@ from typing import Dict, List, Tuple ...@@ -3,8 +3,7 @@ from typing import Dict, List, Tuple
TokensText = Tuple[List[int], str] TokensText = Tuple[List[int], str]
def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[TokensText], def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[TokensText], name_0: str, name_1: str):
name_0: str, name_1: str):
""" """
Compare the two sequences generated by different models, Compare the two sequences generated by different models,
which should be equal. which should be equal.
...@@ -15,19 +14,14 @@ def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[Tok ...@@ -15,19 +14,14 @@ def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[Tok
output_ids_0, output_str_0 = outputs_0 output_ids_0, output_str_0 = outputs_0
output_ids_1, output_str_1 = outputs_1 output_ids_1, output_str_1 = outputs_1
assert output_str_0 == output_str_1, (f"Test{prompt_idx}:" assert output_str_0 == output_str_1, f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}"
f"\n{name_0}:\t{output_str_0!r}" assert output_ids_0 == output_ids_1, f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}"
f"\n{name_1}:\t{output_str_1!r}")
assert output_ids_0 == output_ids_1, (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]] TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]]
def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], outputs_1_lst: List[TokensTextLogprobs], name_0: str, name_1: str):
outputs_1_lst: List[TokensTextLogprobs], name_0: str, name_1: str):
""" """
Compare the logprobs of two sequences generated by different models, Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal. which should be similar but not necessarily equal.
...@@ -41,16 +35,11 @@ def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], ...@@ -41,16 +35,11 @@ def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs],
# Loop through generated tokens. # Loop through generated tokens.
for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
# If generated tokens don't match, then # If generated tokens don't match, then
if output_id_0 != output_id_1: if output_id_0 != output_id_1:
# Each predicted token must be in top N logprobs of the other # Each predicted token must be in top N logprobs of the other
assert output_id_0 in logprobs_1[idx], (f"Test{prompt_idx}:" assert output_id_0 in logprobs_1[idx], f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}"
f"\n{name_0}:\t{output_str_0!r}" assert output_id_1 in logprobs_0[idx], f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}"
f"\n{name_1}:\t{output_str_1!r}")
assert output_id_1 in logprobs_0[idx], (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
# Break out since sequences will now diverge. # Break out since sequences will now diverge.
break break
...@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): ...@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK # N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
False,
dtype=torch.bool,
device=x.device)
dense_mask.scatter_(-1, sparse_index, True) dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block: if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True dense_mask[:, :, -2:, :] = True
...@@ -56,7 +53,6 @@ def _fwd_kernel_inner( ...@@ -56,7 +53,6 @@ def _fwd_kernel_inner(
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
): ):
mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n)
# print # print
...@@ -73,8 +69,7 @@ def _fwd_kernel_inner( ...@@ -73,8 +69,7 @@ def _fwd_kernel_inner(
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
if LAST_K_BLOCK: if LAST_K_BLOCK:
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf"))
float('-inf'))
m_ij = tl.maximum(m_i, tl.max(qk, 1)) m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None] qk -= m_ij[:, None]
...@@ -154,7 +149,7 @@ def _fwd_kernel( ...@@ -154,7 +149,7 @@ def _fwd_kernel(
v_ptrs = V + off_v v_ptrs = V + off_v
mask_ptrs = block_mask_ptr + start_m * stride_bmm mask_ptrs = block_mask_ptr + start_m * stride_bmm
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
...@@ -192,24 +187,12 @@ def _fwd_kernel( ...@@ -192,24 +187,12 @@ def _fwd_kernel(
acc = acc * l_recip acc = acc * l_recip
acc = acc.to(Out.dtype.element_ty) acc = acc.to(Out.dtype.element_ty)
off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[ off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od
None, :] * stride_od
out_ptrs = Out + off_o out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX)
def _forward(ctx, def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None):
q,
k,
v,
block_sparse_mask,
sm_scale,
BLOCK_M=64,
BLOCK_N=64,
num_warps=None,
num_stages=1,
out=None):
assert q.shape[-1] == k.shape[-1] == v.shape[-1] assert q.shape[-1] == k.shape[-1] == v.shape[-1]
assert k.shape[2] == v.shape[2] assert k.shape[2] == v.shape[2]
o = out if out is not None else torch.empty_like(q).contiguous() o = out if out is not None else torch.empty_like(q).contiguous()
...@@ -254,7 +237,6 @@ def _forward(ctx, ...@@ -254,7 +237,6 @@ def _forward(ctx,
class _sparse_attention(torch.autograd.Function): class _sparse_attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, block_sparse_dense, sm_scale): def forward(ctx, q, k, v, block_sparse_dense, sm_scale):
# shape constraints # shape constraints
...@@ -278,9 +260,9 @@ def test_topk_sparse_attention(): ...@@ -278,9 +260,9 @@ def test_topk_sparse_attention():
torch.manual_seed(0) torch.manual_seed(0)
# Create inputs # Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
sm_scale = 1.0 / (D_HEAD**0.5) sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level) # Create sparse mask (downsampled to block level)
...@@ -288,9 +270,7 @@ def test_topk_sparse_attention(): ...@@ -288,9 +270,7 @@ def test_topk_sparse_attention():
downsample_len = math.ceil(SEQ_LEN / downsample_factor) downsample_len = math.ceil(SEQ_LEN / downsample_factor)
print("downsample_len", downsample_len) print("downsample_len", downsample_len)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16)
device='cuda',
dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100 x_ds[:, :, :, 0] = 100
print("x_ds.shape", x_ds.shape) print("x_ds.shape", x_ds.shape)
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
...@@ -302,22 +282,21 @@ def test_topk_sparse_attention(): ...@@ -302,22 +282,21 @@ def test_topk_sparse_attention():
# Compute reference # Compute reference
# Expand block mask to full attention matrix # Expand block mask to full attention matrix
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda"))
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal
# PyTorch reference implementation # PyTorch reference implementation
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float('-inf')) attn = attn.masked_fill(~full_mask, float("-inf"))
attn = F.softmax(attn, dim=-1) attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
# print("ref_output", ref_output) # print("ref_output", ref_output)
# print("triton_output", triton_output) # print("triton_output", triton_output)
# Verify accuracy # Verify accuracy
assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference"
"Triton output doesn't match reference"
print("Pass topk sparse attention test with qlen == klen") print("Pass topk sparse attention test with qlen == klen")
...@@ -329,9 +308,9 @@ def test_topk_sparse_attention_qlt_kl(): ...@@ -329,9 +308,9 @@ def test_topk_sparse_attention_qlt_kl():
torch.manual_seed(0) torch.manual_seed(0)
# Create inputs. # Create inputs.
q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
# softmax scale # softmax scale
sm_scale = 1.0 / (D_HEAD**0.5) sm_scale = 1.0 / (D_HEAD**0.5)
...@@ -339,8 +318,7 @@ def test_topk_sparse_attention_qlt_kl(): ...@@ -339,8 +318,7 @@ def test_topk_sparse_attention_qlt_kl():
print("downsample_factor", downsample_factor) print("downsample_factor", downsample_factor)
downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension
print("downsample_len", downsample_len) print("downsample_len", downsample_len)
x_ds = torch.randn( x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.bfloat16)
BATCH, N_HEADS, downsample_len, downsample_len, device='cuda', dtype=torch.bfloat16)
# Force the first column to be high so that the first block is always selected. # Force the first column to be high so that the first block is always selected.
x_ds[:, :, :, 0] = 100 x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
...@@ -351,26 +329,25 @@ def test_topk_sparse_attention_qlt_kl(): ...@@ -351,26 +329,25 @@ def test_topk_sparse_attention_qlt_kl():
past_len = K_LEN - Q_LEN past_len = K_LEN - Q_LEN
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool() full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool()
full_mask_full = full_mask_full[..., :K_LEN, :K_LEN] full_mask_full = full_mask_full[..., :K_LEN, :K_LEN]
effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN)
i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1)
j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN)
causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN) causal_mask = j_global <= i_global # shape: (Q_LEN, K_LEN)
final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN) final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN)
attn = attn.masked_fill(~final_mask, float('-inf')) attn = attn.masked_fill(~final_mask, float("-inf"))
attn = F.softmax(attn, dim=-1) attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
# Verify accuracy. # Verify accuracy.
assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference when qlen < klen"
"Triton output doesn't match reference when qlen < klen"
print("Pass topk sparse attention test with qlen < klen") print("Pass topk sparse attention test with qlen < klen")
......
...@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): ...@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK # N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
False,
dtype=torch.bool,
device=x.device)
dense_mask.scatter_(-1, sparse_index, True) dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block: if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True dense_mask[:, :, -2:, :] = True
...@@ -30,15 +27,17 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F ...@@ -30,15 +27,17 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
@tilelang.jit( @tilelang.jit(
out_idx=[4], pass_configs={ out_idx=[4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal): def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal):
block_M = 64 block_M = 64
block_N = 64 block_N = 64
num_stages = 1 num_stages = 1
threads = 128 threads = 128
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
block_mask_shape = [batch, heads, downsample_len, downsample_len] block_mask_shape = [batch, heads, downsample_len, downsample_len]
...@@ -47,7 +46,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -47,7 +46,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask_dtype = "bool" block_mask_dtype = "bool"
def kernel_func(block_M, block_N, num_stages, threads): def kernel_func(block_M, block_N, num_stages, threads):
@T.macro @T.macro
def MMA0( def MMA0(
K: T.Tensor(shape, dtype), K: T.Tensor(shape, dtype),
...@@ -59,11 +57,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -59,11 +57,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
-T.infinity(acc_s.dtype))
else: else:
T.clear(acc_s) T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
...@@ -78,7 +75,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -78,7 +75,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
...@@ -127,8 +124,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -127,8 +124,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype),
Output: T.Tensor(shape, dtype), Output: T.Tensor(shape, dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype)
...@@ -143,7 +139,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -143,7 +139,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
block_mask = T.alloc_local([downsample_len], block_mask_dtype) block_mask = T.alloc_local([downsample_len], block_mask_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -152,20 +148,19 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -152,20 +148,19 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask[vj] = BlockSparseMask[bz, by, bx, vj] block_mask[vj] = BlockSparseMask[bz, by, bx, vj]
loop_range = ( loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv( T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) )
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
if block_mask[k] != 0: if block_mask[k] != 0:
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
scores_sum, logsum)
Rescale(acc_o, scores_scale) Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return blocksparse_flashattn return blocksparse_flashattn
...@@ -180,18 +175,16 @@ def test_topk_sparse_attention(): ...@@ -180,18 +175,16 @@ def test_topk_sparse_attention():
torch.manual_seed(0) torch.manual_seed(0)
# Create inputs # Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
sm_scale = 1.0 / (D_HEAD**0.5) sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level) # Create sparse mask (downsampled to block level)
downsample_factor = BLOCK downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor) downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16)
device='cuda',
dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100 x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
...@@ -202,15 +195,15 @@ def test_topk_sparse_attention(): ...@@ -202,15 +195,15 @@ def test_topk_sparse_attention():
# Compute reference # Compute reference
# Expand block mask to full attention matrix # Expand block mask to full attention matrix
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda"))
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal
# PyTorch reference implementation # PyTorch reference implementation
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float('-inf')) attn = attn.masked_fill(~full_mask, float("-inf"))
attn = F.softmax(attn, dim=-1) attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
print("ref_output", ref_output) print("ref_output", ref_output)
print("tilelang_output", tilelang_output) print("tilelang_output", tilelang_output)
......
import math import math
def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, size_one_kv_head, def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local, max_splits):
is_causal_or_local, max_splits):
""" """
Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency. Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency.
......
...@@ -25,26 +25,14 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask(): ...@@ -25,26 +25,14 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask():
def test_example_triton_sparse_gqa_decode_varlen_indice(): def test_example_triton_sparse_gqa_decode_varlen_indice():
example_triton_sparse_gqa_decode_varlen_indice.main( example_triton_sparse_gqa_decode_varlen_indice.main(
batch=8, batch=8, heads=8, heads_kv=4, max_cache_seqlen=2048, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32
heads=8, )
heads_kv=4,
max_cache_seqlen=2048,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32)
def test_example_triton_sparse_gqa_decode_varlen_mask(): def test_example_triton_sparse_gqa_decode_varlen_mask():
example_triton_sparse_gqa_decode_varlen_mask.main( example_triton_sparse_gqa_decode_varlen_mask.main(
batch=16, batch=16, heads=16, heads_kv=8, max_cache_seqlen=1024, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32
heads=16, )
heads_kv=8,
max_cache_seqlen=1024,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -128,9 +128,7 @@ def per_token_group_quant_fp8( ...@@ -128,9 +128,7 @@ def per_token_group_quant_fp8(
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization. scaling factor for quantization.
""" """
assert (x.shape[-1] % assert x.shape[-1] % group_size == 0, f"the last dimension of `x` {x.shape[-1]} must be divisible by `group_size` {group_size}"
group_size == 0), (f"the last dimension of `x` {x.shape[-1]} must be divisible "
f"by `group_size` {group_size}")
assert x.stride(-1) == 1, "`x` groups must be contiguous" assert x.stride(-1) == 1, "`x` groups must be contiguous"
finfo = torch.finfo(dtype) finfo = torch.finfo(dtype)
......
...@@ -4,8 +4,7 @@ import example_per_token_cast_to_fp8 ...@@ -4,8 +4,7 @@ import example_per_token_cast_to_fp8
def test_example_group_per_split_token_cast_to_fp8(): def test_example_group_per_split_token_cast_to_fp8():
example_group_per_split_token_cast_to_fp8.main( example_group_per_split_token_cast_to_fp8.main(M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896])
M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896])
def test_example_per_token_cast_to_fp8(): def test_example_per_token_cast_to_fp8():
......
...@@ -4,7 +4,6 @@ import tilelang.language as T ...@@ -4,7 +4,6 @@ import tilelang.language as T
# @tilelang.jit(compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"]) # @tilelang.jit(compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
...@@ -36,8 +35,7 @@ block_K = 32 ...@@ -36,8 +35,7 @@ block_K = 32
func = matmul(M, N, K, block_M, block_N, block_K) func = matmul(M, N, K, block_M, block_N, block_K)
jit_kernel = tilelang.compile( jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags="-O3 --use_fast_math --expt-relaxed-constexpr")
func, out_idx=[2], target="cuda", compile_flags="-O3 --use_fast_math --expt-relaxed-constexpr")
# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"]) # or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"])
# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3 --use_fast_math --expt-relaxed-constexpr"]) # or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3 --use_fast_math --expt-relaxed-constexpr"])
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment