diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1e5bab47bf37b65c3a6994c17cc61ddb9ea773bf..d1bb4ceeb14ea62031c3397781549e12f7cdc8d4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,19 +39,9 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.14.7 # sync with requirements-lint.txt hooks: + - id: ruff-format - id: ruff-check args: [--fix, --exit-non-zero-on-fix] - - repo: https://github.com/google/yapf - rev: v0.43.0 # sync with requirements-lint.txt - hooks: - - id: yapf - name: yapf-multiproc-bugfix - # yapf is not multiprocess safe, so we run a dummy yapf first. - args: [--in-place, docs/conf.py] - always_run: true - pass_filenames: false - - id: yapf - args: [--recursive, --in-place] - repo: https://github.com/codespell-project/codespell rev: v2.4.1 # sync with requirements-lint.txt hooks: @@ -62,4 +52,4 @@ repos: ^.+\.(cpp|hpp|cxx|cc|c|h|cu|cuh)$| ^.+\.svg$| ^.*\brequirements\b.*\.txt$ - ) \ No newline at end of file + ) diff --git a/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py b/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py index 6401276ac08c07469d8b25fb60b41580221cf308..3dd82aa5e5218cf37379ed69a2ff93ba1020c199 100644 --- a/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py @@ -7,10 +7,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -28,15 +25,15 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F def benchmark_topk_sparse_attention(): from benchmark_configs import configs + torch.manual_seed(0) # Config for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: - # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) import flash_attn diff --git a/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py index 7c9edb59514b0fb55b76811a8102ec893c16e2ca..fff65b44f6aaa2c78468d1c11a8ec9ff08832597 100644 --- a/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py @@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -39,7 +36,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) block_N = 64 num_stages = 2 threads = 128 - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] block_mask_shape = [batch, heads, downsample_len, downsample_len] @@ -48,7 +45,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) block_mask_dtype = "bool" def kernel_func(block_M, block_N, num_stages, threads): - @T.macro def MMA0( K: T.Tensor(shape, dtype), @@ -60,11 +56,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) by: T.int32, bz: T.int32, ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -79,18 +74,18 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) by: T.int32, bz: T.int32, ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -116,22 +111,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(shape, dtype), - K: T.Tensor(shape, dtype), - V: T.Tensor(shape, dtype), - BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), - Output: T.Tensor(shape, dtype), + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), + Output: T.Tensor(shape, dtype), ): - with T.Kernel( - T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -146,7 +140,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) logsum = T.alloc_fragment([block_M], accum_dtype) block_mask = T.alloc_local([downsample_len], block_mask_dtype) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -155,20 +149,19 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) block_mask[vj] = BlockSparseMask[bz, by, bx, vj] loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): if block_mask[k]: MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, - scores_sum, logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main @@ -177,26 +170,23 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) def benchmark_topk_sparse_attention(): from benchmark_configs import configs + torch.manual_seed(0) # Config for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: - # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) # Create sparse mask (downsampled to block level) downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) - program = blocksparse_flashattn( - BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + program = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) kernel = tilelang.compile(program, out_idx=4) def benchmark_fn(): diff --git a/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py index e4828ce5f6a44f9ed86ce4d9a9446d047bee8814..85d754ae3a77f679a9d714eb1c2f83573a47c911 100644 --- a/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py @@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -31,39 +28,37 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F def benchmark_topk_sparse_attention(): from benchmark_configs import configs + torch.manual_seed(0) # Config for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: - # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) def benchmark_fn(): # Compute reference # Expand block mask to full attention matrix - full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal # PyTorch reference implementation - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale - attn = attn.masked_fill(~full_mask, float('-inf')) + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) return ref_output ref_latency = do_bench( diff --git a/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py index 86ac894bc7197669c77f95bf96a70f173166b051..7ebca93a6a3735ac0fae1cdcd706587d3521e5b6 100644 --- a/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py @@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -56,7 +53,6 @@ def _fwd_kernel_inner( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): - mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) if mask_val == True: @@ -72,8 +68,7 @@ def _fwd_kernel_inner( # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N if LAST_K_BLOCK: - qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, - float('-inf')) + qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf")) m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_ij[:, None] @@ -153,7 +148,7 @@ def _fwd_kernel( v_ptrs = V + off_v mask_ptrs = block_mask_ptr + start_m * stride_bmm - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) @@ -191,24 +186,12 @@ def _fwd_kernel( acc = acc * l_recip acc = acc.to(Out.dtype.element_ty) - off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[ - None, :] * stride_od + off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) -def _forward(ctx, - q, - k, - v, - block_sparse_mask, - sm_scale, - BLOCK_M=64, - BLOCK_N=64, - num_warps=None, - num_stages=1, - out=None): - +def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None): assert q.shape[-1] == k.shape[-1] == v.shape[-1] assert k.shape[2] == v.shape[2] o = out if out is not None else torch.empty_like(q).contiguous() @@ -253,7 +236,6 @@ def _forward(ctx, class _sparse_attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, block_sparse_dense, sm_scale): # shape constraints @@ -271,24 +253,22 @@ block_sparse_triton_fn = _sparse_attention.apply def benchmark_topk_sparse_attention(): from benchmark_configs import configs + torch.manual_seed(0) # Config for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: - # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) diff --git a/benchmark/mamba2/benchmark_mamba_chunk_scan.py b/benchmark/mamba2/benchmark_mamba_chunk_scan.py index aff810f660ceeb0716bf07300c3b3801ecbaea9e..a3ed72b1dcb1ae2251ddc580a7dbc215cf30aa25 100644 --- a/benchmark/mamba2/benchmark_mamba_chunk_scan.py +++ b/benchmark/mamba2/benchmark_mamba_chunk_scan.py @@ -51,14 +51,15 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] decay = torch.exp(dt_segment_sum) scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s") - causal_mask = torch.tril( - torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) + causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) scores_decay = scores_decay.masked_fill(~causal_mask, 0) - out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype), - rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)) + out = torch.einsum( + "bchls,bhcs,bcshp->bclhp", scores_decay.to(x.dtype), dt.to(x.dtype), rearrange(x, "b (c s) h p -> b c s h p", c=nchunks) + ) state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) - out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange( - C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out + out_prev = ( + torch.einsum("bclhn,bchpn->bclhp", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out + ) out = out + out_prev out = rearrange(out, "b c l h p -> b (c l) h p") if D is not None: @@ -74,7 +75,6 @@ def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D): def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): - @helion.kernel() def helion_mamba2_chunk_scan_kernel( cb: torch.Tensor, @@ -118,8 +118,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): dtype = cb.dtype accum_dtype = torch.float32 - assert (x.dtype == dt.dtype == dA_cumsum.dtype == C.dtype == prev_states.dtype == D.dtype == - dtype) + assert x.dtype == dt.dtype == dA_cumsum.dtype == C.dtype == prev_states.dtype == D.dtype == dtype out = torch.empty_like(x) @@ -127,11 +126,10 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): for tile_h, tile_m, tile_n, tile_b, tile_c in hl.tile( [nheads, chunk_size, headdim, batch, nchunks], - block_size=[1, block_m, block_n, 1, 1], + block_size=[1, block_m, block_n, 1, 1], ): acc_o = hl.zeros([tile_m, tile_n], dtype=accum_dtype) - dA_cumsum_local_m = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, - tile_m].to(torch.float32) + dA_cumsum_local_m = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, tile_m].to(torch.float32) scale_m_local = torch.exp2(dA_cumsum_local_m * p) C_local = C[ @@ -152,10 +150,8 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): tile_m, tile_k, ] - dA_cumsum_local_k = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, - tile_k].to(torch.float32) - cb_local *= torch.exp2(dA_cumsum_local_m[:, None] * p - - dA_cumsum_local_k[None, :] * p) + dA_cumsum_local_k = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(torch.float32) + cb_local *= torch.exp2(dA_cumsum_local_m[:, None] * p - dA_cumsum_local_k[None, :] * p) dt_local = dt[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(torch.float32) cb_local = (cb_local * dt_local[None, :]).to(dtype) pred = (tile_m.index + 0)[:, None] >= (tile_k.index + 0)[None, :] @@ -169,11 +165,9 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): acc_o = hl.dot(cb_local, x_local, acc=acc_o) D_local = D[tile_h.begin].to(torch.float32) - x_residual = x[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, - tile_n].to(torch.float32) + x_residual = x[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n].to(torch.float32) acc_o += x_residual * D_local - out[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, - tile_n] = acc_o.to(dtype=dtype) + out[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n] = acc_o.to(dtype=dtype) return out @@ -182,12 +176,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): def get_configs(): - iter_params = dict( - block_M=[64, 128, 256], - block_N=[32, 64], - block_K=[64, 128, 256], - block_Dstate=[128], - num_stages=[1, 2, 3, 4, 5]) + iter_params = dict(block_M=[64, 128, 256], block_N=[32, 64], block_K=[64, 128, 256], block_Dstate=[128], num_stages=[1, 2, 3, 4, 5]) return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] @@ -198,19 +187,21 @@ def get_configs(): tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }, ) -def chunk_scan_fwd(batch, - seqlen, - chunk_size, - ngroups, - nheads, - headdim, - dstate, - block_M=64, - block_N=64, - block_K=64, - block_Dstate=128, - num_stages=2, - threads=128): +def chunk_scan_fwd( + batch, + seqlen, + chunk_size, + ngroups, + nheads, + headdim, + dstate, + block_M=64, + block_N=64, + block_K=64, + block_Dstate=128, + num_stages=2, + threads=128, +): dtype = "float16" accum_dtype = "float" nchunks = T.ceildiv(seqlen, chunk_size) @@ -218,20 +209,20 @@ def chunk_scan_fwd(batch, @T.prim_func def main( - cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore - x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore - dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore - dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore - C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore - prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore - D: T.Tensor((nheads), dtype), # type: ignore - Output: T.Tensor((batch, seqlen, nheads, headdim), dtype) # type: ignore + cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore + x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore + dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore + dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore + C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore + prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore + D: T.Tensor((nheads), dtype), # type: ignore + Output: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore ): - with T.Kernel( - nheads, - T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), - batch * nchunks, - threads=threads) as (bz, bx, by): + with T.Kernel(nheads, T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), batch * nchunks, threads=threads) as ( + bz, + bx, + by, + ): acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) acc_o_shared = T.alloc_shared((block_M, block_N), dtype) cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn") @@ -257,27 +248,32 @@ def chunk_scan_fwd(batch, m_idx = bx // T.ceildiv(headdim, block_N) n_idx = bx % T.ceildiv(headdim, block_N) - T.annotate_layout({ - acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared), - cb_shared: tilelang.layout.make_swizzled_layout(cb_shared), - x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared) - }) + T.annotate_layout( + { + acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared), + cb_shared: tilelang.layout.make_swizzled_layout(cb_shared), + x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared), + } + ) T.no_set_max_nreg() - T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M], - dA_cs_m_shared) + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M : (m_idx + 1) * block_M], dA_cs_m_shared) T.copy(dA_cs_m_shared, dA_cs_m_local) T.clear(acc_o) for i in T.Parallel(block_M): scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p) T.copy( - C[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + - (m_idx + 1) * block_M, bz // (nheads // ngroups), 0:block_Dstate], C_shared) - T.copy( - prev_states[batch_idx, chunk_idx, bz, n_idx * block_N:(n_idx + 1) * block_N, - 0:block_Dstate], prev_state_shared) + C[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz // (nheads // ngroups), + 0:block_Dstate, + ], + C_shared, + ) + T.copy(prev_states[batch_idx, chunk_idx, bz, n_idx * block_N : (n_idx + 1) * block_N, 0:block_Dstate], prev_state_shared) T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True) for i, j in T.Parallel(block_M, block_N): acc_o[i, j] *= scale_m_local[i] @@ -286,34 +282,47 @@ def chunk_scan_fwd(batch, for k in T.Pipelined(loop_range, num_stages=num_stages): T.copy( - cb[batch_idx, chunk_idx, bz // (nheads // ngroups), - m_idx * block_M:(m_idx + 1) * block_M, k * block_K:(k + 1) * block_K], - cb_shared) + cb[ + batch_idx, + chunk_idx, + bz // (nheads // ngroups), + m_idx * block_M : (m_idx + 1) * block_M, + k * block_K : (k + 1) * block_K, + ], + cb_shared, + ) T.copy(cb_shared, cb_local) - T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], - dA_cs_k_shared) + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cs_k_shared) T.copy(dA_cs_k_shared, dA_cs_k_local) for i, j in T.Parallel(block_M, block_K): - cb_local[i, - j] = cb_local[i, - j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p) - T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared) + cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p) + T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared) T.copy(dt_shared, dt_local) for i, j in T.Parallel(block_M, block_K): cb_local[i, j] *= dt_local[j] for i, j in T.Parallel(block_M, block_K): - cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, - cb_local[i, j], 0) + cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, cb_local[i, j], 0) T.copy( - x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + - (k + 1) * block_K, bz, n_idx * block_N:(n_idx + 1) * block_N], x_shared) + x[ + batch_idx, + chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + x_shared, + ) T.gemm(cb_local, x_shared, acc_o) D_local[0] = D[bz] T.copy( - x[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + - (m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N], - x_residual_shared) + x[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + x_residual_shared, + ) T.copy(x_residual_shared, x_residual_local) for i, j in T.Parallel(block_M, block_N): acc_o[i, j] += x_residual_local[i, j] * D_local[0] @@ -321,24 +330,37 @@ def chunk_scan_fwd(batch, T.copy(acc_o, acc_o_shared) T.copy( acc_o_shared, - Output[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + - (m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N]) + Output[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + ) return main if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=80, help='heads') - parser.add_argument('--groups', type=int, default=1, help='groups') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--chunk_size', type=int, default=256, help='chunk size') - parser.add_argument('--dim', type=int, default=64, help='dim') - parser.add_argument('--dstate', type=int, default=128, help='dstate') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=80, help="heads") + parser.add_argument("--groups", type=int, default=1, help="groups") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--chunk_size", type=int, default=256, help="chunk size") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--dstate", type=int, default=128, help="dstate") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() - batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate + batch, heads, groups, seq_len, chunk_size, dim, dstate = ( + args.batch, + args.heads, + args.groups, + args.seq_len, + args.chunk_size, + args.dim, + args.dstate, + ) nchunks = math.ceil(seq_len / chunk_size) total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate @@ -360,8 +382,7 @@ if __name__ == "__main__": D = torch.randn(heads).half().cuda() print("Benchmarking Triton...") - triton_latency = do_bench( - lambda: chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D), _n_warmup=10, _n_repeat=10) + triton_latency = do_bench(lambda: chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D), _n_warmup=10, _n_repeat=10) print(f"Triton TFlops: {total_flops / triton_latency * 1e-9}") print("Benchmarking Helion...") diff --git a/benchmark/matmul/benchmark_matmul.py b/benchmark/matmul/benchmark_matmul.py index c64f4fabf8e1bfb3c5c1566cd66359710c70d552..6ca1402d7a691abe6bb5ea2dcac0b2874ac2105c 100644 --- a/benchmark/matmul/benchmark_matmul.py +++ b/benchmark/matmul/benchmark_matmul.py @@ -6,6 +6,7 @@ import tilelang import tilelang.language as T from tilelang.autotuner import autotune from tilelang import jit + # Configure logger logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -101,9 +102,7 @@ def get_configs(args, kwargs): policy=[T.GemmWarpPolicy.Square], enable_rasteration=[True, False], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] return configs @@ -112,7 +111,9 @@ def get_configs(args, kwargs): warmup=3, rep=20, ) -@jit(out_idx=[2],) +@jit( + out_idx=[2], +) def matmul( M, N, @@ -159,9 +160,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): """ The compiled TVM function for block-level matrix multiplication. @@ -176,7 +177,6 @@ def matmul( # Bind x-dimension to block index in N, # y-dimension to block index in M. with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - # Allocate shared memory for A sub-block of shape (block_M, block_K) A_shared = T.alloc_shared((block_M, block_K), dtype) # Allocate shared memory for B sub-block of shape (block_N, block_K) diff --git a/benchmark/matmul/benchmark_matmul_intrinsic.py b/benchmark/matmul/benchmark_matmul_intrinsic.py index 94e36b385baf111709be26bdc4b379af10f5fceb..010ce87f7cc6c56d452dd55609afb7d159572ed5 100644 --- a/benchmark/matmul/benchmark_matmul_intrinsic.py +++ b/benchmark/matmul/benchmark_matmul_intrinsic.py @@ -6,7 +6,8 @@ import tilelang as tl import tilelang.language as T from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from tilelang.transform import simplify_prim_func from tilelang.autotuner import autotune import itertools @@ -103,12 +104,11 @@ def tl_matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -116,10 +116,12 @@ def tl_matmul( B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10, enable=enable_rasteration) @@ -127,7 +129,6 @@ def tl_matmul( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -137,7 +138,6 @@ def tl_matmul( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a(A_local, A_shared, ki) @@ -223,7 +223,6 @@ def get_configs(args, kwargs): for config in configs: print(config) else: - iter_params = dict( block_row_warps=[1, 2, 4], block_col_warps=[1, 2, 4], @@ -233,9 +232,7 @@ def get_configs(args, kwargs): stage=[0, 2], enable_rasteration=[True, False], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] return configs @@ -247,7 +244,9 @@ def get_configs(args, kwargs): ref_prog=ref_program, skip_check=True, ) -@tl.jit(out_idx=[2],) +@tl.jit( + out_idx=[2], +) def matmul( M, N, @@ -291,13 +290,8 @@ if __name__ == "__main__": parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") - parser.add_argument( - "--with_roller", - type=bool, - default=False, - help="Whether to use roller to deduce search spaces") - parser.add_argument( - "--dtype", type=str, default="float16", choices=["float16", "int8"], help="Input data type") + parser.add_argument("--with_roller", type=bool, default=False, help="Whether to use roller to deduce search spaces") + parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "int8"], help="Input data type") args = parser.parse_args() M, N, K = args.m, args.n, args.k diff --git a/benchmark/matmul/benchmark_matmul_sp.py b/benchmark/matmul/benchmark_matmul_sp.py index 0ff3cd0b69926d7b6d6dbd7806762f2df22555a8..22b5d13cfcc97c8d44f24a1f82f16a7201705253 100644 --- a/benchmark/matmul/benchmark_matmul_sp.py +++ b/benchmark/matmul/benchmark_matmul_sp.py @@ -70,7 +70,8 @@ def get_configs(M, N, K): thread_num, policy, enable_rasterization, - )) + ) + ) configs = [ { @@ -81,7 +82,8 @@ def get_configs(M, N, K): "thread_num": c[4], "policy": c[5], "enable_rasterization": c[6], # keep param name for backward-compat - } for c in _configs + } + for c in _configs ] return configs @@ -126,7 +128,9 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype): warmup=3, rep=20, ) - @jit(out_idx=[2],) + @jit( + out_idx=[2], + ) def kernel( block_M=None, block_N=None, @@ -165,10 +169,10 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype): @T.prim_func def main( - A_sparse: T.Tensor((M, K // 2), in_dtype), - E: T.Tensor((M, K // e_factor), e_dtype), - B: T.Tensor((K, N), in_dtype), - C: T.Tensor((M, N), accum_dtype), + A_sparse: T.Tensor((M, K // 2), in_dtype), + E: T.Tensor((M, K // e_factor), e_dtype), + B: T.Tensor((K, N), in_dtype), + C: T.Tensor((M, N), accum_dtype), ): """ The compiled TVM function for block-level matrix multiplication. @@ -182,9 +186,7 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype): """ # Bind x-dimension to block index in N, # y-dimension to block index in M. - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): # Allocate shared memory for A sub-block of shape (block_M, block_K) A_shared = T.alloc_shared((block_M, block_K // 2), in_dtype) # Allocate shared memory for B sub-block of shape (block_N, block_K) @@ -201,12 +203,12 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype): T.disable_warp_group_reg_alloc() T.use_swizzle(panel_size=10, enable=enable_rasterization) - T.annotate_layout({ - E: - make_cutlass_metadata_layout(E, mma_dtype=in_dtype, block_k=block_K), - E_shared: - make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, block_k=block_K), - }) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, block_k=block_K), + } + ) # Loop over sub-blocks in K dimension, pipelined by num_stages for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): # Load a sub-block of A from global memory into A_shared @@ -241,18 +243,13 @@ if __name__ == "__main__": parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") parser.add_argument("--disable_cache", action="store_true") - parser.add_argument( - "--accum_dtype", - type=str, - default="float", - choices=["float", "float16"], - help="Accumulation datatype") + parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype") parser.add_argument( "--bench_torch_sparse", type=str, - choices=['cutlass', 'cusparselt'], + choices=["cutlass", "cusparselt"], default=None, - help="Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported" + help="Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported", ) args = parser.parse_args() @@ -274,7 +271,8 @@ if __name__ == "__main__": if args.bench_torch_sparse is not None: from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor - if args.bench_torch_sparse == 'cutlass': + + if args.bench_torch_sparse == "cutlass": SparseSemiStructuredTensor._FORCE_CUTLASS = True A_sp = to_sparse_semi_structured(A, transposed=False) torch_sparse_latency = do_bench(lambda: A_sp @ B) @@ -285,8 +283,6 @@ if __name__ == "__main__": print(f"Best config: {best_config}") if args.bench_torch_sparse is not None: - print( - f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}" - ) + print(f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}") print(f"Reference Dense TFlops: {total_flops / ref_latency * 1e-9:.3f}") diff --git a/benchmark/matmul_fp8/benchmark_matmul.py b/benchmark/matmul_fp8/benchmark_matmul.py index 796f7b90badaaa8f286f566a0875085fc4225823..930e8a6d1fe68c2699496c5414e4d67836bcbec6 100644 --- a/benchmark/matmul_fp8/benchmark_matmul.py +++ b/benchmark/matmul_fp8/benchmark_matmul.py @@ -104,9 +104,7 @@ def get_configs(args, kwargs): policy=[T.GemmWarpPolicy.Square], enable_rasteration=[True, False], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] return configs @@ -116,7 +114,9 @@ def get_configs(args, kwargs): warmup=3, rep=20, ) -@jit(out_idx=[2],) +@jit( + out_idx=[2], +) def matmul( M, N, @@ -164,9 +164,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): """ The compiled TVM function for block-level matrix multiplication. @@ -181,7 +181,6 @@ def matmul( # Bind x-dimension to block index in N, # y-dimension to block index in M. with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - # Allocate shared memory for A sub-block of shape (block_M, block_K) A_shared = T.alloc_shared((block_M, block_K), dtype) # Allocate shared memory for B sub-block of shape (block_N, block_K) diff --git a/docs/conf.py b/docs/conf.py index 9d52415779bf7816e70015c8ba205c0bbadf3ba3..877b5582e1e28ee75704d5c75a8ff900a61c4cd3 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -20,33 +20,27 @@ extensions = [ "autoapi.extension", ] -autoapi_type = 'python' -autoapi_dirs = ['../tilelang'] +autoapi_type = "python" +autoapi_dirs = ["../tilelang"] autoapi_options = [ - 'members', - 'undoc-members', - 'show-inheritance', - 'show-module-summary', - 'special-members', + "members", + "undoc-members", + "show-inheritance", + "show-module-summary", + "special-members", ] autoapi_keep_files = False # Useful for debugging the generated rst files autoapi_generate_api_docs = True -autodoc_typehints = 'description' +autodoc_typehints = "description" autoapi_ignore = ["*language/ast*", "*version*", "*libinfo*", "*parser*"] -source_suffix = { - '.rst': 'restructuredtext', - '.md': 'markdown', -} +source_suffix = {".rst": "restructuredtext", ".md": "markdown"} -myst_enable_extensions = [ - "colon_fence", - "deflist", -] +myst_enable_extensions = ["colon_fence", "deflist"] redirects = {"get_started/try_out": "../index.html#getting-started"} @@ -66,10 +60,7 @@ html_css_files = ["custom.css"] footer_copyright = "© 2025-2026 TileLang" footer_note = " " -html_theme_options = { - "light_logo": "img/logo-v2.png", - "dark_logo": "img/logo-v2.png", -} +html_theme_options = {"light_logo": "img/logo-v2.png", "dark_logo": "img/logo-v2.png"} header_links = [ ("Home", "https://github.com/tile-ai/tilelang"), diff --git a/examples/amd/example_amd_flash_attn_bwd.py b/examples/amd/example_amd_flash_attn_bwd.py index d5c52f9ca5a66096245903affefe958806b16528..a546110311da92300dbea9a352c935ca8fe90bb0 100644 --- a/examples/amd/example_amd_flash_attn_bwd.py +++ b/examples/amd/example_amd_flash_attn_bwd.py @@ -11,22 +11,20 @@ import time def ref_program(Q, K, V, is_causal, groups=1): - assert Q.size( - 2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" - assert Q.size( - 2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}" dim = Q.size(-1) K_ref = K.repeat_interleave(groups, dim=2) V_ref = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K_ref) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K_ref) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V_ref) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V_ref) lse = torch.logsumexp(scores, dim=-1).float() return output, lse @@ -45,23 +43,23 @@ def get_fwd_configs(): valid_configs = [] - for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(block_M, block_N, num_split_q, - threads, num_stages, - enable_rasterization, k_pack, - panel_size, qk_coalesced_width, - v_coalesced_width): - valid_configs.append({ - "block_M": m, - "block_N": n, - "num_split_q": s, - "threads": t, - "num_stages": stages, - "enable_rasterization": r, - "k_pack": k, - "panel_size": p, - "qk_coalesced_width": qkw, - "v_coalesced_width": vw, - }) + for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product( + block_M, block_N, num_split_q, threads, num_stages, enable_rasterization, k_pack, panel_size, qk_coalesced_width, v_coalesced_width + ): + valid_configs.append( + { + "block_M": m, + "block_N": n, + "num_split_q": s, + "threads": t, + "num_stages": stages, + "enable_rasterization": r, + "k_pack": k, + "panel_size": p, + "qk_coalesced_width": qkw, + "v_coalesced_width": vw, + } + ) return valid_configs @@ -85,7 +83,7 @@ def fast_flashattn( qk_coalesced_width: int, v_coalesced_width: int, ): - scale = (1.0 / dim)**0.5 + scale = (1.0 / dim) ** 0.5 head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] @@ -97,11 +95,11 @@ def fast_flashattn( @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), - LSE: T.Tensor([batch, heads, seq_len], accum_dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + LSE: T.Tensor([batch, heads, seq_len], accum_dtype), ): with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined): T.use_swizzle(panel_size, enable=enable_rasterization) @@ -135,33 +133,21 @@ def fast_flashattn( m_prev = T.alloc_fragment([block_M], accum_dtype) scale_factor = T.alloc_fragment([block_M], accum_dtype) - T.copy( - Q[bz, q_block_offset:q_block_offset + block_M, by, :], - Q_shared, - coalesced_width=vec_size) + T.copy(Q[bz, q_block_offset : q_block_offset + block_M, by, :], Q_shared, coalesced_width=vec_size) - loop_end_k = ( - T.ceildiv(q_block_offset + - block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_end_k = T.ceildiv(q_block_offset + block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) row_sum = T.alloc_fragment([block_M], accum_dtype) for k in T.Pipelined(loop_end_k, num_stages=num_stages): kv_idx = k * block_N - T.copy( - K[bz, kv_idx:kv_idx + block_N, by // groups, :], - K_shared, - coalesced_width=vec_size) - T.copy( - V[bz, kv_idx:kv_idx + block_N, by // groups, :], - V_shared, - coalesced_width=v_vec_size) + T.copy(K[bz, kv_idx : kv_idx + block_N, by // groups, :], K_shared, coalesced_width=vec_size) + T.copy(V[bz, kv_idx : kv_idx + block_N, by // groups, :], V_shared, coalesced_width=v_vec_size) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) T.gemm( @@ -216,8 +202,7 @@ def fast_flashattn( for i in T.Parallel(block_M): if q_block_offset + i < seq_len: - lse_val = T.if_then_else(l_i[i] > 0, - T.log(l_i[i]) + m_i[i], -T.infinity(accum_dtype)) + lse_val = T.if_then_else(l_i[i] > 0, T.log(l_i[i]) + m_i[i], -T.infinity(accum_dtype)) LSE[bz, by, q_block_offset + i] = lse_val bx_loop_var = current_bx + num_split_q @@ -234,16 +219,17 @@ def get_bwd_configs(): panel_size = [7, 8, 9, 10] configs = [] - for m, n, stages, t, r, p in itertools.product(block_M, block_N, num_stages, threads, - enable_rasterization, panel_size): - configs.append({ - "block_M": m, - "block_N": n, - "num_stages": stages, - "threads": t, - "enable_rasterization": r, - "panel_size": p, - }) + for m, n, stages, t, r, p in itertools.product(block_M, block_N, num_stages, threads, enable_rasterization, panel_size): + configs.append( + { + "block_M": m, + "block_N": n, + "num_stages": stages, + "threads": t, + "enable_rasterization": r, + "panel_size": p, + } + ) return configs @@ -256,8 +242,7 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): blk = 32 @T.prim_func - def flash_bwd_prep(O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype), - Delta: T.Tensor([batch, heads, seq_len], accum_dtype)): + def flash_bwd_prep(O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype), Delta: T.Tensor([batch, heads, seq_len], accum_dtype)): with T.Kernel(batch, heads, T.ceildiv(seq_len, blk)) as (bz, bx, by): o = T.alloc_fragment([blk, blk], dtype) do = T.alloc_fragment([blk, blk], dtype) @@ -265,21 +250,33 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep @tilelang.autotune(configs=get_bwd_configs(), cache_input_tensors=True) @tilelang.jit -def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, block_N: int, - num_stages: int, threads: int, enable_rasterization: bool, panel_size: int): - sm_scale = (1.0 / dim)**0.5 +def flashattn_bwd( + batch, + heads, + seq_len, + dim, + is_causal, + groups, + block_M: int, + block_N: int, + num_stages: int, + threads: int, + enable_rasterization: bool, + panel_size: int, +): + sm_scale = (1.0 / dim) ** 0.5 head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] @@ -287,14 +284,17 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b accum_dtype = "float" @T.prim_func - def flash_bwd_kernel(Q: T.Tensor(q_shape, - dtype), K: T.Tensor(kv_shape, - dtype), V: T.Tensor(kv_shape, dtype), - dO: T.Tensor(q_shape, dtype), lse: T.Tensor([batch, heads, seq_len], - accum_dtype), - Delta: T.Tensor([batch, heads, seq_len], - accum_dtype), dQ: T.Tensor(q_shape, accum_dtype), - dK: T.Tensor(kv_shape, accum_dtype), dV: T.Tensor(kv_shape, accum_dtype)): + def flash_bwd_kernel( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + dO: T.Tensor(q_shape, dtype), + lse: T.Tensor([batch, heads, seq_len], accum_dtype), + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), + dQ: T.Tensor(q_shape, accum_dtype), + dK: T.Tensor(kv_shape, accum_dtype), + dV: T.Tensor(kv_shape, accum_dtype), + ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): T.use_swizzle(panel_size, enable=enable_rasterization) @@ -315,8 +315,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b dk = T.alloc_fragment([block_M, dim], accum_dtype) dq = T.alloc_fragment([block_N, dim], accum_dtype) - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) @@ -324,22 +324,21 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q_shared) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q_shared) T.clear(qkT) T.gemm(K_shared, q_shared, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): P_acc[i, j] = T.exp(qkT[i, j] * sm_scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - P_acc[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, - P_acc[i, j], 0.0) + P_acc[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, P_acc[i, j], 0.0) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do_shared) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do_shared) T.clear(dP) T.gemm(V_shared, do_shared, dP, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -347,7 +346,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b T.copy(P_acc, p_cast) T.gemm(p_cast, do_shared, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta_shared) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta_shared) for i, j in T.Parallel(block_M, block_N): p_cast[i, j] = P_acc[i, j] * (dP[i, j] - delta_shared[j]) * sm_scale @@ -378,8 +377,8 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim): def flash_bwd_post(dQ_in: T.Tensor(shape, accum_dtype), dQ_out: T.Tensor(shape, dtype)): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.copy( - dQ_in[bz, bx * blk:(bx + 1) * blk, by, :], - dQ_out[bz, bx * blk:(bx + 1) * blk, by, :], + dQ_in[bz, bx * blk : (bx + 1) * blk, by, :], + dQ_out[bz, bx * blk : (bx + 1) * blk, by, :], ) return flash_bwd_post @@ -446,22 +445,14 @@ def benchmark_function(func, *args, warmup=10, repeat=100): return np.median(times) -def main(batch: int = 1, - heads: int = 8, - seq_len: int = 4096, - dim: int = 128, - is_causal: bool = False, - groups: int = 1): - +def main(batch: int = 1, heads: int = 8, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 1): device = "cuda" dtype = torch.float16 torch.manual_seed(42) torch.cuda.manual_seed(42) - print( - f"Test configuration: batch={batch}, heads={heads}, seq_len={seq_len}, dim={dim}, is_causal={is_causal}, groups={groups}" - ) + print(f"Test configuration: batch={batch}, heads={heads}, seq_len={seq_len}, dim={dim}, is_causal={is_causal}, groups={groups}") flops_per_gemm = 2.0 * batch * heads * seq_len * seq_len * dim total_flops = 5 * flops_per_gemm @@ -517,22 +508,19 @@ def main(batch: int = 1, o_ref.backward(dO) print("Verifying backward pass correctness...") - dq_close, dq_max_diff, dq_mean_diff = debug_tensor_comparison( - dQ_tl, q_ref.grad, "dQ", rtol=0.05, atol=0.05) + dq_close, dq_max_diff, dq_mean_diff = debug_tensor_comparison(dQ_tl, q_ref.grad, "dQ", rtol=0.05, atol=0.05) if dq_close: print("dQ is correct.") else: print("dQ mismatch detected.") - dk_close, dk_max_diff, dk_mean_diff = debug_tensor_comparison( - dK_tl.to(torch.float16), k_ref.grad, "dK", rtol=0.05, atol=0.05) + dk_close, dk_max_diff, dk_mean_diff = debug_tensor_comparison(dK_tl.to(torch.float16), k_ref.grad, "dK", rtol=0.05, atol=0.05) if dk_close: print("dK is correct.") else: print("dK mismatch detected.") - dv_close, dv_max_diff, dv_mean_diff = debug_tensor_comparison( - dV_tl.to(torch.float16), v_ref.grad, "dV", rtol=0.05, atol=0.05) + dv_close, dv_max_diff, dv_mean_diff = debug_tensor_comparison(dV_tl.to(torch.float16), v_ref.grad, "dV", rtol=0.05, atol=0.05) if dv_close: print("dV is correct.") else: @@ -553,9 +541,7 @@ def main(batch: int = 1, torch.cuda.synchronize() ref_latency = benchmark_function(run_reference_fwd_bwd, warmup=10, repeat=100) - print( - f"Reference PyTorch Forward+Backward: {ref_latency:.2f} ms | {total_flops / ref_latency * 1e-9:.2f} TFlops" - ) + print(f"Reference PyTorch Forward+Backward: {ref_latency:.2f} ms | {total_flops / ref_latency * 1e-9:.2f} TFlops") def run_complete_fwd_bwd(): o_tl_bench, lse_tl_bench = fwd_kernel(q, k, v) @@ -593,12 +579,12 @@ def main(batch: int = 1, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=8, help='heads') - parser.add_argument('--seq_len', type=int, default=1024, help='sequence length') - parser.add_argument('--dim', type=int, default=64, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--groups', type=int, default=1, help='groups') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=8, help="heads") + parser.add_argument("--seq_len", type=int, default=1024, help="sequence length") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--groups", type=int, default=1, help="groups") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups) diff --git a/examples/amd/example_amd_flash_attn_fwd.py b/examples/amd/example_amd_flash_attn_fwd.py index 9ffa7cbb0b40501aa83ab44c480a1782fb737b3e..e53299a2bba39e7ab11f6b6d52be2f0ef031f2d8 100644 --- a/examples/amd/example_amd_flash_attn_fwd.py +++ b/examples/amd/example_amd_flash_attn_fwd.py @@ -13,10 +13,10 @@ def supply_tensors_gpu(params): """Supply function that creates tensors on GPU for ROCm/HIP.""" tensors = [] for param in params: - if hasattr(param, 'shape') and hasattr(param, 'dtype'): + if hasattr(param, "shape") and hasattr(param, "dtype"): # Force creation on GPU device shape = [int(s) for s in param.shape] - tensor = torch.randn(shape, dtype=param.dtype, device='cuda') + tensor = torch.randn(shape, dtype=param.dtype, device="cuda") tensors.append(tensor) else: tensors.append(param) @@ -24,22 +24,20 @@ def supply_tensors_gpu(params): def ref_program(Q, K, V, is_causal, groups=1): - assert Q.size( - 2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" - assert Q.size( - 2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}" dim = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -58,23 +56,23 @@ def get_configs(): valid_configs = [] - for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(block_M, block_N, num_split_q, - threads, num_stages, - enable_rasterization, k_pack, - panel_size, qk_coalesced_width, - v_coalesced_width): - valid_configs.append({ - "block_M": m, - "block_N": n, - "num_split_q": s, - "threads": t, - "num_stages": stages, - "enable_rasterization": r, - "k_pack": k, - "panel_size": p, - "qk_coalesced_width": qkw, - "v_coalesced_width": vw, - }) + for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product( + block_M, block_N, num_split_q, threads, num_stages, enable_rasterization, k_pack, panel_size, qk_coalesced_width, v_coalesced_width + ): + valid_configs.append( + { + "block_M": m, + "block_N": n, + "num_split_q": s, + "threads": t, + "num_stages": stages, + "enable_rasterization": r, + "k_pack": k, + "panel_size": p, + "qk_coalesced_width": qkw, + "v_coalesced_width": vw, + } + ) return valid_configs @@ -98,7 +96,7 @@ def fast_flashattn( qk_coalesced_width: int, v_coalesced_width: int, ): - scale = (1.0 / dim)**0.5 + scale = (1.0 / dim) ** 0.5 head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] @@ -110,10 +108,10 @@ def fast_flashattn( @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined): T.use_swizzle(panel_size, enable=enable_rasterization) @@ -147,32 +145,21 @@ def fast_flashattn( m_prev = T.alloc_fragment([block_M], accum_dtype) scale_factor = T.alloc_fragment([block_M], accum_dtype) - T.copy( - Q[bz, q_block_offset:q_block_offset + block_M, by, :], - Q_shared, - coalesced_width=vec_size) + T.copy(Q[bz, q_block_offset : q_block_offset + block_M, by, :], Q_shared, coalesced_width=vec_size) - loop_end_k = T.ceildiv(q_block_offset + block_M, - block_N) if is_causal else T.ceildiv(seq_len, block_N) + loop_end_k = T.ceildiv(q_block_offset + block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) row_sum = T.alloc_fragment([block_M], accum_dtype) for k in T.Pipelined(loop_end_k, num_stages=num_stages): kv_idx = k * block_N - T.copy( - K[bz, kv_idx:kv_idx + block_N, by // groups, :], - K_shared, - coalesced_width=vec_size) - T.copy( - V[bz, kv_idx:kv_idx + block_N, by // groups, :], - V_shared, - coalesced_width=v_vec_size) + T.copy(K[bz, kv_idx : kv_idx + block_N, by // groups, :], K_shared, coalesced_width=vec_size) + T.copy(V[bz, kv_idx : kv_idx + block_N, by // groups, :], V_shared, coalesced_width=v_vec_size) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) T.gemm( @@ -222,13 +209,7 @@ def fast_flashattn( return main -def main(batch: int = 1, - heads: int = 8, - seq_len: int = 4096, - dim: int = 128, - is_causal: bool = False, - groups: int = 1): - +def main(batch: int = 1, heads: int = 8, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 1): flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim total_flops = 2 * flops_per_matmul if is_causal: @@ -250,18 +231,16 @@ def main(batch: int = 1, print(f"Reference (PyTorch): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops") latency = profiler.do_bench(warmup=100) - print( - f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops" - ) + print(f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops") if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=8, help='heads') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--groups', type=int, default=1, help='groups') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=8, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--groups", type=int, default=1, help="groups") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups) diff --git a/examples/analyze/example_conv_analyze.py b/examples/analyze/example_conv_analyze.py index 540fcf4b742ccd2a9fccd88051135d364d6ab10e..b90be1435a1058cc6b09d7947df877e9bd2e33b5 100644 --- a/examples/analyze/example_conv_analyze.py +++ b/examples/analyze/example_conv_analyze.py @@ -25,22 +25,7 @@ def check_hopper(): return False -def kernel(N, - C, - H, - W, - F, - K, - S, - D, - P, - block_M, - block_N, - block_K, - num_stages, - threads, - dtype="float16", - accum_dtype="float"): +def kernel(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"): KH, KW = K, K OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 @@ -50,13 +35,11 @@ def kernel(N, @T.prim_func def conv( - data: T.Tensor((N, H, W, C), dtype), - kernel: T.Tensor((KH, KW, C, F), dtype), - out: T.Tensor((N, OH, OW, F), dtype), + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), ): - with T.Kernel( - T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), - threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by): data_shared = T.alloc_shared((block_M, block_K), dtype) kernel_shared = T.alloc_shared((block_K, block_N), dtype) out_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -65,11 +48,13 @@ def kernel(N, kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) - T.annotate_layout({ - out_shared: make_swizzled_layout(out_shared), - data_shared: make_swizzled_layout(data_shared), - kernel_shared: make_swizzled_layout(kernel_shared), - }) + T.annotate_layout( + { + out_shared: make_swizzled_layout(out_shared), + data_shared: make_swizzled_layout(data_shared), + kernel_shared: make_swizzled_layout(kernel_shared), + } + ) T.clear(out_local) for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): @@ -81,10 +66,8 @@ def kernel(N, m = by * block_M + i access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P access_w = m % OW * S + k // C % KW * D - P - in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and - (access_w < W)) - data_shared[i, j] = T.if_then_else( - in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) + in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W) + data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) T.gemm(data_shared, kernel_shared, out_local) diff --git a/examples/analyze/example_gemm_analyze.py b/examples/analyze/example_gemm_analyze.py index bfd934f6aa692dd56ef578b9affa27faf0e809fe..e28440e1b953b4f4a3097aa36db5a0825ea940b5 100644 --- a/examples/analyze/example_gemm_analyze.py +++ b/examples/analyze/example_gemm_analyze.py @@ -20,9 +20,9 @@ def kernel( @T.prim_func def matmul( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) diff --git a/examples/attention_sink/benchmark_gqa_sink_fwd.py b/examples/attention_sink/benchmark_gqa_sink_fwd.py index 1b7de6b6f2fdff4d3877abf65158a013d3b2c6f1..3538adc38edcd97bff9e6bd4b3e08e149a21efb3 100644 --- a/examples/attention_sink/benchmark_gqa_sink_fwd.py +++ b/examples/attention_sink/benchmark_gqa_sink_fwd.py @@ -51,8 +51,7 @@ def triton_kernel( q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM]) if BANDWIDTH: - lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M + lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M else: lo, hi = 0, start_q + (start_m + 1) * BLOCK_M @@ -120,7 +119,8 @@ def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.T BANDWIDTH=window_size, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - start_q=seq_kv - seq_q) + start_q=seq_kv - seq_q, + ) return o @@ -137,12 +137,11 @@ def main( ): torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= seq_q - flops_per_matmul = 2.0 * batch * heads * min( - window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 total_flops = 2 * flops_per_matmul @@ -170,15 +169,14 @@ def main( block_N=block_N, num_stages=num_stages, threads=threads, - dtype=dtype) + dtype=dtype, + ) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype) if torch.allclose( - triton_program(Q, K, V, sinks, window_size), - ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), - rtol=1e-2, - atol=1e-2): + triton_program(Q, K, V, sinks, window_size), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ): print("Checks for triton passed.✅") else: print("Checks for triton failed.❌") @@ -198,20 +196,14 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=64, help='heads') - parser.add_argument('--seq_q', type=int, default=2048, help='sequence length of query') - parser.add_argument('--seq_kv', type=int, default=2048, help='sequence length of key/value') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--groups', type=int, default=8, help='groups') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_q", type=int, default=2048, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=2048, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--groups", type=int, default=8, help="groups") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() - main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, - args.dtype, args.tune) + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, args.dtype, args.tune) diff --git a/examples/attention_sink/benchmark_mha_sink_fwd.py b/examples/attention_sink/benchmark_mha_sink_fwd.py index f50b945356b2e35ef3d7c889a1e8bd2555907bb1..76997d84b50c8f87faf8afb2fe8e0ab1b737e306 100644 --- a/examples/attention_sink/benchmark_mha_sink_fwd.py +++ b/examples/attention_sink/benchmark_mha_sink_fwd.py @@ -50,8 +50,7 @@ def triton_kernel( q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM]) if BANDWIDTH: - lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M + lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M else: lo, hi = 0, start_q + (start_m + 1) * BLOCK_M @@ -117,26 +116,28 @@ def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.T BANDWIDTH=window_size, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - start_q=seq_kv - seq_q) + start_q=seq_kv - seq_q, + ) return o -def main(batch: int = 1, - heads: int = 32, - seq_q: int = 256, - seq_kv: int = 256, - dim: int = 128, - window_size: Optional[int] = None, - dtype: str = "float16", - tune: bool = False): +def main( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: str = "float16", + tune: bool = False, +): torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= seq_q - flops_per_matmul = 2.0 * batch * heads * min( - window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 total_flops = 2 * flops_per_matmul @@ -163,15 +164,14 @@ def main(batch: int = 1, block_N=block_N, num_stages=num_stages, threads=threads, - dtype=dtype) + dtype=dtype, + ) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) torch.testing.assert_close( - kernel(Q, K, V, sinks), - ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), - rtol=1e-2, - atol=1e-2) + kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ) print("All checks passed.✅") latency = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500) @@ -184,19 +184,13 @@ def main(batch: int = 1, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query') - parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") - parser.add_argument('--tune', action='store_true', help='tune') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune") args = parser.parse_args() - main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, - args.tune) + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune) diff --git a/examples/attention_sink/example_gqa_sink_bwd_bhsd.py b/examples/attention_sink/example_gqa_sink_bwd_bhsd.py index b442505fcb5e0a4b5526e22717ca1d0e707a4f1b..5af787a126fa10fb24e091824522e252d3825080 100644 --- a/examples/attention_sink/example_gqa_sink_bwd_bhsd.py +++ b/examples/attention_sink/example_gqa_sink_bwd_bhsd.py @@ -20,28 +20,30 @@ def get_bwd_configs(): @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd( - batch, - heads, - seq_len, - dim, - groups=1, - window_size=None, # None for full attention - sm_scale=None, - block_M=64, - block_N=64, - num_stages=1, - threads=128, - dtype: str = "float16"): - + batch, + heads, + seq_len, + dim, + groups=1, + window_size=None, # None for full attention + sm_scale=None, + block_M=64, + block_N=64, + num_stages=1, + threads=128, + dtype: str = "float16", +): if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) head_kv = heads // groups @@ -51,12 +53,12 @@ def flashattn_fwd( @T.prim_func def flash_fwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(kv_shape, dtype), # type: ignore - V: T.Tensor(kv_shape, dtype), # type: ignore - Output: T.Tensor(q_shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Sinks: T.Tensor([heads], dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(kv_shape, dtype), # type: ignore + V: T.Tensor(kv_shape, dtype), # type: ignore + Output: T.Tensor(q_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Sinks: T.Tensor([heads], dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -73,7 +75,7 @@ def flashattn_fwd( sinks = T.alloc_fragment([heads], dtype) T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -81,22 +83,20 @@ def flashattn_fwd( sinks[i] = Sinks[by] end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) - start = T.max(0, - (bx * block_M - window_size) // block_N) if window_size is not None else 0 + start = T.max(0, (bx * block_M - window_size) // block_N) if window_size is not None else 0 for k in T.Pipelined(start, end, num_stages=num_stages): - T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by // groups, k * block_N : (k + 1) * block_N, :], K_shared) for i, j in T.Parallel(block_M, block_N): q_idx = bx * block_M + i k_idx = k * block_N + j if window_size is not None: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, - 0, -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) else: acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, by // groups, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by // groups, k * block_N : (k + 1) * block_N, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_M): @@ -106,8 +106,7 @@ def flashattn_fwd( # NOTE(wt): check_inf is necessary for sliding window attention. for i in T.Parallel(block_M): if window_size is not None: - scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, - scores_max[i]) + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim): @@ -124,22 +123,23 @@ def flashattn_fwd( logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i in T.Parallel(block_M): - logsum[i] += T.exp2(sinks[i] * 1.44269504 - - scores_max[i] * scale) # The only change for attention sink + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"): accum_dtype = "float" shape = [batch, heads, seq_len, dim] @@ -147,9 +147,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16") @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -158,26 +158,27 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16") delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], do) + T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep def make_dq_layout(dQ): # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment - return T.Layout(dQ.shape, - lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) @tilelang.jit( - out_idx=[1], pass_configs={ + out_idx=[1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"): accum_dtype = "float" shape = [batch, heads, seq_len, dim] @@ -185,32 +186,27 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16" @T.prim_func def flash_bwd_post( - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(shape, dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.annotate_layout({dQ: make_dq_layout(dQ)}) T.copy( - dQ[bz, by, bx * blk:(bx + 1) * blk, :], - dQ_out[bz, by, bx * blk:(bx + 1) * blk, :], + dQ[bz, by, bx * blk : (bx + 1) * blk, :], + dQ_out[bz, by, bx * blk : (bx + 1) * blk, :], ) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd(batch, - heads, - seq_len, - dim, - groups, - window_size=None, - sm_scale=None, - dtype="float16"): # None for full attention +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None, sm_scale=None, dtype="float16"): # None for full attention if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) head_kv = heads // groups @@ -225,15 +221,15 @@ def flashattn_bwd(batch, @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(kv_shape, dtype), # type: ignore - V: T.Tensor(kv_shape, dtype), # type: ignore - dO: T.Tensor(q_shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(kv_shape, accum_dtype), # type: ignore - dV: T.Tensor(kv_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(kv_shape, dtype), # type: ignore + V: T.Tensor(kv_shape, dtype), # type: ignore + dO: T.Tensor(q_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(kv_shape, accum_dtype), # type: ignore + dV: T.Tensor(kv_shape, accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim], dtype) @@ -253,44 +249,47 @@ def flashattn_bwd(batch, dv_shared = T.alloc_shared([block_M, dim], accum_dtype) dk_shared = T.alloc_shared([block_M, dim], accum_dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) - T.copy(K[bz, bx // groups, by * block_M:(by + 1) * block_M, :], K_shared) - T.copy(V[bz, bx // groups, by * block_M:(by + 1) * block_M, :], V_shared) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) + T.copy(K[bz, bx // groups, by * block_M : (by + 1) * block_M, :], K_shared) + T.copy(V[bz, bx // groups, by * block_M : (by + 1) * block_M, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) - loop_ed = T.min( - T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv( - seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N) + loop_ed = ( + T.min(T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(seq_len, block_N)) + if window_size is not None + else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q) + T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) for i, j in T.Parallel(block_M, block_N): if window_size is not None: qkT[i, j] = T.if_then_else( - by * block_M + i <= k * block_N + j and - by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0) + by * block_M + i <= k * block_N + j and by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0 + ) else: - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) - T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], dst=do) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], dst=do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -299,12 +298,12 @@ def flashattn_bwd(batch, T.copy(dsT_cast, dsT_shared) T.clear(dq) T.gemm(dsT_shared, K_shared, dq, transpose_A=True) - T.atomic_add(dQ[bz, bx, k * block_N:(k + 1) * block_N, :], dq) + T.atomic_add(dQ[bz, bx, k * block_N : (k + 1) * block_N, :], dq) T.copy(dv, dv_shared) - T.atomic_add(dV[bz, bx // groups, by * block_M:(by + 1) * block_M, :], dv_shared) + T.atomic_add(dV[bz, bx // groups, by * block_M : (by + 1) * block_M, :], dv_shared) T.copy(dk, dk_shared) - T.atomic_add(dK[bz, bx // groups, by * block_M:(by + 1) * block_M, :], dk_shared) + T.atomic_add(dK[bz, bx // groups, by * block_M : (by + 1) * block_M, :], dk_shared) return flash_bwd @@ -316,10 +315,10 @@ def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: str = "float16" @T.prim_func def flash_bwd_dsink( - Sinks: T.Tensor([heads], dtype), # type: ignore - Delta: T.Tensor(shape, accum_dtype), # type: ignore - lse: T.Tensor(shape, accum_dtype), # type: ignore - dsinks: T.Tensor(shape, dtype), # type: ignore + Sinks: T.Tensor([heads], dtype), # type: ignore + Delta: T.Tensor(shape, accum_dtype), # type: ignore + lse: T.Tensor(shape, accum_dtype), # type: ignore + dsinks: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=256) as (bx, by, bz): sink = T.alloc_local([1], dtype) @@ -328,21 +327,18 @@ def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: str = "float16" dsink_fragment = T.alloc_fragment([block], dtype) sink[0] = Sinks[bx] - T.copy(lse[bz, bx, by * block:(by + 1) * block], lse_fragment) - T.copy(Delta[bz, bx, by * block:(by + 1) * block], delta_fragment) + T.copy(lse[bz, bx, by * block : (by + 1) * block], lse_fragment) + T.copy(Delta[bz, bx, by * block : (by + 1) * block], delta_fragment) for i in T.Parallel(block): - dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - - lse_fragment[i]) * delta_fragment[i] - T.copy(dsink_fragment, dsinks[bz, bx, by * block:(by + 1) * block]) + dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - lse_fragment[i]) * delta_fragment[i] + T.copy(dsink_fragment, dsinks[bz, bx, by * block : (by + 1) * block]) return flash_bwd_dsink class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, sinks, window_size, groups): - def maybe_contiguous(x): if x.stride(-1) != 1: return x.contiguous() @@ -388,13 +384,14 @@ attention = _attention.apply # Adapted and optimized from # https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py -def ref_program(query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - sinks: torch.Tensor, - sliding_window: Optional[int] = None, - dtype: torch.dtype = torch.float16) -> torch.Tensor: - +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() batch_size, num_keys, num_key_value_heads, head_dim = key.shape @@ -430,32 +427,31 @@ def ref_program(query: torch.Tensor, output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) - output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, - head_dim).to(dtype) + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) return output.transpose(1, 2).contiguous() -def main(BATCH: int = 1, - H: int = 8, - N_CTX: int = 512, - D_HEAD: int = 64, - groups: int = 2, - window_size: Optional[int] = None, - dtype: str = "float16"): +def main( + BATCH: int = 1, + H: int = 8, + N_CTX: int = 512, + D_HEAD: int = 64, + groups: int = 2, + window_size: Optional[int] = None, + dtype: str = "float16", +): torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= N_CTX - flops_per_matmul = 2.0 * BATCH * H * min( - window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation + flops_per_matmul = 2.0 * BATCH * H * min(window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5 total_flops = 5 * flops_per_matmul - Q = (torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_()) - K = torch.randn( - BATCH, H // groups, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_() + Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_() + K = torch.randn(BATCH, H // groups, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_() V = torch.randn_like(K).requires_grad_() sinks = torch.randn(H, dtype=torch_dtype, device="cuda").requires_grad_() dO = torch.randn_like(Q) @@ -479,16 +475,11 @@ def main(BATCH: int = 1, "float16": (1e-2, 1e-2), "bfloat16": (2e-2, 2e-2), }[dtype] - assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f'O max err: {(O-O_ref).abs().max()}' - assert torch.allclose( - dV, dV_ref, rtol=rtol, atol=atol), f'dV max err: {(dV-dV_ref).abs().max()}' - assert torch.allclose( - dK, dK_ref, rtol=rtol, atol=atol), f'dK max err: {(dK-dK_ref).abs().max()}' - assert torch.allclose( - dQ, dQ_ref, rtol=rtol, atol=atol), f'dq max err: {(dQ-dQ_ref).abs().max()}' - assert torch.allclose( - dsinks, dsinks_ref, rtol=rtol, - atol=atol), f'dsinks max err: {(dsinks-dsinks_ref).abs().max()}' + assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}" + assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}" + assert torch.allclose(dK, dK_ref, rtol=rtol, atol=atol), f"dK max err: {(dK - dK_ref).abs().max()}" + assert torch.allclose(dQ, dQ_ref, rtol=rtol, atol=atol), f"dq max err: {(dQ - dQ_ref).abs().max()}" + assert torch.allclose(dsinks, dsinks_ref, rtol=rtol, atol=atol), f"dsinks max err: {(dsinks - dsinks_ref).abs().max()}" print("All checks passed for tilelang kernels.✅") @@ -509,17 +500,12 @@ def main(BATCH: int = 1, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='Batch size') - parser.add_argument('--h', type=int, default=64, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=4096, help='Context size') - parser.add_argument('--d_head', type=int, default=128, help='Head dimension') - parser.add_argument('--groups', type=int, default=8, help='Groups') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--batch", type=int, default=1, help="Batch size") + parser.add_argument("--h", type=int, default=64, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=4096, help="Context size") + parser.add_argument("--d_head", type=int, default=128, help="Head dimension") + parser.add_argument("--groups", type=int, default=8, help="Groups") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") args = parser.parse_args() main(args.batch, args.h, args.n_ctx, args.d_head, args.groups, args.window_size, args.dtype) diff --git a/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py b/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py index 8d1817267fec3c135b9e077aeb475f2792e9e411..feb5844f703746de8f00f09e0b89a3ff5df11e62 100644 --- a/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py +++ b/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py @@ -23,9 +23,11 @@ def get_configs(): rep=100, ) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn( batch, heads, @@ -41,12 +43,11 @@ def flashattn( threads=256, dtype: str = "float16", ): - if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) head_kv = heads // groups @@ -68,13 +69,12 @@ def flashattn( by: T.int32, bz: T.int32, ): - T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by // groups, k * block_N : (k + 1) * block_N, :], K_shared) for i, j in T.Parallel(block_M, block_N): q_idx = bx * block_M + i + past_len k_idx = k * block_N + j if window_size is not None: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) else: acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -89,18 +89,18 @@ def flashattn( by: T.int32, bz: T.int32, ): - T.copy(V[bz, by // groups, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by // groups, k * block_N : (k + 1) * block_N, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -112,8 +112,7 @@ def flashattn( # NOTE(wt): check_inf is necessary for sliding window attention. for i in T.Parallel(block_M): if window_size is not None: - scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, - scores_max[i]) + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, block_N): @@ -128,19 +127,19 @@ def flashattn( @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), - Sinks: T.Tensor([heads], dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + Sinks: T.Tensor([heads], dtype), ): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -157,58 +156,58 @@ def flashattn( logsum = T.alloc_fragment([block_M], accum_dtype) sinks = T.alloc_fragment([block_M], dtype) - T.annotate_layout({ - Q_shared: make_swizzled_layout(Q_shared), - K_shared: make_swizzled_layout(K_shared), - V_shared: make_swizzled_layout(V_shared), - O_shared: make_swizzled_layout(O_shared), - }) + T.annotate_layout( + { + Q_shared: make_swizzled_layout(Q_shared), + K_shared: make_swizzled_layout(K_shared), + V_shared: make_swizzled_layout(V_shared), + O_shared: make_swizzled_layout(O_shared), + } + ) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) for i in T.Parallel(block_M): sinks[i] = Sinks[by] - end = T.min( - T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) - start = T.max(0, (bx * block_M + past_len - window_size) // - block_N) if window_size is not None else 0 + start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0 for k in T.Pipelined( - start, - end, - num_stages=num_stages, - order=[-1, 0, 3, 1, -1, 2], - stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]): + start, + end, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i in T.Parallel(block_M): - logsum[i] += T.exp2(sinks[i] * 1.44269504 - - scores_max[i] * scale) # The only change for attention sink + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main # Following functions are adapted and optimized from # https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py -def ref_program(query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - sinks: torch.Tensor, - sliding_window: Optional[int] = None, - dtype: torch.dtype = torch.float16) -> torch.Tensor: - +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() batch_size, num_keys, num_key_value_heads, head_dim = key.shape @@ -244,23 +243,15 @@ def ref_program(query: torch.Tensor, output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) - output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, - head_dim).to(dtype) + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) return output.transpose(1, 2).contiguous() -def gen_inputs( - B, - H, - Sq, - Skv, - D, - groups, - dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda') - key = torch.randn([B, H // groups, Skv, D], dtype=dtype, device='cuda') - value = torch.randn([B, H // groups, Skv, D], dtype=dtype, device='cuda') - sinks = torch.randn([H], dtype=dtype, device='cuda') +def gen_inputs(B, H, Sq, Skv, D, groups, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda") + key = torch.randn([B, H // groups, Skv, D], dtype=dtype, device="cuda") + value = torch.randn([B, H // groups, Skv, D], dtype=dtype, device="cuda") + sinks = torch.randn([H], dtype=dtype, device="cuda") return query, key, value, sinks @@ -277,12 +268,11 @@ def main( ): torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= seq_q - flops_per_matmul = 2.0 * batch * heads * min( - window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 total_flops = 2 * flops_per_matmul @@ -310,15 +300,14 @@ def main( block_N=block_N, num_stages=num_stages, threads=threads, - dtype=dtype) + dtype=dtype, + ) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype) torch.testing.assert_close( - kernel(Q, K, V, sinks), - ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), - rtol=1e-2, - atol=1e-2) + kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ) print("All checks passed.✅") # Benchmark tilelang @@ -329,20 +318,14 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=64, help='heads') - parser.add_argument('--seq_q', type=int, default=2048, help='sequence length of query') - parser.add_argument('--seq_kv', type=int, default=2048, help='sequence length of key/value') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--groups', type=int, default=8, help='groups') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_q", type=int, default=2048, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=2048, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--groups", type=int, default=8, help="groups") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() - main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, - args.dtype, args.tune) + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, args.dtype, args.tune) diff --git a/examples/attention_sink/example_mha_sink_bwd_bhsd.py b/examples/attention_sink/example_mha_sink_bwd_bhsd.py index b9fa0fd970eea910463005ef3cfab9da626355fa..155c488e658ae3109e5fe9c6161a7759ba4a1991 100644 --- a/examples/attention_sink/example_mha_sink_bwd_bhsd.py +++ b/examples/attention_sink/example_mha_sink_bwd_bhsd.py @@ -20,27 +20,29 @@ def get_bwd_configs(): @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd( - batch, - heads, - seq_len, - dim, - window_size=None, # None for full attention, - sm_scale=None, - block_M=64, - block_N=64, - num_stages=1, - threads=128, - dtype: str = "float16"): - + batch, + heads, + seq_len, + dim, + window_size=None, # None for full attention, + sm_scale=None, + block_M=64, + block_N=64, + num_stages=1, + threads=128, + dtype: str = "float16", +): if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] @@ -48,12 +50,12 @@ def flashattn_fwd( @T.prim_func def flash_fwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - Output: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Sinks: T.Tensor([heads], dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + Output: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Sinks: T.Tensor([heads], dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -70,7 +72,7 @@ def flashattn_fwd( sinks = T.alloc_fragment([heads], dtype) T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -78,22 +80,20 @@ def flashattn_fwd( sinks[i] = Sinks[by] end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) - start = T.max(0, - (bx * block_M - window_size) // block_N) if window_size is not None else 0 + start = T.max(0, (bx * block_M - window_size) // block_N) if window_size is not None else 0 for k in T.Pipelined(start, end, num_stages=num_stages): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) for i, j in T.Parallel(block_M, block_N): q_idx = bx * block_M + i k_idx = k * block_N + j if window_size is not None: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, - 0, -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) else: acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_M): @@ -103,8 +103,7 @@ def flashattn_fwd( # NOTE(wt): check_inf is necessary for sliding window attention. for i in T.Parallel(block_M): if window_size is not None: - scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, - scores_max[i]) + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim): @@ -121,22 +120,23 @@ def flashattn_fwd( logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i in T.Parallel(block_M): - logsum[i] += T.exp2(sinks[i] * 1.44269504 - - scores_max[i] * scale) # The only change for attention sink + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"): accum_dtype = "float" shape = [batch, heads, seq_len, dim] @@ -144,9 +144,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16") @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -155,26 +155,27 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16") delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], do) + T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep def make_dq_layout(dQ): # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment - return T.Layout(dQ.shape, - lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) @tilelang.jit( - out_idx=[1], pass_configs={ + out_idx=[1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"): accum_dtype = "float" shape = [batch, heads, seq_len, dim] @@ -182,22 +183,24 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16" @T.prim_func def flash_bwd_post( - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(shape, dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.annotate_layout({dQ: make_dq_layout(dQ)}) T.copy( - dQ[bz, by, bx * blk:(bx + 1) * blk, :], - dQ_out[bz, by, bx * blk:(bx + 1) * blk, :], + dQ[bz, by, bx * blk : (bx + 1) * blk, :], + dQ_out[bz, by, bx * blk : (bx + 1) * blk, :], ) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) def flashattn_bwd( batch, heads, @@ -207,11 +210,10 @@ def flashattn_bwd( sm_scale=None, dtype: str = "float16", ): - block_M, block_N, num_stages, threads = get_bwd_configs() if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] @@ -222,15 +224,15 @@ def flashattn_bwd( @T.prim_func def flash_bwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dK: T.Tensor(shape, dtype), # type: ignore - dV: T.Tensor(shape, dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, dtype), # type: ignore + dV: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim], dtype) @@ -254,43 +256,46 @@ def flashattn_bwd( dv_shared = T.alloc_shared([block_M, dim], dtype) dk_shared = T.alloc_shared([block_M, dim], dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) - T.copy(K[bz, bx, by * block_M:(by + 1) * block_M, :], K_shared) - T.copy(V[bz, bx, by * block_M:(by + 1) * block_M, :], V_shared) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) + T.copy(K[bz, bx, by * block_M : (by + 1) * block_M, :], K_shared) + T.copy(V[bz, bx, by * block_M : (by + 1) * block_M, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) - loop_ed = T.min( - T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv( - seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N) + loop_ed = ( + T.min(T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(seq_len, block_N)) + if window_size is not None + else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q) + T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) for i, j in T.Parallel(block_M, block_N): if window_size is not None: qkT[i, j] = T.if_then_else( - by * block_M + i <= k * block_N + j and - by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0) + by * block_M + i <= k * block_N + j and by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0 + ) else: - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) - T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], dst=do) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], dst=do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, B=do, C=dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -299,12 +304,12 @@ def flashattn_bwd( T.copy(dsT_cast, dsT_shared) T.clear(dq) T.gemm(dsT_shared, K_shared, dq, transpose_A=True) - T.atomic_add(dQ[bz, bx, k * block_N:(k + 1) * block_N, :], dq) + T.atomic_add(dQ[bz, bx, k * block_N : (k + 1) * block_N, :], dq) T.copy(dv, dv_shared) T.copy(dk, dk_shared) - T.copy(dv_shared, dV[bz, bx, by * block_M:(by + 1) * block_M, :]) - T.copy(dk_shared, dK[bz, bx, by * block_M:(by + 1) * block_M, :]) + T.copy(dv_shared, dV[bz, bx, by * block_M : (by + 1) * block_M, :]) + T.copy(dk_shared, dK[bz, bx, by * block_M : (by + 1) * block_M, :]) return flash_bwd @@ -316,10 +321,10 @@ def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: str = "float16" @T.prim_func def flash_bwd_dsink( - Sinks: T.Tensor([heads], dtype), # type: ignore - Delta: T.Tensor(shape, accum_dtype), # type: ignore - lse: T.Tensor(shape, accum_dtype), # type: ignore - dsinks: T.Tensor(shape, accum_dtype), # type: ignore + Sinks: T.Tensor([heads], dtype), # type: ignore + Delta: T.Tensor(shape, accum_dtype), # type: ignore + lse: T.Tensor(shape, accum_dtype), # type: ignore + dsinks: T.Tensor(shape, accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=128) as (bx, by, bz): sink = T.alloc_local([1], dtype) @@ -328,18 +333,16 @@ def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: str = "float16" dsink_fragment = T.alloc_fragment([block], accum_dtype) sink[0] = Sinks[bx] - T.copy(lse[bz, bx, by * block:(by + 1) * block], lse_fragment) - T.copy(Delta[bz, bx, by * block:(by + 1) * block], delta_fragment) + T.copy(lse[bz, bx, by * block : (by + 1) * block], lse_fragment) + T.copy(Delta[bz, bx, by * block : (by + 1) * block], delta_fragment) for i in T.Parallel(block): - dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - - lse_fragment[i]) * delta_fragment[i] - T.copy(dsink_fragment, dsinks[bz, bx, by * block:(by + 1) * block]) + dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - lse_fragment[i]) * delta_fragment[i] + T.copy(dsink_fragment, dsinks[bz, bx, by * block : (by + 1) * block]) return flash_bwd_dsink class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, sinks, window_size): BATCH, H, N_CTX, D_HEAD = q.shape @@ -383,15 +386,15 @@ attention = _attention.apply # Adapted and optimized from # https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py -def ref_program(query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - sinks: torch.Tensor, - sliding_window: Optional[int] = None, - dtype: torch.dtype = torch.float16) -> torch.Tensor: - - query = query.transpose(1, 2).contiguous().unsqueeze( - 3) # align with the original function's interface +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function's interface key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() @@ -426,29 +429,22 @@ def ref_program(query: torch.Tensor, output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) - output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, - head_dim).to(dtype) + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) return output.transpose(1, 2).contiguous() -def main(BATCH: int = 1, - H: int = 1, - N_CTX: int = 512, - D_HEAD: int = 128, - window_size: Optional[int] = None, - dtype: str = "float16"): +def main(BATCH: int = 1, H: int = 1, N_CTX: int = 512, D_HEAD: int = 128, window_size: Optional[int] = None, dtype: str = "float16"): torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= N_CTX - flops_per_matmul = 2.0 * BATCH * H * min( - window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation + flops_per_matmul = 2.0 * BATCH * H * min(window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5 total_flops = 5 * flops_per_matmul - Q = (torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_()) + Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_() K = torch.randn_like(Q).requires_grad_() V = torch.randn_like(Q).requires_grad_() sinks = torch.randn(H, dtype=torch_dtype, device=Q.device).requires_grad_() @@ -473,16 +469,11 @@ def main(BATCH: int = 1, "float16": (1e-2, 1e-2), "bfloat16": (2e-2, 2e-2), }[dtype] - assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f'O max err: {(O-O_ref).abs().max()}' - assert torch.allclose( - dV, dV_ref, rtol=rtol, atol=atol), f'dV max err: {(dV-dV_ref).abs().max()}' - assert torch.allclose( - dK, dK_ref, rtol=rtol, atol=atol), f'dK max err: {(dK-dK_ref).abs().max()}' - assert torch.allclose( - dQ, dQ_ref, rtol=rtol, atol=atol), f'dq max err: {(dQ-dQ_ref).abs().max()}' - assert torch.allclose( - dsinks, dsinks_ref, rtol=rtol, - atol=atol), f'dsinks max err: {(dsinks-dsinks_ref).abs().max()}' + assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}" + assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}" + assert torch.allclose(dK, dK_ref, rtol=rtol, atol=atol), f"dK max err: {(dK - dK_ref).abs().max()}" + assert torch.allclose(dQ, dQ_ref, rtol=rtol, atol=atol), f"dq max err: {(dQ - dQ_ref).abs().max()}" + assert torch.allclose(dsinks, dsinks_ref, rtol=rtol, atol=atol), f"dsinks max err: {(dsinks - dsinks_ref).abs().max()}" print("All checks passed for tilelang kernels.✅") @@ -503,16 +494,11 @@ def main(BATCH: int = 1, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='Batch size') - parser.add_argument('--h', type=int, default=64, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=4096, help='Context size') - parser.add_argument('--d_head', type=int, default=128, help='Head dimension') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--batch", type=int, default=1, help="Batch size") + parser.add_argument("--h", type=int, default=64, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=4096, help="Context size") + parser.add_argument("--d_head", type=int, default=128, help="Head dimension") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") args = parser.parse_args() main(args.batch, args.h, args.n_ctx, args.d_head, args.window_size, args.dtype) diff --git a/examples/attention_sink/example_mha_sink_fwd_bhsd.py b/examples/attention_sink/example_mha_sink_fwd_bhsd.py index 0ccb69588142b49e4323efd3c89c5e4e2c334d37..78ac443b2b304e956c5fa23c5fc2625437618219 100644 --- a/examples/attention_sink/example_mha_sink_fwd_bhsd.py +++ b/examples/attention_sink/example_mha_sink_fwd_bhsd.py @@ -18,27 +18,30 @@ def get_configs(): @autotune(configs=get_configs(), warmup=500, rep=100) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn( - batch, - heads, - seq_q, - seq_kv, - dim, - window_size=None, # None for full attention - sm_scale=None, - block_M=64, - block_N=64, - num_stages=1, - threads=128, - dtype: str = "float16"): + batch, + heads, + seq_q, + seq_kv, + dim, + window_size=None, # None for full attention + sm_scale=None, + block_M=64, + block_N=64, + num_stages=1, + threads=128, + dtype: str = "float16", +): if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] @@ -58,13 +61,12 @@ def flashattn( by: T.int32, bz: T.int32, ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) for i, j in T.Parallel(block_M, block_N): q_idx = bx * block_M + i + past_len k_idx = k * block_N + j if window_size is not None: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) else: acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -79,18 +81,18 @@ def flashattn( by: T.int32, bz: T.int32, ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -102,8 +104,7 @@ def flashattn( # NOTE(wt): check_inf is necessary for sliding window attention. for i in T.Parallel(block_M): if window_size is not None: - scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, - scores_max[i]) + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, block_N): @@ -118,19 +119,19 @@ def flashattn( @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), - Sinks: T.Tensor([heads], dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + Sinks: T.Tensor([heads], dtype), ): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -147,53 +148,51 @@ def flashattn( logsum = T.alloc_fragment([block_M], accum_dtype) sinks = T.alloc_fragment([block_M], dtype) - T.annotate_layout({ - Q_shared: make_swizzled_layout(Q_shared), - K_shared: make_swizzled_layout(K_shared), - V_shared: make_swizzled_layout(V_shared), - O_shared: make_swizzled_layout(O_shared), - }) + T.annotate_layout( + { + Q_shared: make_swizzled_layout(Q_shared), + K_shared: make_swizzled_layout(K_shared), + V_shared: make_swizzled_layout(V_shared), + O_shared: make_swizzled_layout(O_shared), + } + ) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) for i in T.Parallel(block_M): sinks[i] = Sinks[by] - end = T.min( - T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) - start = T.max(0, (bx * block_M + past_len - window_size) // - block_N) if window_size is not None else 0 + start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0 for k in T.Pipelined(start, end, num_stages=num_stages): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i in T.Parallel(block_M): - logsum[i] += T.exp2(sinks[i] * 1.44269504 - - scores_max[i] * scale) # The only change for attention sink + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main # Modified from https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py -def ref_program(query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - sinks: torch.Tensor, - sliding_window: Optional[int] = None, - dtype: torch.dtype = torch.float16) -> torch.Tensor: - - query = query.transpose(1, 2).contiguous().unsqueeze( - 3) # align with the original function's interface +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function's interface key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() @@ -228,41 +227,35 @@ def ref_program(query: torch.Tensor, output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) - output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, - head_dim).to(dtype) + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) return output.transpose(1, 2).contiguous() -def gen_inputs( - B, - H, - Sq, - Skv, - D, - dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda') - key = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda') - value = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda') - sinks = torch.randn([H], dtype=dtype, device='cuda') +def gen_inputs(B, H, Sq, Skv, D, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda") + key = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda") + value = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda") + sinks = torch.randn([H], dtype=dtype, device="cuda") return query, key, value, sinks -def main(batch: int = 1, - heads: int = 1, - seq_q: int = 256, - seq_kv: int = 256, - dim: int = 128, - window_size: Optional[int] = None, - dtype: str = "float16", - tune: bool = False): +def main( + batch: int = 1, + heads: int = 1, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: str = "float16", + tune: bool = False, +): torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= seq_q - flops_per_matmul = 2.0 * batch * heads * min( - window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 total_flops = 2 * flops_per_matmul @@ -289,19 +282,17 @@ def main(batch: int = 1, block_N=block_N, num_stages=num_stages, threads=threads, - dtype=dtype) + dtype=dtype, + ) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) torch.testing.assert_close( - kernel(Q, K, V, sinks), - ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), - rtol=1e-2, - atol=1e-2) + kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ) print("All checks passed.✅") - latency = do_bench( - lambda: ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), warmup=500) + latency = do_bench(lambda: ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), warmup=500) print("Ref: {:.2f} ms".format(latency)) print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500) @@ -311,19 +302,13 @@ def main(batch: int = 1, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query') - parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") - parser.add_argument('--tune', action='store_true', help='tune') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune") args = parser.parse_args() - main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, - args.tune) + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune) diff --git a/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py b/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py index 64d6ec6981ca63d9e3f785e05abd7fdd034b9374..decdc8f4f624ed4c1c16fd0d316047290d9947c2 100644 --- a/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py +++ b/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py @@ -19,28 +19,30 @@ def get_configs(): @autotune(configs=get_configs(), warmup=500, rep=100) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn( - batch, - heads, - seq_q, - seq_kv, - dim, - window_size=None, # None for full attention - sm_scale=None, - block_M=128, - block_N=128, - num_stages=2, - threads=256, - dtype: str = "float16"): - + batch, + heads, + seq_q, + seq_kv, + dim, + window_size=None, # None for full attention + sm_scale=None, + block_M=128, + block_N=128, + num_stages=2, + threads=256, + dtype: str = "float16", +): if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" if sm_scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] @@ -61,13 +63,12 @@ def flashattn( by: T.int32, bz: T.int32, ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) for i, j in T.Parallel(block_M, block_N): q_idx = bx * block_M + i + past_len k_idx = k * block_N + j if window_size is not None: - acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) else: acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -82,18 +83,18 @@ def flashattn( by: T.int32, bz: T.int32, ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -105,8 +106,7 @@ def flashattn( # NOTE(wt): check_inf is necessary for sliding window attention. for i in T.Parallel(block_M): if window_size is not None: - scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, - scores_max[i]) + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, block_N): @@ -121,19 +121,19 @@ def flashattn( @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), - Sinks: T.Tensor([heads], dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + Sinks: T.Tensor([heads], dtype), ): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -150,60 +150,59 @@ def flashattn( logsum = T.alloc_fragment([block_M], accum_dtype) sinks = T.alloc_fragment([block_M], dtype) - T.annotate_layout({ - Q_shared: make_swizzled_layout(Q_shared), - K_shared: make_swizzled_layout(K_shared), - V_shared: make_swizzled_layout(V_shared), - O_shared: make_swizzled_layout(O_shared), - }) + T.annotate_layout( + { + Q_shared: make_swizzled_layout(Q_shared), + K_shared: make_swizzled_layout(K_shared), + V_shared: make_swizzled_layout(V_shared), + O_shared: make_swizzled_layout(O_shared), + } + ) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) for i in T.Parallel(block_M): sinks[i] = Sinks[by] - end = T.min( - T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) - start = T.max(0, (bx * block_M + past_len - window_size) // - block_N) if window_size is not None else 0 + start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0 for k in T.Pipelined( - start, - end, - num_stages=num_stages, - order=[-1, 0, 3, 1, -1, 2], - stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]): + start, + end, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i in T.Parallel(block_M): - logsum[i] += T.exp2(sinks[i] * 1.44269504 - - scores_max[i] * scale) # The only change for attention sink + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main # Following functions are adapted and optimized from # https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py -def ref_program(query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - sinks: torch.Tensor, - sliding_window: Optional[int] = None, - dtype: torch.dtype = torch.float16) -> torch.Tensor: - - query = query.transpose(1, 2).contiguous().unsqueeze( - 3) # align with the original function'sinterface +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function'sinterface key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() @@ -238,41 +237,35 @@ def ref_program(query: torch.Tensor, output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) - output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, - head_dim).to(dtype) + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) return output.transpose(1, 2).contiguous() -def gen_inputs( - B, - H, - Sq, - Skv, - D, - dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda') - key = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda') - value = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda') - sinks = torch.randn([H], dtype=dtype, device='cuda') +def gen_inputs(B, H, Sq, Skv, D, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda") + key = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda") + value = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda") + sinks = torch.randn([H], dtype=dtype, device="cuda") return query, key, value, sinks -def main(batch: int = 1, - heads: int = 32, - seq_q: int = 256, - seq_kv: int = 256, - dim: int = 128, - window_size: Optional[int] = None, - dtype: str = "float16", - tune: bool = False): +def main( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: str = "float16", + tune: bool = False, +): torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] if window_size is not None: - print('Using sliding window attention.') + print("Using sliding window attention.") assert window_size <= seq_q - flops_per_matmul = 2.0 * batch * heads * min( - window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation else: - print('Using full attention.') + print("Using full attention.") flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 total_flops = 2 * flops_per_matmul @@ -299,15 +292,14 @@ def main(batch: int = 1, block_N=block_N, num_stages=num_stages, threads=threads, - dtype=dtype) + dtype=dtype, + ) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) torch.testing.assert_close( - kernel(Q, K, V, sinks), - ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), - rtol=1e-2, - atol=1e-2) + kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ) print("All checks passed.✅") latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500) @@ -317,19 +309,13 @@ def main(batch: int = 1, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query') - parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument( - '--window_size', - type=int, - default=None, - help='window size (default: None, which means full attention)') - parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") - parser.add_argument('--tune', action='store_true', help='tune') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune") args = parser.parse_args() - main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, - args.tune) + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune) diff --git a/examples/bitnet-1.58b/benchmark_generate.py b/examples/bitnet-1.58b/benchmark_generate.py index d6f21ed502772af054ede58d8edce25a2f5879a0..d678b91a4e1c970e2209d2dfc0a102af4c3cf81b 100644 --- a/examples/bitnet-1.58b/benchmark_generate.py +++ b/examples/bitnet-1.58b/benchmark_generate.py @@ -12,8 +12,7 @@ bitblas.set_log_level("INFO") def generate_text_batch(model, tokenizer, prompts, max_length=100): # Encode the input prompts as a batch - input_ids = tokenizer( - prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device) + input_ids = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device) # Generate cos and sin values (commented out as not used in generation) seq_length = input_ids.size(1) @@ -37,9 +36,7 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100): end_time = time.time() # Decode the output ids to text - generated_texts = [ - tokenizer.decode(output_id, skip_special_tokens=True) for output_id in output_ids - ] + generated_texts = [tokenizer.decode(output_id, skip_special_tokens=True) for output_id in output_ids] generation_time = end_time - start_time num_tokens = sum(len(output_id) for output_id in output_ids) @@ -52,8 +49,8 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100): def profile(model, input_data): - import numpy as np + model = model.cuda() model.eval() @@ -74,25 +71,29 @@ def profile(model, input_data): return np.mean(times) -model_path = '1bitLLM/bitnet_b1_58-3B' +model_path = "1bitLLM/bitnet_b1_58-3B" def main(): parser = argparse.ArgumentParser() - parser.add_argument('--bs', default=16, type=int) - parser.add_argument('--in_seq_len', default=32, type=int) - parser.add_argument('--out_seq_len', default=128, type=int) - parser.add_argument('--bitblas', action='store_true') + parser.add_argument("--bs", default=16, type=int) + parser.add_argument("--in_seq_len", default=32, type=int) + parser.add_argument("--out_seq_len", default=128, type=int) + parser.add_argument("--bitblas", action="store_true") args = parser.parse_args() bs = args.bs in_seq_len = args.in_seq_len out_seq_len = args.out_seq_len is_bitblas = args.bitblas - model = BitnetForCausalLM.from_pretrained( - model_path, - use_flash_attention_2=True, - torch_dtype=torch.float16, - ).cuda().half() + model = ( + BitnetForCausalLM.from_pretrained( + model_path, + use_flash_attention_2=True, + torch_dtype=torch.float16, + ) + .cuda() + .half() + ) if is_bitblas: with torch.no_grad(): model.quantize() @@ -109,5 +110,5 @@ def main(): print(generate_text_batch(model, tokenizer, prompts, max_length=max_length)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/bitnet-1.58b/benchmark_inference_latency.py b/examples/bitnet-1.58b/benchmark_inference_latency.py index 9ce7a3898cba0cb6bdcf09677da2a6123f075a2d..788fc5565d5d58b59ef11a11b33f357e911ba9bc 100644 --- a/examples/bitnet-1.58b/benchmark_inference_latency.py +++ b/examples/bitnet-1.58b/benchmark_inference_latency.py @@ -6,13 +6,14 @@ from modeling_bitnet import BitnetForCausalLM torch.set_grad_enabled(False) parser = argparse.ArgumentParser() -parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str) +parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) def profile(model, input_data): import time import numpy as np + model = model.cuda() model.eval() @@ -35,8 +36,8 @@ def profile(model, input_data): def main(): model = BitnetForCausalLM.from_pretrained( - '1bitLLM/bitnet_b1_58-3B', - device_map='auto', + "1bitLLM/bitnet_b1_58-3B", + device_map="auto", low_cpu_mem_usage=True, use_flash_attention_2=True, torch_dtype=torch.float16, @@ -52,5 +53,5 @@ def main(): print(f"Batch size: {batch_size}, Seq len: {seq_len}, Latency: {latency}") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/bitnet-1.58b/configuration_bitnet.py b/examples/bitnet-1.58b/configuration_bitnet.py index 5f4937b87bf483e8c75fd4f7ba8725a845faa512..63c499db36d96d50f567794bf80a60882e08114f 100644 --- a/examples/bitnet-1.58b/configuration_bitnet.py +++ b/examples/bitnet-1.58b/configuration_bitnet.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" LLaMA model configuration""" +"""LLaMA model configuration""" from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging @@ -180,16 +180,10 @@ class BitnetConfig(PretrainedConfig): return if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: - raise ValueError( - "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " - f"got {self.rope_scaling}") + raise ValueError(f"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, got {self.rope_scaling}") rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: - raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" - ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, - float) or rope_scaling_factor <= 1.0: - raise ValueError( - f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") + raise ValueError(f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}") + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/examples/bitnet-1.58b/eval_correctness.py b/examples/bitnet-1.58b/eval_correctness.py index ac1e340729fc46b0bed8957bb26aedd20d5af2c9..11d47004b81edf517d442cb0eb2b70e6c583cce0 100644 --- a/examples/bitnet-1.58b/eval_correctness.py +++ b/examples/bitnet-1.58b/eval_correctness.py @@ -47,8 +47,8 @@ def generate_text(model, tokenizer, prompt, max_length=100): def profile(model, input_data): - import numpy as np + model = model.cuda() model.eval() @@ -69,18 +69,22 @@ def profile(model, input_data): return np.mean(times) -model_path = '1bitLLM/bitnet_b1_58-3B' +model_path = "1bitLLM/bitnet_b1_58-3B" def main(): - model = BitnetForCausalLM.from_pretrained( - model_path, - use_flash_attention_2=False, - torch_dtype=torch.float16, - ).cuda().half() + model = ( + BitnetForCausalLM.from_pretrained( + model_path, + use_flash_attention_2=False, + torch_dtype=torch.float16, + ) + .cuda() + .half() + ) tokenizer = BitnetTokenizer.from_pretrained(model_path, use_fast=False) - input_id = tokenizer("Hello")['input_ids'] + input_id = tokenizer("Hello")["input_ids"] input_id = torch.tensor(input_id).unsqueeze(0).cuda() print("original model generated text:") @@ -91,5 +95,5 @@ def main(): print(generate_text(model, tokenizer, "Hello", max_length=100)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/bitnet-1.58b/eval_gpu_memory.py b/examples/bitnet-1.58b/eval_gpu_memory.py index 597cbbfcdaaaa1206ceacae45b235b6709f7f61e..00c914cb31c919fc536d0705f59cacf29a30e287 100644 --- a/examples/bitnet-1.58b/eval_gpu_memory.py +++ b/examples/bitnet-1.58b/eval_gpu_memory.py @@ -6,13 +6,14 @@ from modeling_bitnet import BitnetForCausalLM torch.set_grad_enabled(False) parser = argparse.ArgumentParser() -parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str) +parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) def profile(model, input_data): import time import numpy as np + model = model.cuda() model.eval() @@ -35,17 +36,17 @@ def profile(model, input_data): def main(): model = BitnetForCausalLM.from_pretrained( - '1bitLLM/bitnet_b1_58-3B', - device_map='auto', + "1bitLLM/bitnet_b1_58-3B", + device_map="auto", low_cpu_mem_usage=True, use_flash_attention_2=True, torch_dtype=torch.float16, ).half() - print(f"gpu memory: {torch.cuda.memory_allocated() / 1024 ** 3} GB") + print(f"gpu memory: {torch.cuda.memory_allocated() / 1024**3} GB") with torch.no_grad(): model._post_process_weights() - print(f"gpu memory BitBLAS: {torch.cuda.memory_allocated() / 1024 ** 3} GB") + print(f"gpu memory BitBLAS: {torch.cuda.memory_allocated() / 1024**3} GB") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/bitnet-1.58b/eval_ppl.py b/examples/bitnet-1.58b/eval_ppl.py index 61c8488e46707b5e05f3982c8046618ded6b4098..97db2d0f5236f369a33f70ac1b07fe9a8c01df9d 100644 --- a/examples/bitnet-1.58b/eval_ppl.py +++ b/examples/bitnet-1.58b/eval_ppl.py @@ -15,9 +15,9 @@ from tqdm import tqdm torch.set_grad_enabled(False) parser = argparse.ArgumentParser() -parser.add_argument('--seed', default=0, type=int) -parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str) -parser.add_argument('--seqlen', default=2048, type=int) +parser.add_argument("--seed", default=0, type=int) +parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) +parser.add_argument("--seqlen", default=2048, type=int) def calulate_loss(model, input, loss_fct): @@ -29,12 +29,16 @@ def calulate_loss(model, input, loss_fct): def main(args): - datasets = ['c4', 'wikitext2'] - model = BitnetForCausalLM.from_pretrained( - args.hf_path, - use_flash_attention_2=True, - torch_dtype=torch.float16, - ).cuda().half() + datasets = ["c4", "wikitext2"] + model = ( + BitnetForCausalLM.from_pretrained( + args.hf_path, + use_flash_attention_2=True, + torch_dtype=torch.float16, + ) + .cuda() + .half() + ) with torch.no_grad(): model._post_process_weights() tokenizer = BitnetTokenizer.from_pretrained(args.hf_path, use_fast=False) @@ -48,9 +52,9 @@ def main(args): for ii in progress: input = torch.Tensor(testdata[ii]).long().cuda().view(1, -1) loss = calulate_loss(model, input, loss_fct) - count += (input.size(-1) - 1) + count += input.size(-1) - 1 acc_loss += loss.item() - progress.set_description(f"avg_loss = {acc_loss/ count / math.log(2)}") + progress.set_description(f"avg_loss = {acc_loss / count / math.log(2)}") avg_loss = acc_loss / count / math.log(2) ppl.append(2**avg_loss) @@ -60,7 +64,7 @@ def main(args): print("Avg PPL:", sum(ppl) / len(ppl)) -if __name__ == '__main__': +if __name__ == "__main__": torch.set_grad_enabled(False) args = parser.parse_args() random.seed(args.seed) diff --git a/examples/bitnet-1.58b/eval_utils.py b/examples/bitnet-1.58b/eval_utils.py index 46241eedf0a4224c7aae8ceb4a6549350d513fc2..72480c392a7cfa40081546d2da19aa31463aab76 100644 --- a/examples/bitnet-1.58b/eval_utils.py +++ b/examples/bitnet-1.58b/eval_utils.py @@ -15,21 +15,17 @@ def set_seed(seed): def get_test_dataset(dataset_name, tokenizer, seqlen=2048): if dataset_name == "wikitext2": - testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') - testdata = "".join(testdata['text']).split('\n') + testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + testdata = "".join(testdata["text"]).split("\n") elif dataset_name == "c4": - testdata = load_dataset( - 'allenai/c4', - data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, - split='validation')['text'] + testdata = load_dataset("allenai/c4", data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, split="validation")[ + "text" + ] else: raise NotImplementedError testdata = [item for item in testdata if item != ""] - tokenized_text = [ - tokenizer(item, add_special_tokens=False)['input_ids'] + [tokenizer.eos_token_id] - for item in testdata - ] + tokenized_text = [tokenizer(item, add_special_tokens=False)["input_ids"] + [tokenizer.eos_token_id] for item in testdata] data, doc = [], [tokenizer.bos_token_id] for sen in tokenized_text: @@ -45,7 +41,6 @@ def get_test_dataset(dataset_name, tokenizer, seqlen=2048): class LMEvalAdaptor(BaseLM): - def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1): super().__init__() @@ -137,5 +132,4 @@ class LMEvalAdaptor(BaseLM): return out def _model_generate(self, context, max_length, eos_token_id): - return self.model.generate( - context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False) + return self.model.generate(context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False) diff --git a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py index e5af16cc486a107c635ea1d7d88477cd479a4cb9..35a044e504e1b2fa0771e9dec6ee48ceb8eca480 100644 --- a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py +++ b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py @@ -102,17 +102,17 @@ def bitnet_158_int8xint2_decode( @T.prim_func def kernel( - A: T.Buffer(A_shape, in_dtype), - B: T.Buffer(B_shape, storage_dtype), - C: T.Buffer(C_shape, out_dtype), + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer(C_shape, out_dtype), ): with T.Kernel( - T.ceildiv(N, n_partition), - M, - threads=(reduce_thread, n_partition), + T.ceildiv(N, n_partition), + M, + threads=(reduce_thread, n_partition), ) as ( - bx, - by, + bx, + by, ): A_local = T.alloc_local((micro_size_k,), in_dtype) B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype) @@ -133,8 +133,7 @@ def bitnet_158_int8xint2_decode( for v in T.vectorized(micro_size_k_compressed): B_quant_local[v] = B[ bx * n_partition + ni, - ko * (reduce_thread * micro_size_k_compressed) + - kr * micro_size_k_compressed + v, + ko * (reduce_thread * micro_size_k_compressed) + kr * micro_size_k_compressed + v, ] T.call_extern( @@ -156,9 +155,9 @@ def bitnet_158_int8xint2_decode( accum_res[0] += A_local[ki] * B_dequantize_local[ki] with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), - "reduce_scope", - T.reinterpret(T.uint64(0), dtype="handle"), + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), ): T.evaluate( T.tvm_thread_allreduce( @@ -168,7 +167,8 @@ def bitnet_158_int8xint2_decode( reduced_accum_res[0], kr, dtype="handle", - )) + ) + ) if kr == 0: C[by, bx * n_partition + ni] = reduced_accum_res[0] @@ -234,13 +234,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): return new_qweight.view(np.int8) -def assert_bitnet_158_int8xint2_decode_correctness(M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - fast_decoding=True): +def assert_bitnet_158_int8xint2_decode_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding=True): program = bitnet_158_int8xint2_decode(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding) print(program) kernel = tilelang.compile(program) diff --git a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py index d8b1f6228e7d0a386e71db3d4482dfe3e5290873..d68a01286d2f86be9ce7d46de7b15bc66a405404 100644 --- a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py +++ b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py @@ -8,11 +8,13 @@ import tilelang.language as T from tilelang import tvm as tvm from tvm import DataType from tilelang.intrinsics.mma_layout import ( - make_mma_swizzle_layout as make_swizzle_layout,) + make_mma_swizzle_layout as make_swizzle_layout, +) import numpy as np from tilelang.intrinsics.mma_macro_generator import ( - INT4TensorCoreIntrinEmitter,) + INT4TensorCoreIntrinEmitter, +) from tilelang.transform import simplify_prim_func torch.manual_seed(42) @@ -181,38 +183,36 @@ def bitnet_158_int8xint2_prefill( @T.prim_func def main( - A: T.Buffer(A_shape, in_dtype), - B: T.Buffer(B_shape, storage_dtype), - C: T.Buffer((M, N), out_dtype), + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer((M, N), out_dtype), ): """ - GPU kernel entry that performs a blocked, pipelined matrix multiplication A @ B.T writing into C. + GPU kernel entry that performs a blocked, pipelined matrix multiplication A @ B.T writing into C. - This kernel: - - Loads tiles of A and a compressed/interleaved representation of B from global memory into shared memory. - - Decodes B's packed low-precision format (storage_dtype, e.g., 2-bit packed) into element values of `in_dtype` in shared memory via an external decode routine. - - Uses Warp/MMA tiled fragments and an INT4/INT2-capable MMA emitter to compute accumulation across K in a pipelined fashion with configurable stages. - - Writes accumulated tile results from shared memory back to global C with the expected block/micro-tile indexing. + This kernel: + - Loads tiles of A and a compressed/interleaved representation of B from global memory into shared memory. + - Decodes B's packed low-precision format (storage_dtype, e.g., 2-bit packed) into element values of `in_dtype` in shared memory via an external decode routine. + - Uses Warp/MMA tiled fragments and an INT4/INT2-capable MMA emitter to compute accumulation across K in a pipelined fashion with configurable stages. + - Writes accumulated tile results from shared memory back to global C with the expected block/micro-tile indexing. - Parameters: - A: Input matrix buffer of shape A_shape and element type `in_dtype`. Represents the MxK activations. - B: Compressed/interleaved weight buffer of shape B_shape and storage type `storage_dtype`. Must contain B in the packed low-precision layout expected by the decode routine used by this kernel. - C: Output buffer of shape (M, N) and type `out_dtype`; receives the resulting matrix (accumulated values are produced in `accum_dtype` and stored into C). + Parameters: + A: Input matrix buffer of shape A_shape and element type `in_dtype`. Represents the MxK activations. + B: Compressed/interleaved weight buffer of shape B_shape and storage type `storage_dtype`. Must contain B in the packed low-precision layout expected by the decode routine used by this kernel. + C: Output buffer of shape (M, N) and type `out_dtype`; receives the resulting matrix (accumulated values are produced in `accum_dtype` and stored into C). - Side effects: - Writes results into C. Calls external device decode functions to expand B from its packed representation into shared memory before computation. + Side effects: + Writes results into C. Calls external device decode functions to expand B from its packed representation into shared memory before computation. """ with T.Kernel( - T.ceildiv(N, block_N), - T.ceildiv(M, block_M), - threads=threads, - prelude=decode_i2s_to_i8s, + T.ceildiv(N, block_N), + T.ceildiv(M, block_M), + threads=threads, + prelude=decode_i2s_to_i8s, ) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) - B_dequantize_shared = T.alloc_shared( - B_dequantize_shared_shape, in_dtype, scope=shared_scope) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype) B_frag = T.alloc_local((warp_cols * fragement_size_b), in_dtype) @@ -223,10 +223,12 @@ def bitnet_158_int8xint2_prefill( thread_bindings = T.thread_binding(0, threads, "threadIdx.x") - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_dequantize_shared: make_swizzle_layout(B_dequantize_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_dequantize_shared: make_swizzle_layout(B_dequantize_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -234,7 +236,6 @@ def bitnet_158_int8xint2_prefill( T.clear(C_frag) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -243,12 +244,9 @@ def bitnet_158_int8xint2_prefill( for j, k in T.Parallel(block_N, block_K // num_elems_per_byte): B_shared[j, k] = B[bx * block_N + j, ko * (block_K // num_elems_per_byte) + k] - for i in T.serial(block_N * block_K // num_elems_per_byte // - (threads * local_size_compressed)): + for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * local_size_compressed)): for v in T.vectorized(0, local_size_compressed): - index = ( - i * threads * local_size_compressed + - thread_bindings * local_size_compressed + v) + index = i * threads * local_size_compressed + thread_bindings * local_size_compressed + v vi, vj = T.index_to_coordinates(index, B_shared_shape) B_local[v] = B_shared[vi, vj] @@ -260,12 +258,11 @@ def bitnet_158_int8xint2_prefill( ) for v in T.vectorized(0, local_size): - index = (i * threads * local_size + thread_bindings * local_size + v) + index = i * threads * local_size + thread_bindings * local_size + v vi, vj = T.index_to_coordinates(index, B_dequantize_shared_shape) B_dequantize_shared[vi, vj] = B_dequantize_local[v] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_frag, @@ -360,13 +357,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): return new_qweight.view(np.int8) -def assert_bitnet_158_int8xint2_prefill_correctness(M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - fast_decoding=True): +def assert_bitnet_158_int8xint2_prefill_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding=True): program = bitnet_158_int8xint2_prefill(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding) print(program) kernel = tilelang.compile(program) diff --git a/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py b/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py index 986463598846fafffc4fc38b35a4137e23edbeaa..f2a0e2e7ef3e28e52aa347fc1aaebf683f87701e 100644 --- a/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py +++ b/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py @@ -6,7 +6,8 @@ from tvm import tl as TL import tvm.tl.language as T from bitblas.tl.utils import get_swizzle_layout from bitblas.tl.mma_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from bitblas.base import simplify_prim_func torch.manual_seed(0) @@ -101,12 +102,11 @@ def tl_matmul( @T.prim_func def main( - A: T.Buffer(A_shape, in_dtype), - B: T.Buffer(B_shape, in_dtype), - C: T.Buffer((M, N), out_dtype), + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -116,10 +116,12 @@ def tl_matmul( thread_bindings = T.thread_binding(0, threads, "threadIdx.x") - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -127,7 +129,6 @@ def tl_matmul( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -137,7 +138,6 @@ def tl_matmul( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, diff --git a/examples/bitnet-1.58b/load_from_quantized.py b/examples/bitnet-1.58b/load_from_quantized.py index 26a32f9747f611f26a63d2a24b35343b1e55d9f0..8c775aa4c8e819ee3ac800fce4ebe0452fac54be 100644 --- a/examples/bitnet-1.58b/load_from_quantized.py +++ b/examples/bitnet-1.58b/load_from_quantized.py @@ -49,7 +49,13 @@ def generate_text(model, tokenizer, prompt, max_length=100): def main(): # load quantized model - qmodel = BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half() + qmodel = ( + BitnetForCausalLM.from_quantized( + saved_model_path, + ) + .cuda() + .half() + ) tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False) # print("original model generated text:") # print(generate_text(model, tokenizer, "Hi, ", max_length=100)) diff --git a/examples/bitnet-1.58b/maint/create_bitblas_ckpt.py b/examples/bitnet-1.58b/maint/create_bitblas_ckpt.py index 1e29a553abf75e4a73937a80564d33815f0b983f..2604ef38770fa58fa80cf87709e0b205eae26ecd 100644 --- a/examples/bitnet-1.58b/maint/create_bitblas_ckpt.py +++ b/examples/bitnet-1.58b/maint/create_bitblas_ckpt.py @@ -25,9 +25,9 @@ parser.add_argument("--saved_model_path", type=str, default=None) args = parser.parse_args() model_name_or_path = args.model_name_or_path -saved_model_path = os.path.join( - dirpath, "models", - f"{model_name_or_path}_bitblas") if args.saved_model_path is None else args.saved_model_path +saved_model_path = ( + os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas") if args.saved_model_path is None else args.saved_model_path +) def generate_text(model, tokenizer, prompt, max_length=100): @@ -67,7 +67,10 @@ def main(): model_name_or_path, use_flash_attention_2=False, torch_dtype=torch.float16, - ).cuda().half()) + ) + .cuda() + .half() + ) tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False) # print("original model generated text:") @@ -112,10 +115,16 @@ def main(): file_path = cached_file(model_name_or_path, file) os.system(f"cp {file_path} {saved_model_path}") # load quantized model - qmodel = BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half() + qmodel = ( + BitnetForCausalLM.from_quantized( + saved_model_path, + ) + .cuda() + .half() + ) print("quantized model generated text:") print(generate_text(qmodel, tokenizer, "Hi, ", max_length=100)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/bitnet-1.58b/modeling_bitnet.py b/examples/bitnet-1.58b/modeling_bitnet.py index 6e3c42b6f9968bb2f76e20609fd48a1d9994f171..1830995ee6177536089fe517646b290c18bb05f2 100644 --- a/examples/bitnet-1.58b/modeling_bitnet.py +++ b/examples/bitnet-1.58b/modeling_bitnet.py @@ -64,8 +64,7 @@ def find_layers(module, layers=None, name=""): return {name: module} res = {} for name1, child in module.named_children(): - res.update( - find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1)) + res.update(find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1)) return res @@ -87,7 +86,6 @@ def _get_unpad_data(attention_mask): class BitnetRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): """ BitnetRMSNorm is equivalent to T5LayerNorm @@ -108,34 +106,23 @@ ALL_LAYERNORM_LAYERS.append(BitnetRMSNorm) class BitnetRotaryEmbedding(nn.Module): - - def __init__(self, - dim, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): super().__init__() self.scaling_factor = scaling_factor self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / ( - self.base - **(torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq) # For BC we register cos and sin cached self.max_seq_len_cached = max_position_embeddings - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) t = t / self.scaling_factor freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer( - "_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) - self.register_buffer( - "_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) @property def sin_cached(self): @@ -156,14 +143,12 @@ class BitnetRotaryEmbedding(nn.Module): @torch.no_grad() def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] - inv_freq_expanded = self.inv_freq[None, :, - None].float().expand(position_ids.shape[0], -1, 1) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type - device_type = device_type if isinstance(device_type, - str) and device_type != "mps" else "cpu" + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) @@ -174,8 +159,8 @@ class BitnetRotaryEmbedding(nn.Module): def rotate_half(x): """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) @@ -207,7 +192,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class BitnetMLP(nn.Module): - def __init__(self, config): super().__init__() self.config = config @@ -245,7 +229,6 @@ class BitnetMLP(nn.Module): class BitnetMLPFuseGateUp(nn.Module): - def __init__(self, config): super().__init__() self.config = config @@ -272,8 +255,7 @@ class BitnetMLPFuseGateUp(nn.Module): def from_bit_mlp(cls, bit_mlp: BitnetMLP): module = cls(bit_mlp.config) # assign the weights - module.gate_up_proj.weight = nn.Parameter( - torch.cat([bit_mlp.gate_proj.weight, bit_mlp.up_proj.weight], dim=0)) + module.gate_up_proj.weight = nn.Parameter(torch.cat([bit_mlp.gate_proj.weight, bit_mlp.up_proj.weight], dim=0)) module.down_proj = bit_mlp.down_proj module.ffn_layernorm = bit_mlp.ffn_layernorm return module @@ -295,8 +277,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, - head_dim) + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) @@ -311,7 +292,8 @@ class BitnetAttention(nn.Module): logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class.") + "when creating this class." + ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size @@ -325,8 +307,8 @@ class BitnetAttention(nn.Module): if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads}).") + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads})." + ) self.q_proj = BitLinear( self.hidden_size, @@ -387,10 +369,8 @@ class BitnetAttention(nn.Module): value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) past_key_value = getattr(self, "past_key_value", past_key_value) cos, sin = self.rotary_emb(value_states, position_ids) @@ -399,30 +379,24 @@ class BitnetAttention(nn.Module): if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( - self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, :key_states.shape[-2]] + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout( - attn_weights, p=self.attention_dropout, training=self.training) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}") + raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is {attn_output.size()}") attn_output = attn_output.transpose(1, 2).contiguous() @@ -448,7 +422,8 @@ class BitnetAttentionQKVFused(nn.Module): logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class.") + "when creating this class." + ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size @@ -462,8 +437,8 @@ class BitnetAttentionQKVFused(nn.Module): if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads}).") + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads})." + ) self.qkv_proj = BitLinear( self.hidden_size, @@ -497,17 +472,12 @@ class BitnetAttentionQKVFused(nn.Module): module = cls(bit_attention.config, bit_attention.layer_idx) # assign the weights module.qkv_proj.weight = nn.Parameter( - torch.cat([ - bit_attention.q_proj.weight, bit_attention.k_proj.weight, - bit_attention.v_proj.weight - ], - dim=0)) + torch.cat([bit_attention.q_proj.weight, bit_attention.k_proj.weight, bit_attention.v_proj.weight], dim=0) + ) if bit_attention.q_proj.bias is not None and bit_attention.k_proj.bias is not None and bit_attention.v_proj.bias is not None: module.qkv_proj.bias = nn.Parameter( - torch.cat([ - bit_attention.q_proj.bias, bit_attention.k_proj.bias, bit_attention.v_proj.bias - ], - dim=0)) + torch.cat([bit_attention.q_proj.bias, bit_attention.k_proj.bias, bit_attention.v_proj.bias], dim=0) + ) module.o_proj = bit_attention.o_proj module.inner_attn_ln = bit_attention.inner_attn_ln if bit_attention.config.rope_scaling is None: @@ -528,16 +498,13 @@ class BitnetAttentionQKVFused(nn.Module): bsz, q_len, _ = hidden_states.size() qkv_states = self.qkv_proj(hidden_states) query_states, key_states, value_states = torch.split( - qkv_states, [ - self.num_heads * self.head_dim, self.num_key_value_heads * self.head_dim, - self.num_key_value_heads * self.head_dim - ], - dim=-1) + qkv_states, + [self.num_heads * self.head_dim, self.num_key_value_heads * self.head_dim, self.num_key_value_heads * self.head_dim], + dim=-1, + ) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) past_key_value = getattr(self, "past_key_value", past_key_value) cos, sin = self.rotary_emb(value_states, position_ids) @@ -546,30 +513,24 @@ class BitnetAttentionQKVFused(nn.Module): if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( - self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, :key_states.shape[-2]] + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout( - attn_weights, p=self.attention_dropout, training=self.training) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}") + raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is {attn_output.size()}") attn_output = attn_output.transpose(1, 2).contiguous() @@ -622,10 +583,8 @@ class BitnetFlashAttention2(BitnetAttention): # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -635,8 +594,7 @@ class BitnetFlashAttention2(BitnetAttention): if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. @@ -665,14 +623,14 @@ class BitnetFlashAttention2(BitnetAttention): logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}.") + f" {target_dtype}." + ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate) + attn_output = self._flash_attention_forward(query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.inner_attn_ln(attn_output) @@ -683,14 +641,9 @@ class BitnetFlashAttention2(BitnetAttention): return attn_output, attn_weights, past_key_value - def _flash_attention_forward(self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None): + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token first unpad the input, then computes the attention scores and pad the final attention scores. @@ -720,7 +673,8 @@ class BitnetFlashAttention2(BitnetAttention): if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length) + query_states, key_states, value_states, attention_mask, query_length + ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens @@ -740,13 +694,7 @@ class BitnetFlashAttention2(BitnetAttention): attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal) + attn_output = flash_attn_func(query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal) return attn_output @@ -754,28 +702,24 @@ class BitnetFlashAttention2(BitnetAttention): indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k) + query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, - device=query_layer.device) # There is a memcpy here, that is very bad. + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( - query_layer, attention_mask) + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) return ( query_layer, @@ -794,13 +738,11 @@ LLAMA_ATTENTION_CLASSES = { class BitnetDecoderLayer(nn.Module): - def __init__(self, config: BitnetConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation]( - config=config, layer_idx=layer_idx) + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) self.mlp = BitnetMLP(config) self.input_layernorm = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -834,7 +776,8 @@ class BitnetDecoderLayer(nn.Module): if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`", - stacklevel=2) + stacklevel=2, + ) residual = hidden_states @@ -925,8 +868,7 @@ class BitnetPreTrainedModel(PreTrainedModel): dtype = self.config._pre_quantization_dtype else: dtype = layer.self_attn.o_proj.weight.dtype - layer.self_attn.past_key_value = cache_cls( - self.config, max_batch_size, max_cache_len, device=device, dtype=dtype) + layer.self_attn.past_key_value = cache_cls(self.config, max_batch_size, max_cache_len, device=device, dtype=dtype) def _reset_cache(self): for layer in self.model.layers: @@ -1025,9 +967,7 @@ class BitnetModel(BitnetPreTrainedModel): self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([ - BitnetDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList([BitnetDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) self.norm = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -1055,21 +995,15 @@ class BitnetModel(BitnetPreTrainedModel): cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" - ) + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one") if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.") use_cache = False if inputs_embeds is None: @@ -1083,10 +1017,7 @@ class BitnetModel(BitnetPreTrainedModel): if cache_position is None: if isinstance(past_key_values, StaticCache): raise ValueError("cache_position is a required argument when using StaticCache.") - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + inputs_embeds.shape[1], - device=inputs_embeds.device) + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device) if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -1143,12 +1074,9 @@ class BitnetModel(BitnetPreTrainedModel): next_cache = None if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() - if isinstance(next_decoder_cache, Cache) else next_decoder_cache) + next_cache = next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None) + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -1172,14 +1100,9 @@ class BitnetModel(BitnetPreTrainedModel): if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache target_length = self.config.max_position_embeddings else: # dynamic cache - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1) - - causal_mask = torch.full((sequence_length, target_length), - fill_value=min_dtype, - dtype=dtype, - device=device) + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1 + + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) @@ -1188,10 +1111,8 @@ class BitnetModel(BitnetPreTrainedModel): causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.dim() == 2: mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[..., :mask_length].eq( - 0.0) * attention_mask[:, None, None, :].eq(0.0) - causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill( - padding_mask, min_dtype) + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) elif attention_mask.dim() == 4: # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. @@ -1201,8 +1122,7 @@ class BitnetModel(BitnetPreTrainedModel): offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[:mask_shape[0], :mask_shape[1], - offset:mask_shape[2] + offset, :mask_shape[3]] = mask_slice + causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = mask_slice return causal_mask @@ -1279,9 +1199,7 @@ class BitnetForCausalLM(BitnetPreTrainedModel): "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) @@ -1327,13 +1245,9 @@ class BitnetForCausalLM(BitnetPreTrainedModel): attentions=outputs.attentions, ) - def prepare_inputs_for_generation(self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - **kwargs): + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs + ): # With static cache, the `past_key_values` is None # TODO joao: standardize interface for the different Cache classes and remove of this if has_static_cache = False @@ -1344,13 +1258,13 @@ class BitnetForCausalLM(BitnetPreTrainedModel): past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): - past_length = cache_position[ - 0] if cache_position is not None else past_key_values.get_seq_length() + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() max_cache_length = ( torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None else None) - cache_length = past_length if max_cache_length is None else torch.min( - max_cache_length, past_length) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] @@ -1361,7 +1275,7 @@ class BitnetForCausalLM(BitnetPreTrainedModel): # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):] + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: @@ -1369,8 +1283,7 @@ class BitnetForCausalLM(BitnetPreTrainedModel): # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if (max_cache_length is not None and attention_mask is not None and - cache_length + input_ids.shape[1] > max_cache_length): + if max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length: attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids") @@ -1379,7 +1292,7 @@ class BitnetForCausalLM(BitnetPreTrainedModel): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1]:] + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: @@ -1392,39 +1305,38 @@ class BitnetForCausalLM(BitnetPreTrainedModel): input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] if cache_position is None: - cache_position = torch.arange( - past_length, past_length + input_length, device=input_ids.device) + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) else: cache_position = cache_position[-input_length:] if has_static_cache: past_key_values = None - model_inputs.update({ - "position_ids": position_ids, - "cache_position": cache_position, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - }) + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) return model_inputs @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple( - past_state.index_select(0, beam_idx.to(past_state.device)) - for past_state in layer_past),) + reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),) return reordered_past @staticmethod def recursive_set(model, name, attr): - ''' - set layers.25.mlp.up_proj to attr - ''' + """ + set layers.25.mlp.up_proj to attr + """ - names = name.split('.') + names = name.split(".") obj = model for n in names[:-1]: obj = getattr(obj, n) @@ -1521,6 +1433,7 @@ class BitnetForCausalLM(BitnetPreTrainedModel): fuse_gateup = quant_config.get("fuse_gateup", True) import accelerate + if checkpoint_format == "bitblas": model = cls(config) for name, module in model.named_modules(): @@ -1567,7 +1480,6 @@ class BitnetForCausalLM(BitnetPreTrainedModel): LLAMA_START_DOCSTRING, ) class BitnetForSequenceClassification(BitnetPreTrainedModel): - def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels @@ -1631,8 +1543,7 @@ class BitnetForSequenceClassification(BitnetPreTrainedModel): else: if input_ids is not None: # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, - self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 sequence_lengths = sequence_lengths % input_ids.shape[-1] sequence_lengths = sequence_lengths.to(logits.device) else: @@ -1646,8 +1557,7 @@ class BitnetForSequenceClassification(BitnetPreTrainedModel): if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or - labels.dtype == torch.int): + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" diff --git a/examples/bitnet-1.58b/tokenization_bitnet.py b/examples/bitnet-1.58b/tokenization_bitnet.py index 6fea3252a9d85c76150888c0187b0a429d4cdfa8..2adfd6dee10e6d0fba443e14c7b828e73b378554 100644 --- a/examples/bitnet-1.58b/tokenization_bitnet.py +++ b/examples/bitnet-1.58b/tokenization_bitnet.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tokenization classes for LLaMA.""" + import os from shutil import copyfile from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple @@ -37,12 +38,10 @@ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "hf-internal-testing/llama-tokenizer": - "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", }, "tokenizer_file": { - "hf-internal-testing/llama-tokenizer": - "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", }, } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { @@ -159,14 +158,10 @@ class BitnetTokenizer(PreTrainedTokenizer): **kwargs, ): self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs - bos_token = AddedToken( - bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token - eos_token = AddedToken( - eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token - unk_token = AddedToken( - unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token - pad_token = AddedToken( - pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token + bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token if legacy is None: logger.warning_once( @@ -174,7 +169,8 @@ class BitnetTokenizer(PreTrainedTokenizer): " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you." " If you want to use the new behavior, set `legacy=False`. This should only be set if you understand what it" " means, and thoroughly read the reason why this was added as explained in" - " https://github.com/huggingface/transformers/pull/24565") + " https://github.com/huggingface/transformers/pull/24565" + ) legacy = True self.legacy = legacy @@ -214,8 +210,7 @@ class BitnetTokenizer(PreTrainedTokenizer): with open(self.vocab_file, "rb") as f: sp_model = f.read() - model_pb2 = import_protobuf( - f"The new behavior of {self.__class__.__name__} (with `self.legacy = False`)") + model_pb2 = import_protobuf(f"The new behavior of {self.__class__.__name__} (with `self.legacy = False`)") model = model_pb2.ModelProto.FromString(sp_model) normalizer_spec = model_pb2.NormalizerSpec() normalizer_spec.add_dummy_prefix = False @@ -261,8 +256,7 @@ class BitnetTokenizer(PreTrainedTokenizer): tokens = super().tokenize(text, **kwargs) - if len(tokens - ) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: tokens = tokens[1:] return tokens @@ -284,7 +278,7 @@ class BitnetTokenizer(PreTrainedTokenizer): # 1. Encode string + prefix ex: " Hey" tokens = self.sp_model.encode(self.unk_token + text, out_type=str) # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] - return tokens[self.unk_token_length:] if len(tokens) >= self.unk_token_length else tokens + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" @@ -332,12 +326,9 @@ class BitnetTokenizer(PreTrainedTokenizer): if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") return - out_vocab_file = os.path.join(save_directory, - (filename_prefix + "-" if filename_prefix else "") + - VOCAB_FILES_NAMES["vocab_file"]) + out_vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]) - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile( - self.vocab_file): + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) elif not os.path.isfile(self.vocab_file): with open(out_vocab_file, "wb") as fi: @@ -357,10 +348,9 @@ class BitnetTokenizer(PreTrainedTokenizer): return output - def get_special_tokens_mask(self, - token_ids_0: List[int], - token_ids_1: Optional[List[int]] = None, - already_has_special_tokens: bool = False) -> List[int]: + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: """ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer `prepare_for_model` method. @@ -377,20 +367,16 @@ class BitnetTokenizer(PreTrainedTokenizer): `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. """ if already_has_special_tokens: - return super().get_special_tokens_mask( - token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True) + return super().get_special_tokens_mask(token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True) bos_token_id = [1] if self.add_bos_token else [] eos_token_id = [1] if self.add_eos_token else [] if token_ids_1 is None: return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id - return (bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + bos_token_id + - ([0] * len(token_ids_1)) + eos_token_id) + return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + bos_token_id + ([0] * len(token_ids_1)) + eos_token_id - def create_token_type_ids_from_sequences(self, - token_ids_0: List[int], - token_ids_1: Optional[List[int]] = None) -> List[int]: + def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: """ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT sequence pair mask has the following format: @@ -473,9 +459,9 @@ class BitnetTokenizer(PreTrainedTokenizer): "{% elif message['role'] == 'assistant' %}" "{{ ' ' + content.strip() + ' ' + eos_token }}" "{% endif %}" - "{% endfor %}") - template = template.replace("USE_DEFAULT_PROMPT", - "true" if self.use_default_system_prompt else "false") + "{% endfor %}" + ) + template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false") default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'") template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) diff --git a/examples/bitnet-1.58b/utils_quant.py b/examples/bitnet-1.58b/utils_quant.py index 5f5db5dbc04919546ac7137f6dc83c26377b56d5..5a50edb392ead6d55c9e34f19409cfb94848f13a 100644 --- a/examples/bitnet-1.58b/utils_quant.py +++ b/examples/bitnet-1.58b/utils_quant.py @@ -24,15 +24,14 @@ def weight_quant(weight, num_bits=1): def activation_quant(x, num_bits=8): dtype = x.dtype x = x.float() - Qn = -(2**(num_bits - 1)) - Qp = 2**(num_bits - 1) - 1 + Qn = -(2 ** (num_bits - 1)) + Qp = 2 ** (num_bits - 1) - 1 s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) result = (x * s).round().clamp(Qn, Qp) / s return result.type(dtype) class BitLinearBitBLAS(nn.Module): - def __init__( self, in_features: int, @@ -68,7 +67,7 @@ class BitLinearBitBLAS(nn.Module): self.bitblas_matmul = self._get_or_create_bitblas_operator(matmul_config, ENABLE_TUNING) self.format = "bitnet" - self.Qp = 2**(self.input_bits - 1) - 1 + self.Qp = 2 ** (self.input_bits - 1) - 1 def _get_or_create_bitblas_operator(self, config, enable_tuning): if global_operator_cache.size() == 0: @@ -99,8 +98,7 @@ class BitLinearBitBLAS(nn.Module): @classmethod def from_bit_linear(cls, bitlinear, weight_group=1): - bitblas_linear = cls( - bitlinear.in_features, bitlinear.out_features, weight_bits=1, input_bits=8) + bitblas_linear = cls(bitlinear.in_features, bitlinear.out_features, weight_bits=1, input_bits=8) sw, qweight = bitblas_linear.create_bitblas_weights(bitlinear.weight, weight_group) bitblas_linear.register_buffer("qweight", qweight) bitblas_linear.register_buffer("sw", sw) @@ -158,8 +156,8 @@ class BitLinearBitBLAS(nn.Module): @torch.compile def activation_quant(self, x, num_bits=8): x = x.float() - Qn = -(2**(num_bits - 1)) - Qp = 2**(num_bits - 1) - 1 + Qn = -(2 ** (num_bits - 1)) + Qp = 2 ** (num_bits - 1) - 1 s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) result = (x * s).round().clamp(Qn, Qp) return result.type(torch.int8), s @@ -173,9 +171,8 @@ class BitLinearBitBLAS(nn.Module): # for the correctness evaluation. def native_forward(self, input): - quant_input = (input + (activation_quant(input, self.input_bits) - input).detach()) - quant_weight = ( - self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach()) + quant_input = input + (activation_quant(input, self.input_bits) - input).detach() + quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach() out = nn.functional.linear(quant_input, quant_weight) if self.bias is not None: @@ -214,7 +211,6 @@ class BitLinearBitBLAS(nn.Module): # Naive BitLinear from HuggingFace class BitLinear(nn.Linear): - def __init__(self, *kargs, weight_bits=1, input_bits=8, **kwargs): super(BitLinear, self).__init__(*kargs, **kwargs) """ @@ -224,10 +220,8 @@ class BitLinear(nn.Linear): self.input_bits = input_bits def forward(self, input): - quant_input = input + (activation_quant(input, self.input_bits) - input).detach() - quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - - self.weight).detach() + quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach() out = nn.functional.linear(quant_input, quant_weight) if self.bias is not None: diff --git a/examples/bitnet-1.58b/vllm_workspace/conftest.py b/examples/bitnet-1.58b/vllm_workspace/conftest.py index 951f3899148e999d7409a914a0cbad0809586373..e9e2997ef67c5c22b26235d00000332dfe20910f 100644 --- a/examples/bitnet-1.58b/vllm_workspace/conftest.py +++ b/examples/bitnet-1.58b/vllm_workspace/conftest.py @@ -20,7 +20,7 @@ from transformers import ( from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset from vllm.config import TokenizerPoolConfig -from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel) +from vllm.distributed import destroy_distributed_environment, destroy_model_parallel from vllm.inputs import TextPrompt from vllm.logger import init_logger from vllm.sequence import SampleLogprobs @@ -56,12 +56,13 @@ else: class _ImageAssets(_ImageAssetsBase): - def __init__(self) -> None: - super().__init__([ - ImageAsset("stop_sign"), - ImageAsset("cherry_blossom"), - ]) + super().__init__( + [ + ImageAsset("stop_sign"), + ImageAsset("cherry_blossom"), + ] + ) def prompts(self, prompts: _ImageAssetPrompts) -> List[str]: """ @@ -136,7 +137,6 @@ _T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding) class HfRunner: - def wrap_device(self, input: _T) -> _T: if not is_cpu(): return input.to("cuda") @@ -166,7 +166,8 @@ class HfRunner: SentenceTransformer( model_name, device="cpu", - ).to(dtype=torch_dtype)) + ).to(dtype=torch_dtype) + ) else: if is_vision_model: auto_cls = AutoModelForVision2Seq @@ -184,7 +185,8 @@ class HfRunner: torch_dtype=torch_dtype, trust_remote_code=True, **model_kwargs, - )) + ) + ) self.tokenizer = AutoTokenizer.from_pretrained( model_name, @@ -204,8 +206,7 @@ class HfRunner: ) except Exception: logger.warning( - "Unable to auto-load processor from HuggingFace for " - "model %s. Using tokenizer instead.", + "Unable to auto-load processor from HuggingFace for model %s. Using tokenizer instead.", model_name, ) self.processor = self.tokenizer @@ -362,7 +363,7 @@ class HfRunner: last_hidden_states, self.model.get_output_embeddings().weight.t(), ) - if (getattr(self.model.get_output_embeddings(), "bias", None) is not None): + if getattr(self.model.get_output_embeddings(), "bias", None) is not None: logits += self.model.get_output_embeddings().bias.unsqueeze(0) logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) seq_logprobs.append(logprobs) @@ -389,8 +390,7 @@ class HfRunner: all_output_strs.append(self.tokenizer.decode(output_ids)) outputs = zip(all_output_ids, all_output_strs, all_logprobs) - return [(output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs] + return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: return self.model.encode(prompts) @@ -409,7 +409,6 @@ def hf_runner(): class VllmRunner: - def __init__( self, model_name: str, @@ -514,12 +513,10 @@ class VllmRunner: num_logprobs: int, images: Optional[List[Image.Image]] = None, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: - greedy_logprobs_params = SamplingParams( - temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs) + greedy_logprobs_params = SamplingParams(temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs) outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params, images=images) - return [(output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs] + return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] def generate_beam_search( self, diff --git a/examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py b/examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py index 55a24543e3ebaee7e7f6a1278dd86d8501dcc0c9..ea18239cbc8fc00aaf65297a77fd5db0bf27e6ac 100644 --- a/examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py +++ b/examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py @@ -32,15 +32,14 @@ args = parser.parse_args() ckpt_path = args.ckpt_path with VllmRunner( - ckpt_path, - dtype="half", - quantization="bitblas", - # set enforce_eager = False to enable cuda graph - # set enforce_eager = True to disable cuda graph - enforce_eager=False, + ckpt_path, + dtype="half", + quantization="bitblas", + # set enforce_eager = False to enable cuda graph + # set enforce_eager = True to disable cuda graph + enforce_eager=False, ) as bitnet_model: - bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], - max_tokens=1024) + bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], max_tokens=1024) print("bitnet inference:") print(bitbnet_outputs[0][0]) print(bitbnet_outputs[0][1]) diff --git a/examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py b/examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py index 4f5f87f6ffe47ef2b8cbd9e2d2839d6b3b961cb0..f631fb306772408b17d71c35a5ae8bc1084e10d9 100644 --- a/examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py +++ b/examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py @@ -33,13 +33,13 @@ args = parser.parse_args() ckpt_path = args.ckpt_path with VllmRunner( - ckpt_path, - dtype="half", - quantization="bitnet_bitblas", - gpu_memory_utilization=0.5, - # set enforce_eager = False to enable cuda graph - # set enforce_eager = True to disable cuda graph - enforce_eager=False, + ckpt_path, + dtype="half", + quantization="bitnet_bitblas", + gpu_memory_utilization=0.5, + # set enforce_eager = False to enable cuda graph + # set enforce_eager = True to disable cuda graph + enforce_eager=False, ) as bitnet_model: bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], max_tokens=128) print("bitnet inference output:") diff --git a/examples/bitnet-1.58b/vllm_workspace/utils.py b/examples/bitnet-1.58b/vllm_workspace/utils.py index daa9d8f52bddf6eb815e1b1ece851666bcf2a8a4..e96b19e28ca9e21af070bdd187e4b026aca26bc7 100644 --- a/examples/bitnet-1.58b/vllm_workspace/utils.py +++ b/examples/bitnet-1.58b/vllm_workspace/utils.py @@ -3,8 +3,7 @@ from typing import Dict, List, Tuple TokensText = Tuple[List[int], str] -def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[TokensText], - name_0: str, name_1: str): +def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[TokensText], name_0: str, name_1: str): """ Compare the two sequences generated by different models, which should be equal. @@ -15,19 +14,14 @@ def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[Tok output_ids_0, output_str_0 = outputs_0 output_ids_1, output_str_1 = outputs_1 - assert output_str_0 == output_str_1, (f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") - assert output_ids_0 == output_ids_1, (f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") + assert output_str_0 == output_str_1, f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}" + assert output_ids_0 == output_ids_1, f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}" TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]] -def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], - outputs_1_lst: List[TokensTextLogprobs], name_0: str, name_1: str): +def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], outputs_1_lst: List[TokensTextLogprobs], name_0: str, name_1: str): """ Compare the logprobs of two sequences generated by different models, which should be similar but not necessarily equal. @@ -41,16 +35,11 @@ def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], # Loop through generated tokens. for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): - # If generated tokens don't match, then if output_id_0 != output_id_1: # Each predicted token must be in top N logprobs of the other - assert output_id_0 in logprobs_1[idx], (f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") - assert output_id_1 in logprobs_0[idx], (f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") + assert output_id_0 in logprobs_1[idx], f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}" + assert output_id_1 in logprobs_0[idx], f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}" # Break out since sequences will now diverge. break diff --git a/examples/blocksparse_attention/block_sparse_attn_triton.py b/examples/blocksparse_attention/block_sparse_attn_triton.py index 014f0c5fcbce0d96abf4c374029a4f0400c57741..1794836342197de8c16bfa2eb515e872c94c663b 100644 --- a/examples/blocksparse_attention/block_sparse_attn_triton.py +++ b/examples/blocksparse_attention/block_sparse_attn_triton.py @@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -56,7 +53,6 @@ def _fwd_kernel_inner( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): - mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) # print @@ -73,8 +69,7 @@ def _fwd_kernel_inner( # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N if LAST_K_BLOCK: - qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, - float('-inf')) + qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf")) m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_ij[:, None] @@ -154,7 +149,7 @@ def _fwd_kernel( v_ptrs = V + off_v mask_ptrs = block_mask_ptr + start_m * stride_bmm - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) @@ -192,24 +187,12 @@ def _fwd_kernel( acc = acc * l_recip acc = acc.to(Out.dtype.element_ty) - off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[ - None, :] * stride_od + off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) -def _forward(ctx, - q, - k, - v, - block_sparse_mask, - sm_scale, - BLOCK_M=64, - BLOCK_N=64, - num_warps=None, - num_stages=1, - out=None): - +def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None): assert q.shape[-1] == k.shape[-1] == v.shape[-1] assert k.shape[2] == v.shape[2] o = out if out is not None else torch.empty_like(q).contiguous() @@ -254,7 +237,6 @@ def _forward(ctx, class _sparse_attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, block_sparse_dense, sm_scale): # shape constraints @@ -278,9 +260,9 @@ def test_topk_sparse_attention(): torch.manual_seed(0) # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) @@ -288,9 +270,7 @@ def test_topk_sparse_attention(): downsample_len = math.ceil(SEQ_LEN / downsample_factor) print("downsample_len", downsample_len) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 print("x_ds.shape", x_ds.shape) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) @@ -302,22 +282,21 @@ def test_topk_sparse_attention(): # Compute reference # Expand block mask to full attention matrix - full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal # PyTorch reference implementation - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale - attn = attn.masked_fill(~full_mask, float('-inf')) + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) # print("ref_output", ref_output) # print("triton_output", triton_output) # Verify accuracy - assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ - "Triton output doesn't match reference" + assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference" print("Pass topk sparse attention test with qlen == klen") @@ -329,9 +308,9 @@ def test_topk_sparse_attention_qlt_kl(): torch.manual_seed(0) # Create inputs. - q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) + q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) # softmax scale sm_scale = 1.0 / (D_HEAD**0.5) @@ -339,8 +318,7 @@ def test_topk_sparse_attention_qlt_kl(): print("downsample_factor", downsample_factor) downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension print("downsample_len", downsample_len) - x_ds = torch.randn( - BATCH, N_HEADS, downsample_len, downsample_len, device='cuda', dtype=torch.bfloat16) + x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.bfloat16) # Force the first column to be high so that the first block is always selected. x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) @@ -351,26 +329,25 @@ def test_topk_sparse_attention_qlt_kl(): past_len = K_LEN - Q_LEN - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale - full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool() + full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool() full_mask_full = full_mask_full[..., :K_LEN, :K_LEN] effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) - causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN) + causal_mask = j_global <= i_global # shape: (Q_LEN, K_LEN) final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN) - attn = attn.masked_fill(~final_mask, float('-inf')) + attn = attn.masked_fill(~final_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) # Verify accuracy. - assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ - "Triton output doesn't match reference when qlen < klen" + assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference when qlen < klen" print("Pass topk sparse attention test with qlen < klen") diff --git a/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py b/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py index 7e90db7e5f036f0ef92d2ef8f6689e270f722045..afb4cc888ab9e2449e599a64930a0789792433df 100644 --- a/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py +++ b/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py @@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -30,15 +27,17 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F @tilelang.jit( - out_idx=[4], pass_configs={ + out_idx=[4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal): block_M = 64 block_N = 64 num_stages = 1 threads = 128 - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] block_mask_shape = [batch, heads, downsample_len, downsample_len] @@ -47,7 +46,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) block_mask_dtype = "bool" def kernel_func(block_M, block_N, num_stages, threads): - @T.macro def MMA0( K: T.Tensor(shape, dtype), @@ -59,11 +57,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) by: T.int32, bz: T.int32, ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -78,18 +75,18 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) by: T.int32, bz: T.int32, ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -113,22 +110,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def blocksparse_flashattn( - Q: T.Tensor(shape, dtype), - K: T.Tensor(shape, dtype), - V: T.Tensor(shape, dtype), - BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), - Output: T.Tensor(shape, dtype), + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), + Output: T.Tensor(shape, dtype), ): - with T.Kernel( - T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -143,7 +139,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) logsum = T.alloc_fragment([block_M], accum_dtype) block_mask = T.alloc_local([downsample_len], block_mask_dtype) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -152,20 +148,19 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) block_mask[vj] = BlockSparseMask[bz, by, bx, vj] loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): if block_mask[k] != 0: MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, - scores_sum, logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return blocksparse_flashattn @@ -180,18 +175,16 @@ def test_topk_sparse_attention(): torch.manual_seed(0) # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) @@ -202,15 +195,15 @@ def test_topk_sparse_attention(): # Compute reference # Expand block mask to full attention matrix - full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal # PyTorch reference implementation - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale - attn = attn.masked_fill(~full_mask, float('-inf')) + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) print("ref_output", ref_output) print("tilelang_output", tilelang_output) diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py index 1c4b847de1797d679bf8be6f7769e5b0a1855184..99418d5fdd959dcad064e20881da47f918976325 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py @@ -13,17 +13,20 @@ from heuristic import num_splits_heuristic def flashattn(batch, heads, heads_kv, dim, dim_v): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) dtype = "float16" accum_dtype = "float" kv_group_num = heads // heads_kv @tilelang.jit( - out_idx=[-1], pass_configs={ + out_idx=[-1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) - def kernel_func(block_N, block_H, page_block_size, num_split, num_stages, threads, num_pages, - max_num_blocks_per_seq, max_selected_blocks): + }, + ) + def kernel_func( + block_N, block_H, page_block_size, num_split, num_stages, threads, num_pages, max_num_blocks_per_seq, max_selected_blocks + ): shape_q = [batch, heads, dim] shape_k = [num_pages, page_block_size, heads_kv, dim] shape_v = [num_pages, page_block_size, heads_kv, dim_v] @@ -37,17 +40,16 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): @T.macro def flash_attn_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - block_table: T.Tensor(shape_block_table, "int32"), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, "int32"), + cache_seqlens: T.Tensor([batch], "int32"), + block_table: T.Tensor(shape_block_table, "int32"), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), ): - with T.Kernel( - batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim_v], dtype) @@ -67,7 +69,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): sid = bz cur_kv_head = hid // (kv_group_num // valid_block_H) - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -75,7 +77,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): num_blocks = max_selected_blocks blocks_per_split = T.floordiv(num_blocks, num_split) remaining_blocks = T.floormod(num_blocks, num_split) - loop_range = (blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0)) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) start = blocks_per_split * sid + T.min(sid, remaining_blocks) has_valid_block = False for k in T.Pipelined(loop_range, num_stages=num_stages): @@ -85,29 +87,20 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): block_table_idx = T.floordiv(logical_block_idx, block_ratio) block_tile_idx = T.floormod(logical_block_idx, block_ratio) physical_block_idx = block_table[bid, block_table_idx] - T.copy( - K[physical_block_idx, - block_tile_idx * block_N:(block_tile_idx + 1) * block_N, - cur_kv_head, :], K_shared) + T.copy(K[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], K_shared) T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.if_then_else( - logical_block_idx * block_N + j >= cache_seqlens[bid], - -T.infinity(accum_dtype), acc_s[i, j]) + logical_block_idx * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j] + ) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - - scores_max[i] * scale) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) @@ -116,10 +109,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): T.copy(acc_s, acc_s_cast) for i, j in T.Parallel(block_H, dim_v): acc_o[i, j] *= scores_scale[i] - T.copy( - V[physical_block_idx, - block_tile_idx * block_N:(block_tile_idx + 1) * block_N, - cur_kv_head, :], V_shared) + T.copy(V[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) if has_valid_block: for i, j in T.Parallel(block_H, dim_v): @@ -138,9 +128,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): @T.macro def combine( - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), ): with T.Kernel(heads, batch, threads=128) as (by, bz): po_local = T.alloc_fragment([dim_v], accum_dtype) @@ -151,17 +141,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): scale_local = T.alloc_local([1], accum_dtype) max_split = T.alloc_local([1], "int32") - T.annotate_layout({ - lse_logsum_local: - T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) T.clear(lse_logsum_local) T.clear(o_accum_local) lse_max_local[0] = -T.infinity(accum_dtype) for k in T.serial(num_split): lse_local_split[0] = glse[bz, by, k] - if (lse_local_split[0] != 0): + if lse_local_split[0] != 0: max_split[0] = k lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) @@ -183,18 +174,17 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): @T.prim_func def main( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - block_table: T.Tensor(shape_block_table, "int32"), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, "int32"), + cache_seqlens: T.Tensor([batch], "int32"), + block_table: T.Tensor(shape_block_table, "int32"), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), ): - flash_attn_split(Q, K, V, block_indices, cache_seqlens, block_table, glse, - Output_partial) + flash_attn_split(Q, K, V, block_indices, cache_seqlens, block_table, glse, Output_partial) combine(glse, Output_partial, Output) return main @@ -203,7 +193,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): class SparseFlashAttn(torch.nn.Module): - def __init__(self, batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_pages): super(SparseFlashAttn, self).__init__() self.batch = batch @@ -249,18 +238,11 @@ class SparseFlashAttn(torch.nn.Module): num_sm = self.num_sm num_split = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) - - glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') - output_partial = torch.empty((batch, heads, num_split, dim_v), - dtype=torch.float32, - device='cuda') + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") output = self.kernel( query, @@ -275,14 +257,13 @@ class SparseFlashAttn(torch.nn.Module): return output -def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_seqlens, - block_table, page_block_size, block_size): +def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_seqlens, block_table, page_block_size, block_size): """ Paged version of sparse attention reference implementation. - + Args: query: [batch, heads, dim] - key_cache: [num_pages, page_block_size, heads_kv, dim] + key_cache: [num_pages, page_block_size, heads_kv, dim] value_cache: [num_pages, page_block_size, heads_kv, dim] block_indices: [batch, heads_kv, max_selected_blocks] - logical block indices cache_seqlens: [batch] - actual sequence lengths @@ -298,12 +279,8 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_ # Reconstruct the full key and value tensors from paged cache max_cache_seqlen = max(cache_seqlens).item() - key_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim), - dtype=key_cache.dtype, - device=key_cache.device) - value_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim_v), - dtype=value_cache.dtype, - device=value_cache.device) + key_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim), dtype=key_cache.dtype, device=key_cache.device) + value_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim_v), dtype=value_cache.dtype, device=value_cache.device) # Reconstruct full tensors from paged cache using block_table for b in range(batch): @@ -319,20 +296,14 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_ actual_block_size = end_token - start_token # Copy from paged cache to full tensors - key_full[b, :, start_token:end_token, :] = key_cache[ - physical_block_idx, :actual_block_size, :, :].transpose(0, 1) - value_full[b, :, start_token:end_token, :] = value_cache[ - physical_block_idx, :actual_block_size, :, :].transpose(0, 1) + key_full[b, :, start_token:end_token, :] = key_cache[physical_block_idx, :actual_block_size, :, :].transpose(0, 1) + value_full[b, :, start_token:end_token, :] = value_cache[physical_block_idx, :actual_block_size, :, :].transpose(0, 1) # Reshape query for grouped attention - query = rearrange( - query, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] # Compute attention scores - scores = einsum( - query, key_full, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = einsum(query, key_full, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] # Create sparse mask based on block_indices sparse_mask = torch.zeros_like(scores) @@ -348,24 +319,23 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_ sparse_mask[b, :, h, start_pos:end_pos] = 1 # Apply sparse mask - scores = scores.masked_fill(sparse_mask == 0, float('-inf')) + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) # Apply causal mask based on actual sequence lengths range_len = torch.arange(scores.shape[-1], device=scores.device).unsqueeze(0) cache_seqlens_expanded = cache_seqlens.unsqueeze(1) pad_mask = range_len >= cache_seqlens_expanded pad_mask = pad_mask[:, None, None, :] - scores = scores.masked_fill(pad_mask, float('-inf')) + scores = scores.masked_fill(pad_mask, float("-inf")) # Compute attention weights attention = F.softmax(scores / scale, dim=-1) # Apply attention to values - out = einsum(attention, value_full, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] + out = einsum(attention, value_full, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] # Reshape output back to original format - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out @@ -373,17 +343,23 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_ def ref_program_fa(query, kcache, vcache, cache_seqlens, block_table): # latency reference # from flash_attn_interface import flash_attn_with_kvcache # fa3 - from flash_attn import flash_attn_with_kvcache #fa2 + from flash_attn import flash_attn_with_kvcache # fa2 + query = query.unsqueeze(1) - output = flash_attn_with_kvcache( - query, kcache, vcache, cache_seqlens=cache_seqlens, block_table=block_table) + output = flash_attn_with_kvcache(query, kcache, vcache, cache_seqlens=cache_seqlens, block_table=block_table) output = output.squeeze(1) return output def main(args): - - batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = ( + args.batch, + args.heads, + args.heads_kv, + args.max_cache_seqlen, + args.dim, + args.dim_v, + ) sparse_ratio = args.sparse_ratio block_N = args.block_N page_block_size = args.page_block_size @@ -395,35 +371,30 @@ def main(args): dtype = torch.float16 # Generate random inputs - Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') - cache_seqlens = torch.randint( - max_cache_seqlen // 2, max_cache_seqlen + 1, (batch,), dtype=torch.int32, device='cuda') + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(max_cache_seqlen // 2, max_cache_seqlen + 1, (batch,), dtype=torch.int32, device="cuda") print("cache_seqlens: ", cache_seqlens) - K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') - V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") # Create paged KV cache - K_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim), dtype=dtype, device='cuda') - V_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim_v), - dtype=dtype, - device='cuda') + K_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim), dtype=dtype, device="cuda") + V_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim_v), dtype=dtype, device="cuda") # Create block table and block indices for dense case (all blocks selected) max_num_blocks_per_seq = int(math.ceil(max_cache_seqlen / page_block_size)) print("max_num_blocks_per_seq: ", max_num_blocks_per_seq) - block_table = torch.zeros((batch, max_num_blocks_per_seq), dtype=torch.int32, device='cuda') - block_indices = torch.zeros((batch, heads_kv, max_selected_blocks), - dtype=torch.int32, - device='cuda') + block_table = torch.zeros((batch, max_num_blocks_per_seq), dtype=torch.int32, device="cuda") + block_indices = torch.zeros((batch, heads_kv, max_selected_blocks), dtype=torch.int32, device="cuda") # Fill block table and block indices and cache # Create a pool of available physical blocks - total_blocks_needed = sum( - int(math.ceil(cache_seqlens[seq_idx].item() / page_block_size)) for seq_idx in range(batch)) + total_blocks_needed = sum(int(math.ceil(cache_seqlens[seq_idx].item() / page_block_size)) for seq_idx in range(batch)) available_blocks = list(range(total_blocks_needed)) import random + random.seed(42) # For reproducibility random.shuffle(available_blocks) @@ -458,10 +429,8 @@ def main(args): actual_block_size = end_token - start_token # Copy K and V data to the paged cache - K_cache[physical_block_idx, :actual_block_size, :, :] = K[seq_idx, - start_token:end_token, :, :] - V_cache[physical_block_idx, :actual_block_size, :, :] = V[seq_idx, - start_token:end_token, :, :] + K_cache[physical_block_idx, :actual_block_size, :, :] = K[seq_idx, start_token:end_token, :, :] + V_cache[physical_block_idx, :actual_block_size, :, :] = V[seq_idx, start_token:end_token, :, :] # Fill block_indices for sparse attention # For dense case (verification), we select all blocks in reverse order @@ -496,10 +465,9 @@ def main(args): remaining_blocks = [b for b in all_blocks if b not in selected_blocks] if remaining_blocks: import random + random.seed(42) # For reproducibility - additional_blocks = random.sample( - remaining_blocks, - min(num_selected - recent_blocks, len(remaining_blocks))) + additional_blocks = random.sample(remaining_blocks, min(num_selected - recent_blocks, len(remaining_blocks))) selected_blocks.extend(additional_blocks) # Sort selected blocks in reverse order (most recent first) @@ -512,25 +480,20 @@ def main(args): block_indices[seq_idx, head_idx, i] = -1 # Initialize sparse attention module - sparse_attn = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, - num_blocks) - output_sparse = sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens, - block_table) + sparse_attn = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_blocks) + output_sparse = sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table) import flash_attn # noqa: F401 - output_ref_torch = ref_program_torch_paged(Q, K_cache, V_cache, block_indices, cache_seqlens, - block_table, page_block_size, block_N) + output_ref_torch = ref_program_torch_paged(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table, page_block_size, block_N) output_ref_fa = ref_program_fa(Q, K_cache, V_cache, cache_seqlens, block_table) # Check correctness if sparse_ratio == 0.0: max_diff = torch.max(torch.abs(output_sparse - output_ref_fa)).item() mean_diff = torch.mean(torch.abs(output_sparse - output_ref_fa)).item() - assert torch.allclose( - output_ref_fa, output_ref_torch, atol=1e-2), "Reference outputs do not match!" + assert torch.allclose(output_ref_fa, output_ref_torch, atol=1e-2), "Reference outputs do not match!" else: - max_diff = torch.max(torch.abs(output_sparse - output_ref_torch)).item() mean_diff = torch.mean(torch.abs(output_sparse - output_ref_torch)).item() @@ -574,16 +537,15 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') - parser.add_argument( - '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--dim_v', type=int, default=128, help='dim_v') - parser.add_argument('--sparse_ratio', type=float, default=0.0, help='sparse ratio') - parser.add_argument('--block_N', type=int, default=64, help='block_N') - parser.add_argument('--page_block_size', type=int, default=256, help='block size of pages') - parser.add_argument('--num_pages', type=int, default=1024, help='total number of pages') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.0, help="sparse ratio") + parser.add_argument("--block_N", type=int, default=64, help="block_N") + parser.add_argument("--page_block_size", type=int, default=256, help="block size of pages") + parser.add_argument("--num_pages", type=int, default=1024, help="total number of pages") args = parser.parse_args() main(args) diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py index b30875228f3587723c0f06b82d0a7405fb3994fd..8b5cde38d23ca5cdf2b3713eea9def6390ae7139 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py @@ -10,17 +10,18 @@ from heuristic import num_splits_heuristic def flashattn(batch, heads, heads_kv, dim, dim_v): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) dtype = "float16" accum_dtype = "float" kv_group_num = heads // heads_kv @tilelang.jit( - out_idx=[-1], pass_configs={ + out_idx=[-1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) - def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, - max_selected_blocks): + }, + ) + def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, max_selected_blocks): shape_q = [batch, heads, dim] shape_k = [batch, max_cache_seqlen, heads_kv, dim] shape_v = [batch, max_cache_seqlen, heads_kv, dim_v] @@ -31,17 +32,16 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): @T.macro def flash_attn_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - # actual_num_blocks: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, "int32"), + cache_seqlens: T.Tensor([batch], "int32"), + # actual_num_blocks: T.Tensor([batch], "int32"), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), ): - with T.Kernel( - batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim_v], dtype) @@ -62,7 +62,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): sid = bz cur_kv_head = hid // (kv_group_num // valid_block_H) - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -70,7 +70,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): num_blocks = max_selected_blocks blocks_per_split = T.floordiv(num_blocks, num_split) remaining_blocks = T.floormod(num_blocks, num_split) - loop_range = (blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0)) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) start = blocks_per_split * sid + T.min(sid, remaining_blocks) has_valid_block = False @@ -78,26 +78,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): i_s = block_indices[bid, cur_kv_head, start + k] if i_s >= 0: has_valid_block = True - T.copy(K[bid, i_s * block_N:(i_s + 1) * block_N, cur_kv_head, :], K_shared) + T.copy(K[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], K_shared) T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition for i, j in T.Parallel(block_H, block_N): - acc_s[i, - j] = T.if_then_else(i_s * block_N + j >= cache_seqlens[bid], - -T.infinity(accum_dtype), acc_s[i, j]) + acc_s[i, j] = T.if_then_else(i_s * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j]) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - - scores_max[i] * scale) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) @@ -106,7 +98,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): T.copy(acc_s, acc_s_cast) for i, j in T.Parallel(block_H, dim_v): acc_o[i, j] *= scores_scale[i] - T.copy(V[bid, i_s * block_N:(i_s + 1) * block_N, cur_kv_head, :], V_shared) + T.copy(V[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) if has_valid_block: for i, j in T.Parallel(block_H, dim_v): @@ -125,9 +117,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): @T.macro def combine( - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), ): with T.Kernel(heads, batch, threads=128) as (by, bz): po_local = T.alloc_fragment([dim_v], accum_dtype) @@ -138,17 +130,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): scale_local = T.alloc_local([1], accum_dtype) max_split = T.alloc_local([1], "int32") - T.annotate_layout({ - lse_logsum_local: - T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) T.clear(lse_logsum_local) T.clear(o_accum_local) lse_max_local[0] = -T.infinity(accum_dtype) for k in T.serial(num_split): lse_local_split[0] = glse[bz, by, k] - if (lse_local_split[0] != 0): + if lse_local_split[0] != 0: max_split[0] = k lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) @@ -170,15 +163,15 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): @T.prim_func def main( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - # actual_num_blocks: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, "int32"), + cache_seqlens: T.Tensor([batch], "int32"), + # actual_num_blocks: T.Tensor([batch], "int32"), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), ): # flash_attn_split(Q, K, V, block_indices, cache_seqlens, actual_num_blocks, glse, Output_partial) flash_attn_split(Q, K, V, block_indices, cache_seqlens, glse, Output_partial) @@ -190,7 +183,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): class SparseFlashAttn(torch.nn.Module): - def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): super(SparseFlashAttn, self).__init__() self.batch = batch @@ -209,7 +201,8 @@ class SparseFlashAttn(torch.nn.Module): num_stages=2, threads=128, max_cache_seqlen=T.dynamic("max_cache_seqlen"), - max_selected_blocks=T.dynamic("max_selected_blocks")) + max_selected_blocks=T.dynamic("max_selected_blocks"), + ) props = torch.cuda.get_device_properties(torch.device("cuda:0")) self.num_sm = props.multi_processor_count @@ -232,25 +225,17 @@ class SparseFlashAttn(torch.nn.Module): num_sm = self.num_sm num_split = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) - - glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') - output_partial = torch.empty((batch, heads, num_split, dim_v), - dtype=torch.float32, - device='cuda') + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") output = self.kernel(query, key, value, block_indices, cache_seqlens, glse, output_partial) return output -def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seqlens, - max_cache_seqlen, block_size): +def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, block_size): """ Args: query: [batch, heads, dim] @@ -272,31 +257,24 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql block_H = 64 actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32) - actual_num_blocks = actual_num_blocks[:, - 0] #[batch], number of valid blocks, assume all groups in the same batch have the same number of blocks + actual_num_blocks = actual_num_blocks[ + :, 0 + ] # [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks # get num_split num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H - num_n_blocks = max_selected_blocks #(kv_seqlen + block_size - 1 ) // block_size + num_n_blocks = max_selected_blocks # (kv_seqlen + block_size - 1 ) // block_size # num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks - size_one_kv_head = max_selected_blocks * block_size * ( - dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2 + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks num_sm = 132 num_split = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) - - glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') - Output_partial = torch.empty((batch, heads, num_split, dim_v), - dtype=torch.float32, - device='cuda') + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( block_N=block_size, block_H=block_H, @@ -304,29 +282,24 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql num_stages=2, threads=128, max_cache_seqlen=T.dynamic("max_cache_seqlen"), - max_selected_blocks=T.dynamic("max_selected_blocks")) + max_selected_blocks=T.dynamic("max_selected_blocks"), + ) output = kernel(query, key, value, block_indices, cache_seqlens, glse, Output_partial) return output -def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, - block_size): - +def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size): batch, heads, dim = query.shape heads_kv = key.shape[2] num_head_groups = query.shape[1] // key.shape[2] scale = dim**0.5 - key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] - value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] + key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] - query = rearrange( - query, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] sparse_mask = torch.zeros_like(scores) # Assign mask values based on block_indices @@ -335,28 +308,26 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache valid_indices = block_indices[b, h] # Extract indices for this batch and head for idx in valid_indices: if idx >= 0: - sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1 - scores = scores.masked_fill(sparse_mask == 0, float('-inf')) + sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1 + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) - range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0) + range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0) cache_seqlens_expanded = cache_seqlens.unsqueeze(1) pad_mask = range_len >= cache_seqlens_expanded pad_mask = pad_mask[:, None, None, :] - scores = scores.masked_fill(pad_mask, float('-inf')) - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = scores.masked_fill(pad_mask, float("-inf")) + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] - out = einsum(attention, value, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out -def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, - block_size): +def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size): # latency reference # from flash_attn_interface import flash_attn_with_kvcache # fa3 - from flash_attn import flash_attn_with_kvcache #fa2 + from flash_attn import flash_attn_with_kvcache # fa2 + query = query.unsqueeze(1) output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) output = output.squeeze(1) @@ -368,23 +339,13 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3): print(name + " all_close={}".format(all_close)) if not all_close: diff = (expect - actual).abs() - print("all_close={}, max={}, min={}, mean={}".format(all_close, - diff.max().item(), - diff.min().item(), - diff.mean().item())) + print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item())) max_indices = torch.nonzero(diff == diff.max().item()) first_index = tuple(max_indices[0].tolist()) print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}") -def main(batch=8, - heads=32, - heads_kv=8, - max_cache_seqlen=8192, - dim=128, - dim_v=128, - sparse_ratio=0.8, - block_size=32): +def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v sparse_ratio = sparse_ratio block_size = block_size @@ -392,10 +353,10 @@ def main(batch=8, print("max_selected_blocks: ", max_selected_blocks) dtype = torch.float16 - Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') - K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') - V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') - cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda') + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') # # Ensure at least one element equals cache_seqlen # random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index @@ -406,10 +367,7 @@ def main(batch=8, max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() print("max_valid_num_blocks: ", max_valid_num_blocks) # Initialize block_indices with -1 (for padding blocks) - block_indices = torch.full((batch, heads_kv, max_selected_blocks), - -1, - dtype=torch.int32, - device='cuda') + block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda") # max_num_blocks = int((max_cache_seqlen + block_size - 1)/ block_size) # block_indices = torch.full((batch, heads_kv, max_num_blocks), -1, dtype=torch.int32, device='cuda') @@ -418,10 +376,9 @@ def main(batch=8, max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch if max_valid_block > 0: # Ensure there's at least one valid block for h in range(heads_kv): - valid_indices = torch.randperm( - max_valid_block, device='cuda', dtype=torch.int32)[:max_selected_blocks] + valid_indices = torch.randperm(max_valid_block, device="cuda", dtype=torch.int32)[:max_selected_blocks] # valid_indices = torch.randperm(max_valid_block, device='cuda', dtype=torch.int32)[:max_num_blocks] - block_indices[b, h, :len(valid_indices)] = valid_indices + block_indices[b, h, : len(valid_indices)] = valid_indices # Sort indices within each batch-group for consistency block_indices, _ = block_indices.sort(dim=-1, descending=True) @@ -434,8 +391,7 @@ def main(batch=8, print("max_num_blocks: ", max_num_blocks) # parity reference - ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, - block_size) + ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) @@ -445,13 +401,11 @@ def main(batch=8, ## latency reference for _ in range(10): - ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, - max_num_blocks, block_size) + ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) torch.cuda.synchronize() start = time.time() for _ in range(100): - ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, - max_num_blocks, block_size) + ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) torch.cuda.synchronize() print("dense time: ", (time.time() - start) / 100 * 1000) @@ -469,15 +423,13 @@ def main(batch=8, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') - parser.add_argument( - '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--dim_v', type=int, default=128, help='dim_v') - parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') - parser.add_argument('--block_size', type=int, default=32, help='block_size') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio") + parser.add_argument("--block_size", type=int, default=32, help="block_size") args = parser.parse_args() - main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, - args.sparse_ratio, args.block_size) + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size) diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py index 3417bd7f8386e395ddd2f1cc76eacb74b01484cb..0d759211adf7310215e1155dfefde0e58c8a8596 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py @@ -12,15 +12,17 @@ from heuristic import num_splits_heuristic def flashattn(batch, heads, heads_kv, dim, dim_v): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) dtype = "float16" accum_dtype = "float" kv_group_num = heads // heads_kv @tilelang.jit( - out_idx=[-1], pass_configs={ + out_idx=[-1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, + ) def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, num_blocks): shape_q = [batch, heads, dim] shape_k = [batch, max_cache_seqlen, heads_kv, dim] @@ -32,16 +34,15 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): @T.macro def flash_attn_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_mask: T.Tensor(shape_mask, "bool"), - cache_seqlens: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_mask: T.Tensor(shape_mask, "bool"), + cache_seqlens: T.Tensor([batch], "int32"), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), ): - with T.Kernel( - batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim_v], dtype) @@ -62,39 +63,31 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): sid = bz cur_kv_head = hid // (kv_group_num // valid_block_H) - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) blocks_per_split = T.floordiv(num_blocks, num_split) remaining_blocks = T.floormod(num_blocks, num_split) - loop_range = (blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0)) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) start = blocks_per_split * sid + T.min(sid, remaining_blocks) has_valid_block = False for k in T.Pipelined(loop_range, num_stages=num_stages): if block_mask[bid, hid, start + k]: has_valid_block = True - T.copy( - K[bid, (start + k) * block_N:(start + k + 1) * block_N, cur_kv_head, :], - K_shared) + T.copy(K[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], K_shared) T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else((start + k) * block_N + j - >= cache_seqlens[bx], - -T.infinity(accum_dtype), acc_s[i, j]) + acc_s[i, j] = T.if_then_else( + (start + k) * block_N + j >= cache_seqlens[bx], -T.infinity(accum_dtype), acc_s[i, j] + ) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - - scores_max[i] * scale) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) @@ -103,9 +96,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): T.copy(acc_s, acc_s_cast) for i, j in T.Parallel(block_H, dim_v): acc_o[i, j] *= scores_scale[i] - T.copy( - V[bid, (start + k) * block_N:(start + k + 1) * block_N, cur_kv_head, :], - V_shared) + T.copy(V[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) if has_valid_block: for i, j in T.Parallel(block_H, dim_v): @@ -123,9 +114,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): @T.macro def combine( - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), ): with T.Kernel(heads, batch, threads=128) as (by, bz): po_local = T.alloc_fragment([dim_v], accum_dtype) @@ -135,10 +126,11 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): lse_max_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype) - T.annotate_layout({ - lse_logsum_local: - T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) T.clear(lse_logsum_local) T.clear(o_accum_local) @@ -161,14 +153,14 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): @T.prim_func def main( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_mask: T.Tensor(shape_mask, "bool"), - cache_seqlens: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_mask: T.Tensor(shape_mask, "bool"), + cache_seqlens: T.Tensor([batch], "int32"), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), ): flash_attn_split(Q, K, V, block_mask, cache_seqlens, glse, Output_partial) combine(glse, Output_partial, Output) @@ -179,7 +171,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): class SparseFlashAttn(torch.nn.Module): - def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): super(SparseFlashAttn, self).__init__() self.batch = batch @@ -198,7 +189,8 @@ class SparseFlashAttn(torch.nn.Module): num_stages=2, threads=128, max_cache_seqlen=T.dynamic("max_cache_seqlen"), - num_blocks=T.dynamic("num_blocks")) + num_blocks=T.dynamic("num_blocks"), + ) props = torch.cuda.get_device_properties(torch.device("cuda:0")) self.num_sm = props.multi_processor_count @@ -217,24 +209,16 @@ class SparseFlashAttn(torch.nn.Module): num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H num_n_blocks = max_selected_blocks - size_one_kv_head = max_selected_blocks * block_size * ( - dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2 + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks # num_sm = 132 num_sm = self.num_sm num_split = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) # print("num_split: ", num_split) - glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') - Output_partial = torch.empty((batch, heads, num_split, dim_v), - dtype=torch.float32, - device='cuda') + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") output = self.kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial) return output @@ -259,26 +243,21 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, block_H = 64 actual_num_blocks = torch.sum(block_mask, dim=-1).to(torch.int32) - actual_num_blocks = actual_num_blocks[:, - 0] #[batch], number of valid blocks, assume all groups in the same batch have the same number of blocks + actual_num_blocks = actual_num_blocks[ + :, 0 + ] # [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks max_selected_blocks = actual_num_blocks.max().item() # get num_split num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H - num_n_blocks = max_selected_blocks #(kv_seqlen + block_size - 1 ) // block_size + num_n_blocks = max_selected_blocks # (kv_seqlen + block_size - 1 ) // block_size # num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks - size_one_kv_head = max_selected_blocks * block_size * ( - dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2 + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks num_sm = 132 num_split = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( block_N=block_size, @@ -287,11 +266,10 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, num_stages=2, threads=128, max_cache_seqlen=T.dynamic("max_cache_seqlen"), - num_blocks=T.dynamic("num_blocks")) - glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') - Output_partial = torch.empty((batch, heads, num_split, dim_v), - dtype=torch.float32, - device='cuda') + num_blocks=T.dynamic("num_blocks"), + ) + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") # print(kernel.get_kernel_source()) output = kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial) @@ -299,24 +277,18 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, return output -def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, - block_size): - +def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size): batch, heads, dim = query.shape heads_kv = key.shape[2] num_head_groups = query.shape[1] // key.shape[2] scale = dim**0.5 - key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] - value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] + key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] - query = rearrange( - query, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] sparse_mask = torch.zeros_like(scores) # Assign mask values @@ -324,29 +296,27 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se for h in range(heads_kv): for idx in range(num_blocks): if block_mask[b, h, idx]: - sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1 + sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1 - scores = scores.masked_fill(sparse_mask == 0, float('-inf')) + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) - range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0) + range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0) cache_seqlens_expanded = cache_seqlens.unsqueeze(1) pad_mask = range_len >= cache_seqlens_expanded pad_mask = pad_mask[:, None, None, :] - scores = scores.masked_fill(pad_mask, float('-inf')) - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = scores.masked_fill(pad_mask, float("-inf")) + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] - out = einsum(attention, value, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out -def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, - block_size): +def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size): # latency reference # from flash_attn_interface import flash_attn_with_kvcache # fa3 - from flash_attn import flash_attn_with_kvcache #fa2 + from flash_attn import flash_attn_with_kvcache # fa2 + query = query.unsqueeze(1) output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) output = output.squeeze(1) @@ -360,23 +330,13 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3): # print(expect[3, 28]) # print(actual[3, 28]) diff = (expect - actual).abs() - print("all_close={}, max={}, min={}, mean={}".format(all_close, - diff.max().item(), - diff.min().item(), - diff.mean().item())) + print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item())) max_indices = torch.nonzero(diff == diff.max().item()) first_index = tuple(max_indices[0].tolist()) print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}") -def main(batch=8, - heads=32, - heads_kv=8, - max_cache_seqlen=8192, - dim=128, - dim_v=128, - sparse_ratio=0.8, - block_size=32): +def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v sparse_ratio = sparse_ratio block_size = block_size @@ -384,14 +344,13 @@ def main(batch=8, print("max_selected_blocks: ", max_selected_blocks) dtype = torch.float16 - Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') - K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') - V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') - cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda') + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") # Ensure at least one element equals cache_seqlen - random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index - cache_seqlens[ - random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence + random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index + cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') print("cache_seqlens: ", cache_seqlens) @@ -403,7 +362,7 @@ def main(batch=8, max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() print("max_valid_num_blocks: ", max_valid_num_blocks) # Initialize block_mask with false (for padding blocks) - block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device='cuda') + block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device="cuda") # Assign valid indices while ensuring no duplicates within each batch-group for b in range(batch): @@ -411,13 +370,12 @@ def main(batch=8, valid_num_block = valid_num_blocks[b].item() # Valid blocks for this batch if valid_num_block > 0: # Ensure there's at least one valid block for h in range(heads_kv): - perm = torch.randperm(max_valid_block, device='cuda')[:valid_num_block] + perm = torch.randperm(max_valid_block, device="cuda")[:valid_num_block] block_mask[b, h, perm] = True # print("block_mask: ", block_mask) # parity reference - ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, - block_size) + ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size) # out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size) model = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) out = model(Q, K, V, block_mask, cache_seqlens) @@ -427,13 +385,11 @@ def main(batch=8, ## latency reference for _ in range(10): - ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, - block_size) + ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size) torch.cuda.synchronize() start = time.time() for _ in range(100): - ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, - block_size) + ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size) torch.cuda.synchronize() print("dense time: ", (time.time() - start) / 100 * 1000) @@ -452,15 +408,13 @@ def main(batch=8, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') - parser.add_argument( - '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--dim_v', type=int, default=128, help='dim_v') - parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') - parser.add_argument('--block_size', type=int, default=32, help='block_size') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio") + parser.add_argument("--block_size", type=int, default=32, help="block_size") args = parser.parse_args() - main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, - args.sparse_ratio, args.block_size) + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size) diff --git a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py index 85b72b775e4b8d038f36931d1fd3b20274b8f07a..b61d52fa092f4d8cd115905d71cde59a99ca88dc 100644 --- a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py +++ b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py @@ -12,12 +12,8 @@ from heuristic import num_splits_heuristic @triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [1, 2, 4]\ - for num_stages in [1, 2, 3, 4, 7] - ], - key=['BLOCK_H', 'BLOCK_N', 'BLOCK_D'], + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]], + key=["BLOCK_H", "BLOCK_N", "BLOCK_D"], ) @triton.jit def _split_kernel( @@ -79,16 +75,11 @@ def _split_kernel( loop_range = blocks_per_split q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h - k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[ - None, :] * stride_k_s + offs_d[:, None] * stride_k_d - v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, - None] * stride_v_s + offs_d[ - None, :] * stride_v_d + k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[None, :] * stride_k_s + offs_d[:, None] * stride_k_d + v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, None] * stride_v_s + offs_d[None, :] * stride_v_d mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h - q = tl.load( - q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, - mask=offs_h[:, None] < gqa_group_size) + q = tl.load(q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, mask=offs_h[:, None] < gqa_group_size) start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks) for i in range(loop_range): block_idx = tl.load(mask_ptr + (start + i) * stride_mask_s) @@ -119,23 +110,18 @@ def _split_kernel( acc = acc * l_recip acc = acc.to(o_partial_ptr.dtype.element_ty) - lse_partial_ptr += batch_idx * stride_lse_b + ( - head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split + lse_partial_ptr += batch_idx * stride_lse_b + (head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size) - o_partial_ptr += batch_idx * stride_o_b + ( - head_idx_q + - offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d + o_partial_ptr += ( + batch_idx * stride_o_b + (head_idx_q + offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d + ) tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size) @triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [1, 2, 4]\ - for num_stages in [1, 2, 3, 4, 7] - ], - key=['BLOCK_D'], + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]], + key=["BLOCK_D"], ) @triton.jit def _merge_kernel( @@ -163,18 +149,15 @@ def _merge_kernel( offs_d = tl.arange(0, BLOCK_D) lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h - lse = tl.load( - lse_offsets + offs_splits * lse_partial_stride_split, - mask=offs_splits < num_splits, - other=float("-inf")) + lse = tl.load(lse_offsets + offs_splits * lse_partial_stride_split, mask=offs_splits < num_splits, other=float("-inf")) lse_max = tl.max(lse) o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h o_partial = tl.load( - o_offsets + offs_splits[:, None] * o_partial_stride_split + - offs_d[None, :] * o_partial_stride_d, - mask=offs_splits[:, None] < num_splits) + o_offsets + offs_splits[:, None] * o_partial_stride_split + offs_d[None, :] * o_partial_stride_d, + mask=offs_splits[:, None] < num_splits, + ) sumexp_normalized_splitk = tl.exp(lse - lse_max) sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0) numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0) @@ -209,19 +192,13 @@ def block_sparse_flash_decode_gqa_indice_triton( num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H num_n_blocks = max_selected_blocks - size_one_kv_head = max_selected_blocks * block_size * ( - dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2 + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks num_sm = 64 # num_sm = self.num_sm num_splits = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) # print("num_splits:", num_splits, "num_blocks:", num_n_blocks) @@ -295,24 +272,18 @@ def block_sparse_flash_decode_gqa_indice_triton( return output -def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, - block_size): - +def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size): batch, heads, dim = query.shape heads_kv = key.shape[2] dim_v = value.shape[-1] num_head_groups = query.shape[1] // key.shape[2] scale = dim**0.5 - key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] - value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] + key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] - query = rearrange( - query, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] sparse_mask = torch.zeros_like(scores) # Assign mask values based on block_indices @@ -321,42 +292,33 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache valid_indices = block_indices[b, h] # Extract indices for this batch and head for idx in valid_indices: if idx >= 0: - sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1 - scores = scores.masked_fill(sparse_mask == 0, float('-inf')) + sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1 + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) - range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0) + range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0) cache_seqlens_expanded = cache_seqlens.unsqueeze(1) pad_mask = range_len >= cache_seqlens_expanded pad_mask = pad_mask[:, None, None, :] - scores = scores.masked_fill(pad_mask, float('-inf')) - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = scores.masked_fill(pad_mask, float("-inf")) + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] - out = einsum(attention, value, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out def ref_program_fa(query, key, value, cache_seqlens): # latency reference # from flash_attn_interface import flash_attn_with_kvcache # fa3 - from flash_attn import flash_attn_with_kvcache #fa2 + from flash_attn import flash_attn_with_kvcache # fa2 + query = query.unsqueeze(1) output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) output = output.squeeze(1) return output -def main(batch=64, - heads=32, - heads_kv=8, - max_cache_seqlen=8192, - dim=128, - dim_v=128, - sparse_ratio=0.8, - block_size=32): - +def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v sparse_ratio = sparse_ratio block_size = block_size @@ -369,34 +331,29 @@ def main(batch=64, dtype = torch.float16 block_H = 64 - Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') - K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') - V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') - cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda') + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') # Ensure at least one element equals cache_seqlen - random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index - cache_seqlens[ - random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence + random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index + cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence print("cache_seqlens: ", cache_seqlens) max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() print("max_valid_num_blocks: ", max_valid_num_blocks) # Initialize block_indices with -1 (for padding blocks) - block_indices = torch.full((batch, heads_kv, max_selected_blocks), - -1, - dtype=torch.int32, - device='cuda') + block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda") # Assign valid indices while ensuring no duplicates within each batch-group for b in range(batch): max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch if max_valid_block > 0: # Ensure there's at least one valid block for h in range(heads_kv): - valid_indices = torch.randperm( - max_valid_block, device='cuda', dtype=torch.int32)[:max_selected_blocks] - block_indices[b, h, :len(valid_indices)] = valid_indices + valid_indices = torch.randperm(max_valid_block, device="cuda", dtype=torch.int32)[:max_selected_blocks] + block_indices[b, h, : len(valid_indices)] = valid_indices # Sort indices within each batch-group for consistency block_indices, _ = block_indices.sort(dim=-1, descending=True) @@ -408,8 +365,7 @@ def main(batch=64, max_num_blocks = torch.max(max_valid_num_blocks).item() print("max_num_blocks: ", max_num_blocks) - ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, - block_size) + ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) triton_out = block_sparse_flash_decode_gqa_indice_triton( Q, @@ -423,8 +379,7 @@ def main(batch=64, ) print("max difference: ", torch.max(torch.abs(ref - triton_out))) - assert torch.allclose( - ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation" + assert torch.allclose(ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation" print("Passed the ref test!") # Measure performance @@ -466,15 +421,13 @@ def main(batch=64, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=64, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') - parser.add_argument( - '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--dim_v', type=int, default=128, help='dim_v') - parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') - parser.add_argument('--block_size', type=int, default=32, help='block_size') + parser.add_argument("--batch", type=int, default=64, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio") + parser.add_argument("--block_size", type=int, default=32, help="block_size") args = parser.parse_args() - main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, - args.sparse_ratio, args.block_size) + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size) diff --git a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py index 348572526583b33679d54f626e1d702dde3507ca..c05b3777952fddc834cc46377a823a4c14e0e999 100644 --- a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py +++ b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py @@ -11,12 +11,8 @@ from heuristic import num_splits_heuristic @triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [1, 2, 4]\ - for num_stages in [1, 2, 3, 4, 7] - ], - key=['BLOCK_H', 'BLOCK_N', 'BLOCK_D'], + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]], + key=["BLOCK_H", "BLOCK_N", "BLOCK_D"], ) @triton.jit def _split_kernel( @@ -77,16 +73,11 @@ def _split_kernel( loop_range = blocks_per_split q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h - k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[ - None, :] * stride_k_s + offs_d[:, None] * stride_k_d - v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, - None] * stride_v_s + offs_d[ - None, :] * stride_v_d + k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[None, :] * stride_k_s + offs_d[:, None] * stride_k_d + v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, None] * stride_v_s + offs_d[None, :] * stride_v_d mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h - q = tl.load( - q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, - mask=offs_h[:, None] < gqa_group_size) + q = tl.load(q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, mask=offs_h[:, None] < gqa_group_size) start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks) for block_idx in range(loop_range): start_n = (start + block_idx) * BLOCK_N @@ -117,23 +108,18 @@ def _split_kernel( acc = acc * l_recip acc = acc.to(o_partial_ptr.dtype.element_ty) - lse_partial_ptr += batch_idx * stride_lse_b + ( - head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split + lse_partial_ptr += batch_idx * stride_lse_b + (head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size) - o_partial_ptr += batch_idx * stride_o_b + ( - head_idx_q + - offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d + o_partial_ptr += ( + batch_idx * stride_o_b + (head_idx_q + offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d + ) tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size) @triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [1, 2, 4]\ - for num_stages in [1, 2, 3, 4, 7] - ], - key=['BLOCK_D'], + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]], + key=["BLOCK_D"], ) @triton.jit def _merge_kernel( @@ -161,18 +147,15 @@ def _merge_kernel( offs_d = tl.arange(0, BLOCK_D) lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h - lse = tl.load( - lse_offsets + offs_splits * lse_partial_stride_split, - mask=offs_splits < num_splits, - other=float("-inf")) + lse = tl.load(lse_offsets + offs_splits * lse_partial_stride_split, mask=offs_splits < num_splits, other=float("-inf")) lse_max = tl.max(lse) o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h o_partial = tl.load( - o_offsets + offs_splits[:, None] * o_partial_stride_split + - offs_d[None, :] * o_partial_stride_d, - mask=offs_splits[:, None] < num_splits) + o_offsets + offs_splits[:, None] * o_partial_stride_split + offs_d[None, :] * o_partial_stride_d, + mask=offs_splits[:, None] < num_splits, + ) sumexp_normalized_splitk = tl.exp(lse - lse_max) sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0) numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0) @@ -207,19 +190,13 @@ def block_sparse_flash_decode_gqa_mask_triton( num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H num_n_blocks = max_selected_blocks - size_one_kv_head = max_selected_blocks * block_size * ( - dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2 + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks num_sm = 64 # num_sm = self.num_sm num_splits = num_splits_heuristic( - total_mblocks, - num_sm, - num_n_blocks, - num_m_blocks, - size_one_kv_head, - is_causal_or_local=True, - max_splits=128) + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) # print("num_splits:", num_splits, "num_blocks:", num_n_blocks) @@ -292,24 +269,18 @@ def block_sparse_flash_decode_gqa_mask_triton( return output -def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, - block_size): - +def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size): batch, heads, dim = query.shape heads_kv = key.shape[2] num_head_groups = query.shape[1] // key.shape[2] scale = dim**0.5 - key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] - value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim] + key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] - query = rearrange( - query, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] sparse_mask = torch.zeros_like(scores) # Assign mask values @@ -317,43 +288,34 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se for h in range(heads_kv): for idx in range(num_blocks): if block_mask[b, h, idx]: - sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1 + sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1 - scores = scores.masked_fill(sparse_mask == 0, float('-inf')) + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) - range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0) + range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0) cache_seqlens_expanded = cache_seqlens.unsqueeze(1) pad_mask = range_len >= cache_seqlens_expanded pad_mask = pad_mask[:, None, None, :] - scores = scores.masked_fill(pad_mask, float('-inf')) - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] + scores = scores.masked_fill(pad_mask, float("-inf")) + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] - out = einsum(attention, value, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out def ref_program_fa(query, key, value, cache_seqlens): # latency reference # from flash_attn_interface import flash_attn_with_kvcache # fa3 - from flash_attn import flash_attn_with_kvcache #fa2 + from flash_attn import flash_attn_with_kvcache # fa2 + query = query.unsqueeze(1) output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) output = output.squeeze(1) return output -def main(batch=64, - heads=32, - heads_kv=8, - max_cache_seqlen=8192, - dim=128, - dim_v=128, - sparse_ratio=0.8, - block_size=32): - +def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v block_size = block_size sparse_ratio = sparse_ratio @@ -363,14 +325,13 @@ def main(batch=64, dtype = torch.float16 - Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') - K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') - V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') - cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda') + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") # Ensure at least one element equals cache_seqlen - random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index - cache_seqlens[ - random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence + random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index + cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence num_blocks = (max_cache_seqlen + block_size - 1) // block_size @@ -379,7 +340,7 @@ def main(batch=64, max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() print("max_valid_num_blocks: ", max_valid_num_blocks) # Initialize block_mask with false (for padding blocks) - block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device='cuda') + block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device="cuda") # Assign valid indices while ensuring no duplicates within each batch-group for b in range(batch): @@ -387,11 +348,10 @@ def main(batch=64, valid_num_block = valid_num_blocks[b].item() # Valid blocks for this batch if valid_num_block > 0: # Ensure there's at least one valid block for h in range(heads_kv): - perm = torch.randperm(max_valid_block, device='cuda')[:valid_num_block] + perm = torch.randperm(max_valid_block, device="cuda")[:valid_num_block] block_mask[b, h, perm] = True - ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, - block_size) + ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size) triton_out = block_sparse_flash_decode_gqa_mask_triton( Q, @@ -404,8 +364,7 @@ def main(batch=64, ) # print("max difference: ", torch.max(torch.abs(ref - triton_out))) - assert torch.allclose( - ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation" + assert torch.allclose(ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation" print("Passed the ref test!") # Measure performance @@ -448,15 +407,13 @@ def main(batch=64, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=64, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') - parser.add_argument( - '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--dim_v', type=int, default=128, help='dim_v') - parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') - parser.add_argument('--block_size', type=int, default=32, help='block_size') + parser.add_argument("--batch", type=int, default=64, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio") + parser.add_argument("--block_size", type=int, default=32, help="block_size") args = parser.parse_args() - main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, - args.sparse_ratio, args.block_size) + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size) diff --git a/examples/blocksparse_attention/heuristic.py b/examples/blocksparse_attention/heuristic.py index b60a81dc353aa908bf74f5a0aada4bf1a178cda1..0e6fc528196e3f111924b7d16b34d0c9af8c3800 100644 --- a/examples/blocksparse_attention/heuristic.py +++ b/examples/blocksparse_attention/heuristic.py @@ -1,8 +1,7 @@ import math -def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, size_one_kv_head, - is_causal_or_local, max_splits): +def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local, max_splits): """ Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency. diff --git a/examples/blocksparse_attention/test_example_blocksparse_attention.py b/examples/blocksparse_attention/test_example_blocksparse_attention.py index adda1f0f15764393dbc08f138e6f492d7e14c5ec..dd33f46c4ef9705350bc2cc8894cb715d4444346 100644 --- a/examples/blocksparse_attention/test_example_blocksparse_attention.py +++ b/examples/blocksparse_attention/test_example_blocksparse_attention.py @@ -25,26 +25,14 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask(): def test_example_triton_sparse_gqa_decode_varlen_indice(): example_triton_sparse_gqa_decode_varlen_indice.main( - batch=8, - heads=8, - heads_kv=4, - max_cache_seqlen=2048, - dim=128, - dim_v=128, - sparse_ratio=0.8, - block_size=32) + batch=8, heads=8, heads_kv=4, max_cache_seqlen=2048, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32 + ) def test_example_triton_sparse_gqa_decode_varlen_mask(): example_triton_sparse_gqa_decode_varlen_mask.main( - batch=16, - heads=16, - heads_kv=8, - max_cache_seqlen=1024, - dim=128, - dim_v=128, - sparse_ratio=0.8, - block_size=32) + batch=16, heads=16, heads_kv=8, max_cache_seqlen=1024, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32 + ) if __name__ == "__main__": diff --git a/examples/blocksparse_gemm/example_blocksparse_gemm.py b/examples/blocksparse_gemm/example_blocksparse_gemm.py index 8cd3a8218a816d2d78ba799e34d239bfa016c529..0cbef5e0cc50aca495c0c7fc19aa1f926ffc6ea3 100644 --- a/examples/blocksparse_gemm/example_blocksparse_gemm.py +++ b/examples/blocksparse_gemm/example_blocksparse_gemm.py @@ -19,8 +19,7 @@ parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M") parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N") parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K") parser.add_argument("--sparsity", type=float, default=0.5, help="Sparsity ratio (0-1)") -parser.add_argument( - "--use_autotune", action="store_true", default=False, help="Whether to use autotune") +parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune") args, _ = parser.parse_known_args() M, N, K = args.m, args.n, args.k @@ -41,17 +40,19 @@ def get_configs(): thread_num = [128, 256] enable_rasterization = [True, False] - _configs = list( - itertools.product(block_M, block_N, block_K, num_stages, thread_num, enable_rasterization)) + _configs = list(itertools.product(block_M, block_N, block_K, num_stages, thread_num, enable_rasterization)) - return [{ - "block_M": c[0], - "block_N": c[1], - "block_K": c[2], - "num_stages": c[3], - "thread_num": c[4], - "enable_rasteration": c[5], - } for c in _configs] + return [ + { + "block_M": c[0], + "block_N": c[1], + "block_K": c[2], + "num_stages": c[3], + "thread_num": c[4], + "enable_rasteration": c[5], + } + for c in _configs + ] def ref_program(A, B, BlockMask, block_M, block_N, block_K): @@ -61,12 +62,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K): accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device) for k in range(K // block_K): if BlockMask[i, j, k]: - accu += ( - A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to( - torch.float32) @ B[k * block_K:(k + 1) * block_K, - j * block_N:(j + 1) * block_N].to(torch.float32)) - ref_c[i * block_M:(i + 1) * block_M, - j * block_N:(j + 1) * block_N] = accu.to(torch.float16) + accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[ + k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N + ].to(torch.float32) + ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16) return ref_c @@ -89,28 +88,21 @@ def supply_program(params: List[KernelParam]): return input_tensors -@tilelang.autotune(configs=get_configs(),) +@tilelang.autotune( + configs=get_configs(), +) @tilelang.jit(out_idx=[-1]) -def blocksparse_matmul(M, - N, - K, - block_M, - block_N, - block_K, - num_stages, - thread_num, - enable_rasteration, - dtype="float16", - accum_dtype="float"): - +def blocksparse_matmul( + M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype="float16", accum_dtype="float" +): block_mask_shape = (M // block_M, N // block_N, K // block_K) @T.prim_func def block_sparse_matmul( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - BlockMask: T.Tensor(block_mask_shape, "bool"), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -134,7 +126,6 @@ def blocksparse_matmul(M, def main(): - # Initialize input matrices A and B on the GPU with half precision a = torch.randn(M, K).cuda().half() b = torch.randn(K, N).cuda().half() @@ -147,8 +138,7 @@ def main(): best_config = kernel.config best_latency = kernel.latency - block_M, block_N, block_K = best_config["block_M"], best_config["block_N"], best_config[ - "block_K"] + block_M, block_N, block_K = best_config["block_M"], best_config["block_N"], best_config["block_K"] print(f"Best Config: {best_config}") print(f"Sparsity Ratio: {sparsity}") @@ -163,7 +153,8 @@ def main(): block_K=DEFAULT_BLOCK_K, num_stages=DEFAULT_NUM_STAGES, thread_num=DEFAULT_THREAD_NUM, - enable_rasteration=DEFAULT_ENABLE_RASTERIZATION) + enable_rasteration=DEFAULT_ENABLE_RASTERIZATION, + ) block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})") # Create block mask with desired sparsity diff --git a/examples/cast/example_group_per_split_token_cast_to_fp8.py b/examples/cast/example_group_per_split_token_cast_to_fp8.py index 102ac20213b78ad011a358fc714f2ee9eb84a624..ec15b292e7d22f1cbb60e70cb73cbfcc9447cbfa 100644 --- a/examples/cast/example_group_per_split_token_cast_to_fp8.py +++ b/examples/cast/example_group_per_split_token_cast_to_fp8.py @@ -16,11 +16,13 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m): fp8_max = 448.0 @T.prim_func - def group_per_split_token_cast(X: T.Tensor((M, N), dtype), batch_sizes: T.Tensor( - (BG,), "int32"), X_fp8: T.Tensor((BG, M_max, N), "float8_e4m3"), X_amax: T.Tensor( - (BG, M_max, T.ceildiv(N, group_size)), accum_dtype)): - with T.Kernel( - T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz): + def group_per_split_token_cast( + X: T.Tensor((M, N), dtype), + batch_sizes: T.Tensor((BG,), "int32"), + X_fp8: T.Tensor((BG, M_max, N), "float8_e4m3"), + X_amax: T.Tensor((BG, M_max, T.ceildiv(N, group_size)), accum_dtype), + ): + with T.Kernel(T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz): row = bx row_g_id = by bg = bz @@ -31,36 +33,32 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m): y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3") row_offset = T.alloc_fragment((1,), "int32") - T.annotate_layout({ - y_local: - T.Fragment( - y_local.shape, - forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32), - }) + T.annotate_layout( + { + y_local: T.Fragment(y_local.shape, forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32), + } + ) row_offset[0] = 0 for i in T.serial(bg): row_offset[0] += batch_sizes[i] T.copy( - X[row_offset[0] + row * blk_m:row_offset[0] + (row + 1) * blk_m, - row_g_id * group_size:(row_g_id + 1) * group_size], y_local) + X[row_offset[0] + row * blk_m : row_offset[0] + (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size], + y_local, + ) T.reduce_absmax(y_local, y_amax_local, dim=1) for i in T.Parallel(blk_m): y_amax_local[i] = T.max(y_amax_local[i], 1e-4) - y_s_local[i] = T.if_then_else(row * blk_m + i < batch_sizes[bg], - y_amax_local[i] / fp8_max, 0) + y_s_local[i] = T.if_then_else(row * blk_m + i < batch_sizes[bg], y_amax_local[i] / fp8_max, 0) for i, j in T.Parallel(blk_m, group_size): y_q_local[i, j] = T.clamp(y_local[i, j] / y_s_local[i], fp8_min, fp8_max) T.copy(y_q_local, y_q_local_fp8) for i, j in T.Parallel(blk_m, group_size): - y_q_local_fp8[i, j] = T.if_then_else(row * blk_m + i < batch_sizes[bg], - y_q_local[i, j], 0) + y_q_local_fp8[i, j] = T.if_then_else(row * blk_m + i < batch_sizes[bg], y_q_local[i, j], 0) for i in T.Parallel(blk_m): X_amax[bg, row * blk_m + i, row_g_id] = y_s_local[i] - T.copy( - y_q_local_fp8, X_fp8[bg, row * blk_m:(row + 1) * blk_m, - row_g_id * group_size:(row_g_id + 1) * group_size]) + T.copy(y_q_local_fp8, X_fp8[bg, row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size]) return group_per_split_token_cast @@ -127,8 +125,7 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: return x.squeeze(0) if remove_dim else x # Normal layout requires transposing - aligned_x = torch.transpose( - torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) + aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) aligned_x[:, :m, :] = x aligned_x = aligned_x[:, :m, :] return aligned_x.squeeze(0) if remove_dim else aligned_x @@ -146,15 +143,17 @@ def ref_per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens x_fp8 = x_fp8.view(m, -1)[:, :n].contiguous() return x_fp8, (x_amax / 448.0).view(m, -1) -def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \ - Tuple[torch.Tensor, torch.Tensor]: + +def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # assert x.shape[0] == batch_sizes.sum() M_max = ceil_div(batch_sizes.max(), 128) * 128 split_x = torch.split(x, batch_sizes.tolist(), dim=0) padded_x = [torch.nn.functional.pad(t, (0, 0, 0, M_max - t.shape[0])) for t in split_x] num_groups, m, n = batch_sizes.shape[0], M_max, x.shape[1] - x_fp8 = (torch.empty((num_groups, m, n), device='cuda', dtype=torch.float8_e4m3fn), - torch.empty((num_groups, m, n // 128), device='cuda', dtype=torch.float)) + x_fp8 = ( + torch.empty((num_groups, m, n), device="cuda", dtype=torch.float8_e4m3fn), + torch.empty((num_groups, m, n // 128), device="cuda", dtype=torch.float), + ) for i in range(num_groups): x_fp8[0][i], x_fp8[1][i] = ref_per_token_cast_to_fp8(padded_x[i]) x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) diff --git a/examples/cast/example_per_token_cast_to_fp8.py b/examples/cast/example_per_token_cast_to_fp8.py index 484a092f09b76a63890f032d36522205d21f6d0f..45281ab1471c1134d3d42ecdf6c1e848c7276238 100644 --- a/examples/cast/example_per_token_cast_to_fp8.py +++ b/examples/cast/example_per_token_cast_to_fp8.py @@ -13,8 +13,9 @@ def per_token_cast_to_fp8(M, N, blk_m): fp8_max = 448.0 @T.prim_func - def per_token_cast(X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "float8_e4m3"), - X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype)): + def per_token_cast( + X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "float8_e4m3"), X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype) + ): with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (bx, by): row = bx row_g_id = by @@ -24,16 +25,13 @@ def per_token_cast_to_fp8(M, N, blk_m): y_q_local = T.alloc_fragment((blk_m, group_size), dtype) y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3") - T.annotate_layout({ - y_local: - T.Fragment( - y_local.shape, - forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32), - }) + T.annotate_layout( + { + y_local: T.Fragment(y_local.shape, forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32), + } + ) - T.copy( - X[row * blk_m:(row + 1) * blk_m, row_g_id * group_size:(row_g_id + 1) * group_size], - y_local) + T.copy(X[row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size], y_local) T.reduce_absmax(y_local, y_amax_local, dim=1) for i in T.Parallel(blk_m): y_amax_local[i] = T.max(y_amax_local[i], 1e-4) @@ -43,9 +41,7 @@ def per_token_cast_to_fp8(M, N, blk_m): T.copy(y_q_local, y_q_local_fp8) for i in T.Parallel(blk_m): X_amax[row * blk_m + i, row_g_id] = y_s_local[i] - T.copy( - y_q_local_fp8, X_fp8[row * blk_m:(row + 1) * blk_m, - row_g_id * group_size:(row_g_id + 1) * group_size]) + T.copy(y_q_local_fp8, X_fp8[row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size]) return per_token_cast @@ -105,8 +101,7 @@ def main(M=8192, N=8192, blk_m=8): from example_triton_cast_to_fp8 import per_token_group_quant_fp8 def run_triton(): - x_fp8_triton_, x_amax_triton_ = per_token_group_quant_fp8( - x, 128, 1e-4, dtype=torch.float8_e4m3fn, column_major_scales=False) + x_fp8_triton_, x_amax_triton_ = per_token_group_quant_fp8(x, 128, 1e-4, dtype=torch.float8_e4m3fn, column_major_scales=False) return x_fp8_triton_, x_amax_triton_ x_fp8_triton, x_amax_triton = run_triton() diff --git a/examples/cast/example_triton_cast_to_fp8.py b/examples/cast/example_triton_cast_to_fp8.py index cc56defe774b0d1467a39b0d199a2c015cfbf13b..1859433f10b6f6bd438846473b5661718c34fe4f 100644 --- a/examples/cast/example_triton_cast_to_fp8.py +++ b/examples/cast/example_triton_cast_to_fp8.py @@ -128,9 +128,7 @@ def per_token_group_quant_fp8( Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. """ - assert (x.shape[-1] % - group_size == 0), (f"the last dimension of `x` {x.shape[-1]} must be divisible " - f"by `group_size` {group_size}") + assert x.shape[-1] % group_size == 0, f"the last dimension of `x` {x.shape[-1]} must be divisible by `group_size` {group_size}" assert x.stride(-1) == 1, "`x` groups must be contiguous" finfo = torch.finfo(dtype) diff --git a/examples/cast/test_example_cast.py b/examples/cast/test_example_cast.py index 1ca000eb2cf9bfa916fb952dbfc499ddb413e4ed..e8b10a7979cf6506ec93c21bc8e9d3ddec2cc214 100644 --- a/examples/cast/test_example_cast.py +++ b/examples/cast/test_example_cast.py @@ -4,8 +4,7 @@ import example_per_token_cast_to_fp8 def test_example_group_per_split_token_cast_to_fp8(): - example_group_per_split_token_cast_to_fp8.main( - M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896]) + example_group_per_split_token_cast_to_fp8.main(M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896]) def test_example_per_token_cast_to_fp8(): diff --git a/examples/compile_flags/usecase.py b/examples/compile_flags/usecase.py index 8451b04fcf6ac3879ed1d98e1c6cf76143396db0..80e2b784b25e803368efabd60e9e1f05a3df118d 100644 --- a/examples/compile_flags/usecase.py +++ b/examples/compile_flags/usecase.py @@ -4,12 +4,11 @@ import tilelang.language as T # @tilelang.jit(compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"]) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): @@ -36,8 +35,7 @@ block_K = 32 func = matmul(M, N, K, block_M, block_N, block_K) -jit_kernel = tilelang.compile( - func, out_idx=[2], target="cuda", compile_flags="-O3 --use_fast_math --expt-relaxed-constexpr") +jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags="-O3 --use_fast_math --expt-relaxed-constexpr") # or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"]) # or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3 --use_fast_math --expt-relaxed-constexpr"]) diff --git a/examples/conftest.py b/examples/conftest.py index 9f49d40a9b50e14e41915811589d0011d3c2c910..4010e0d83ae84c641151d6dd56dbf40ee42e301f 100644 --- a/examples/conftest.py +++ b/examples/conftest.py @@ -33,12 +33,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): "warnings", "error", } - if (sum( - len(terminalreporter.stats.get(k, [])) - for k in known_types.difference({"skipped", "deselected"})) == 0): + if sum(len(terminalreporter.stats.get(k, [])) for k in known_types.difference({"skipped", "deselected"})) == 0: terminalreporter.write_sep( "!", - (f"Error: No tests were collected. " - f"{dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"), + (f"Error: No tests were collected. {dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"), ) pytest.exit("No tests were collected.", returncode=5) diff --git a/examples/convolution/example_convolution.py b/examples/convolution/example_convolution.py index b2696ba8f537d2ff6638be9fa247f56034435be0..a84e5878af21de55bb7f44cbe86e661fa1c9f35b 100644 --- a/examples/convolution/example_convolution.py +++ b/examples/convolution/example_convolution.py @@ -14,7 +14,6 @@ def check_hopper(): def ref_program(stride, padding, dilation): - def main(A, B): A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W @@ -26,22 +25,7 @@ def ref_program(stride, padding, dilation): @tilelang.jit(out_idx=[2]) -def convolution(N, - C, - H, - W, - F, - K, - S, - D, - P, - block_M, - block_N, - block_K, - num_stages, - threads, - dtype="float16", - accum_dtype="float"): +def convolution(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"): KH, KW = K, K OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 @@ -51,13 +35,11 @@ def convolution(N, @T.prim_func def main( - data: T.Tensor((N, H, W, C), dtype), - kernel: T.Tensor((KH, KW, C, F), dtype), - out: T.Tensor((N, OH, OW, F), dtype), + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), ): - with T.Kernel( - T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), - threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by): data_shared = T.alloc_shared((block_M, block_K), dtype) kernel_shared = T.alloc_shared((block_K, block_N), dtype) out_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -66,11 +48,13 @@ def convolution(N, kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) - T.annotate_layout({ - out_shared: tilelang.layout.make_swizzled_layout(out_shared), - data_shared: tilelang.layout.make_swizzled_layout(data_shared), - kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), - }) + T.annotate_layout( + { + out_shared: tilelang.layout.make_swizzled_layout(out_shared), + data_shared: tilelang.layout.make_swizzled_layout(data_shared), + kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), + } + ) T.clear(out_local) for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): @@ -82,10 +66,8 @@ def convolution(N, m = by * block_M + i access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P access_w = m % OW * S + k // C % KW * D - P - in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and - (access_w < W)) - data_shared[i, j] = T.if_then_else( - in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) + in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W) + data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) T.gemm(data_shared, kernel_shared, out_local) @@ -97,15 +79,15 @@ def convolution(N, def main(argv=None): parser = argparse.ArgumentParser() - parser.add_argument('--n', type=int, default=128, help='n') - parser.add_argument('--c', type=int, default=128, help='c') - parser.add_argument('--h', type=int, default=64, help='h') - parser.add_argument('--w', type=int, default=64, help='w') - parser.add_argument('--f', type=int, default=128, help='f') - parser.add_argument('--k', type=int, default=3, help='k') - parser.add_argument('--s', type=int, default=1, help='s') - parser.add_argument('--d', type=int, default=1, help='d') - parser.add_argument('--p', type=int, default=1, help='p') + parser.add_argument("--n", type=int, default=128, help="n") + parser.add_argument("--c", type=int, default=128, help="c") + parser.add_argument("--h", type=int, default=64, help="h") + parser.add_argument("--w", type=int, default=64, help="w") + parser.add_argument("--f", type=int, default=128, help="f") + parser.add_argument("--k", type=int, default=3, help="k") + parser.add_argument("--s", type=int, default=1, help="s") + parser.add_argument("--d", type=int, default=1, help="d") + parser.add_argument("--p", type=int, default=1, help="p") args = parser.parse_args(argv) N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p diff --git a/examples/convolution/example_convolution_autotune.py b/examples/convolution/example_convolution_autotune.py index 393677489b7f0fb13ea3b1d3ae688043e5847a31..600b608a3cf74bf043c2cac072d1342c93dc9e39 100644 --- a/examples/convolution/example_convolution_autotune.py +++ b/examples/convolution/example_convolution_autotune.py @@ -14,7 +14,6 @@ def check_hopper(): def ref_program(stride, padding, dilation): - def main(A, B): A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W @@ -40,7 +39,8 @@ def get_configs(): num_stages, thread_num, enable_rasterization, - )) + ) + ) configs = [ { @@ -50,7 +50,8 @@ def get_configs(): "num_stages": c[3], "thread_num": c[4], "enable_rasteration": c[5], # keep param name for backward-compat - } for c in _configs + } + for c in _configs ] return configs @@ -64,53 +65,18 @@ def get_heuristic_config() -> dict: sm_version = sm_major * 10 + sm_minor print(f"CUDA device capability: {sm_version}") if sm_version in {80}: - return { - "block_M": 128, - "block_N": 256, - "block_K": 32, - "num_stages": 2, - "thread_num": 128, - "enable_rasteration": True - } + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 2, "thread_num": 128, "enable_rasteration": True} elif sm_version in {90}: - return { - "block_M": 128, - "block_N": 256, - "block_K": 64, - "num_stages": 3, - "thread_num": 256, - "enable_rasteration": True - } + return {"block_M": 128, "block_N": 256, "block_K": 64, "num_stages": 3, "thread_num": 256, "enable_rasteration": True} else: - return { - "block_M": 128, - "block_N": 256, - "block_K": 32, - "num_stages": 0, - "thread_num": 128, - "enable_rasteration": True - } + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 0, "thread_num": 128, "enable_rasteration": True} @tilelang.autotune(configs=get_configs()) @tilelang.jit(out_idx=[2]) -def convolution(N, - C, - H, - W, - F, - K, - S, - D, - P, - block_M, - block_N, - block_K, - num_stages, - thread_num, - enable_rasteration, - dtype="float16", - accum_dtype="float"): +def convolution( + N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype="float16", accum_dtype="float" +): KH, KW = K, K OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 @@ -120,13 +86,11 @@ def convolution(N, @T.prim_func def main( - data: T.Tensor((N, H, W, C), dtype), - kernel: T.Tensor((KH, KW, C, F), dtype), - out: T.Tensor((N, OH, OW, F), dtype), + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), ): - with T.Kernel( - T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), - threads=thread_num) as (bx, by): + with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=thread_num) as (bx, by): data_shared = T.alloc_shared((block_M, block_K), dtype) kernel_shared = T.alloc_shared((block_K, block_N), dtype) out_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -136,9 +100,11 @@ def convolution(N, out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) if is_hopper: - T.annotate_layout({ - out_shared: tilelang.layout.make_swizzled_layout(out_shared), - }) + T.annotate_layout( + { + out_shared: tilelang.layout.make_swizzled_layout(out_shared), + } + ) T.clear(out_local) for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): @@ -150,10 +116,8 @@ def convolution(N, m = by * block_M + i access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P access_w = m % OW * S + k // C % KW * D - P - in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and - (access_w < W)) - data_shared[i, j] = T.if_then_else( - in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) + in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W) + data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) T.gemm(data_shared, kernel_shared, out_local) @@ -166,17 +130,19 @@ def convolution(N, return main -def main(n: int = 128, - c: int = 128, - h: int = 64, - w: int = 64, - f: int = 128, - k: int = 3, - s: int = 1, - d: int = 1, - p: int = 1, - use_autotune: bool = False, - with_roller: bool = True): +def main( + n: int = 128, + c: int = 128, + h: int = 64, + w: int = 64, + f: int = 128, + k: int = 3, + s: int = 1, + d: int = 1, + p: int = 1, + use_autotune: bool = False, + with_roller: bool = True, +): N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p ref_prog = ref_program(S, P, D) @@ -196,25 +162,16 @@ def main(n: int = 128, if __name__ == "__main__": parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") - parser.add_argument('--n', type=int, default=128, help='n') - parser.add_argument('--c', type=int, default=128, help='c') - parser.add_argument('--h', type=int, default=64, help='h') - parser.add_argument('--w', type=int, default=64, help='w') - parser.add_argument('--f', type=int, default=128, help='f') - parser.add_argument('--k', type=int, default=3, help='k') - parser.add_argument('--s', type=int, default=1, help='s') - parser.add_argument('--d', type=int, default=1, help='d') - parser.add_argument('--p', type=int, default=1, help='p') - parser.add_argument( - "--use_autotune", - action="store_true", - default=False, - help="Whether to use autotune for matmul configs") - parser.add_argument( - "--with_roller", - action="store_true", - default=True, - help="Whether to enable BitBLAS roller for search space") + parser.add_argument("--n", type=int, default=128, help="n") + parser.add_argument("--c", type=int, default=128, help="c") + parser.add_argument("--h", type=int, default=64, help="h") + parser.add_argument("--w", type=int, default=64, help="w") + parser.add_argument("--f", type=int, default=128, help="f") + parser.add_argument("--k", type=int, default=3, help="k") + parser.add_argument("--s", type=int, default=1, help="s") + parser.add_argument("--d", type=int, default=1, help="d") + parser.add_argument("--p", type=int, default=1, help="p") + parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs") + parser.add_argument("--with_roller", action="store_true", default=True, help="Whether to enable BitBLAS roller for search space") args = parser.parse_args() - main(args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p, args.use_autotune, - args.with_roller) + main(args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p, args.use_autotune, args.with_roller) diff --git a/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py b/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py index 715f09a9b143f55d793f2ccbf206db9e7d666d0e..8aba9140656ffb2bdf45c9b94645a6e40824343f 100644 --- a/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py +++ b/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py @@ -41,14 +41,13 @@ def tl_gemm( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - scales_a: T.Tensor(Scales_A_shape, "float32"), - scales_b: T.Tensor(Scales_B_shape, "float32"), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + scales_a: T.Tensor(Scales_A_shape, "float32"), + scales_b: T.Tensor(Scales_B_shape, "float32"), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) C_shared = T.alloc_shared(C_shared_shape, out_dtype) @@ -93,21 +92,18 @@ def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: m, n = x.shape x_view = x.view(m, -1, 128) x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view( - m, n), (x_amax / 448.0).view(m, -1) + return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape - x_padded = torch.zeros( - ceildiv(m, 128) * 128, ceildiv(n, 128) * 128, dtype=x.dtype, device=x.device) + x_padded = torch.zeros(ceildiv(m, 128) * 128, ceildiv(n, 128) * 128, dtype=x.dtype, device=x.device) x_padded[:m, :n] = x x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view( - x_view.size(0), x_view.size(2)) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype): @@ -127,13 +123,14 @@ def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype): c_acc.zero_() for k in range(ceildiv(K, 128)): c = torch._scaled_mm( - A_fp8[i * 128:(i + 1) * 128, k * 128:(k + 1) * 128], - B_fp8[j * 128:(j + 1) * 128, k * 128:(k + 1) * 128].T, + A_fp8[i * 128 : (i + 1) * 128, k * 128 : (k + 1) * 128], + B_fp8[j * 128 : (j + 1) * 128, k * 128 : (k + 1) * 128].T, scale_a=A_scales[i, k].view(128, 1).contiguous(), scale_b=B_scales[j, k].view(1, 128).contiguous(), - out_dtype=torch.bfloat16) + out_dtype=torch.bfloat16, + ) c_acc += c.to(torch.float32) - C[i * 128:(i + 1) * 128, j * 128:(j + 1) * 128] = c_acc.to(out_dtype) + C[i * 128 : (i + 1) * 128, j * 128 : (j + 1) * 128] = c_acc.to(out_dtype) return C diff --git a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py index 61c3b63c0b5d81c72dd983236f269107b33423f1..49958379888978f6f344611cf0daa39d54876ed4 100644 --- a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py +++ b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py @@ -8,6 +8,7 @@ import argparse def get_configs(): import itertools + BLOCK_N = [16, 32, 64, 128] BLOCK_H = [16, 32, 64, 128] num_split = [1, 2, 4, 8, 16, 32] @@ -15,30 +16,26 @@ def get_configs(): _configs = list(itertools.product(BLOCK_N, BLOCK_H, num_split, threads)) - return [{ - "block_N": c[0], - "block_H": c[1], - "num_split": c[2], - "threads": c[3], - } for c in _configs] + return [ + { + "block_N": c[0], + "block_H": c[1], + "num_split": c[2], + "threads": c[3], + } + for c in _configs + ] @tilelang.autotune(configs=get_configs()) @tilelang.jit( - out_idx=[6], pass_configs={ + out_idx=[6], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashmla_decode(batch, - heads, - kv_head_num, - seqlen_kv, - dim, - pe_dim, - block_N, - block_H, - num_split, - threads=128): - scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) + }, +) +def flashmla_decode(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, threads=128): + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) dtype = "float16" accum_dtype = "float" kv_group_num = heads // kv_head_num @@ -47,11 +44,11 @@ def flashmla_decode(batch, @T.macro def flash_attn( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=threads) as (bx, by): Q_local = T.alloc_fragment([block_H, dim], dtype) @@ -70,24 +67,19 @@ def flashmla_decode(batch, cur_kv_head = by // (kv_group_num // block_H) T.use_swizzle(10) - T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_local) - T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_local) + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_local) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_local) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = T.ceildiv(seqlen_kv, block_N) for k in T.Pipelined(loop_range, num_stages=0): - T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared) - T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared) + T.copy(KV[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared) T.clear(acc_s) T.gemm(Q_local, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.gemm( - Q_pe_local, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_pe_local, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) @@ -107,20 +99,18 @@ def flashmla_decode(batch, T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, dim): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :]) + T.copy(acc_o, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :]) @T.macro def flash_attn_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), ): - with T.Kernel( - batch, heads // min(block_H, kv_group_num), num_split, - threads=threads) as (bx, by, bz): + with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=threads) as (bx, by, bz): Q_local = T.alloc_fragment([block_H, dim], dtype) Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype) KV_shared = T.alloc_shared([block_N, dim], dtype) @@ -136,8 +126,8 @@ def flashmla_decode(batch, cur_kv_head = by // (kv_group_num // block_H) T.use_swizzle(10) - T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_local) - T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_local) + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_local) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_local) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -150,12 +140,7 @@ def flashmla_decode(batch, T.copy(K_pe[bx, kv_start:kv_end, cur_kv_head, :], K_pe_shared) T.clear(acc_s) T.gemm(Q_local, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.gemm( - Q_pe_local, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_pe_local, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) @@ -176,14 +161,14 @@ def flashmla_decode(batch, acc_o[i, j] /= logsum[i] for i in T.Parallel(block_H): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, glse[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz]) - T.copy(acc_o, Output_partial[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz, :]) + T.copy(logsum, glse[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz]) + T.copy(acc_o, Output_partial[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz, :]) @T.macro def combine( - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(heads, batch, threads=128) as (by, bz): po_local = T.alloc_fragment([dim], dtype) @@ -193,9 +178,11 @@ def flashmla_decode(batch, lse_max_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype) - T.annotate_layout({ - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) T.clear(lse_logsum_local) T.clear(o_accum_local) @@ -218,26 +205,26 @@ def flashmla_decode(batch, @T.prim_func def main_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial) combine(glse, Output_partial, Output) @T.prim_func def main_no_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): flash_attn(Q, Q_pe, KV, K_pe, Output) @@ -262,43 +249,36 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): dim = q.shape[-1] pe_dim = q_pe.shape[-1] num_head_groups = q.shape[1] // kv.shape[2] - scale = (dim + pe_dim)**0.5 - q = rearrange( - q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - q_pe = rearrange( - q_pe, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] - kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] query = torch.concat([q, q_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1) - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, kv, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=128, help='batch size') - parser.add_argument('--heads', type=int, default=128, help='q heads number') - parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') - parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') - parser.add_argument('--autotune', action='store_true', help='auto tune') + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") + parser.add_argument("--autotune", action="store_true", help="auto tune") args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim enable_autotune = args.autotune @@ -314,17 +294,7 @@ if __name__ == "__main__": if enable_autotune: kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim) else: - kernel = flashmla_decode( - batch, - heads, - kv_heads, - kv_ctx, - dim, - pe_dim, - BLOCK_N, - BLOCK_H, - num_split, - threads=threads) + kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, threads=threads) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) input_tensors = profiler._get_inputs() tilelang_output = kernel(*input_tensors) diff --git a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py index 0006d946875a890782e50d2bee558a0842dbfaf9..18c0a5f86d7625af022832d36f58b123c0feb0f8 100644 --- a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py +++ b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py @@ -32,8 +32,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): @torch.inference_mode() -def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, - h_kv, d, dv, causal, dtype): +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): blocked_v = blocked_k[..., :dv] def ref_mla(): @@ -94,8 +93,7 @@ def _mla_attn_kernel( offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) - offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[ - None, :] + offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :] q_nope = tl.load(Q_nope + offs_q_nope) offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) @@ -141,9 +139,7 @@ def _mla_attn_kernel( e_sum = e_sum * re_scale + tl.sum(p, 1) e_max = n_e_max - offs_o = cur_batch * stride_o_b + cur_head[:, - None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[ - None, :] + offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :] tl.store(O + offs_o, acc / e_sum[:, None]) offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV tl.store(O + offs_o_1, e_max + tl.log(e_sum)) @@ -309,24 +305,30 @@ def mla_decode_triton( @torch.inference_mode() -def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, - cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - +def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): blocked_v = blocked_k[..., :dv] assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() def flash_mla_triton(): num_kv_splits = 32 o = torch.empty([b * s_q, h_q, dv]) attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) mla_decode_triton( - q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope.view(-1, dv), - blocked_k_pe.view(-1, d - dv), o, block_table, cache_seqlens, attn_logits, - num_kv_splits, 1 / math.sqrt(d), block_size) + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, d - dv), + blocked_k_nope.view(-1, dv), + blocked_k_pe.view(-1, d - dv), + o, + block_table, + cache_seqlens, + attn_logits, + num_kv_splits, + 1 / math.sqrt(d), + block_size, + ) return o.view([b, s_q, h_q, dv]) out_flash = flash_mla_triton() @@ -362,14 +364,15 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal q = torch.randn(b, s_q, h_q, d) block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) - out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_a, lse_a, perf_a = baseline_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" if target not in ["flash_mla_triton"]: @@ -377,21 +380,14 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(dtype).bits // 8) - print( - f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s" - ) - print( - f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s") + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - print( - f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}" - ) + print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") torch.set_default_dtype(dtype) device = torch.device("cuda:0") torch.set_default_device(device) @@ -408,19 +404,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): q = torch.randn(b, s_q, h_q, d) block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(dtype).bits // 8) - print( - f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") return bytes / 10**6 / perf_b @@ -429,26 +422,22 @@ available_targets = [ "flash_mla_triton", ] -shape_configs = [{ - "b": - batch, - "s_q": - 1, - "cache_seqlens": - torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), - "h_q": - head, - "h_kv": - 1, - "d": - 512 + 64, - "dv": - 512, - "causal": - True, - "dtype": - torch.float16 -} for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 16384] for head in [128]] +shape_configs = [ + { + "b": batch, + "s_q": 1, + "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), + "h_q": head, + "h_kv": 1, + "d": 512 + 64, + "dv": 512, + "causal": True, + "dtype": torch.float16, + } + for batch in [128] + for seqlen in [1024, 2048, 4096, 8192, 16384] + for head in [128] +] def get_args(): @@ -470,26 +459,54 @@ if __name__ == "__main__": for shape in shape_configs: if args.all: for target in available_targets: - perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], - shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], - shape["causal"], shape["dtype"]) + perf = compare_a( + target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" ) elif args.compare: - perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], - shape["cache_seqlens"], shape["h_q"], shape["h_kv"], - shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + perfa, prefb = compare_ab( + args.baseline, + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n' + f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n" ) fout.write( - f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n' + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n" ) elif args.one: - perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], - shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], - shape["causal"], shape["dtype"]) + perf = compare_a( + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" ) diff --git a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py index 644f97da15c12ddc1ccb1fd17b3be8b41a80b106..861e841c4ec8b68851cd4bfdbfdce0fede87960f 100644 --- a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py +++ b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py @@ -29,8 +29,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): @torch.inference_mode() -def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, - h_kv, d, dv, causal, dtype): +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): blocked_v = blocked_k[..., :dv] def ref_mla(): @@ -91,8 +90,7 @@ def _mla_attn_kernel( offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) - offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[ - None, :] + offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :] q_nope = tl.load(Q_nope + offs_q_nope) offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) @@ -138,9 +136,7 @@ def _mla_attn_kernel( e_sum = e_sum * re_scale + tl.sum(p, 1) e_max = n_e_max - offs_o = cur_batch * stride_o_b + cur_head[:, - None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[ - None, :] + offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :] tl.store(O + offs_o, acc / e_sum[:, None]) offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV tl.store(O + offs_o_1, e_max + tl.log(e_sum)) @@ -306,24 +302,30 @@ def mla_decode_triton( @torch.inference_mode() -def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, - cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - +def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): blocked_v = blocked_k[..., :dv] assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() def flash_mla_triton(): num_kv_splits = 32 o = torch.empty([b * s_q, h_q, dv]) attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) mla_decode_triton( - q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope.view(-1, dv), - blocked_k_pe.view(-1, d - dv), o, block_table, cache_seqlens, attn_logits, - num_kv_splits, 1 / math.sqrt(d), block_size) + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, d - dv), + blocked_k_nope.view(-1, dv), + blocked_k_pe.view(-1, d - dv), + o, + block_table, + cache_seqlens, + attn_logits, + num_kv_splits, + 1 / math.sqrt(d), + block_size, + ) return o.view([b, s_q, h_q, dv]) out_flash = flash_mla_triton() @@ -359,14 +361,15 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal q = torch.randn(b, s_q, h_q, d) block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) - out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_a, lse_a, perf_a = baseline_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" if target not in ["flash_mla_triton"]: @@ -374,21 +377,14 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(dtype).bits // 8) - print( - f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s" - ) - print( - f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s") + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - print( - f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}" - ) + print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") torch.set_default_dtype(dtype) device = torch.device("cuda:0") torch.set_default_device(device) @@ -405,19 +401,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): q = torch.randn(b, s_q, h_q, d) block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(dtype).bits // 8) - print( - f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") return bytes / 10**6 / perf_b @@ -426,26 +419,22 @@ available_targets = [ "flash_mla_triton", ] -shape_configs = [{ - "b": - batch, - "s_q": - 1, - "cache_seqlens": - torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), - "h_q": - head, - "h_kv": - 1, - "d": - 512 + 64, - "dv": - 512, - "causal": - True, - "dtype": - torch.float16 -} for batch in [64, 128] for seqlen in [1024, 2048, 4096, 8192, 16384] for head in [128]] +shape_configs = [ + { + "b": batch, + "s_q": 1, + "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), + "h_q": head, + "h_kv": 1, + "d": 512 + 64, + "dv": 512, + "causal": True, + "dtype": torch.float16, + } + for batch in [64, 128] + for seqlen in [1024, 2048, 4096, 8192, 16384] + for head in [128] +] def get_args(): @@ -467,26 +456,54 @@ if __name__ == "__main__": for shape in shape_configs: if args.all: for target in available_targets: - perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], - shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], - shape["causal"], shape["dtype"]) + perf = compare_a( + target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" ) elif args.compare: - perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], - shape["cache_seqlens"], shape["h_q"], shape["h_kv"], - shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + perfa, prefb = compare_ab( + args.baseline, + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n' + f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n" ) fout.write( - f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n' + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n" ) elif args.one: - perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], - shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], - shape["causal"], shape["dtype"]) + perf = compare_a( + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" ) diff --git a/examples/deepseek_mla/benchmark_mla.py b/examples/deepseek_mla/benchmark_mla.py index a542ff611d1be481732696baaff09623ac283f2d..544b5e1285c173e1521f049e1de9521baa53afee 100644 --- a/examples/deepseek_mla/benchmark_mla.py +++ b/examples/deepseek_mla/benchmark_mla.py @@ -33,8 +33,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): @torch.inference_mode() -def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, - h_kv, d, dv, causal, dtype): +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): blocked_v = blocked_k[..., :dv] def ref_mla(): @@ -61,8 +60,7 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, @torch.inference_mode() -def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, - h_kv, d, dv, causal, dtype): +def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): from flash_mla import flash_mla_with_kvcache, get_mla_metadata blocked_v = blocked_k[..., :dv] @@ -87,14 +85,13 @@ def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, @torch.inference_mode() -def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, - h_q, h_kv, d, dv, causal, dtype): +def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): # pip install flashinfer-python import flashinfer + assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() kv_indptr = [0] kv_indices = [] @@ -111,8 +108,7 @@ def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32) - mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( - torch.empty(128 * 1024 * 1024, dtype=torch.int8), backend="fa3") + mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(torch.empty(128 * 1024 * 1024, dtype=torch.int8), backend="fa3") mla_wrapper.plan( q_indptr, kv_indptr, @@ -129,12 +125,7 @@ def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q ) def flashinfer(): - output, lse = mla_wrapper.run( - q_nope.view(-1, h_q, dv), - q_pe.view(-1, h_q, d - dv), - blocked_k_nope, - blocked_k_pe, - return_lse=True) + output, lse = mla_wrapper.run(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope, blocked_k_pe, return_lse=True) return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1) out_flash, lse_flash = flashinfer() @@ -177,8 +168,7 @@ def _mla_attn_kernel( offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) - offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[ - None, :] + offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :] q_nope = tl.load(Q_nope + offs_q_nope) offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) @@ -224,9 +214,7 @@ def _mla_attn_kernel( e_sum = e_sum * re_scale + tl.sum(p, 1) e_max = n_e_max - offs_o = cur_batch * stride_o_b + cur_head[:, - None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[ - None, :] + offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :] tl.store(O + offs_o, acc / e_sum[:, None]) offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV tl.store(O + offs_o_1, e_max + tl.log(e_sum)) @@ -393,24 +381,30 @@ def mla_decode_triton( @torch.inference_mode() -def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, - cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - +def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): blocked_v = blocked_k[..., :dv] assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() def flash_mla_triton(): num_kv_splits = 32 o = torch.empty([b * s_q, h_q, dv]) attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) mla_decode_triton( - q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope.view(-1, dv), - blocked_k_pe.view(-1, d - dv), o, block_table, cache_seqlens, attn_logits, - num_kv_splits, 1 / math.sqrt(d), block_size) + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, d - dv), + blocked_k_nope.view(-1, dv), + blocked_k_pe.view(-1, d - dv), + o, + block_table, + cache_seqlens, + attn_logits, + num_kv_splits, + 1 / math.sqrt(d), + block_size, + ) return o.view([b, s_q, h_q, dv]) out_flash = flash_mla_triton() @@ -419,13 +413,10 @@ def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, @torch.inference_mode() -def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, - cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - +def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() dpe = d - dv num_kv_splits = 1 @@ -434,8 +425,7 @@ def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) - kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, - num_kv_splits, block_size) + kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size) def flash_mla_tilelang(): out = kernel( @@ -486,38 +476,31 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal q = torch.randn(b, s_q, h_q, d) block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) - out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_a, lse_a, perf_a = baseline_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" - if target not in ["flashinfer", "flash_mla_triton", "tilelang" - ] and baseline not in ["flashinfer", "flash_mla_triton", "tilelang"]: + if target not in ["flashinfer", "flash_mla_triton", "tilelang"] and baseline not in ["flashinfer", "flash_mla_triton", "tilelang"]: # flashinfer has a different lse return value # flash_mla_triton and flash_mla_tilelang doesn't return lse torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(dtype).bits // 8) - print( - f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s" - ) - print( - f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s") + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - print( - f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}" - ) + print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") torch.set_default_dtype(dtype) device = torch.device("cuda:0") torch.set_default_device(device) @@ -534,19 +517,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): q = torch.randn(b, s_q, h_q, d) block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(dtype).bits // 8) - print( - f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") return bytes / 10**6 / perf_b @@ -558,26 +538,22 @@ available_targets = [ "flash_mla_triton", ] -shape_configs = [{ - "b": - batch, - "s_q": - 1, - "cache_seqlens": - torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), - "h_q": - head, - "h_kv": - 1, - "d": - 512 + 64, - "dv": - 512, - "causal": - True, - "dtype": - torch.float16 -} for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 16384, 32768] for head in [128]] +shape_configs = [ + { + "b": batch, + "s_q": 1, + "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), + "h_q": head, + "h_kv": 1, + "d": 512 + 64, + "dv": 512, + "causal": True, + "dtype": torch.float16, + } + for batch in [128] + for seqlen in [1024, 2048, 4096, 8192, 16384, 32768] + for head in [128] +] def get_args(): @@ -599,26 +575,54 @@ if __name__ == "__main__": for shape in shape_configs: if args.all: for target in available_targets: - perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], - shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], - shape["causal"], shape["dtype"]) + perf = compare_a( + target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" ) elif args.compare: - perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], - shape["cache_seqlens"], shape["h_q"], shape["h_kv"], - shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + perfa, prefb = compare_ab( + args.baseline, + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n' + f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n" ) fout.write( - f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n' + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n" ) elif args.one: - perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], - shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], - shape["causal"], shape["dtype"]) + perf = compare_a( + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) fout.write( - f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" ) diff --git a/examples/deepseek_mla/example_mla_decode.py b/examples/deepseek_mla/example_mla_decode.py index 3932d112e649b3803edb9532fbeeb13d69eef213..733ae3c460480b4d01ed96da58967ddd4c184f69 100644 --- a/examples/deepseek_mla/example_mla_decode.py +++ b/examples/deepseek_mla/example_mla_decode.py @@ -8,11 +8,12 @@ import argparse @tilelang.jit( - out_idx=[6], pass_configs={ + out_idx=[6], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, - softmax_scale): + }, +) +def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale): scale = float(softmax_scale * 1.44269504) # log2(e) dtype = "float16" accum_dtype = "float" @@ -22,11 +23,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ @T.macro def flash_attn( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid): Q_shared = T.alloc_shared([block_H, dim], dtype) @@ -44,33 +45,24 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ logsum = T.alloc_fragment([block_H], accum_dtype) cur_kv_head = hid // (kv_group_num // block_H) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - }) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + } + ) - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = T.ceildiv(seqlen_kv, block_N) for k in T.Pipelined(loop_range, num_stages=2): - T.copy(KV[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared) - T.copy(K_pe[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared) - T.gemm( - Q_shared, - KV_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol, - clear_accum=True) - T.gemm( - Q_pe_shared, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + T.copy(KV[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) @@ -90,20 +82,18 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ for i, j in T.Parallel(block_H, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :]) + T.copy(O_shared, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :]) @T.macro def flash_attn_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), ): - with T.Kernel( - batch, heads // min(block_H, kv_group_num), num_split, - threads=256) as (bid, hid, bz): + with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=256) as (bid, hid, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) S_shared = T.alloc_shared([block_H, block_N], dtype) Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) @@ -121,13 +111,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ cur_kv_head = hid // (kv_group_num // block_H) T.use_swizzle(10) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - S_shared: tilelang.layout.make_swizzled_layout(S_shared), - }) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + S_shared: tilelang.layout.make_swizzled_layout(S_shared), + } + ) - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -139,14 +131,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared) T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared) T.clear(acc_s) - T.gemm( - Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) - T.gemm( - Q_pe_shared, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) @@ -168,16 +154,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ acc_o[i, j] /= logsum[i] for i in T.Parallel(block_H): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, bz]) + T.copy(logsum, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz]) T.copy(acc_o, O_shared) - T.copy(O_shared, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - bz, :]) + T.copy(O_shared, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz, :]) @T.macro def combine( - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(heads, batch, threads=128) as (hid, bz): po_local = T.alloc_fragment([dim], dtype) @@ -187,9 +172,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ lse_max_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype) - T.annotate_layout({ - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) T.clear(lse_logsum_local) T.clear(o_accum_local) @@ -212,26 +199,26 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ @T.prim_func def main_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial) combine(glse, Output_partial, Output) @T.prim_func def main_no_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): flash_attn(Q, Q_pe, KV, K_pe, Output) @@ -256,31 +243,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): dim = q.shape[-1] pe_dim = q_pe.shape[-1] num_head_groups = q.shape[1] // kv.shape[2] - scale = (dim + pe_dim)**0.5 - q = rearrange( - q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - q_pe = rearrange( - q_pe, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] - kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] query = torch.concat([q, q_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1) - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, kv, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out @@ -298,10 +278,9 @@ def main( BLOCK_N = 64 BLOCK_H = min(64, heads // kv_heads) num_split = 1 - softmax_scale = (dim + pe_dim)**-0.5 + softmax_scale = (dim + pe_dim) ** -0.5 - kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, - softmax_scale) + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, softmax_scale) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4) latency = profiler.do_bench(warmup=500) @@ -311,12 +290,12 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=132, help='batch size') - parser.add_argument('--heads', type=int, default=128, help='q heads number') - parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') - parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') + parser.add_argument("--batch", type=int, default=132, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim main(batch, heads, kv_heads, kv_ctx, dim, pe_dim) diff --git a/examples/deepseek_mla/example_mla_decode_paged.py b/examples/deepseek_mla/example_mla_decode_paged.py index d23ff00c40ab137f1d30f481feeeb1b049fb8258..dee05c1e99c644353ff7e49e00bf5716896ad6bd 100644 --- a/examples/deepseek_mla/example_mla_decode_paged.py +++ b/examples/deepseek_mla/example_mla_decode_paged.py @@ -8,22 +8,14 @@ import math @tilelang.jit( - out_idx=[8], pass_configs={ + out_idx=[8], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def mla_decode_tilelang(batch, - h_q, - h_kv, - max_seqlen_pad, - dv, - dpe, - block_N, - block_H, - num_split, - block_size, - softmax_scale=None): + }, +) +def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split, block_size, softmax_scale=None): if softmax_scale is None: - softmax_scale = (dv + dpe)**-0.5 + softmax_scale = (dv + dpe) ** -0.5 scale = float(softmax_scale * 1.44269504) # log2(e) dtype = "float16" accum_dtype = "float" @@ -34,13 +26,13 @@ def mla_decode_tilelang(batch, @T.macro def flash_mla_kernel( - Q: T.Tensor([batch, h_q, dv], dtype), - Q_pe: T.Tensor([batch, h_q, dpe], dtype), - KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), - K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), - BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), - CACHE_SEQLENS: T.Tensor([batch], "int32"), - Output: T.Tensor([batch, h_q, dv], dtype), + Q: T.Tensor([batch, h_q, dv], dtype), + Q_pe: T.Tensor([batch, h_q, dpe], dtype), + KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), + BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), + CACHE_SEQLENS: T.Tensor([batch], "int32"), + Output: T.Tensor([batch, h_q, dv], dtype), ): with T.Kernel(batch, h_q // min(block_H, kv_group_num), threads=256) as (bx, by): Q_shared = T.alloc_shared([block_H, dv], dtype) @@ -59,13 +51,15 @@ def mla_decode_tilelang(batch, cur_kv_head = by // (kv_group_num // block_H) T.use_swizzle(10) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - S_shared: tilelang.layout.make_swizzled_layout(S_shared), - }) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + S_shared: tilelang.layout.make_swizzled_layout(S_shared), + } + ) - T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -73,25 +67,17 @@ def mla_decode_tilelang(batch, loop_range = T.ceildiv(CACHE_SEQLENS[bx], block_N) for kr in T.Pipelined(loop_range, num_stages=2): k = loop_range - 1 - kr - kv_start = BLOCK_TABLE[bx, (k * block_N) // - block_size] * block_size + (k * block_N) % block_size - T.copy(KV[kv_start:kv_start + block_N, cur_kv_head, :], KV_shared) - T.copy(K_pe[kv_start:kv_start + block_N, cur_kv_head, :], K_pe_shared) + kv_start = BLOCK_TABLE[bx, (k * block_N) // block_size] * block_size + (k * block_N) % block_size + T.copy(KV[kv_start : kv_start + block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[kv_start : kv_start + block_N, cur_kv_head, :], K_pe_shared) T.clear(acc_s) - T.gemm( - Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) - T.gemm( - Q_pe_shared, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) if kr == 0: for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx], - -T.infinity(accum_dtype), acc_s[i, j]) + acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j]) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) @@ -109,21 +95,20 @@ def mla_decode_tilelang(batch, for i, j in T.Parallel(block_H, dv): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :]) + T.copy(O_shared, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :]) @T.macro def flash_mla_split_kv_kernel( - Q: T.Tensor([batch, h_q, dv], dtype), - Q_pe: T.Tensor([batch, h_q, dpe], dtype), - KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), - K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), - BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), - CACHE_SEQLENS: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, h_q, num_split], dtype), - Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + Q: T.Tensor([batch, h_q, dv], dtype), + Q_pe: T.Tensor([batch, h_q, dpe], dtype), + KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), + BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), + CACHE_SEQLENS: T.Tensor([batch], "int32"), + glse: T.Tensor([batch, h_q, num_split], dtype), + Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), ): - with T.Kernel( - batch, h_q // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz): + with T.Kernel(batch, h_q // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dv], dtype) S_shared = T.alloc_shared([block_H, block_N], dtype) Q_pe_shared = T.alloc_shared([block_H, dpe], dtype) @@ -141,13 +126,15 @@ def mla_decode_tilelang(batch, cur_kv_head = by // (kv_group_num // block_H) T.use_swizzle(10) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - S_shared: tilelang.layout.make_swizzled_layout(S_shared), - }) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + S_shared: tilelang.layout.make_swizzled_layout(S_shared), + } + ) - T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -155,28 +142,20 @@ def mla_decode_tilelang(batch, total_blocks = T.ceildiv(CACHE_SEQLENS[bx], block_N) blocks_per_split = T.floordiv(total_blocks, num_split) remaining_blocks = T.floormod(total_blocks, num_split) - loop_range = (blocks_per_split + T.if_then_else(bz < remaining_blocks, 1, 0)) + loop_range = blocks_per_split + T.if_then_else(bz < remaining_blocks, 1, 0) start = (blocks_per_split * bz + T.min(bz, remaining_blocks)) * block_N for k in T.Pipelined(loop_range, num_stages=2): - kv_start = BLOCK_TABLE[bx, (start + k * block_N) // - block_size] * block_size + (k * block_N) % block_size - T.copy(KV[kv_start:kv_start + block_N, cur_kv_head, :], KV_shared) - T.copy(K_pe[kv_start:kv_start + block_N, cur_kv_head, :], K_pe_shared) + kv_start = BLOCK_TABLE[bx, (start + k * block_N) // block_size] * block_size + (k * block_N) % block_size + T.copy(KV[kv_start : kv_start + block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[kv_start : kv_start + block_N, cur_kv_head, :], K_pe_shared) T.clear(acc_s) - T.gemm( - Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) - T.gemm( - Q_pe_shared, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else(start + k * block_N + j >= CACHE_SEQLENS[bx], - -T.infinity(accum_dtype), acc_s[i, j]) + acc_s[i, j] = T.if_then_else(start + k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j]) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) @@ -196,15 +175,15 @@ def mla_decode_tilelang(batch, acc_o[i, j] /= logsum[i] for i in T.Parallel(block_H): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, glse[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz]) + T.copy(logsum, glse[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz]) T.copy(acc_o, O_shared) - T.copy(O_shared, Output_partial[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz, :]) + T.copy(O_shared, Output_partial[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz, :]) @T.macro def combine( - glse: T.Tensor([batch, h_q, num_split], dtype), - Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), - Output: T.Tensor([batch, h_q, dv], dtype), + glse: T.Tensor([batch, h_q, num_split], dtype), + Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + Output: T.Tensor([batch, h_q, dv], dtype), ): with T.Kernel(h_q, batch, threads=128) as (by, bz): po_local = T.alloc_fragment([dv], dtype) @@ -214,9 +193,11 @@ def mla_decode_tilelang(batch, lse_max_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype) - T.annotate_layout({ - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) T.clear(lse_logsum_local) T.clear(o_accum_local) @@ -239,31 +220,30 @@ def mla_decode_tilelang(batch, @T.prim_func def main_split( - Q: T.Tensor([batch, h_q, dv], dtype), - Q_pe: T.Tensor([batch, h_q, dpe], dtype), - KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), - K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), - block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, h_q, num_split], dtype), - Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), - Output: T.Tensor([batch, h_q, dv], dtype), + Q: T.Tensor([batch, h_q, dv], dtype), + Q_pe: T.Tensor([batch, h_q, dpe], dtype), + KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), + block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), + cache_seqlens: T.Tensor([batch], "int32"), + glse: T.Tensor([batch, h_q, num_split], dtype), + Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + Output: T.Tensor([batch, h_q, dv], dtype), ): - flash_mla_split_kv_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, glse, - Output_partial) + flash_mla_split_kv_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, glse, Output_partial) combine(glse, Output_partial, Output) @T.prim_func def main_no_split( - Q: T.Tensor([batch, h_q, dv], dtype), - Q_pe: T.Tensor([batch, h_q, dpe], dtype), - KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), - K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), - block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - glse: T.Tensor([batch, h_q, num_split], dtype), - Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), - Output: T.Tensor([batch, h_q, dv], dtype), + Q: T.Tensor([batch, h_q, dv], dtype), + Q_pe: T.Tensor([batch, h_q, dpe], dtype), + KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), + block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), + cache_seqlens: T.Tensor([batch], "int32"), + glse: T.Tensor([batch, h_q, num_split], dtype), + Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + Output: T.Tensor([batch, h_q, dv], dtype), ): flash_mla_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, Output) @@ -284,8 +264,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): s_q = query.shape[-2] s_k = key.shape[-2] attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype, device=query.device) - temp_mask = torch.ones( - s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) attn_weight += attn_bias @@ -295,8 +274,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): @torch.inference_mode() -def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, - h_kv, d, dv, causal, dtype): +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): # q: [b, s_q, h_q, d] # block_table: [b, max_seqlen_pad // block_size] # blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d] @@ -325,13 +303,10 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, return out_torch -def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, - h_q, h_kv, d, dv, causal, dtype): - +def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() dpe = d - dv num_kv_splits = 1 @@ -341,8 +316,7 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) - kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, - num_kv_splits, block_size, softmax_scale) + kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size, softmax_scale) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) def flash_mla_tilelang(): @@ -360,8 +334,7 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s out_flash = flash_mla_tilelang() t = do_bench(flash_mla_tilelang) - out_ref = run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, - cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_ref = run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) torch.testing.assert_close(out_flash, out_ref, rtol=0.01, atol=0.01) print("All close") return out_flash, t @@ -369,12 +342,12 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=128, help='batch size') - parser.add_argument('--h_q', type=int, default=128, help='q heads number') - parser.add_argument('--h_kv', type=int, default=1, help='kv heads number') - parser.add_argument('--cache_seqlen', type=int, default=8192, help='kv cache context length') - parser.add_argument('--d', type=int, default=576, help='query/key head dim, d = dv + dpe') - parser.add_argument('--dv', type=int, default=512, help='value head dim') + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--h_q", type=int, default=128, help="q heads number") + parser.add_argument("--h_kv", type=int, default=1, help="kv heads number") + parser.add_argument("--cache_seqlen", type=int, default=8192, help="kv cache context length") + parser.add_argument("--d", type=int, default=576, help="query/key head dim, d = dv + dpe") + parser.add_argument("--dv", type=int, default=512, help="value head dim") args = parser.parse_args() b, h_q, h_kv, cache_seqlen, d, dv = args.batch, args.h_q, args.h_kv, args.cache_seqlen, args.d, args.dv @@ -383,9 +356,7 @@ if __name__ == "__main__": s_q = 1 # for decode, s_q = 1 block_size = 64 - cache_seqlens = torch.tensor([cache_seqlen + 2 * i for i in range(b)], - dtype=torch.int32, - device=device) + cache_seqlens = torch.tensor([cache_seqlen + 2 * i for i in range(b)], dtype=torch.int32, device=device) dpe = d - dv causal = True @@ -397,12 +368,11 @@ if __name__ == "__main__": total_flops = s_q * total_seqlens * h_q * d * 2 q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device) - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32, - device=device).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32, device=device).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d, dtype=dtype, device=device) - out_flash, latency = run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_flash, latency = run_tilelang_mla( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) diff --git a/examples/deepseek_mla/example_mla_decode_persistent.py b/examples/deepseek_mla/example_mla_decode_persistent.py index 2f896f2652a1420303e4c6a41285cf2587c12bbf..305fd30ed3d54f0bec1ba8c4e7297bac820d0ba7 100644 --- a/examples/deepseek_mla/example_mla_decode_persistent.py +++ b/examples/deepseek_mla/example_mla_decode_persistent.py @@ -9,11 +9,13 @@ import argparse @tilelang.jit( - out_idx=[6], pass_configs={ + out_idx=[6], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): - scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) dtype = "float16" accum_dtype = "float" kv_group_num = heads // kv_head_num @@ -23,13 +25,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ @T.prim_func def main_split_persistent( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(sm_num, threads=256) as (block_id): Q_shared = T.alloc_shared([block_H, dim], dtype) @@ -53,11 +55,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ lse_max_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype) - T.annotate_layout({ - # O_shared: tilelang.layout.make_swizzled_layout(O_shared), - S_shared: tilelang.layout.make_swizzled_layout(S_shared), - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) + T.annotate_layout( + { + # O_shared: tilelang.layout.make_swizzled_layout(O_shared), + S_shared: tilelang.layout.make_swizzled_layout(S_shared), + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) T.use_swizzle(10) total_tiles = batch * (heads // min(block_H, kv_group_num)) * num_split @@ -70,8 +74,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ cur_kv_head = hid // (kv_group_num // block_H) if bid < batch and hid * VALID_BLOCK_H < heads and sid < num_split: - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -83,26 +87,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared) T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared) T.clear(acc_s) - T.gemm( - Q_shared, - KV_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) - T.gemm( - Q_pe_shared, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - - scores_max[i] * scale) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) @@ -117,11 +110,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ acc_o[i, j] /= logsum[i] for i in T.Parallel(block_H): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, sid]) + T.copy(logsum, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, sid]) # T.copy(acc_o, O_shared) - T.copy( - acc_o, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - sid, :]) + T.copy(acc_o, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, sid, :]) T.sync_grid() waves = T.ceildiv(heads * batch, sm_num) @@ -167,42 +158,35 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): dim = q.shape[-1] pe_dim = q_pe.shape[-1] num_head_groups = q.shape[1] // kv.shape[2] - scale = (dim + pe_dim)**0.5 - q = rearrange( - q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - q_pe = rearrange( - q_pe, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] - kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] query = torch.concat([q, q_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1) - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, kv, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out def main(): parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=128, help='batch size') - parser.add_argument('--heads', type=int, default=128, help='q heads number') - parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') - parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) diff --git a/examples/deepseek_mla/example_mla_decode_ws.py b/examples/deepseek_mla/example_mla_decode_ws.py index fcd427efaa27ac6e9c0b80721b70b1d45b892326..3fb90a556d2232eff10ee8069856b05118ebcf7b 100644 --- a/examples/deepseek_mla/example_mla_decode_ws.py +++ b/examples/deepseek_mla/example_mla_decode_ws.py @@ -13,14 +13,19 @@ import argparse tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }, compile_flags=[ - "-O3", "-Wno-deprecated-declarations", "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", - "--ptxas-options=-v,--register-usage-level=10", "-DNDEBUG" + "-O3", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--ptxas-options=-v,--register-usage-level=10", + "-DNDEBUG", ], ) -def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, - softmax_scale): +def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale): sm_scale = float(softmax_scale * 1.44269504) # log2(e) dtype = "float16" accum_dtype = "float" @@ -30,11 +35,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ @T.macro def flash_attn( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=384) as (hid, bid): Q_shared_l = T.alloc_shared([block_H, dim // 2], dtype) @@ -75,16 +80,16 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ tx = T.get_thread_binding() - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, 0:dim // 2], Q_shared_l) - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, dim // 2:dim], Q_shared_r) - T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_tail_shared) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, 0 : dim // 2], Q_shared_l) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, dim // 2 : dim], Q_shared_r) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_tail_shared) T.barrier_arrive(bar_q) if tx < 128: T.set_max_nreg(240, 1) T.fill(sumexp, 0) - T.fill(m_i, -2**30) # avoid -inf - inf to cause nan + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan T.fill(acc_o_l, 0) T.barrier_wait(bar_q, 0) @@ -166,8 +171,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ for h_i in T.Parallel(block_H): sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale T.copy(acc_o_l, O_shared_l) - T.copy(O_shared_l, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - 0:dim // 2]) + T.copy(O_shared_l, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, 0 : dim // 2]) elif tx >= 128 and tx < 256: T.set_max_nreg(168, 1) @@ -197,8 +201,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ acc_o_r[h_i, d_i] /= sum_exp_shared[h_i] T.copy(acc_o_r, O_shared_r) - T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - dim // 2:dim]) + T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, dim // 2 : dim]) elif tx >= 256: # producer @@ -211,19 +214,17 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_0_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_0_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, dim // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = K_pe[bid, kv_indices, cur_kv_head, - (tx - 256) % 8 * 8 + v] + K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ + bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v + ] T.cp_async_barrier_noinc(bar_k_0_ready[0]) # Buffer 1 @@ -233,33 +234,29 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_1_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_1_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, dim // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = K_pe[bid, kv_indices, cur_kv_head, - (tx - 256) % 8 * 8 + v] + K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ + bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v + ] T.cp_async_barrier_noinc(bar_k_1_ready[0]) @T.macro def flash_attn_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), ): - with T.Kernel( - batch, heads // min(block_H, kv_group_num), num_split, - threads=384) as (bid, hid, bz): + with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=384) as (bid, hid, bz): Q_shared_l = T.alloc_shared([block_H, dim // 2], dtype) Q_shared_r = T.alloc_shared([block_H, dim // 2], dtype) Q_tail_shared = T.alloc_shared([block_H, pe_dim], dtype) @@ -298,16 +295,16 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ tx = T.get_thread_binding() - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, 0:dim // 2], Q_shared_l) - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, dim // 2:dim], Q_shared_r) - T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_tail_shared) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, 0 : dim // 2], Q_shared_l) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, dim // 2 : dim], Q_shared_r) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_tail_shared) T.barrier_arrive(bar_q) if tx < 128: T.set_max_nreg(240, 1) T.fill(sumexp, 0) - T.fill(m_i, -2**30) # avoid -inf - inf to cause nan + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan T.fill(acc_o_l, 0) T.barrier_wait(bar_q, 0) @@ -389,10 +386,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ for h_i in T.Parallel(block_H): sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale T.copy(acc_o_l, O_shared_l) - T.copy( - O_shared_l, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - bz, 0:dim // 2]) - T.copy(sumexp, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, bz]) + T.copy(O_shared_l, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz, 0 : dim // 2]) + T.copy(sumexp, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz]) elif tx >= 128 and tx < 256: T.set_max_nreg(168, 1) @@ -422,9 +417,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ acc_o_r[h_i, d_i] /= sum_exp_shared[h_i] T.copy(acc_o_r, O_shared_r) - T.copy( - O_shared_r, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - bz, dim // 2:dim]) + T.copy(O_shared_r, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz, dim // 2 : dim]) elif tx >= 256: # producer @@ -433,54 +426,48 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ # Buffer 0 T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - kv_indices = (seqlen_kv // num_split) * bz + ( - i_i * 2) * block_N + r * 16 + (tx - 256) // 8 + kv_indices = (seqlen_kv // num_split) * bz + (i_i * 2) * block_N + r * 16 + (tx - 256) // 8 with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_0_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_0_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, dim // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = K_pe[bid, kv_indices, cur_kv_head, - (tx - 256) % 8 * 8 + v] + K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ + bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v + ] T.cp_async_barrier_noinc(bar_k_0_ready[0]) # Buffer 1 T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - kv_indices = (seqlen_kv // num_split) * bz + ( - i_i * 2 + 1) * block_N + r * 16 + (tx - 256) // 8 + kv_indices = (seqlen_kv // num_split) * bz + (i_i * 2 + 1) * block_N + r * 16 + (tx - 256) // 8 with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_1_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_1_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[bid, kv_indices, cur_kv_head, dim // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = K_pe[bid, kv_indices, cur_kv_head, - (tx - 256) % 8 * 8 + v] + K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ + bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v + ] T.cp_async_barrier_noinc(bar_k_1_ready[0]) @T.macro def combine( - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(heads, batch, threads=128) as (hid, bz): po_local = T.alloc_fragment([dim], dtype) @@ -490,9 +477,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ lse_max_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype) - T.annotate_layout({ - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - }) + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) T.clear(lse_logsum_local) T.clear(o_accum_local) @@ -515,26 +504,26 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ @T.prim_func def main_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial) combine(glse, Output_partial, Output) @T.prim_func def main_no_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): flash_attn(Q, Q_pe, KV, K_pe, Output) @@ -559,31 +548,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): dim = q.shape[-1] pe_dim = q_pe.shape[-1] num_head_groups = q.shape[1] // kv.shape[2] - scale = (dim + pe_dim)**0.5 - q = rearrange( - q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - q_pe = rearrange( - q_pe, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] - kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] query = torch.concat([q, q_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1) - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, kv, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out @@ -601,10 +583,9 @@ def main( BLOCK_N = 64 BLOCK_H = min(64, heads // kv_heads) num_split = 1 - softmax_scale = (dim + pe_dim)**-0.5 + softmax_scale = (dim + pe_dim) ** -0.5 - kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, - softmax_scale) + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, softmax_scale) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4) latency = profiler.do_bench(warmup=500) @@ -614,12 +595,12 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=132, help='batch size') - parser.add_argument('--heads', type=int, default=128, help='q heads number') - parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') - parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') + parser.add_argument("--batch", type=int, default=132, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim main(batch, heads, kv_heads, kv_ctx, dim, pe_dim) diff --git a/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py b/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py index b141822fe0a01e11197f08bf8d8371325f0f9ee6..4a1a84cf1ec62ea7b9559d24823fc4aa60af9e3b 100644 --- a/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py +++ b/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py @@ -8,11 +8,13 @@ import argparse @tilelang.jit( - out_idx=[-1], pass_configs={ + out_idx=[-1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H): - scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) dtype = "float16" q_dtype = "float8_e4m3" accum_dtype = "float" @@ -22,11 +24,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ @T.prim_func def main_no_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], q_dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], q_dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=256) as (bx, by): Q_shared = T.alloc_shared([block_H, dim], dtype) @@ -46,31 +48,27 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ cur_kv_head = by // (kv_group_num // block_H) T.use_swizzle(10) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - }) - - T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + } + ) + + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) T.disable_warp_group_reg_alloc() loop_range = T.ceildiv(seqlen_kv, block_N) for k in T.Pipelined(loop_range, num_stages=2): - T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], qKV_shared) - T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared) + T.copy(KV[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], qKV_shared) + T.copy(K_pe[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared) T.copy(qKV_shared, KV_shared) T.clear(acc_s) - T.gemm( - Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) - T.gemm( - Q_pe_shared, - K_pe_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) @@ -90,7 +88,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ for i, j in T.Parallel(block_H, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :]) + T.copy(O_shared, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :]) return main_no_split @@ -108,42 +106,35 @@ def ref_program(q, q_pe, kv, k_pe): dim = q.shape[-1] pe_dim = q_pe.shape[-1] num_head_groups = q.shape[1] // kv.shape[2] - scale = (dim + pe_dim)**0.5 - q = rearrange( - q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - q_pe = rearrange( - q_pe, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] - kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] query = torch.concat([q, q_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1) - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, kv, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=128, help='batch size') - parser.add_argument('--heads', type=int, default=128, help='q heads number') - parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') - parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) diff --git a/examples/deepseek_mla/torch_refs.py b/examples/deepseek_mla/torch_refs.py index 4b4c888cd2ac2db732948940599a845d5a663f20..aae6c7cd2b619afee90f39058cfd9a4a6a71e49e 100644 --- a/examples/deepseek_mla/torch_refs.py +++ b/examples/deepseek_mla/torch_refs.py @@ -11,7 +11,7 @@ def flash_split_ref(Q, Q_pe, KV, K_pe): block_N = 64 seqlen_kv = KV.size(1) - scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) acc_s = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float) acc_s_cast = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float16) acc_o = torch.empty((batch, nheads, dim), device="cuda", dtype=torch.float) @@ -31,18 +31,20 @@ def flash_split_ref(Q, Q_pe, KV, K_pe): for ks in range(num_split): acc_o.fill_(0) logsum.fill_(0) - scores_max.fill_(float('-inf')) - scores_max_prev.fill_(float('-inf')) + scores_max.fill_(float("-inf")) + scores_max_prev.fill_(float("-inf")) for i in range(int((seqlen_kv // num_split) / block_N)): acc_s.fill_(0) - acc_s = torch.einsum('bhd,bkhd->bhk', Q_, - KV_[:, (seqlen_kv // num_split) * ks + - i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) # [batch, nheads, block_N] + acc_s = torch.einsum( + "bhd,bkhd->bhk", + Q_, + KV_[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) # [batch, nheads, block_N] acc_s += torch.einsum( - 'bhd,bkhd->bhk', Q_pe_, - K_pe_[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) + "bhd,bkhd->bhk", + Q_pe_, + K_pe_[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) scores_max_prev = scores_max scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads] scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads] @@ -50,9 +52,10 @@ def flash_split_ref(Q, Q_pe, KV, K_pe): acc_s = torch.exp2(acc_s - scores_max[:, :, None]) acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N] acc_o += torch.einsum( - 'bhk,bkhd->bhd', acc_s_cast, - KV_[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) + "bhk,bkhd->bhd", + acc_s_cast, + KV_[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) scores_sum = acc_s.sum(dim=-1, keepdim=False) logsum = logsum * scores_scale + scores_sum acc_o /= logsum[:, :, None] diff --git a/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py b/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py index daee39865ccbc9776be153632d417e26fee366d4..ea3f72c50581c98e585c136699f1c9efd1a550b0 100644 --- a/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py +++ b/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py @@ -14,21 +14,44 @@ from fla.ops.utils import prepare_token_indices from fla.utils import autocast_custom_fwd, contiguous -@triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, - 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), -}) +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) @triton.jit -def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices, - block_counts, offsets, token_indices, T, H: tl.constexpr, - HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, - S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr, - BV: tl.constexpr, USE_OFFSETS: tl.constexpr, - USE_BLOCK_COUNTS: tl.constexpr): +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H @@ -40,20 +63,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc NS = S - p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), - (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # the Q block is kept in the shared memory throughout the whole kernel # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), - (G, BV), (1, 0)) + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_slc = tl.zeros([G, BV], dtype=tl.float32) - b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_slc = tl.zeros([G], dtype=tl.float32) for i in range(NS): i_s = tl.load(block_indices + i).to(tl.int32) * BS @@ -66,7 +87,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) # [G, BS] b_s_slc = tl.dot(b_q, b_k_slc) - b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf')) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) # [G] b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc @@ -87,7 +108,6 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -100,8 +120,7 @@ class ParallelNSAFunction(torch.autograd.Function): # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] token_indices = prepare_token_indices(offsets) if offsets is not None else None - o, lse = parallel_nsa_fwd( - q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) + o, lse = parallel_nsa_fwd(q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) ctx.save_for_backward(q, k, v, o, lse) ctx.block_indices = block_indices ctx.block_size = block_size @@ -172,7 +191,6 @@ def parallel_nsa_fwd( @torch.compile class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -195,7 +213,8 @@ class ParallelNSAFunction(torch.autograd.Function): window_size=window_size, scale=scale, offsets=offsets, - token_indices=token_indices) + token_indices=token_indices, + ) ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa) ctx.block_indices = block_indices ctx.block_counts = block_counts @@ -207,18 +226,20 @@ class ParallelNSAFunction(torch.autograd.Function): return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa -def parallel_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -258,44 +279,44 @@ def parallel_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" if isinstance(block_counts, int): block_indices = block_indices[:, :, :, :block_counts] block_counts = None - o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, - window_size, scale, cu_seqlens) + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) if window_size > 0: o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) else: o = o_slc * g_slc.unsqueeze(-1) if head_first: - o = rearrange(o, 'b t h d -> b h t d') + o = rearrange(o, "b t h d -> b h t d") return o -def naive_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def naive_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -335,26 +356,24 @@ def naive_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - raise RuntimeError( - "Sequences with variable lengths are not supported for head-first mode") + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") dtype = q.dtype G = q.shape[2] // k.shape[2] BS = block_size S = block_indices.shape[-1] - k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices)) + k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices)) if isinstance(block_counts, torch.Tensor): - block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G) + block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G) c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) q, k, v = map(lambda x: x.float(), (q, k, v)) @@ -364,14 +383,11 @@ def naive_nsa(q: torch.Tensor, if cu_seqlens is None: varlen = False B, T = q.shape[:2] - cu_seqlens = torch.cat( - [block_indices.new_tensor(range(0, B * T, T)), - block_indices.new_tensor([B * T])]) + cu_seqlens = torch.cat([block_indices.new_tensor(range(0, B * T, T)), block_indices.new_tensor([B * T])]) for i in range(len(cu_seqlens) - 1): if not varlen: - q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[ - i], block_indices[i] + q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[i], block_indices[i] if isinstance(block_counts, torch.Tensor): s_b = block_counts[i] else: @@ -379,10 +395,10 @@ def naive_nsa(q: torch.Tensor, else: T = cu_seqlens[i + 1] - cu_seqlens[i] q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = map( - lambda x: x[0][cu_seqlens[i]:cu_seqlens[i + 1]], - (q, k, v, g_slc, g_swa, block_indices)) + lambda x: x[0][cu_seqlens[i] : cu_seqlens[i + 1]], (q, k, v, g_slc, g_swa, block_indices) + ) if isinstance(block_counts, torch.Tensor): - s_b = block_counts[0][cu_seqlens[i]:cu_seqlens[i + 1]] + s_b = block_counts[0][cu_seqlens[i] : cu_seqlens[i + 1]] else: s_b = block_counts @@ -404,71 +420,58 @@ def naive_nsa(q: torch.Tensor, else: s_i = s_b # [S*BS, HQ, -1] - k_i_slc, v_i_slc = map( - lambda x: x.gather( - 0, - i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b)) + k_i_slc, v_i_slc = map(lambda x: x.gather(0, i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b)) # [S*BS, HQ] - attn_slc = torch.einsum('h d, n h d -> n h', q_i, k_i_slc).masked_fill( - torch.logical_or(i_i < 0, i_i > i_q) | - (c >= s_i if block_counts is not None else False), float('-inf')).softmax(0) + attn_slc = ( + torch.einsum("h d, n h d -> n h", q_i, k_i_slc) + .masked_fill(torch.logical_or(i_i < 0, i_i > i_q) | (c >= s_i if block_counts is not None else False), float("-inf")) + .softmax(0) + ) if not varlen: - o_slc[i, i_q] = torch.einsum('n h, n h v -> h v', attn_slc, - v_i_slc) * g_slc_i.unsqueeze(-1) + o_slc[i, i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) else: - o_slc[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_slc, - v_i_slc) * g_slc_i.unsqueeze(-1) + o_slc[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) if window_size > 0: - k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1):i_q + 1], - (k_b, v_b)) - attn_swa = torch.einsum('h d, n h d -> n h', q_i, k_i_swa).softmax(0) + k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1) : i_q + 1], (k_b, v_b)) + attn_swa = torch.einsum("h d, n h d -> n h", q_i, k_i_swa).softmax(0) if not varlen: - o_swa[i, i_q] = torch.einsum('n h, n h v -> h v', attn_swa, - v_i_swa) * g_swa_i.unsqueeze(-1) + o_swa[i, i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) else: - o_swa[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_swa, - v_i_swa) * g_swa_i.unsqueeze(-1) + o_swa[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) if head_first: - o_slc = rearrange(o_slc, 'b t h d -> b h t d') - o_swa = rearrange(o_swa, 'b t h d -> b h t d') + o_slc = rearrange(o_slc, "b t h d -> b h t d") + o_swa = rearrange(o_swa, "b t h d -> b h t d") return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype) def get_configs(): import itertools + iter_params = dict( block_T=[128, 256, 512], num_stages=[0, 1, 2, 4, 5], threads=[32, 64, 128, 256, 512], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] -@tilelang.autotune(configs=get_configs(),) +@tilelang.autotune( + configs=get_configs(), +) @tilelang.jit( pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) -def tilelang_sparse_attention(batch, - heads, - seq_len, - dim, - is_causal, - scale=None, - block_size=64, - groups=1, - selected_blocks=16, - block_T=128, - num_stages=2, - threads=32): + } +) +def tilelang_sparse_attention( + batch, heads, seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16, block_T=128, num_stages=2, threads=32 +): if scale is None: - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) else: scale = scale * 1.44269504 # log2(e) @@ -493,11 +496,11 @@ def tilelang_sparse_attention(batch, @T.prim_func def tilelang_sparse_attention( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([G, BK], dtype) @@ -520,7 +523,7 @@ def tilelang_sparse_attention(batch, i_b, i_h = i_bh // head_kv, i_bh % head_kv NS = S - T.copy(Q[i_b, i_t, i_h * G:(i_h + 1) * G, :], Q_shared) + T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -530,21 +533,15 @@ def tilelang_sparse_attention(batch, i_s = BlockIndices[i_b, i_t, i_h, i] * BS if i_s <= i_t and i_s >= 0: # [BS, BK] - T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared) + T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) if is_causal: for i, j in T.Parallel(G, BS): - acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) # Softmax T.copy(scores_max, scores_max_prev) @@ -564,45 +561,33 @@ def tilelang_sparse_attention(batch, acc_o[i, j] *= scores_scale[i] # V * softmax(Q * K) - T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared) + T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(G, BV): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[i_b, i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV]) + T.copy(O_shared, Output[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) return tilelang_sparse_attention def generate_block_indices(batch, seq_len, heads, selected_blocks, block_size): """Generate random block indices for the benchmark.""" - block_indices = torch.full((batch, seq_len, heads, selected_blocks), - seq_len, - dtype=torch.long, - device='cuda') + block_indices = torch.full((batch, seq_len, heads, selected_blocks), seq_len, dtype=torch.long, device="cuda") for b in range(batch): for t in range(seq_len): for h in range(heads): i_i = torch.randperm(max(1, (t // block_size)))[:selected_blocks] - block_indices[b, t, h, :len(i_i)] = i_i + block_indices[b, t, h, : len(i_i)] = i_i return block_indices.sort(-1)[0] -def benchmark_nsa(batch_size, - seq_len, - heads, - head_query, - dim, - selected_blocks, - block_size, - dtype, - scale, - warmup=10, - iterations=100, - validate=False): +def benchmark_nsa( + batch_size, seq_len, heads, head_query, dim, selected_blocks, block_size, dtype, scale, warmup=10, iterations=100, validate=False +): """Benchmark the TileLang Sparse Attention implementation.""" # Set random seed for reproducibility @@ -628,14 +613,13 @@ def benchmark_nsa(batch_size, print(f"Profiler latency: {profiler_latency} ms") # Create input tensors - Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device='cuda') - K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda') - V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda') - out = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device='cuda') + Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda") + K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda") + V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda") + out = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda") # Generate block indices - block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks, - block_size).to(torch.int32) + block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks, block_size).to(torch.int32) # Warmup for _ in range(warmup): @@ -666,10 +650,9 @@ def benchmark_nsa(batch_size, # Validate result against reference if requested if validate: - g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device='cuda') - g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device='cuda') - block_counts = torch.randint( - 1, selected_blocks + 1, (batch_size, seq_len, heads), device='cuda') + g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda") + g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda") + block_counts = torch.randint(1, selected_blocks + 1, (batch_size, seq_len, heads), device="cuda") ref = naive_nsa( q=Q, @@ -700,22 +683,13 @@ def benchmark_nsa(batch_size, "head_query": head_query, "dim": dim, "selected_blocks": selected_blocks, - "block_size": block_size + "block_size": block_size, } -def benchmark_triton_nsa(batch_size, - seq_len, - heads, - head_query, - dim, - selected_blocks, - block_size, - dtype, - scale, - warmup=10, - iterations=100, - validate=False): +def benchmark_triton_nsa( + batch_size, seq_len, heads, head_query, dim, selected_blocks, block_size, dtype, scale, warmup=10, iterations=100, validate=False +): """Benchmark the Triton-based TileLang Sparse Attention implementation.""" # Set random seed for reproducibility @@ -723,18 +697,17 @@ def benchmark_triton_nsa(batch_size, torch.random.manual_seed(0) # Create input tensors - Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device='cuda') - K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda') - V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda') - g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device='cuda') - g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device='cuda') + Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda") + K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda") + V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda") + g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda") + g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda") # Generate block indices block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks, block_size) - block_counts = torch.randint( - 1, selected_blocks + 1, (batch_size, seq_len, heads), device='cuda') - o_slc = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device='cuda') - lse_slc = torch.empty((batch_size, seq_len, head_query), dtype=torch.float, device='cuda') + block_counts = torch.randint(1, selected_blocks + 1, (batch_size, seq_len, heads), device="cuda") + o_slc = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda") + lse_slc = torch.empty((batch_size, seq_len, head_query), dtype=torch.float, device="cuda") # Warmup for _ in range(warmup): @@ -750,7 +723,8 @@ def benchmark_triton_nsa(batch_size, block_counts=block_counts, block_size=block_size, window_size=0, - scale=scale) + scale=scale, + ) # Synchronize before timing torch.cuda.synchronize() @@ -770,7 +744,8 @@ def benchmark_triton_nsa(batch_size, block_counts=block_counts, block_size=block_size, window_size=0, - scale=scale) + scale=scale, + ) torch.cuda.synchronize() end_time = time.time() @@ -815,54 +790,28 @@ def benchmark_triton_nsa(batch_size, "head_query": head_query, "dim": dim, "selected_blocks": selected_blocks, - "block_size": block_size + "block_size": block_size, } -def run_benchmark_suite(impl='all'): +def run_benchmark_suite(impl="all"): """Run a suite of benchmarks with different configurations.""" # Define configurations to benchmark configs = [ # Small model config - Note: head_query must be a multiple of heads*16 for Triton - { - "batch_size": 2, - "seq_len": 1024, - "heads": 8, - "head_query": 8 * 16, - "dim": 64, - "selected_blocks": 8, - "block_size": 32 - }, - + {"batch_size": 2, "seq_len": 1024, "heads": 8, "head_query": 8 * 16, "dim": 64, "selected_blocks": 8, "block_size": 32}, # Medium model config - { - "batch_size": 2, - "seq_len": 2048, - "heads": 16, - "head_query": 16 * 16, - "dim": 64, - "selected_blocks": 16, - "block_size": 64 - }, - + {"batch_size": 2, "seq_len": 2048, "heads": 16, "head_query": 16 * 16, "dim": 64, "selected_blocks": 16, "block_size": 64}, # Large model config - { - "batch_size": 1, - "seq_len": 4096, - "heads": 32, - "head_query": 32 * 16, - "dim": 128, - "selected_blocks": 32, - "block_size": 128 - }, + {"batch_size": 1, "seq_len": 4096, "heads": 32, "head_query": 32 * 16, "dim": 128, "selected_blocks": 32, "block_size": 128}, ] results = [] for config in configs: print(f"Running benchmark with config: {config}") - if impl in ['all', 'tilelang']: + if impl in ["all", "tilelang"]: print("Benchmarking TileLang implementation:") result = benchmark_nsa( batch_size=config["batch_size"], @@ -874,12 +823,13 @@ def run_benchmark_suite(impl='all'): block_size=config["block_size"], dtype=torch.float16, scale=0.1, - validate=False) + validate=False, + ) results.append({"impl": "tilelang", **result}) print(f"Average time: {result['avg_time_ms']:.2f} ms") print(f"Performance: {result['tflops']:.2f} TFLOPs") - if impl in ['all', 'triton']: + if impl in ["all", "triton"]: print("Benchmarking Triton implementation:") result = benchmark_triton_nsa( batch_size=config["batch_size"], @@ -891,19 +841,24 @@ def run_benchmark_suite(impl='all'): block_size=config["block_size"], dtype=torch.float16, scale=0.1, - validate=False) + validate=False, + ) results.append({"impl": "triton", **result}) print(f"Average time: {result['avg_time_ms']:.2f} ms") print(f"Performance: {result['tflops']:.2f} TFLOPs") - if impl in ['all']: + if impl in ["all"]: # Print comparison if both implementations were run tilelang_result = next( - r for r in results if r["impl"] == "tilelang" and - r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"]) + r + for r in results + if r["impl"] == "tilelang" and r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"] + ) triton_result = next( - r for r in results if r["impl"] == "triton" and - r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"]) + r + for r in results + if r["impl"] == "triton" and r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"] + ) speedup = tilelang_result["avg_time_ms"] / triton_result["avg_time_ms"] print(f"Speedup (Triton vs TileLang): {speedup:.2f}x") @@ -921,8 +876,7 @@ if __name__ == "__main__": parser.add_argument("--dim", type=int, default=128, help="Head dimension") parser.add_argument("--selected_blocks", type=int, default=16, help="Number of selected blocks") parser.add_argument("--block_size", type=int, default=32, help="Block size") - parser.add_argument( - "--dtype", type=str, default="float16", help="Data type (float16 or float32)") + parser.add_argument("--dtype", type=str, default="float16", help="Data type (float16 or float32)") parser.add_argument("--scale", type=float, default=0.1, help="Attention scale factor") parser.add_argument("--iterations", type=int, default=100, help="Number of iterations") parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations") @@ -933,7 +887,8 @@ if __name__ == "__main__": type=str, default="all", choices=["tilelang", "triton", "all"], - help="Implementation to benchmark (tilelang, triton, or all)") + help="Implementation to benchmark (tilelang, triton, or all)", + ) args = parser.parse_args() @@ -941,8 +896,7 @@ if __name__ == "__main__": if args.impl in ["triton", "all"] and args.head_query % (args.heads * 16) != 0: # Adjust head_query to nearest valid value args.head_query = ((args.head_query // (args.heads * 16)) + 1) * (args.heads * 16) - print( - f"Adjusted head_query to {args.head_query} to be compatible with Triton implementation") + print(f"Adjusted head_query to {args.head_query} to be compatible with Triton implementation") if args.suite: run_benchmark_suite(impl=args.impl) @@ -963,12 +917,14 @@ if __name__ == "__main__": scale=args.scale, warmup=args.warmup, iterations=args.iterations, - validate=args.validate) + validate=args.validate, + ) print("\nBenchmark Results (TileLang):") print( - f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, " + - f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, " + - f"block_size={args.block_size}") + f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, " + + f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, " + + f"block_size={args.block_size}" + ) print(f"Average time: {result['avg_time_ms']:.2f} ms") print(f"Performance: {result['tflops']:.2f} TFLOPs") @@ -986,11 +942,13 @@ if __name__ == "__main__": scale=args.scale, warmup=args.warmup, iterations=args.iterations, - validate=args.validate) + validate=args.validate, + ) print("\nBenchmark Results (Triton):") print( - f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, " + - f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, " + - f"block_size={args.block_size}") + f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, " + + f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, " + + f"block_size={args.block_size}" + ) print(f"Average time: {result['avg_time_ms']:.2f} ms") print(f"Performance: {result['tflops']:.2f} TFLOPs") diff --git a/examples/deepseek_nsa/example_tilelang_nsa_bwd.py b/examples/deepseek_nsa/example_tilelang_nsa_bwd.py index 1d1b5ea3ba9c82ec6e824a9c5ee012859af530bb..56e98a95b05738a8c4aa4bca960909eba4f8fbba 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_bwd.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_bwd.py @@ -7,6 +7,7 @@ import torch import triton import fla + if parse(fla.__version__) < parse("0.2.1"): from fla.ops.common.utils import prepare_token_indices else: @@ -22,7 +23,8 @@ import tilelang tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + } +) def tilelang_kernel_fwd( batch, heads, @@ -34,11 +36,10 @@ def tilelang_kernel_fwd( groups=1, selected_blocks=16, ): - from tilelang import language as T if scale is None: - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) else: scale = scale * 1.44269504 # log2(e) @@ -67,12 +68,12 @@ def tilelang_kernel_fwd( @T.prim_func def native_sparse_attention( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), - O_slc: T.Tensor(o_slc_shape, dtype), - LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + O_slc: T.Tensor(o_slc_shape, dtype), + LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), ): with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([G, BK], dtype) @@ -93,7 +94,7 @@ def tilelang_kernel_fwd( i_b, i_h = i_bh // head_kv, i_bh % head_kv NS = S - T.copy(Q[i_b, i_t, i_h * G:(i_h + 1) * G, :], Q_shared) + T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -103,12 +104,11 @@ def tilelang_kernel_fwd( i_s = BlockIndices[i_b, i_t, i_h, i] * BS if i_s <= i_t and i_s >= 0: # [BS, BK] - T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared) + T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) if is_causal: for k, j in T.Parallel(G, BS): - acc_s[k, j] = T.if_then_else(i_t >= (i_s + j), 0, - -T.infinity(acc_s.dtype)) + acc_s[k, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) @@ -138,7 +138,7 @@ def tilelang_kernel_fwd( acc_o[k, j] *= scores_scale[k] # V * softmax(Q * K) - T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared) + T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(G, BV): @@ -146,18 +146,20 @@ def tilelang_kernel_fwd( T.copy(acc_o, O_shared) T.copy( O_shared, - O_slc[i_b, i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV], + O_slc[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV], ) for i in T.Parallel(G): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, LSE_slc[i_b, i_t, i_h * G:(i_h + 1) * G]) + T.copy(logsum, LSE_slc[i_b, i_t, i_h * G : (i_h + 1) * G]) return native_sparse_attention -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) def tilelang_kernel_bwd_dkv( batch, heads, @@ -172,7 +174,7 @@ def tilelang_kernel_bwd_dkv( accum_dtype="float", ): if scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 else: sm_scale = scale @@ -207,15 +209,15 @@ def tilelang_kernel_bwd_dkv( @T.prim_func def flash_bwd_dkv( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(k_shape, dtype), - V: T.Tensor(v_shape, dtype), - LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), - Delta_slc: T.Tensor(delta_slc_shape, accum_dtype), - DO_slc: T.Tensor(do_slc_shape, dtype), - DK: T.Tensor(dk_shape, dtype), - DV: T.Tensor(dv_shape, dtype), - BlockMask: T.Tensor(block_mask_shape, "int32"), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(k_shape, dtype), + V: T.Tensor(v_shape, dtype), + LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), + Delta_slc: T.Tensor(delta_slc_shape, accum_dtype), + DO_slc: T.Tensor(do_slc_shape, dtype), + DK: T.Tensor(dk_shape, dtype), + DV: T.Tensor(dv_shape, dtype), + BlockMask: T.Tensor(block_mask_shape, "int32"), ): with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh): K_shared = T.alloc_shared([BS, BK], dtype) @@ -238,31 +240,33 @@ def tilelang_kernel_bwd_dkv( i_b, i_h = i_bh // H, i_bh % H - T.copy(K[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BK], K_shared) - T.copy(V[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BV], V_shared) + T.copy(K[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK], K_shared) + T.copy(V[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV], V_shared) # [BS, BK] T.clear(dk) # [BS, BV] T.clear(dv) - T.annotate_layout({ - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) loop_st = i_s * BS loop_ed = seq_len for i in T.Pipelined( - start=loop_st, - stop=loop_ed, - num_stages=0, + start=loop_st, + stop=loop_ed, + num_stages=0, ): b_m_slc = BlockMask[i_b, i, i_h, i_s] if b_m_slc != 0: # [G, BK] - T.copy(Q[i_b, i, i_h * G:(i_h + 1) * G, :BK], Q_shared) + T.copy(Q[i_b, i, i_h * G : (i_h + 1) * G, :BK], Q_shared) T.clear(qkT) # [BS, BK] @ [G, BK] -> [BS, G] T.gemm( @@ -273,7 +277,7 @@ def tilelang_kernel_bwd_dkv( policy=T.GemmWarpPolicy.FullRow, ) # [G] - T.copy(LSE_slc[i_b, i, i_h * G:(i_h + 1) * G], lse_shared) + T.copy(LSE_slc[i_b, i, i_h * G : (i_h + 1) * G], lse_shared) for _i, _j in T.Parallel(BS, G): qkT[_i, _j] = T.exp2(qkT[_i, _j] * scale - lse_shared[_j]) @@ -282,7 +286,7 @@ def tilelang_kernel_bwd_dkv( qkT[_i, _j] = T.if_then_else(i >= (i_s * BS + _i), qkT[_i, _j], 0) # [G, BV] - T.copy(DO_slc[i_b, i, i_h * G:(i_h + 1) * G, :BV], do) + T.copy(DO_slc[i_b, i, i_h * G : (i_h + 1) * G, :BV], do) T.clear(dsT) # [BS, BV] @ [G, BV] -> [BS, G] T.gemm( @@ -296,7 +300,7 @@ def tilelang_kernel_bwd_dkv( # [BS, G] @ [G, BV] -> [BS, BV] T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) # [G] - T.copy(Delta_slc[i_b, i, i_h * G:(i_h + 1) * G], delta) + T.copy(Delta_slc[i_b, i, i_h * G : (i_h + 1) * G], delta) for i, j in T.Parallel(BS, G): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -305,8 +309,8 @@ def tilelang_kernel_bwd_dkv( T.copy(dv, dv_shared) T.copy(dk, dk_shared) - T.copy(dv_shared, DV[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BV]) - T.copy(dk_shared, DK[i_v, i_b, i_s * BS:(i_s + 1) * BS, i_h, :BK]) + T.copy(dv_shared, DV[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV]) + T.copy(dk_shared, DK[i_v, i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK]) return flash_bwd_dkv @@ -321,9 +325,11 @@ def make_dq_layout(dQ): ) -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) def tilelang_kernel_bwd_dqkv( batch, heads, @@ -338,7 +344,7 @@ def tilelang_kernel_bwd_dqkv( accum_dtype="float", ): if scale is None: - sm_scale = (1.0 / dim)**0.5 + sm_scale = (1.0 / dim) ** 0.5 else: sm_scale = scale @@ -373,16 +379,16 @@ def tilelang_kernel_bwd_dqkv( @T.prim_func def flash_bwd_dqkv( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(k_shape, dtype), - V: T.Tensor(v_shape, dtype), - LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), - Delta_slc: T.Tensor(delta_slc_shape, accum_dtype), - DO_slc: T.Tensor(do_slc_shape, dtype), - DQ: T.Tensor(dq_shape, dtype), - DK: T.Tensor(dk_shape, dtype), - DV: T.Tensor(dv_shape, dtype), - BlockMask: T.Tensor(block_mask_shape, "int32"), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(k_shape, dtype), + V: T.Tensor(v_shape, dtype), + LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), + Delta_slc: T.Tensor(delta_slc_shape, accum_dtype), + DO_slc: T.Tensor(do_slc_shape, dtype), + DQ: T.Tensor(dq_shape, dtype), + DK: T.Tensor(dk_shape, dtype), + DV: T.Tensor(dv_shape, dtype), + BlockMask: T.Tensor(block_mask_shape, "int32"), ): with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh): K_shared = T.alloc_shared([BS, BK], dtype) @@ -406,31 +412,33 @@ def tilelang_kernel_bwd_dqkv( i_b, i_h = i_bh // H, i_bh % H - T.copy(K[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BK], K_shared) - T.copy(V[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BV], V_shared) + T.copy(K[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK], K_shared) + T.copy(V[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV], V_shared) # [BS, BK] T.clear(dk) # [BS, BV] T.clear(dv) - T.annotate_layout({ - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) loop_st = i_s * BS loop_ed = seq_len for i in T.Pipelined( - start=loop_st, - stop=loop_ed, - num_stages=0, + start=loop_st, + stop=loop_ed, + num_stages=0, ): b_m_slc = BlockMask[i_b, i, i_h, i_s] if b_m_slc != 0: # [G, BK] - T.copy(Q[i_b, i, i_h * G:(i_h + 1) * G, :BK], Q_shared) + T.copy(Q[i_b, i, i_h * G : (i_h + 1) * G, :BK], Q_shared) T.clear(qkT) # [BS, BK] @ [G, BK] -> [BS, G] T.gemm( @@ -441,7 +449,7 @@ def tilelang_kernel_bwd_dqkv( policy=T.GemmWarpPolicy.FullRow, ) # [G] - T.copy(LSE_slc[i_b, i, i_h * G:(i_h + 1) * G], lse_shared) + T.copy(LSE_slc[i_b, i, i_h * G : (i_h + 1) * G], lse_shared) for _i, _j in T.Parallel(BS, G): qkT[_i, _j] = T.exp2(qkT[_i, _j] * scale - lse_shared[_j]) @@ -450,7 +458,7 @@ def tilelang_kernel_bwd_dqkv( qkT[_i, _j] = T.if_then_else(i >= (i_s * BS + _i), qkT[_i, _j], 0) # [G, BV] - T.copy(DO_slc[i_b, i, i_h * G:(i_h + 1) * G, :BV], do) + T.copy(DO_slc[i_b, i, i_h * G : (i_h + 1) * G, :BV], do) T.clear(dsT) # [BS, BV] @ [G, BV] -> [BS, G] T.gemm( @@ -464,7 +472,7 @@ def tilelang_kernel_bwd_dqkv( # [BS, G] @ [G, BV] -> [BS, BV] T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) # [G] - T.copy(Delta_slc[i_b, i, i_h * G:(i_h + 1) * G], delta) + T.copy(Delta_slc[i_b, i, i_h * G : (i_h + 1) * G], delta) for _i, _j in T.Parallel(BS, G): dsT_cast[_i, _j] = qkT[_i, _j] * (dsT[_i, _j] - delta[_j]) * sm_scale @@ -480,16 +488,18 @@ def tilelang_kernel_bwd_dqkv( T.copy(dv, dv_shared) T.copy(dk, dk_shared) - T.copy(dv_shared, DV[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BV]) - T.copy(dk_shared, DK[i_v, i_b, i_s * BS:(i_s + 1) * BS, i_h, :BK]) + T.copy(dv_shared, DV[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV]) + T.copy(dk_shared, DK[i_v, i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK]) return flash_bwd_dqkv @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def tilelang_kernel_preprocess( batch, heads, @@ -505,9 +515,9 @@ def tilelang_kernel_preprocess( @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, seq_len, heads], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, seq_len, heads], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -516,20 +526,22 @@ def tilelang_kernel_preprocess( delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, by * blk:(by + 1) * blk, bx]) + T.copy(delta, Delta[bz, by * blk : (by + 1) * blk, bx]) return flash_bwd_prep @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def tilelang_kernel_block_mask( batch, heads, @@ -551,9 +563,9 @@ def tilelang_kernel_block_mask( @T.prim_func def flash_bwd_block_mask( - BlockIndices: T.Tensor(block_indices_shape, dtype), # type: ignore - BlockCounts: T.Tensor(block_counts_shape, dtype), # type: ignore - BlockMask: T.Tensor(block_mask_shape, dtype), # type: ignore + BlockIndices: T.Tensor(block_indices_shape, dtype), # type: ignore + BlockCounts: T.Tensor(block_counts_shape, dtype), # type: ignore + BlockMask: T.Tensor(block_mask_shape, dtype), # type: ignore ): with T.Kernel(seq_len, batch, heads * S) as (bx, by, bz): i_t, i_b, i_hs = bx, by, bz @@ -603,9 +615,7 @@ def parallel_nsa_bwd( dk = torch.empty(NV, *k.shape, dtype=k.dtype, device=q.device) dv = torch.empty(v.shape, dtype=v.dtype, device=q.device) - block_mask = tilelang_kernel_block_mask(B, H, T, S, - BS)(block_indices.to(torch.int32), - block_counts.to(torch.int32)).to(torch.bool) + block_mask = tilelang_kernel_block_mask(B, H, T, S, BS)(block_indices.to(torch.int32), block_counts.to(torch.int32)).to(torch.bool) fused_qkv_bwd_kernel = tilelang_kernel_bwd_dqkv( batch=B, @@ -618,8 +628,7 @@ def parallel_nsa_bwd( selected_blocks=S, scale=scale, ) - fused_qkv_bwd_kernel(q, k, v, lse_slc, delta_slc, do_slc, dq, dk, dv, - block_mask.to(torch.int32)) + fused_qkv_bwd_kernel(q, k, v, lse_slc, delta_slc, do_slc, dq, dk, dv, block_mask.to(torch.int32)) dq = dq.sum(0) dk = dk.sum(0) @@ -628,7 +637,6 @@ def parallel_nsa_bwd( @torch.compile class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -773,23 +781,21 @@ def parallel_nsa( Outputs of shape `[B, SEQLEN, HQ, V]` if `head_first=False` else `[B, HQ, SEQLEN, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), - (q, k, v, block_indices)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): block_counts = rearrange(block_counts, "b h t -> b t h") - assert (q.shape[2] % (k.shape[2] * 16) == 0), "Group size must be a multiple of 16 in NSA" + assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" if isinstance(block_counts, int): block_indices = block_indices[:, :, :, :block_counts] block_counts = None - o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, - window_size, scale, cu_seqlens) + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) if window_size > 0: o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) else: @@ -814,7 +820,7 @@ if __name__ == "__main__": for t in range(T): for h in range(H): i_i = torch.randperm(max(1, (t // block_size)))[:S] - block_indices[b, t, h, :len(i_i)] = i_i + block_indices[b, t, h, : len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] block_counts = torch.randint(1, S + 1, (B, T, H), device="cuda") diff --git a/examples/deepseek_nsa/example_tilelang_nsa_decode.py b/examples/deepseek_nsa/example_tilelang_nsa_decode.py index 58f4355094ae914c3520d4bdd2877c2a08ef58fc..38fc51a9f0a5c8f407d03905d91a5ded5415d57a 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_decode.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_decode.py @@ -16,7 +16,8 @@ tilelang.testing.set_random_seed(42) tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def native_sparse_attention( batch, heads, @@ -25,10 +26,10 @@ def native_sparse_attention( scale=None, block_size=64, # Tile size for attention computation groups=1, # Grouped query attention (GQA) groups - selected_blocks=16 # Number of blocks to select per attention head + selected_blocks=16, # Number of blocks to select per attention head ): if scale is None: - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups # Modified shapes for inference (q has seq_len=1)a q_shape = [batch, 1, heads, dim] # Changed seq_len to 1 @@ -53,12 +54,11 @@ def native_sparse_attention( @T.prim_func def native_sparse_attention( - Q: T.Tensor(q_shape, dtype), # [batch, 1, heads, dim] - K: T.Tensor(kv_shape, dtype), # [batch, seq_len, head_kv, dim] - V: T.Tensor(kv_shape, dtype), # Same shape as K - BlockIndices: T.Tensor(block_indices_shape, - block_indices_dtype), # Selected block indices - Output: T.Tensor(q_shape, dtype), # Output attention tensor + Q: T.Tensor(q_shape, dtype), # [batch, 1, heads, dim] + K: T.Tensor(kv_shape, dtype), # [batch, seq_len, head_kv, dim] + V: T.Tensor(kv_shape, dtype), # Same shape as K + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), # Selected block indices + Output: T.Tensor(q_shape, dtype), # Output attention tensor ): with T.Kernel(1, NV, batch * head_kv, threads=threads) as (bx, by, bz): # Shared memory allocations for tile storage @@ -82,7 +82,7 @@ def native_sparse_attention( NS = S # Copy Q for the single position - T.copy(Q[i_b, 0, i_h * G:(i_h + 1) * G, :], Q_shared) # Changed i_t to 0 + T.copy(Q[i_b, 0, i_h * G : (i_h + 1) * G, :], Q_shared) # Changed i_t to 0 T.fill(acc_o, 0) T.fill(logsum, 0) @@ -93,16 +93,11 @@ def native_sparse_attention( i_s = BlockIndices[i_b, 0, i_h, i] * BS # Get block offset if i_s >= 0: # Skip invalid/padding blocks # Load current key block to shared memory - T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared) + T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) # Compute QK^T attention scores T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) # Online softmax with numerical stability # 1. Compute max for scaling @@ -122,15 +117,14 @@ def native_sparse_attention( T.copy(acc_s, acc_s_cast) # Accumulate attention-weighted values - T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared) + T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) # Final normalization and output for i, j in T.Parallel(G, BV): acc_o[i, j] /= logsum[i] # Normalize by logsum T.copy(acc_o, O_shared) - T.copy(O_shared, Output[i_b, 0, i_h * G:(i_h + 1) * G, - i_v * BV:(i_v + 1) * BV]) # Changed i_t to 0 + T.copy(O_shared, Output[i_b, 0, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) # Changed i_t to 0 return native_sparse_attention @@ -149,21 +143,21 @@ def main(): selected_blocks=S, ) - Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) - K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) - V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) + Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) - mask = torch.randint(0, 2, (B, SEQ_LEN, groups), device='cuda') - DO = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device='cuda') + mask = torch.randint(0, 2, (B, SEQ_LEN, groups), device="cuda") + DO = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device="cuda") - block_indices = torch.full((B, SEQ_LEN_Q, H, S), SEQ_LEN, dtype=torch.long, device='cuda') + block_indices = torch.full((B, SEQ_LEN_Q, H, S), SEQ_LEN, dtype=torch.long, device="cuda") for b in range(B): for t in range(SEQ_LEN_Q): for h in range(H): i_i = torch.randperm(max(1, (t // block_size)))[:S] - block_indices[b, t, h, :len(i_i)] = i_i + block_indices[b, t, h, : len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] - block_counts = torch.randint(1, S + 1, (B, SEQ_LEN_Q, H), device='cuda') + block_counts = torch.randint(1, S + 1, (B, SEQ_LEN_Q, H), device="cuda") out = kernel(Q, K, V, block_indices.to(torch.int32)) diff --git a/examples/deepseek_nsa/example_tilelang_nsa_fwd.py b/examples/deepseek_nsa/example_tilelang_nsa_fwd.py index 0b71779b8eb3861699627048d377b023329f3790..a8dd26b63f9a6a3c184f2c24a10adcc5c551fa25 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_fwd.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_fwd.py @@ -14,18 +14,11 @@ tilelang.testing.set_random_seed(0) tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) -def native_sparse_attention(batch, - heads, - seq_len, - dim, - is_causal, - scale=None, - block_size=64, - groups=1, - selected_blocks=16): + }, +) +def native_sparse_attention(batch, heads, seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16): if scale is None: - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) else: scale = scale * 1.44269504 # log2(e) @@ -52,11 +45,11 @@ def native_sparse_attention(batch, @T.prim_func def native_sparse_attention( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([G, BK], dtype) @@ -77,7 +70,7 @@ def native_sparse_attention(batch, i_b, i_h = i_bh // head_kv, i_bh % head_kv NS = S - T.copy(Q[i_b, i_t, i_h * G:(i_h + 1) * G, :], Q_shared) + T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -87,21 +80,15 @@ def native_sparse_attention(batch, i_s = BlockIndices[i_b, i_t, i_h, i] * BS if i_s <= i_t and i_s >= 0: # [BS, BK] - T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared) + T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) if is_causal: for i, j in T.Parallel(G, BS): - acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) # Softmax T.copy(scores_max, scores_max_prev) @@ -121,13 +108,13 @@ def native_sparse_attention(batch, acc_o[i, j] *= scores_scale[i] # V * softmax(Q * K) - T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared) + T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(G, BV): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[i_b, i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV]) + T.copy(O_shared, Output[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) return native_sparse_attention @@ -148,20 +135,20 @@ def main(): ) print(kernel.get_kernel_source()) torch.random.manual_seed(0) - Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) - K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) - V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) - g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True) - g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True) - DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda') - - block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device='cuda') - block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device='cuda') + Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda") + + block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device="cuda") + block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device="cuda") for b in range(B): for t in range(SEQ_LEN): for h in range(H): i_i = torch.randperm(max(1, (t // block_size)))[:S] - block_indices[b, t, h, :len(i_i)] = i_i + block_indices[b, t, h, : len(i_i)] = i_i block_counts[b, t, h] = (block_indices[b, t, h] != SEQ_LEN).sum().item() block_indices = block_indices.sort(-1)[0] diff --git a/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py b/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py index d365e7a5f952377f2f326994e6d93ec48a8d0753..af87db8b2deb1a0639c02dc0c1398edb3cb1847a 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py @@ -8,6 +8,7 @@ from tilelang import language as T import tilelang.testing import fla + if parse(fla.__version__) < parse("0.2.1"): from fla.ops.common.utils import prepare_token_indices else: @@ -21,18 +22,11 @@ from einops import rearrange tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) -def native_sparse_attention_varlen(batch, - heads, - c_seq_len, - dim, - is_causal, - scale=None, - block_size=64, - groups=1, - selected_blocks=16): + } +) +def native_sparse_attention_varlen(batch, heads, c_seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16): if scale is None: - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [c_seq_len, heads, dim] kv_shape = [c_seq_len, head_kv, dim] @@ -66,14 +60,14 @@ def native_sparse_attention_varlen(batch, @T.prim_func def native_sparse_attention_varlen( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - O_slc: T.Tensor(o_slc_shape, dtype), - BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), - BlockCounts: T.Tensor(block_counts_shape, block_counts_dtype), - Offsets: T.Tensor(offsets_shape, offsets_dtype), - TokenIndices: T.Tensor(token_indices_shape, token_indices_dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + O_slc: T.Tensor(o_slc_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + BlockCounts: T.Tensor(block_counts_shape, block_counts_dtype), + Offsets: T.Tensor(offsets_shape, offsets_dtype), + TokenIndices: T.Tensor(token_indices_shape, token_indices_dtype), ): with T.Kernel(c_seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([G, BK], dtype) @@ -100,7 +94,7 @@ def native_sparse_attention_varlen(batch, current_seq_len = eos - bos NS = BlockCounts[i_t, i_h] - T.copy(Q[bos + i_t, i_h * G:(i_h + 1) * G, :BK], Q_shared) + T.copy(Q[bos + i_t, i_h * G : (i_h + 1) * G, :BK], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -112,21 +106,15 @@ def native_sparse_attention_varlen(batch, # [BS, BK] # Lei: may have some padding issues # we should learn from mha varlen templates to handle this - T.copy(K[bos + i_s:bos + i_s + BS, i_h, :BK], K_shared) + T.copy(K[bos + i_s : bos + i_s + BS, i_h, :BK], K_shared) if is_causal: for i, j in T.Parallel(G, BS): - acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) # Softmax T.copy(scores_max, scores_max_prev) @@ -146,13 +134,13 @@ def native_sparse_attention_varlen(batch, acc_o[i, j] *= scores_scale[i] # V * softmax(Q * K) - T.copy(V[bos + i_s:bos + i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared) + T.copy(V[bos + i_s : bos + i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(G, BV): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, O_slc[bos + i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV]) + T.copy(O_shared, O_slc[bos + i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) return native_sparse_attention_varlen @@ -190,17 +178,20 @@ def parallel_nsa_fwd( o_slc = torch.empty(B, C_SEQ_LEN, HQ, V, dtype=v.dtype, device=q.device) kernel( - q.view(C_SEQ_LEN, HQ, D), k.view(C_SEQ_LEN, H, D), v.view(C_SEQ_LEN, H, D), + q.view(C_SEQ_LEN, HQ, D), + k.view(C_SEQ_LEN, H, D), + v.view(C_SEQ_LEN, H, D), o_slc.view(C_SEQ_LEN, HQ, V), block_indices.to(torch.int32).view(C_SEQ_LEN, H, S), - block_counts.to(torch.int32).view(C_SEQ_LEN, H), offsets.to(torch.int32), - token_indices.to(torch.int32)) + block_counts.to(torch.int32).view(C_SEQ_LEN, H), + offsets.to(torch.int32), + token_indices.to(torch.int32), + ) return o_slc @torch.compile class ParallelNSAFunction(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets): ctx.dtype = q.dtype @@ -221,22 +212,25 @@ class ParallelNSAFunction(torch.autograd.Function): window_size=window_size, scale=scale, offsets=offsets, - token_indices=token_indices) + token_indices=token_indices, + ) return o_slc.to(q.dtype) -def parallel_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -276,29 +270,27 @@ def parallel_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" if isinstance(block_counts, int): block_indices = block_indices[:, :, :, :block_counts] block_counts = None - o_slc = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, - scale, cu_seqlens) + o_slc = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) if window_size > 0: assert False, "Window size is not supported yet" else: o = o_slc * g_slc.unsqueeze(-1) if head_first: - o = rearrange(o, 'b t h d -> b h t d') + o = rearrange(o, "b t h d -> b h t d") return o @@ -306,41 +298,57 @@ if __name__ == "__main__": N, C_SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 64, 1, 32, torch.float16 torch.manual_seed(42) # randomly split the sequence into N segments - offsets = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(16, C_SEQ_LEN)[torch.randperm(C_SEQ_LEN - 1)[:N - 1]], - torch.tensor([C_SEQ_LEN], dtype=torch.long) - ], 0).cuda().sort()[0] + offsets = ( + torch.cat( + [ + torch.tensor([0], dtype=torch.long), + torch.arange(16, C_SEQ_LEN)[torch.randperm(C_SEQ_LEN - 1)[: N - 1]], + torch.tensor([C_SEQ_LEN], dtype=torch.long), + ], + 0, + ) + .cuda() + .sort()[0] + ) # seq-first required for inputs with variable lengths - perm_q = torch.randperm(C_SEQ_LEN, device='cuda') - perm_k = torch.randperm(C_SEQ_LEN, device='cuda') - perm_v = torch.randperm(C_SEQ_LEN, device='cuda') - q = torch.linspace( - 0, 1, steps=C_SEQ_LEN, dtype=dtype, - device='cuda')[perm_q].view(1, C_SEQ_LEN, 1, 1).expand(1, C_SEQ_LEN, HQ, - D).clone().requires_grad_(True) - k = torch.linspace( - 0, 1, steps=C_SEQ_LEN, dtype=dtype, - device='cuda')[perm_k].view(1, C_SEQ_LEN, 1, 1).expand(1, C_SEQ_LEN, H, - D).clone().requires_grad_(True) - v = torch.linspace( - 0, 1, steps=C_SEQ_LEN, dtype=dtype, - device='cuda')[perm_v].view(1, C_SEQ_LEN, 1, 1).expand(1, C_SEQ_LEN, H, - D).clone().requires_grad_(True) - g_slc = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True) - g_swa = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True) - do = torch.randn((1, C_SEQ_LEN, HQ, D), dtype=dtype, device='cuda') + perm_q = torch.randperm(C_SEQ_LEN, device="cuda") + perm_k = torch.randperm(C_SEQ_LEN, device="cuda") + perm_v = torch.randperm(C_SEQ_LEN, device="cuda") + q = ( + torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_q] + .view(1, C_SEQ_LEN, 1, 1) + .expand(1, C_SEQ_LEN, HQ, D) + .clone() + .requires_grad_(True) + ) + k = ( + torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_k] + .view(1, C_SEQ_LEN, 1, 1) + .expand(1, C_SEQ_LEN, H, D) + .clone() + .requires_grad_(True) + ) + v = ( + torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_v] + .view(1, C_SEQ_LEN, 1, 1) + .expand(1, C_SEQ_LEN, H, D) + .clone() + .requires_grad_(True) + ) + g_slc = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((1, C_SEQ_LEN, HQ, D), dtype=dtype, device="cuda") token_indices = prepare_token_indices(offsets).tolist() - block_indices = torch.full((1, C_SEQ_LEN, H, S), C_SEQ_LEN, dtype=torch.long, device='cuda') + block_indices = torch.full((1, C_SEQ_LEN, H, S), C_SEQ_LEN, dtype=torch.long, device="cuda") for i in range(C_SEQ_LEN): _, t = token_indices[i] for h in range(H): i_i = torch.randperm(max(1, tilelang.cdiv(t, block_size)))[:S] - block_indices[0, i, h, :len(i_i)] = i_i + block_indices[0, i, h, : len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] - block_counts = torch.randint(1, S + 1, (1, C_SEQ_LEN, H), device='cuda') + block_counts = torch.randint(1, S + 1, (1, C_SEQ_LEN, H), device="cuda") ref = naive_nsa( q=q, @@ -351,7 +359,8 @@ if __name__ == "__main__": block_indices=block_indices, block_counts=block_counts, block_size=block_size, - cu_seqlens=offsets) + cu_seqlens=offsets, + ) tri = parallel_nsa( q=q, @@ -362,7 +371,8 @@ if __name__ == "__main__": block_indices=block_indices, block_counts=block_counts, block_size=block_size, - cu_seqlens=offsets) + cu_seqlens=offsets, + ) print("tri", tri) print("ref", ref) diff --git a/examples/deepseek_nsa/example_triton_nsa_bwd.py b/examples/deepseek_nsa/example_triton_nsa_bwd.py index e912794a458a340b9111550298eeeaac19bddc44..af05bfa701654e3ec2dd53ffb2c0b50c61514801 100644 --- a/examples/deepseek_nsa/example_triton_nsa_bwd.py +++ b/examples/deepseek_nsa/example_triton_nsa_bwd.py @@ -8,6 +8,7 @@ import triton import triton.language as tl import fla + if parse(fla.__version__) < parse("0.2.1"): from fla.ops.common.utils import prepare_token_indices else: @@ -17,21 +18,44 @@ from reference import naive_nsa from einops import rearrange -@triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, - 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), -}) +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) @triton.jit -def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices, - block_counts, offsets, token_indices, T, H: tl.constexpr, - HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, - S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr, - BV: tl.constexpr, USE_OFFSETS: tl.constexpr, - USE_BLOCK_COUNTS: tl.constexpr): +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H @@ -46,20 +70,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc # else: NS = S - p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), - (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # the Q block is kept in the shared memory throughout the whole kernel # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), - (G, BV), (1, 0)) + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_slc = tl.zeros([G, BV], dtype=tl.float32) - b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_slc = tl.zeros([G], dtype=tl.float32) for i in range(NS): i_s = tl.load(block_indices + i).to(tl.int32) * BS @@ -72,7 +94,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) # [G, BS] b_s_slc = tl.dot(b_q, b_k_slc) - b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf')) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) # [G] b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc @@ -92,7 +114,6 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -105,8 +126,7 @@ class ParallelNSAFunction(torch.autograd.Function): # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] token_indices = prepare_token_indices(offsets) if offsets is not None else None - o, lse = parallel_nsa_fwd( - q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) + o, lse = parallel_nsa_fwd(q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) ctx.save_for_backward(q, k, v, o, lse) ctx.block_indices = block_indices ctx.block_size = block_size @@ -134,7 +154,8 @@ class ParallelNSAFunction(torch.autograd.Function): window_size=ctx.window_size, scale=ctx.scale, offsets=ctx.offsets, - token_indices=ctx.token_indices) + token_indices=ctx.token_indices, + ) return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None @@ -199,37 +220,56 @@ def parallel_nsa_fwd( return o_slc, lse_slc, o_swa, lse_swa -@triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) +@triton.heuristics({"USE_OFFSETS": lambda args: args["offsets"] is not None}) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) -@triton.jit(do_not_specialize=['T']) -def parallel_nsa_bwd_kernel_dkv(q, k, v, lse_slc, lse_swa, delta_slc, delta_swa, do_slc, do_swa, dk, - dv, block_mask, offsets, chunk_indices, scale, T, B: tl.constexpr, - H: tl.constexpr, HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, - V: tl.constexpr, M: tl.constexpr, BS: tl.constexpr, - WS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr): +@triton.jit(do_not_specialize=["T"]) +def parallel_nsa_bwd_kernel_dkv( + q, + k, + v, + lse_slc, + lse_swa, + delta_slc, + delta_swa, + do_slc, + do_swa, + dk, + dv, + block_mask, + offsets, + chunk_indices, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + M: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, +): i_v, i_s, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: - i_n, i_s = tl.load(chunk_indices + i_s * 2).to(tl.int32), tl.load(chunk_indices + i_s * 2 + - 1).to(tl.int32) + i_n, i_s = tl.load(chunk_indices + i_s * 2).to(tl.int32), tl.load(chunk_indices + i_s * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T - p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_s * BS, 0), (BS, BK), - (1, 0)) - p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV), - (BS, BV), (1, 0)) - p_dk = tl.make_block_ptr(dk + (i_v * B * T * H + bos * H + i_h) * K, (T, K), (H * K, 1), - (i_s * BS, 0), (BS, BK), (1, 0)) - p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV), - (BS, BV), (1, 0)) + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_s * BS, 0), (BS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_v * B * T * H + bos * H + i_h) * K, (T, K), (H * K, 1), (i_s * BS, 0), (BS, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0)) # [BS, BK] b_k = tl.load(p_k, boundary_check=(0, 1)) @@ -241,14 +281,12 @@ def parallel_nsa_bwd_kernel_dkv(q, k, v, lse_slc, lse_swa, delta_slc, delta_swa, for i in range(i_s * BS, T): b_m_slc = tl.load(block_mask + (bos + i) * H * M + i_h * M + i_s) if b_m_slc: - p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), - (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_do_slc = tl.make_block_ptr(do_slc + (bos + i) * HQ * V, (HQ, V), (V, 1), - (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_do_slc = tl.make_block_ptr(do_slc + (bos + i) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_slc = lse_slc + (bos + i) * HQ + i_h * G + tl.arange(0, G) p_delta_slc = delta_slc + (bos + i) * HQ + i_h * G + tl.arange(0, G) # [G, BV] @@ -272,14 +310,12 @@ def parallel_nsa_bwd_kernel_dkv(q, k, v, lse_slc, lse_swa, delta_slc, delta_swa, if WS > 0: o_s = i_s * BS + tl.arange(0, BS) if max(i_s * BS, i - WS + 1) < min((i_s + 1) * BS, i + 1): - p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), - (G, BK), (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_do_swa = tl.make_block_ptr(do_swa + (bos + i) * HQ * V, (HQ, V), (V, 1), - (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_do_swa = tl.make_block_ptr(do_swa + (bos + i) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_swa = lse_swa + (bos + i) * HQ + i_h * G + tl.arange(0, G) p_delta_swa = delta_swa + (bos + i) * HQ + i_h * G + tl.arange(0, G) # [G, BV] @@ -304,12 +340,19 @@ def parallel_nsa_bwd_kernel_dkv(q, k, v, lse_slc, lse_swa, delta_slc, delta_swa, tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) -@triton.heuristics( - {'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor)}) +@triton.heuristics({"USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor)}) @triton.jit -def parallel_nsa_kernel_mask(block_indices, block_counts, block_mask, T: tl.constexpr, - H: tl.constexpr, S: tl.constexpr, BS: tl.constexpr, NS: tl.constexpr, - USE_BLOCK_COUNTS: tl.constexpr): +def parallel_nsa_kernel_mask( + block_indices, + block_counts, + block_mask, + T: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + NS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_b, i_hs = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_h, i_s = i_hs // S, i_hs % S @@ -320,31 +363,56 @@ def parallel_nsa_kernel_mask(block_indices, block_counts, block_mask, T: tl.cons b_m = b_i * BS <= i_t if b_i < NS and b_i >= 0: - tl.store(block_mask + i_b * T * H * NS + i_t * H * NS + i_h * NS + b_i, - b_m.to(block_mask.dtype.element_ty)) + tl.store(block_mask + i_b * T * H * NS + i_t * H * NS + i_h * NS + b_i, b_m.to(block_mask.dtype.element_ty)) -@triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, - 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor) -}) +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) -@triton.jit(do_not_specialize=['T']) -def parallel_nsa_bwd_kernel_dq(q, k, v, lse_slc, delta_slc, do_slc, lse_swa, delta_swa, do_swa, dq, - scale, block_indices, block_counts, offsets, token_indices, T, - B: tl.constexpr, H: tl.constexpr, HQ: tl.constexpr, G: tl.constexpr, - K: tl.constexpr, V: tl.constexpr, S: tl.constexpr, BS: tl.constexpr, - WS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr, USE_BLOCK_COUNTS: tl.constexpr): +@triton.jit(do_not_specialize=["T"]) +def parallel_nsa_bwd_kernel_dq( + q, + k, + v, + lse_slc, + delta_slc, + do_slc, + lse_swa, + delta_swa, + do_swa, + dq, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: - i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + - 1).to(tl.int32) + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: @@ -449,27 +517,49 @@ def parallel_nsa_bwd_kernel_dq(q, k, v, lse_slc, delta_slc, do_slc, lse_swa, del tl.store(p_dq, (b_dq_slc + b_dq_swa).to(p_dq.dtype.element_ty), boundary_check=(0, 1)) -@triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, - 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), -}) +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) @triton.jit -def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices, - block_counts, offsets, token_indices, T, H: tl.constexpr, - HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, - S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr, - BV: tl.constexpr, USE_OFFSETS: tl.constexpr, - USE_BLOCK_COUNTS: tl.constexpr): +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: - i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + - 1).to(tl.int32) + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: @@ -484,20 +574,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc else: NS = S - p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), - (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # the Q block is kept in the shared memory throughout the whole kernel # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), - (G, BV), (1, 0)) + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_slc = tl.zeros([G, BV], dtype=tl.float32) - b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_slc = tl.zeros([G], dtype=tl.float32) for i in range(NS): i_s = tl.load(block_indices + i).to(tl.int32) * BS @@ -510,7 +598,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) # [G, BS] b_s_slc = tl.dot(b_q, b_k_slc) - b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf')) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) # [G] b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc @@ -529,13 +617,12 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty)) if WS > 0: - p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1), - (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_swa = lse_swa + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_swa = tl.zeros([G, BV], dtype=tl.float32) - b_m_swa = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_swa = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_swa = tl.zeros([G], dtype=tl.float32) for i_s in range(max(0, i_t - WS + 1), i_t + 1, BS): p_k_swa = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) @@ -546,7 +633,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_swa = tl.load(p_v_swa, boundary_check=(0, 1)) # [G, BS] b_s_swa = tl.dot(b_q, b_k_swa) - b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float('-inf')) + b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float("-inf")) # [G] b_m_swa, b_mp_swa = tl.maximum(b_m_swa, tl.max(b_s_swa, 1)), b_m_swa @@ -593,14 +680,8 @@ def parallel_nsa_block_mask( block_mask = torch.zeros(B, T, H, NS, dtype=torch.bool, device=block_indices.device) parallel_nsa_kernel_mask[(T, B, H * S)]( - block_indices=block_indices, - block_counts=block_counts, - block_mask=block_mask, - T=T, - H=H, - S=S, - BS=BS, - NS=NS) + block_indices=block_indices, block_counts=block_counts, block_mask=block_mask, T=T, H=H, S=S, BS=BS, NS=NS + ) return block_mask @@ -676,7 +757,8 @@ def parallel_nsa_bwd( BS=BS, WS=WS, BK=BK, - BV=BV) + BV=BV, + ) dq = dq.sum(0) if offsets is not None: @@ -719,14 +801,14 @@ def parallel_nsa_bwd( BS=BS, WS=WS, BK=BK, - BV=BV) + BV=BV, + ) dk = dk.sum(0) return dq, dk, dv @torch.compile class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -749,7 +831,8 @@ class ParallelNSAFunction(torch.autograd.Function): window_size=window_size, scale=scale, offsets=offsets, - token_indices=token_indices) + token_indices=token_indices, + ) ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa) ctx.block_indices = block_indices ctx.block_counts = block_counts @@ -781,22 +864,25 @@ class ParallelNSAFunction(torch.autograd.Function): window_size=ctx.window_size, scale=ctx.scale, offsets=ctx.offsets, - token_indices=ctx.token_indices) + token_indices=ctx.token_indices, + ) return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None -def parallel_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -836,51 +922,49 @@ def parallel_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" if isinstance(block_counts, int): block_indices = block_indices[:, :, :, :block_counts] block_counts = None - o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, - window_size, scale, cu_seqlens) + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) if window_size > 0: o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) else: o = o_slc * g_slc.unsqueeze(-1) if head_first: - o = rearrange(o, 'b t h d -> b h t d') + o = rearrange(o, "b t h d -> b h t d") return o if __name__ == "__main__": B, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 32, 1, 32, torch.float16 torch.random.manual_seed(0) - q = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) - k = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True) - v = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True) - g_slc = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - g_swa = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - do = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda') - - block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='cuda') + q = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + k = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + v = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda") + + block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device="cuda") for b in range(B): for t in range(T): for h in range(H): i_i = torch.randperm(max(1, (t // block_size)))[:S] - block_indices[b, t, h, :len(i_i)] = i_i + block_indices[b, t, h, : len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] - block_counts = torch.randint(1, S + 1, (B, T, H), device='cuda') + block_counts = torch.randint(1, S + 1, (B, T, H), device="cuda") ref = naive_nsa( q=q, diff --git a/examples/deepseek_nsa/example_triton_nsa_fwd.py b/examples/deepseek_nsa/example_triton_nsa_fwd.py index 2c740013a7ca7a8f5d6ccc3a71357d2d63c1fe6e..c9ab28daaf931ebc7565130343f8e4c15a1570d2 100644 --- a/examples/deepseek_nsa/example_triton_nsa_fwd.py +++ b/examples/deepseek_nsa/example_triton_nsa_fwd.py @@ -8,6 +8,7 @@ import triton import triton.language as tl import fla + if parse(fla.__version__) < parse("0.2.1"): from fla.ops.common.utils import prepare_token_indices else: @@ -17,21 +18,44 @@ from reference import naive_nsa from einops import rearrange -@triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, - 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), -}) +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) @triton.jit -def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices, - block_counts, offsets, token_indices, T, H: tl.constexpr, - HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, - S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr, - BV: tl.constexpr, USE_OFFSETS: tl.constexpr, - USE_BLOCK_COUNTS: tl.constexpr): +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H @@ -46,20 +70,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc # else: NS = S - p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), - (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # the Q block is kept in the shared memory throughout the whole kernel # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), - (G, BV), (1, 0)) + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_slc = tl.zeros([G, BV], dtype=tl.float32) - b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_slc = tl.zeros([G], dtype=tl.float32) for i in range(NS): i_s = tl.load(block_indices + i).to(tl.int32) * BS @@ -72,7 +94,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) # [G, BS] b_s_slc = tl.dot(b_q, b_k_slc) - b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf')) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) # [G] b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc @@ -92,7 +114,6 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -105,8 +126,7 @@ class ParallelNSAFunction(torch.autograd.Function): # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] token_indices = prepare_token_indices(offsets) if offsets is not None else None - o, lse = parallel_nsa_fwd( - q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) + o, lse = parallel_nsa_fwd(q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) ctx.save_for_backward(q, k, v, o, lse) ctx.block_indices = block_indices ctx.block_size = block_size @@ -177,7 +197,6 @@ def parallel_nsa_fwd( @torch.compile class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -200,7 +219,8 @@ class ParallelNSAFunction(torch.autograd.Function): window_size=window_size, scale=scale, offsets=offsets, - token_indices=token_indices) + token_indices=token_indices, + ) ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa) ctx.block_indices = block_indices ctx.block_counts = block_counts @@ -212,18 +232,20 @@ class ParallelNSAFunction(torch.autograd.Function): return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa -def parallel_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -263,51 +285,49 @@ def parallel_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" if isinstance(block_counts, int): block_indices = block_indices[:, :, :, :block_counts] block_counts = None - o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, - window_size, scale, cu_seqlens) + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) if window_size > 0: o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) else: o = o_slc * g_slc.unsqueeze(-1) if head_first: - o = rearrange(o, 'b t h d -> b h t d') + o = rearrange(o, "b t h d -> b h t d") return o if __name__ == "__main__": B, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 32, 1, 32, torch.float16 torch.random.manual_seed(0) - q = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) - k = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True) - v = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True) - g_slc = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - g_swa = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - do = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda') - - block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='cuda') + q = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + k = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + v = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda") + + block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device="cuda") for b in range(B): for t in range(T): for h in range(H): i_i = torch.randperm(max(1, (t // block_size)))[:S] - block_indices[b, t, h, :len(i_i)] = i_i + block_indices[b, t, h, : len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] - block_counts = torch.randint(1, S + 1, (B, T, H), device='cuda') + block_counts = torch.randint(1, S + 1, (B, T, H), device="cuda") ref = naive_nsa( q=q, diff --git a/examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py b/examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py index 9ccbff6a4f12cd871201489d8ed75e2168fc350c..cb4eb6d7ba6119a0ebf16700d65b55b1fd1a237b 100644 --- a/examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py +++ b/examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py @@ -8,6 +8,7 @@ import triton import triton.language as tl import fla + if parse(fla.__version__) < parse("0.2.1"): from fla.ops.common.utils import prepare_token_indices else: @@ -17,27 +18,49 @@ from reference import naive_nsa from einops import rearrange -@triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, - 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), -}) +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], - key=['BS', 'BK', 'BV'], + key=["BS", "BK", "BV"], ) @triton.jit -def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices, - block_counts, offsets, token_indices, T, H: tl.constexpr, - HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, - S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr, - BV: tl.constexpr, USE_OFFSETS: tl.constexpr, - USE_BLOCK_COUNTS: tl.constexpr): +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: - i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + - 1).to(tl.int32) + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: @@ -52,20 +75,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc else: NS = S - p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), - (1, 0)) + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # the Q block is kept in the shared memory throughout the whole kernel # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) - p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), - (G, BV), (1, 0)) + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_slc = tl.zeros([G, BV], dtype=tl.float32) - b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_slc = tl.zeros([G], dtype=tl.float32) for i in range(NS): i_s = tl.load(block_indices + i).to(tl.int32) * BS @@ -78,7 +99,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) # [G, BS] b_s_slc = tl.dot(b_q, b_k_slc) - b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf')) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) # [G] b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc @@ -97,13 +118,12 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty)) if WS > 0: - p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1), - (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse_swa = lse_swa + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o_swa = tl.zeros([G, BV], dtype=tl.float32) - b_m_swa = tl.full([G], float('-inf'), dtype=tl.float32) + b_m_swa = tl.full([G], float("-inf"), dtype=tl.float32) b_acc_swa = tl.zeros([G], dtype=tl.float32) for i_s in range(max(0, i_t - WS + 1), i_t + 1, BS): p_k_swa = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) @@ -114,7 +134,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc b_v_swa = tl.load(p_v_swa, boundary_check=(0, 1)) # [G, BS] b_s_swa = tl.dot(b_q, b_k_swa) - b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float('-inf')) + b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float("-inf")) # [G] b_m_swa, b_mp_swa = tl.maximum(b_m_swa, tl.max(b_s_swa, 1)), b_m_swa @@ -196,7 +216,6 @@ def parallel_nsa_fwd( @torch.compile class ParallelNSAFunction(torch.autograd.Function): - @staticmethod @contiguous @autocast_custom_fwd @@ -219,7 +238,8 @@ class ParallelNSAFunction(torch.autograd.Function): window_size=window_size, scale=scale, offsets=offsets, - token_indices=token_indices) + token_indices=token_indices, + ) ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa) ctx.block_indices = block_indices ctx.block_counts = block_counts @@ -231,18 +251,20 @@ class ParallelNSAFunction(torch.autograd.Function): return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa -def parallel_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -282,29 +304,27 @@ def parallel_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" if isinstance(block_counts, int): block_indices = block_indices[:, :, :, :block_counts] block_counts = None - o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, - window_size, scale, cu_seqlens) + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) if window_size > 0: o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) else: o = o_slc * g_slc.unsqueeze(-1) if head_first: - o = rearrange(o, 'b t h d -> b h t d') + o = rearrange(o, "b t h d -> b h t d") return o @@ -312,38 +332,35 @@ if __name__ == "__main__": N, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 64, 1, 32, torch.float16 torch.manual_seed(42) # randomly split the sequence into N segments - offsets = torch.cat([ - torch.tensor([0], dtype=torch.long), - torch.arange(16, T)[torch.randperm(T - 1)[:N - 1]], - torch.tensor([T], dtype=torch.long) - ], 0).cuda().sort()[0] + offsets = ( + torch.cat( + [torch.tensor([0], dtype=torch.long), torch.arange(16, T)[torch.randperm(T - 1)[: N - 1]], torch.tensor([T], dtype=torch.long)], + 0, + ) + .cuda() + .sort()[0] + ) # offsets.shape is [N+1] # seq-first required for inputs with variable lengths - perm_q = torch.randperm(T, device='cuda') - perm_k = torch.randperm(T, device='cuda') - perm_v = torch.randperm(T, device='cuda') - q = torch.linspace( - 0, 1, steps=T, dtype=dtype, - device='cuda')[perm_q].view(1, T, 1, 1).expand(1, T, HQ, D).clone().requires_grad_(True) - k = torch.linspace( - 0, 1, steps=T, dtype=dtype, - device='cuda')[perm_k].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) - v = torch.linspace( - 0, 1, steps=T, dtype=dtype, - device='cuda')[perm_v].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) - g_slc = torch.rand((1, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - g_swa = torch.rand((1, T, HQ), dtype=dtype, device='cuda').requires_grad_(True) - do = torch.randn((1, T, HQ, D), dtype=dtype, device='cuda') + perm_q = torch.randperm(T, device="cuda") + perm_k = torch.randperm(T, device="cuda") + perm_v = torch.randperm(T, device="cuda") + q = torch.linspace(0, 1, steps=T, dtype=dtype, device="cuda")[perm_q].view(1, T, 1, 1).expand(1, T, HQ, D).clone().requires_grad_(True) + k = torch.linspace(0, 1, steps=T, dtype=dtype, device="cuda")[perm_k].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) + v = torch.linspace(0, 1, steps=T, dtype=dtype, device="cuda")[perm_v].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) + g_slc = torch.rand((1, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.rand((1, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((1, T, HQ, D), dtype=dtype, device="cuda") token_indices = prepare_token_indices(offsets).tolist() - block_indices = torch.full((1, T, H, S), T, dtype=torch.long, device='cuda') + block_indices = torch.full((1, T, H, S), T, dtype=torch.long, device="cuda") for i in range(T): _, t = token_indices[i] for h in range(H): i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S] - block_indices[0, i, h, :len(i_i)] = i_i + block_indices[0, i, h, : len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] - block_counts = torch.randint(1, S + 1, (1, T, H), device='cuda') + block_counts = torch.randint(1, S + 1, (1, T, H), device="cuda") ref = naive_nsa( q=q, @@ -354,7 +371,8 @@ if __name__ == "__main__": block_indices=block_indices, block_counts=block_counts, block_size=block_size, - cu_seqlens=offsets) + cu_seqlens=offsets, + ) tri = parallel_nsa( q=q, @@ -365,7 +383,8 @@ if __name__ == "__main__": block_indices=block_indices, block_counts=block_counts, block_size=block_size, - cu_seqlens=offsets) + cu_seqlens=offsets, + ) print("tri", tri) print("ref", ref) diff --git a/examples/deepseek_nsa/reference.py b/examples/deepseek_nsa/reference.py index 958d0c19ee6798b234dd644b8169d27636def779..58083108eb30e871fba15b60a9f36bacee9c3949 100644 --- a/examples/deepseek_nsa/reference.py +++ b/examples/deepseek_nsa/reference.py @@ -6,18 +6,20 @@ from typing import Union from einops import rearrange, repeat -def naive_nsa(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False) -> torch.Tensor: +def naive_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: r""" Args: q (torch.Tensor): @@ -57,26 +59,24 @@ def naive_nsa(q: torch.Tensor, Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: - raise RuntimeError( - "Sequences with variable lengths are not supported for head-first mode") + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), - (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') + block_counts = rearrange(block_counts, "b h t -> b t h") dtype = q.dtype G = q.shape[2] // k.shape[2] BS = block_size S = block_indices.shape[-1] - k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices)) + k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices)) if isinstance(block_counts, torch.Tensor): - block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G) + block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G) c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) q, k, v = map(lambda x: x.float(), (q, k, v)) @@ -86,14 +86,11 @@ def naive_nsa(q: torch.Tensor, if cu_seqlens is None: varlen = False B, T = q.shape[:2] - cu_seqlens = torch.cat( - [block_indices.new_tensor(range(0, B * T, T)), - block_indices.new_tensor([B * T])]) + cu_seqlens = torch.cat([block_indices.new_tensor(range(0, B * T, T)), block_indices.new_tensor([B * T])]) for i in range(len(cu_seqlens) - 1): if not varlen: - q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[ - i], block_indices[i] + q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[i], block_indices[i] if isinstance(block_counts, torch.Tensor): s_b = block_counts[i] else: @@ -101,10 +98,10 @@ def naive_nsa(q: torch.Tensor, else: T = cu_seqlens[i + 1] - cu_seqlens[i] q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = map( - lambda x: x[0][cu_seqlens[i]:cu_seqlens[i + 1]], - (q, k, v, g_slc, g_swa, block_indices)) + lambda x: x[0][cu_seqlens[i] : cu_seqlens[i + 1]], (q, k, v, g_slc, g_swa, block_indices) + ) if isinstance(block_counts, torch.Tensor): - s_b = block_counts[0][cu_seqlens[i]:cu_seqlens[i + 1]] + s_b = block_counts[0][cu_seqlens[i] : cu_seqlens[i + 1]] else: s_b = block_counts @@ -126,34 +123,28 @@ def naive_nsa(q: torch.Tensor, else: s_i = s_b # [S*BS, HQ, -1] - k_i_slc, v_i_slc = map( - lambda x: x.gather( - 0, - i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b)) + k_i_slc, v_i_slc = map(lambda x: x.gather(0, i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b)) # [S*BS, HQ] - attn_slc = torch.einsum('h d, n h d -> n h', q_i, k_i_slc).masked_fill( - torch.logical_or(i_i < 0, i_i > i_q) | - (c >= s_i if block_counts is not None else False), float('-inf')).softmax(0) + attn_slc = ( + torch.einsum("h d, n h d -> n h", q_i, k_i_slc) + .masked_fill(torch.logical_or(i_i < 0, i_i > i_q) | (c >= s_i if block_counts is not None else False), float("-inf")) + .softmax(0) + ) if not varlen: - o_slc[i, i_q] = torch.einsum('n h, n h v -> h v', attn_slc, - v_i_slc) * g_slc_i.unsqueeze(-1) + o_slc[i, i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) else: - o_slc[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_slc, - v_i_slc) * g_slc_i.unsqueeze(-1) + o_slc[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) if window_size > 0: - k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1):i_q + 1], - (k_b, v_b)) - attn_swa = torch.einsum('h d, n h d -> n h', q_i, k_i_swa).softmax(0) + k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1) : i_q + 1], (k_b, v_b)) + attn_swa = torch.einsum("h d, n h d -> n h", q_i, k_i_swa).softmax(0) if not varlen: - o_swa[i, i_q] = torch.einsum('n h, n h v -> h v', attn_swa, - v_i_swa) * g_swa_i.unsqueeze(-1) + o_swa[i, i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) else: - o_swa[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_swa, - v_i_swa) * g_swa_i.unsqueeze(-1) + o_swa[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) if head_first: - o_slc = rearrange(o_slc, 'b t h d -> b h t d') - o_swa = rearrange(o_swa, 'b t h d -> b h t d') + o_slc = rearrange(o_slc, "b t h d -> b h t d") + o_swa = rearrange(o_swa, "b t h d -> b h t d") return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype) @@ -187,7 +178,7 @@ def naive_nsa_simple( o (torch.Tensor): Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 dtype = q.dtype HQ = q.shape[2] @@ -197,8 +188,8 @@ def naive_nsa_simple( BS = block_size S = block_indices.shape[-1] SELECTED_BLOCKS_SIZE = S * BS - k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices)) - block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G) + k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices)) + block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G) c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) q, k, v = map(lambda x: x.float(), (q, k, v)) o = torch.zeros_like(v) @@ -228,10 +219,10 @@ def naive_nsa_simple( v_i[t, h] = v_b[selected_block_index, h, :] # [S*BS, HQ] - attn = torch.einsum('h d, n h d -> n h', q_i, k_i) - attn = attn.masked_fill((i_i > i_q) | (c >= s_i), float('-inf')) + attn = torch.einsum("h d, n h d -> n h", q_i, k_i) + attn = attn.masked_fill((i_i > i_q) | (c >= s_i), float("-inf")) attn = torch.softmax(attn, dim=0) - o[i, i_q] = torch.einsum('n h, n h v -> h v', attn, v_i) + o[i, i_q] = torch.einsum("n h, n h v -> h v", attn, v_i) return o.to(dtype) @@ -265,7 +256,7 @@ def naive_nsa_simple_inference( o (torch.Tensor): Outputs of shape `[B, 1, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 dtype = q.dtype HQ = q.shape[2] @@ -275,8 +266,8 @@ def naive_nsa_simple_inference( BS = block_size S = block_indices.shape[-1] SELECTED_BLOCKS_SIZE = S * BS - k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices)) - block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G) + k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices)) + block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G) c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) q, k, v = map(lambda x: x.float(), (q, k, v)) o = torch.zeros_like(q) @@ -306,9 +297,9 @@ def naive_nsa_simple_inference( v_i[t, h] = v_b[selected_block_index, h, :] # [S*BS, HQ] - attn = torch.einsum('h d, n h d -> n h', q_i, k_i) - attn = attn.masked_fill((c >= s_i), float('-inf')) + attn = torch.einsum("h d, n h d -> n h", q_i, k_i) + attn = attn.masked_fill((c >= s_i), float("-inf")) attn = torch.softmax(attn, dim=0) - o[i, 0] = torch.einsum('n h, n h v -> h v', attn, v_i) + o[i, 0] = torch.einsum("n h, n h v -> h v", attn, v_i) return o.to(dtype) diff --git a/examples/deepseek_v32/fp8_lighting_indexer.py b/examples/deepseek_v32/fp8_lighting_indexer.py index dd940648b66cd0834583861a39a0176be25e5261..305e2afc489678699411c7751261e4a0ebc6ff91 100644 --- a/examples/deepseek_v32/fp8_lighting_indexer.py +++ b/examples/deepseek_v32/fp8_lighting_indexer.py @@ -28,11 +28,11 @@ def validate_tensor_match(a, b, tolerance=1e-8, tensor_name="tensor", should_rai if should_raise: assert False if not torch.isclose( - a.masked_fill(a_finite, 0), - b.masked_fill(b_finite, 0), - rtol=0, - atol=0, - equal_nan=True, + a.masked_fill(a_finite, 0), + b.masked_fill(b_finite, 0), + rtol=0, + atol=0, + equal_nan=True, ).all(): display_error_message(f"{tensor_name} Error: nonfinite value mismatch") if should_raise: @@ -55,13 +55,10 @@ def get_configs(): threads=[128, 256], block_Q=[1, 2, 4], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] class SupplyProg: - def __init__(self): self.tensors_dict = {} @@ -88,7 +85,8 @@ supply_prog = SupplyProg() @tilelang.jit( pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - },) + }, +) def mqa_attn_return_logits( heads, index_dim, @@ -113,16 +111,15 @@ def mqa_attn_return_logits( @T.prim_func def mqa_attn_return_logits_kernel( - IndexQ: T.Tensor(index_q_shape, dtype), # type: ignore - IndexK: T.Tensor(index_k_shape, dtype), # type: ignore - IndexKScale: T.Tensor(index_k_scale_shape, accum_dtype), # type: ignore - Logits: T.Tensor(logits_shape, accum_dtype), # type: ignore - Weights: T.Tensor([seq_len, heads], accum_dtype), # type: ignore - CuSeqLenKS: T.Tensor([seq_len], index_dtype), # type: ignore - CuSeqLenKE: T.Tensor([seq_len], index_dtype), # type: ignore + IndexQ: T.Tensor(index_q_shape, dtype), # type: ignore + IndexK: T.Tensor(index_k_shape, dtype), # type: ignore + IndexKScale: T.Tensor(index_k_scale_shape, accum_dtype), # type: ignore + Logits: T.Tensor(logits_shape, accum_dtype), # type: ignore + Weights: T.Tensor([seq_len, heads], accum_dtype), # type: ignore + CuSeqLenKS: T.Tensor([seq_len], index_dtype), # type: ignore + CuSeqLenKE: T.Tensor([seq_len], index_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx: - index_q_shared = T.alloc_shared([block_Q * heads, index_dim], dtype) index_k_shared = T.alloc_shared([block_N, index_dim], dtype) index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype) @@ -140,17 +137,14 @@ def mqa_attn_return_logits( cu_k_e_max[0] = -2147483648 for bq_i in T.serial(block_Q): - cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], - seq_len_kv)) + cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv)) for bq_i in T.serial(block_Q): - cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], - seq_len_kv)) + cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv)) T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared) T.copy(Weights[seq_len_i, 0], weights) - for nbn_i in T.Pipelined( - T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages): + for nbn_i in T.Pipelined(T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages): T.copy(IndexK[cu_k_s_min[0] + nbn_i * block_N, 0], index_k_shared) T.copy(IndexKScale[cu_k_s_min[0] + nbn_i * block_N], index_k_scale_fragment) @@ -164,15 +158,14 @@ def mqa_attn_return_logits( ) for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads): - s_reshaped[bn_i, bq_i, - h_i] = (T.max(s_reshaped[bn_i, bq_i, h_i], 0) * - weights[bq_i, h_i]) * index_k_scale_fragment[bn_i] + s_reshaped[bn_i, bq_i, h_i] = (T.max(s_reshaped[bn_i, bq_i, h_i], 0) * weights[bq_i, h_i]) * index_k_scale_fragment[ + bn_i + ] T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) for bq_i, bn_i in T.Parallel(block_Q, block_N): - Logits[seq_len_i + bq_i, cu_k_s_min[0] + nbn_i * block_N + bn_i] = ( - logits[bn_i, bq_i]) + Logits[seq_len_i + bq_i, cu_k_s_min[0] + nbn_i * block_N + bn_i] = logits[bn_i, bq_i] return mqa_attn_return_logits_kernel @@ -190,9 +183,9 @@ def clean_logits_( @T.prim_func def clean_logits_kernel( - Logits: T.Tensor([seq_len, seq_len_kv], dtype), # type: ignore - CuSeqLenKS: T.Tensor([seq_len], indices_dtype), # type: ignore - CuSeqLenKE: T.Tensor([seq_len], indices_dtype), # type: ignore + Logits: T.Tensor([seq_len, seq_len_kv], dtype), # type: ignore + CuSeqLenKS: T.Tensor([seq_len], indices_dtype), # type: ignore + CuSeqLenKE: T.Tensor([seq_len], indices_dtype), # type: ignore ): with T.Kernel(seq_len, threads=threads) as bx: tx = T.thread_binding(0, threads, thread="threadIdx.x") @@ -210,13 +203,7 @@ def clean_logits_( return clean_logits_kernel -def mqa_attn_return_logits_interface(q, - kv, - kv_scales, - weights, - cu_seqlen_ks, - cu_seqlen_ke, - clean_logits=True): +def mqa_attn_return_logits_interface(q, kv, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits=True): seq_len, heads, index_dim = q.shape seq_len_kv = kv.shape[0] @@ -238,20 +225,19 @@ def mqa_attn_return_logits_interface(q, return logits -def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, - cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor): +def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor): k = kv q = q.float() k = k.float() seq_len_kv = kv.shape[0] - mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= cu_seqlen_ks[:, None] - mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < cu_seqlen_ke[:, None] + mask_lo = torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] + mask_hi = torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] mask = mask_lo & mask_hi - score = torch.einsum('mhd,nd->hmn', q, k) + score = torch.einsum("mhd,nd->hmn", q, k) logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) - logits = logits.masked_fill(~mask, float('-inf')) + logits = logits.masked_fill(~mask, float("-inf")) cost = mask.sum() return logits, cost @@ -265,32 +251,22 @@ def test_fp8_lighting_indexer(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1): weights = torch.randn(S, H, device="cuda", dtype=torch.float32) p = (torch.randn(S, SKV, device="cuda", dtype=torch.float32) * 4).softmax(dim=-1) - ks, ke = generate_random_cu_seqlens( - per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048) + ks, ke = generate_random_cu_seqlens(per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048) - logits_ref, cost_ref = ref_fp8_mqa_logits( - q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + logits_ref, cost_ref = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) q_fp8 = q.to(torch.float8_e4m3fn) kv_fp8, kv_scales = per_custom_dims_cast_to_fp8(kv, (0,), False) - logits_tl = mqa_attn_return_logits_interface( - q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) - diff = validate_tensor_match( - logits_ref, logits_tl, tolerance=1e-14, tensor_name="logits", should_raise=False) + logits_tl = mqa_attn_return_logits_interface(q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + diff = validate_tensor_match(logits_ref, logits_tl, tolerance=1e-14, tensor_name="logits", should_raise=False) print(f"diff: {diff}") from tilelang.profiler import do_bench def logits_fn(): - return mqa_attn_return_logits_interface( - q=q_fp8, - kv=kv_fp8, - kv_scales=kv_scales, - weights=weights, - cu_seqlen_ks=ks, - cu_seqlen_ke=ke) + return mqa_attn_return_logits_interface(q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: logits_fn() diff --git a/examples/deepseek_v32/sparse_mla_bwd.py b/examples/deepseek_v32/sparse_mla_bwd.py index 4ff3b8194563ce4bd36928e44be6bc84b2a30317..1266e70edf9d43dfa527ed0d4a159a1d4564059d 100644 --- a/examples/deepseek_v32/sparse_mla_bwd.py +++ b/examples/deepseek_v32/sparse_mla_bwd.py @@ -22,9 +22,9 @@ def preprocess( @T.prim_func def preprocess_kernel( - O: T.Tensor(shape, dtype), - dO: T.Tensor(shape, dtype), - Delta: T.Tensor([B, S, H], accum_dtype), + O: T.Tensor(shape, dtype), + dO: T.Tensor(shape, dtype), + Delta: T.Tensor([B, S, H], accum_dtype), ): with T.Kernel(H, T.ceildiv(S, block_ND), B) as (bx, by, bz): o = T.alloc_fragment([block_ND, block_ND], accum_dtype) @@ -33,16 +33,12 @@ def preprocess( acc = T.alloc_fragment([block_ND, block_ND], accum_dtype) T.clear(acc) for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages): - T.copy( - O[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], - o) - T.copy( - dO[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], - do) + T.copy(O[bz, by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], o) + T.copy(dO[bz, by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], do) for i, j in T.Parallel(block_ND, block_ND): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, by * block_ND:(by + 1) * block_ND, bx]) + T.copy(delta, Delta[bz, by * block_ND : (by + 1) * block_ND, bx]) return preprocess_kernel @@ -65,13 +61,13 @@ def postprocess( @T.prim_func def postprocess_kernel( - dKV: T.Tensor(dkv_shape, accum_dtype), - dKV_out: T.Tensor(dkv_shape, dtype), + dKV: T.Tensor(dkv_shape, accum_dtype), + dKV_out: T.Tensor(dkv_shape, dtype), ): with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, B, threads=threads) as (bx, by, bz): T.copy( - dKV[bz, bx * block_N:(bx + 1) * block_N, by, :], - dKV_out[bz, bx * block_N:(bx + 1) * block_N, by, :], + dKV[bz, bx * block_N : (bx + 1) * block_N, by, :], + dKV_out[bz, bx * block_N : (bx + 1) * block_N, by, :], ) return postprocess_kernel @@ -83,7 +79,8 @@ def postprocess( tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True, - }) + }, +) def bwd( B, S, @@ -102,14 +99,14 @@ def bwd( dtype="bfloat16", accum_dtype="float", ): - assert is_causal == True, 'non-casual is not supported now' - assert topk % block_size == 0, 'otherwise will load some index=0 thus causing wrong kv to be loaded' + assert is_causal == True, "non-casual is not supported now" + assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" assert dtype == "bfloat16" assert accum_dtype == "float" assert indices_dtype == "int32" if sm_scale is None: - sm_scale = (D + D_tail)**(-0.5) + sm_scale = (D + D_tail) ** (-0.5) sm_scale_mul_reciprocal_log2 = sm_scale * 1.44269504 # log2(e) H_kv = H // kv_group @@ -132,14 +129,14 @@ def bwd( @T.prim_func def sparse_mla_bwd_kernel( - Q: T.Tensor(q_shape, dtype), - KV: T.Tensor(k_shape, dtype), - dO: T.Tensor(o_shape, dtype), - Indices: T.Tensor(indices_shape, indices_dtype), - Lse: T.Tensor(lse_shape, accum_dtype), - Delta: T.Tensor(delta_shape, accum_dtype), - dQ: T.Tensor(q_shape, dtype), - dKV: T.Tensor(k_shape, accum_dtype), + Q: T.Tensor(q_shape, dtype), + KV: T.Tensor(k_shape, dtype), + dO: T.Tensor(o_shape, dtype), + Indices: T.Tensor(indices_shape, indices_dtype), + Lse: T.Tensor(lse_shape, accum_dtype), + Delta: T.Tensor(delta_shape, accum_dtype), + dQ: T.Tensor(q_shape, dtype), + dKV: T.Tensor(k_shape, accum_dtype), ): with T.Kernel(S, B, kv_group, threads=threads) as (s_i, by, bz): Q_shared = T.alloc_shared([padded_H, D], dtype) @@ -165,17 +162,19 @@ def bwd( max_kv_i = s_i - T.copy(Q[by, s_i, bz * padded_H:(bz + 1) * padded_H, :D], Q_shared) - T.copy(Q[by, s_i, bz * padded_H:(bz + 1) * padded_H, D:], Q_tail_shared) - T.copy(dO[by, s_i, bz * padded_H:(bz + 1) * padded_H, :D], dO_shared) + T.copy(Q[by, s_i, bz * padded_H : (bz + 1) * padded_H, :D], Q_shared) + T.copy(Q[by, s_i, bz * padded_H : (bz + 1) * padded_H, D:], Q_tail_shared) + T.copy(dO[by, s_i, bz * padded_H : (bz + 1) * padded_H, :D], dO_shared) T.clear(acc_dq) T.clear(acc_dq_tail) - T.annotate_layout({ - dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared), - dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared), - }) + T.annotate_layout( + { + dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared), + dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared), + } + ) # Process each block of indices for i_i in T.Pipelined(NS, num_stages=num_stages): @@ -191,62 +190,31 @@ def bwd( for bi_i, d_i in T.Parallel(BS, D): KV_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BS + bi_i], bz, d_i] - T.gemm( - Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) for bi_i, d_i in T.Parallel(BS, D_tail): - KV_tail_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BS + bi_i], bz, - D + d_i] - T.gemm( - Q_tail_shared, - KV_tail_shared, - acc_p, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + KV_tail_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BS + bi_i], bz, D + d_i] + T.gemm(Q_tail_shared, KV_tail_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) for h_i, bi_i in T.Parallel(padded_H, BS): - acc_p[h_i, bi_i] = T.exp2(acc_p[h_i, bi_i] * sm_scale_mul_reciprocal_log2 - - Lse[by, s_i, bz * padded_H + h_i]) + acc_p[h_i, bi_i] = T.exp2(acc_p[h_i, bi_i] * sm_scale_mul_reciprocal_log2 - Lse[by, s_i, bz * padded_H + h_i]) T.copy(acc_p, P_shared_cast) - T.gemm( - dO_shared, - KV_shared, - acc_dp, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol, - clear_accum=True) + T.gemm(dO_shared, KV_shared, acc_dp, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) for h_i, bi_i in T.Parallel(padded_H, BS): - acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * ( - acc_dp[h_i, bi_i] - Delta[by, s_i, bz * padded_H + h_i]) * sm_scale + acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[by, s_i, bz * padded_H + h_i]) * sm_scale T.copy(acc_dp, dP_shared_cast) T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol) T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol) - T.gemm( - dP_shared_cast, - Q_shared, - acc_dkv, - transpose_A=True, - policy=T.GemmWarpPolicy.FullCol, - clear_accum=True) - T.gemm( - P_shared_cast, - dO_shared, - acc_dkv, - transpose_A=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(dP_shared_cast, Q_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) + T.gemm(P_shared_cast, dO_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) T.clear(acc_dkv_tail) - T.gemm( - dP_shared_cast, - Q_tail_shared, - acc_dkv_tail, - transpose_A=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(dP_shared_cast, Q_tail_shared, acc_dkv_tail, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) for s in range(split_store): for bi_i, d_i in T.Parallel(BS, D): @@ -255,41 +223,32 @@ def bwd( for bi_i, d_i in T.Parallel(BS, D_tail): if bi_i < BS // split_store: - acc_dkv_tail_shared[bi_i, - d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), - d_i] + acc_dkv_tail_shared[bi_i, d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), d_i] for bi_i, d_i in T.Parallel(BS // split_store, D // 4): T.atomic_addx4( - dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], - bz, d_i * 4], acc_dkv_shared[bi_i, d_i * 4]) + dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, d_i * 4], + acc_dkv_shared[bi_i, d_i * 4], + ) # Atomically update dKV, dKV_tail tensors for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4): T.atomic_addx4( - dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], - bz, D + d_i * 4], acc_dkv_tail_shared[bi_i, d_i * 4]) + dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, D + d_i * 4], + acc_dkv_tail_shared[bi_i, d_i * 4], + ) # Store the accumulated dQ T.copy(acc_dq, dQ_shared) T.copy(acc_dq_tail, dQ_tail_shared) - T.copy(dQ_shared, dQ[by, s_i, bz * padded_H:(bz + 1) * padded_H, :D]) - T.copy(dQ_tail_shared, dQ[by, s_i, bz * padded_H:(bz + 1) * padded_H, D:]) + T.copy(dQ_shared, dQ[by, s_i, bz * padded_H : (bz + 1) * padded_H, :D]) + T.copy(dQ_tail_shared, dQ[by, s_i, bz * padded_H : (bz + 1) * padded_H, D:]) return sparse_mla_bwd_kernel -def sparse_mla_bwd(q, - kv, - o, - do, - indices, - lse, - sm_scale=None, - is_casual=True, - return_kernel=False, - delta=None): +def sparse_mla_bwd(q, kv, o, do, indices, lse, sm_scale=None, is_casual=True, return_kernel=False, delta=None): assert q.is_contiguous() assert kv.is_contiguous() assert indices.is_contiguous() @@ -322,6 +281,7 @@ def sparse_mla_bwd(q, def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, sm_scale=None, is_casual=True): from sparse_mla_fwd import ref_sparse_mla_fwd_interface + q = q.detach().clone() kv = kv.detach().clone() q.requires_grad = True @@ -331,30 +291,22 @@ def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, sm_scale=None, is_c return q.grad, kv.grad -def test_sparse_mla_bwd(B=1, - S=4096, - SKV=8192, - H=64, - HKV=1, - DQKV=576, - DV=512, - topk=2048, - dtype=torch.bfloat16, - check_correctness=True): +def test_sparse_mla_bwd(B=1, S=4096, SKV=8192, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16, check_correctness=True): # Prepare data - q = torch.randn((B, S, H, DQKV), dtype=dtype, device='cuda').requires_grad_(True) - kv = torch.randn((B, SKV, HKV, DQKV), dtype=dtype, device='cuda').requires_grad_(True) - do = torch.randn((B, S, H, DV), dtype=dtype, device='cuda') + q = torch.randn((B, S, H, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((B, SKV, HKV, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((B, S, H, DV), dtype=dtype, device="cuda") - indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device='cuda') + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") for b in range(B): for t in range(S): for h in range(HKV): i_i = torch.randperm(max(1, t))[:topk] - indices[b, t, h, :len(i_i)] = i_i + indices[b, t, h, : len(i_i)] = i_i # Forward from sparse_mla_fwd import sparse_mla_fwd_interface + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices) tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse) @@ -365,13 +317,15 @@ def test_sparse_mla_bwd(B=1, assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv") print("assert_tensors_similar passed") - per_token_flop = 2 * sum([ - H * DV * topk, - H * DQKV * topk, - H * DQKV * topk, - H * DQKV * topk, - H * DV * topk, - ]) + per_token_flop = 2 * sum( + [ + H * DV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DV * topk, + ] + ) from tilelang.profiler import do_bench def fn(): @@ -379,20 +333,9 @@ def test_sparse_mla_bwd(B=1, ms = do_bench(fn, rep=100, warmup=250) print(f"Average time: {ms:.3f} ms") - print(f'bwd io bandwidth = ', - (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12) - print(f'bwd tflops = ', per_token_flop * S / (ms * 1e-3) / 1e12) + print(f"bwd io bandwidth = ", (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12) + print(f"bwd tflops = ", per_token_flop * S / (ms * 1e-3) / 1e12) if __name__ == "__main__": - test_sparse_mla_bwd( - B=1, - S=4096, - SKV=8192, - H=64, - HKV=1, - DQKV=576, - DV=512, - topk=2048, - dtype=torch.bfloat16, - check_correctness=True) + test_sparse_mla_bwd(B=1, S=4096, SKV=8192, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16, check_correctness=True) diff --git a/examples/deepseek_v32/sparse_mla_fwd.py b/examples/deepseek_v32/sparse_mla_fwd.py index e65b890176105031bafbf38c48f77a7f99a33083..3b963c751e0b2a8ccba93387cdf0b80dec4db3cc 100644 --- a/examples/deepseek_v32/sparse_mla_fwd.py +++ b/examples/deepseek_v32/sparse_mla_fwd.py @@ -25,15 +25,12 @@ def sparse_mla_fwd( num_stages=2, threads=256, ): - assert dim == tilelang.math.next_power_of_2( - dim), f"haven't check padding correctness yet, dim={dim}" - assert tail_dim == tilelang.math.next_power_of_2( - tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" assert is_causal == True, "non-casual is not supported" - assert (topk % - block_I == 0), "otherwise will load some index=0 thus causing wrong kv to be loaded" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" if sm_scale is None: - sm_scale = (1.0 / (dim + tail_dim))**0.5 * 1.44269504 # log2(e) + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e) else: sm_scale = sm_scale * 1.44269504 # log2(e) @@ -55,9 +52,9 @@ def sparse_mla_fwd( H = head_kv padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) if padded_H != H: - assert ( - kv_group == 1 - ), "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + ) BI = block_I NI = tilelang.cdiv(topk, block_I) D = dim @@ -73,18 +70,17 @@ def sparse_mla_fwd( @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), # type: ignore - KV: T.Tensor(kv_shape, dtype), # type: ignore - Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore - Output: T.Tensor(o_shape, dtype), # type: ignore - Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore ): - with T.Kernel( - seq_len * REPLICATE_H, batch, kv_group, threads=threads) as ( - bx, - by, - bz, - ): + with T.Kernel(seq_len * REPLICATE_H, batch, kv_group, threads=threads) as ( + bx, + by, + bz, + ): Q_shared = T.alloc_shared([H_per_block, D], dtype) Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) KV_shared = T.alloc_shared([BI, D], dtype) @@ -118,16 +114,13 @@ def sparse_mla_fwd( T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) for i_i in T.Pipelined(NI, num_stages=num_stages): - for bi_i in T.Parallel(BI): mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i for bi_i, d_i in T.Parallel(BI, D): - KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, - d_i] + KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i] for bi_i, d_i in T.Parallel(BI, D_tail): - K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, - D + d_i] + K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i] for h_i, bi_i in T.Parallel(H_per_block, BI): acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) @@ -176,15 +169,7 @@ def sparse_mla_fwd( return main -def sparse_mla_fwd_interface(q, - kv, - indices, - sm_scale=None, - return_p_sum: bool = False, - d_v=512, - block_I=64, - num_stages=2, - threads=256): +def sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, return_p_sum: bool = False, d_v=512, block_I=64, num_stages=2, threads=256): is_casual = True assert return_p_sum == False, "This kernel file is for fwd only" assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() @@ -201,16 +186,8 @@ def sparse_mla_fwd_interface(q, assert indices.shape == (batch, seq_len, kv_group, topk) kernel = sparse_mla_fwd( - heads, - dim, - tail_dim, - topk, - kv_group, - sm_scale, - is_casual, - block_I=block_I, - num_stages=num_stages, - threads=threads) + heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual, block_I=block_I, num_stages=num_stages, threads=threads + ) out, lse = kernel(q, kv, indices) return out, lse @@ -230,14 +207,14 @@ def ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True): b, _, _, dim_v = v.shape g_index = g h_index = h // g - compressed_casual_mask = torch.arange( - 0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange( - 1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda").view(1, -1) + compressed_casual_mask = torch.arange(0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange( + 1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda" + ).view(1, -1) mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) mask = mask[..., :-1] mask = mask & compressed_casual_mask.view(1, 1, sq, sk) - mask[:, :, :1 - 1, 0] = True + mask[:, :, : 1 - 1, 0] = True mask = mask.view(b, g_index, 1, sq, sk) q = q.view(b, sq, g, -1, dim_q) @@ -252,19 +229,21 @@ def ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True): return o.to(torch.bfloat16) -def test_sparse_mla_fwd(B=1, - S=4096, - SKV=8192, - H=128, - HKV=1, - DQK=576, - DV=512, - topk=2048, - dtype=torch.bfloat16, - check_correctness=True, - block_I=64, - num_stages=2, - threads=256): +def test_sparse_mla_fwd( + B=1, + S=4096, + SKV=8192, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=2048, + dtype=torch.bfloat16, + check_correctness=True, + block_I=64, + num_stages=2, + threads=256, +): torch.random.manual_seed(0) q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) @@ -274,10 +253,9 @@ def test_sparse_mla_fwd(B=1, for t in range(S): for h in range(HKV): i_i = torch.randperm(max(1, t))[:topk] - indices[b, t, h, :len(i_i)] = i_i + indices[b, t, h, : len(i_i)] = i_i - tl_out, tl_lse = sparse_mla_fwd_interface( - q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads) + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads) if check_correctness: # otherwise may cause out of memory @@ -286,8 +264,7 @@ def test_sparse_mla_fwd(B=1, print("assert_tensors_similar passed") def fn(): - return sparse_mla_fwd_interface( - q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads) + return sparse_mla_fwd_interface(q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads) from tilelang.profiler import do_bench @@ -315,4 +292,5 @@ if __name__ == "__main__": check_correctness=True, block_I=64, num_stages=2, - threads=256) + threads=256, + ) diff --git a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py index 1621d85ba97962447781df9ae93fef316c7346ca..972160c99a9f1faf9c777443c925fdcbb59abc60 100644 --- a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py +++ b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py @@ -9,10 +9,16 @@ import argparse @tilelang.jit( out_idx=[-2, -1], compile_flags=[ - "-O3", "-Wno-deprecated-declarations", "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", - "--ptxas-options=-v,--register-usage-level=10", "-DNDEBUG" + "-O3", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--ptxas-options=-v,--register-usage-level=10", + "-DNDEBUG", ], ) def sparse_mla_fwd( @@ -32,14 +38,12 @@ def sparse_mla_fwd( num_stages=0, threads=384, ): - assert dim == tilelang.math.next_power_of_2( - dim), f"haven't check padding correctness yet, dim={dim}" - assert tail_dim == tilelang.math.next_power_of_2( - tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" - assert is_causal == True, 'non-casual is not supported' - assert topk % block_I == 0, 'otherwise will load some index=0 thus causing wrong kv to be loaded' + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert is_causal == True, "non-casual is not supported" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" if sm_scale is None: - sm_scale = (1.0 / (dim + tail_dim))**0.5 * 1.44269504 # log2(e) + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e) else: sm_scale = sm_scale * 1.44269504 # log2(e) @@ -57,15 +61,17 @@ def sparse_mla_fwd( H = head_kv padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) if padded_H != H: - assert kv_group == 1, 'here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)' + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + ) BI = block_I NI = tilelang.cdiv(topk, block_I) - assert NI % 2 == 0, 'NI should be a multiple of 2' + assert NI % 2 == 0, "NI should be a multiple of 2" D = dim D_tail = tail_dim KV_stride = kv_stride if head_kv > 64: - assert head_kv % 64 == 0, 'head_kv should be a multiple of 64' + assert head_kv % 64 == 0, "head_kv should be a multiple of 64" REPLICATE_H = head_kv // 64 else: REPLICATE_H = 1 @@ -74,18 +80,14 @@ def sparse_mla_fwd( @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), # type: ignore - KV: T.Tensor(kv_shape, dtype), # type: ignore - Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore - q_start_index_s: T.Tensor(1, indices_dtype), - Output: T.Tensor(o_shape, dtype), # type: ignore - Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + q_start_index_s: T.Tensor(1, indices_dtype), + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore ): - with T.Kernel( - (seq_len - kv_stride + 1 if CP0 else seq_len) * REPLICATE_H, - batch, - kv_group, - threads=threads) as (bx, by, bz): + with T.Kernel((seq_len - kv_stride + 1 if CP0 else seq_len) * REPLICATE_H, batch, kv_group, threads=threads) as (bx, by, bz): Q_shared_l = T.alloc_shared([H_per_block, D // 2], dtype) Q_shared_r = T.alloc_shared([H_per_block, D // 2], dtype) Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) @@ -122,8 +124,7 @@ def sparse_mla_fwd( bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256) b_i, g_i = by, bz - s_i = (bx + (KV_stride - 1 if CP0 else 0)) if REPLICATE_H == 1 else ( - bx // REPLICATE_H + (KV_stride - 1 if CP0 else 0)) + s_i = (bx + (KV_stride - 1 if CP0 else 0)) if REPLICATE_H == 1 else (bx // REPLICATE_H + (KV_stride - 1 if CP0 else 0)) q_i = q_start_index_s[0] + s_i max_kv_i = (q_i + 1 - KV_stride) // KV_stride @@ -132,26 +133,24 @@ def sparse_mla_fwd( tx = T.get_thread_binding() - T.copy(Q[b_i, s_i, H0:H1, 0:D // 2], Q_shared_l) - T.copy(Q[b_i, s_i, H0:H1, D // 2:D], Q_shared_r) + T.copy(Q[b_i, s_i, H0:H1, 0 : D // 2], Q_shared_l) + T.copy(Q[b_i, s_i, H0:H1, D // 2 : D], Q_shared_r) T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) T.barrier_arrive(bar_q) if tx < 128: T.set_max_nreg(240, 1) T.fill(sumexp, 0) - T.fill(m_i, -2**30) # avoid -inf - inf to cause nan + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan T.fill(acc_o_l, 0) T.barrier_wait(bar_q, 0) for i_i in T.serial(T.ceildiv(NI, 2)): - # Buffer 0 T.barrier_wait(bar_k_0_ready[0], (i_i & 1)) for h_i, bi_i in T.Parallel(H_per_block, BI): - acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, - -T.infinity(acc_s.dtype)) + acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, -T.infinity(acc_s.dtype)) T.gemm(Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_tail_shared, K_tail_shared_0, acc_s, transpose_B=True, wg_wait=-1) @@ -187,8 +186,7 @@ def sparse_mla_fwd( T.barrier_wait(bar_k_1_ready[0], (i_i & 1)) for h_i, bi_i in T.Parallel(H_per_block, BI): - acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, - -T.infinity(acc_s.dtype)) + acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, -T.infinity(acc_s.dtype)) T.gemm(Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_tail_shared, K_tail_shared_1, acc_s, transpose_B=True, wg_wait=-1) @@ -227,7 +225,7 @@ def sparse_mla_fwd( for h_i in T.Parallel(H_per_block): sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale T.copy(acc_o_l, O_shared_l) - T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0:D // 2]) + T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0 : D // 2]) elif tx >= 128 and tx < 256: T.set_max_nreg(168, 1) @@ -257,7 +255,7 @@ def sparse_mla_fwd( acc_o_r[h_i, d_i] /= sum_exp_shared[h_i] T.copy(acc_o_r, O_shared_r) - T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2:D]) + T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2 : D]) elif tx >= 256: # producer T.set_max_nreg(80, 0) @@ -265,70 +263,58 @@ def sparse_mla_fwd( # Buffer 0 T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - indices_local[0] = Indices[b_i, s_i, g_i, - (i_i * 2) * BI + r * 16 + (tx - 256) // 8] + indices_local[0] = Indices[b_i, s_i, g_i, (i_i * 2) * BI + r * 16 + (tx - 256) // 8] is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i if is_kv_valid[r * 16 + (tx - 256) // 8]: with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_0_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_0_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, D // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local[0], g_i, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local[0], g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, - D + (tx - 256) % 8 * 8 + v] + K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local[0], g_i, D + (tx - 256) % 8 * 8 + v + ] T.cp_async_barrier_noinc(bar_k_0_ready[0]) # Buffer 1 T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - indices_local[0] = Indices[b_i, s_i, g_i, - (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8] + indices_local[0] = Indices[b_i, s_i, g_i, (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8] is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i if is_kv_valid[r * 16 + (tx - 256) // 8]: with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_1_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_1_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, D // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local[0], g_i, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local[0], g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, - D + (tx - 256) % 8 * 8 + v] + K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local[0], g_i, D + (tx - 256) % 8 * 8 + v + ] T.cp_async_barrier_noinc(bar_k_1_ready[0]) return main -def sparse_mla_fwd_interface(q, - kv, - indices, - q_start_index_s, - kv_stride, - sm_scale=None, - is_casual=True, - return_kernel=False, - print_kernel=False): +def sparse_mla_fwd_interface( + q, kv, indices, q_start_index_s, kv_stride, sm_scale=None, is_casual=True, return_kernel=False, print_kernel=False +): assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() batch, seq_len, heads, dim_plus_tail_dim = q.shape _, seq_len_kv, kv_group, _ = kv.shape - assert dim_plus_tail_dim == 576, 'you should assign dim otherwise' + assert dim_plus_tail_dim == 576, "you should assign dim otherwise" dim = 512 assert kv.shape[-1] == dim_plus_tail_dim @@ -338,29 +324,23 @@ def sparse_mla_fwd_interface(q, assert indices.shape == (batch, seq_len, kv_group, topk) if q_start_index_s != 0: - assert q_start_index_s > kv_stride, "If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)" + assert q_start_index_s > kv_stride, ( + "If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)" + ) CP0 = q_start_index_s == 0 - kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, - kv_group, sm_scale, is_casual, CP0) + kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, kv_group, sm_scale, is_casual, CP0) if print_kernel: print(kernel.get_kernel_source()) - out, lse = kernel(q, kv, indices, - torch.tensor([q_start_index_s], dtype=torch.int32, device="cuda")) + out, lse = kernel(q, kv, indices, torch.tensor([q_start_index_s], dtype=torch.int32, device="cuda")) if return_kernel: return kernel if q_start_index_s == 0 and kv_stride > 1: - out[:, :kv_stride - 1, :, :] = 0 + out[:, : kv_stride - 1, :, :] = 0 return out, lse -def ref_sparse_mla_fwd_interface(q, - kv, - indices, - q_start_index_s, - kv_stride=4, - sm_scale=None, - is_casual=True): +def ref_sparse_mla_fwd_interface(q, kv, indices, q_start_index_s, kv_stride=4, sm_scale=None, is_casual=True): q = q.float() kv = kv.float() indices = indices.transpose(1, 2) @@ -369,7 +349,7 @@ def ref_sparse_mla_fwd_interface(q, if q_start_index_s is None: q_start_index_s = sk * kv_stride - sq - assert kv.shape[-1] == 576, 'you should assign dim otherwise' + assert kv.shape[-1] == 576, "you should assign dim otherwise" dim = 512 k = kv v = kv[..., :dim] @@ -378,15 +358,14 @@ def ref_sparse_mla_fwd_interface(q, num_kv_per_index = 1 g_index = g h_index = h // g - compressed_casual_mask = torch.arange( - q_start_index_s, sq + q_start_index_s, dtype=torch.int32, - device="cuda").view(-1, 1) >= torch.arange( - kv_stride - 1, sk * kv_stride, kv_stride, dtype=torch.int32, device="cuda").view(1, -1) + compressed_casual_mask = torch.arange(q_start_index_s, sq + q_start_index_s, dtype=torch.int32, device="cuda").view( + -1, 1 + ) >= torch.arange(kv_stride - 1, sk * kv_stride, kv_stride, dtype=torch.int32, device="cuda").view(1, -1) mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) mask = mask[..., :-1] mask = mask & compressed_casual_mask.view(1, 1, sq, sk) - mask[:, :, :kv_stride - 1, 0] = True + mask[:, :, : kv_stride - 1, 0] = True mask = mask.view(b, g_index, 1, sq, sk) q = q.view(b, sq, g, -1, dim_q) @@ -401,41 +380,32 @@ def ref_sparse_mla_fwd_interface(q, return o.to(torch.bfloat16) -def test_sparse_mla_fwd_pipelined(B=1, - S=4096, - SKV=8192, - H=128, - HKV=1, - DQK=576, - DV=512, - topk=2048, - dtype=torch.bfloat16, - q_start_s_index=1024, - check_correctness=True): +def test_sparse_mla_fwd_pipelined( + B=1, S=4096, SKV=8192, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16, q_start_s_index=1024, check_correctness=True +): KV_stride = 1 torch.random.manual_seed(0) - q = torch.randn((B, S, H, DQK), dtype=dtype, device='cuda').requires_grad_(True) / 10 - kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device='cuda').requires_grad_(True) / 10 + q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10 + kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10 q_start_s_index_t = torch.tensor([q_start_s_index], dtype=torch.int32, device="cuda") q.clamp_(-10, 10) kv.clamp_(-10, 10) - indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device='cuda') + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") for b in range(B): for t in range(S): for h in range(HKV): i_i = torch.randperm(min(max(1, ((t + q_start_s_index) // KV_stride)), SKV))[:topk] - indices[b, t, h, :len(i_i)] = i_i + indices[b, t, h, : len(i_i)] = i_i - kernel = sparse_mla_fwd_interface( - q, kv, indices, q_start_s_index, KV_stride, return_kernel=True, print_kernel=True) + kernel = sparse_mla_fwd_interface(q, kv, indices, q_start_s_index, KV_stride, return_kernel=True, print_kernel=True) def fn(): out, lse = kernel(q, kv, indices, q_start_s_index_t) if q_start_s_index == 0 and KV_stride > 1: - out[:, :KV_stride - 1, :, :] = 0 + out[:, : KV_stride - 1, :, :] = 0 return out, lse tl_out, tl_lse = fn() @@ -446,14 +416,15 @@ def test_sparse_mla_fwd_pipelined(B=1, torch.testing.assert_close(tl_out, ref_out, rtol=1e-3, atol=1e-3) from tilelang.profiler import do_bench + ms = do_bench( fn, rep=10, warmup=10, ) print(f"Average time: {ms:.3f} ms") - print(f'fwd io bandwidth = ', (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) - print(f'fwd tflops = ', (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) + print(f"fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) + print(f"fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) if __name__ == "__main__": @@ -464,5 +435,4 @@ if __name__ == "__main__": B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 else: B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 - test_sparse_mla_fwd_pipelined( - B, S, SKV, H, HKV, DQK, DV, topk, dtype, check_correctness=args.test_correctness) + test_sparse_mla_fwd_pipelined(B, S, SKV, H, HKV, DQK, DV, topk, dtype, check_correctness=args.test_correctness) diff --git a/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py b/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py index 2dd27048eef9b2303c65ce42b8e5c66e1106cd5c..6b7e879ba6e4fc1155660ad6be10da917c7a5ad5 100644 --- a/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py +++ b/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py @@ -21,23 +21,20 @@ def test_example_fp8_lighting_indexer(): @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_sparse_mla_fwd(): # small shapes for testing - sparse_mla_fwd.test_sparse_mla_fwd( - S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) + sparse_mla_fwd.test_sparse_mla_fwd(S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_sparse_mla_fwd_pipelined(): # small shapes for testing - sparse_mla_fwd_pipelined.test_sparse_mla_fwd_pipelined( - S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) + sparse_mla_fwd_pipelined.test_sparse_mla_fwd_pipelined(S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_sparse_mla_bwd(): - sparse_mla_bwd.test_sparse_mla_bwd( - S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False) + sparse_mla_bwd.test_sparse_mla_bwd(S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False) if __name__ == "__main__": diff --git a/examples/deepseek_v32/topk_selector.py b/examples/deepseek_v32/topk_selector.py index 4a4b432775885a1632c430b79d1a7395ab28973f..cf87f526d5c7acfc5ffa8ef08df03d00c41c926b 100644 --- a/examples/deepseek_v32/topk_selector.py +++ b/examples/deepseek_v32/topk_selector.py @@ -127,9 +127,9 @@ def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"): l_num_input = s_num_input[r_idx] for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): if s * BLOCK_SIZE + tx < l_num_input: - l_bin_id32 = T.Cast("int32", (( - convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> - (24 - round * 8)) & 0xFF)) + l_bin_id32 = T.Cast( + "int32", ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF) + ) T.atomic_add(s_histogram[l_bin_id32], 1) T.sync_threads() # cumsum @@ -156,23 +156,20 @@ def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"): for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): T.sync_threads() if s * BLOCK_SIZE + tx < l_num_input: - l_bin_id32 = T.Cast("int32", (( - convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> - (24 - round * 8)) & 0xFF)) + l_bin_id32 = T.Cast( + "int32", ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF) + ) if l_bin_id32 > l_threshold_bin_id: - pos = T.atomic_add( - s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos + pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos index[bx, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: if round == 3: - l_out_pos = T.atomic_add( - s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos + l_out_pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos if l_out_pos < topk: index[bx, l_out_pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] else: pos = T.atomic_add(s_num_input[r_idx ^ 1], 1, return_prev=True) - s_input_idx[r_idx ^ 1, pos] = s_input_idx[r_idx, - s * BLOCK_SIZE + tx] + s_input_idx[r_idx ^ 1, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] return tl_topk_kernel @@ -186,7 +183,6 @@ def tl_topk(input, starts, ends, topk): def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048): - batch = 64 seq_len = 32 * 1024 topk = 2048 @@ -212,8 +208,7 @@ def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048): set_ref = set(ref_np) set_trt = set(trt_np) intersection = set_ref & set_trt - print("selected/all:", len(intersection), "/", len(set_ref), "=", - len(intersection) / len(set_ref)) + print("selected/all:", len(intersection), "/", len(set_ref), "=", len(intersection) / len(set_ref)) # Performance test with CUDA events diff --git a/examples/deepseek_v32/utils.py b/examples/deepseek_v32/utils.py index 2ea34b14a465c2079d2726d2b2d108397a10c685..d7252e171108aa13396f6d3e91d84d04de1d3c17 100644 --- a/examples/deepseek_v32/utils.py +++ b/examples/deepseek_v32/utils.py @@ -23,8 +23,7 @@ def _is_equal(a, b): if isinstance(a, torch.Tensor): return a is b # Whitelist of types that are safe to compare by value for caching. - if isinstance(a, (int, float, str, bool, type(None))) and isinstance( - b, (int, float, str, bool, type(None))): + if isinstance(a, (int, float, str, bool, type(None))) and isinstance(b, (int, float, str, bool, type(None))): return a == b # For other types, we cannot guarantee a cheap and safe comparison, so we fail the cache check. return False @@ -58,9 +57,11 @@ def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor] if len(args) == len(last_args) and len(kwargs) == len(last_kwargs): # For Tensors, check for object identity. For other types, check for equality. # Python caches small integers, so `is` works for them but not for large integers like 4096. - if all(_is_equal(a, b) for a, b in zip(args, last_args)) and \ - set(kwargs.keys()) == set(last_kwargs.keys()) and \ - all(_is_equal(v, last_kwargs[k]) for k, v in kwargs.items()): + if ( + all(_is_equal(a, b) for a, b in zip(args, last_args)) + and set(kwargs.keys()) == set(last_kwargs.keys()) + and all(_is_equal(v, last_kwargs[k]) for k, v in kwargs.items()) + ): return last_result result = fn(*args, **kwargs) @@ -79,73 +80,68 @@ def cal_seq_idx_from_cu_seqlens(cu_seqlens: torch.LongTensor, seq_len: int): @tensor_cache -def cal_seq_idx_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, - seq_len: int) -> torch.IntTensor: - seq_idx_for_q = torch.full((seq_len,), - len(cu_seqlens_qs), - dtype=torch.int32, - device=cu_seqlens_qs.device) +def cal_seq_idx_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, seq_len: int) -> torch.IntTensor: + seq_idx_for_q = torch.full((seq_len,), len(cu_seqlens_qs), dtype=torch.int32, device=cu_seqlens_qs.device) for i in range(len(cu_seqlens_qs)): - seq_idx_for_q[cu_seqlens_qs[i]:cu_seqlens_qe[i]] = i + seq_idx_for_q[cu_seqlens_qs[i] : cu_seqlens_qe[i]] = i return seq_idx_for_q @tensor_cache -def cal_cu_seqlen_ks_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, - cu_seqlens_ks: torch.LongTensor, seq_len: int) -> torch.IntTensor: +def cal_cu_seqlen_ks_for_q( + cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, cu_seqlens_ks: torch.LongTensor, seq_len: int +) -> torch.IntTensor: cu_seqlen_ks_for_each_q = torch.gather( - input=torch.cat([ - cu_seqlens_ks, - torch.full((1,), - torch.iinfo(torch.int32).max, - dtype=torch.int32, - device=cu_seqlens_qs.device) - ]), + input=torch.cat([cu_seqlens_ks, torch.full((1,), torch.iinfo(torch.int32).max, dtype=torch.int32, device=cu_seqlens_qs.device)]), dim=0, - index=cal_seq_idx_for_q( - cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long()) + index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long(), + ) return cu_seqlen_ks_for_each_q.int() @tensor_cache -def cal_cu_seqlen_ke_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, - cu_seqlens_ks: torch.LongTensor, cu_seqlens_ke: torch.LongTensor, - q_start_idxs: torch.LongTensor, seq_len: int, - kv_stride: int) -> torch.IntTensor: +def cal_cu_seqlen_ke_for_q( + cu_seqlens_qs: torch.LongTensor, + cu_seqlens_qe: torch.LongTensor, + cu_seqlens_ks: torch.LongTensor, + cu_seqlens_ke: torch.LongTensor, + q_start_idxs: torch.LongTensor, + seq_len: int, + kv_stride: int, +) -> torch.IntTensor: cu_seqlen_ke_for_each_q = torch.gather( - input=torch.cat( - [cu_seqlens_ke, - torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device)]), + input=torch.cat([cu_seqlens_ke, torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device)]), dim=0, - index=cal_seq_idx_for_q( - cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long()) - casual_cu_seqlen_ke_for_each_q = torch.zeros((seq_len,), - dtype=torch.int32, - device=cu_seqlens_qs.device) + index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long(), + ) + casual_cu_seqlen_ke_for_each_q = torch.zeros((seq_len,), dtype=torch.int32, device=cu_seqlens_qs.device) for i in range(len(cu_seqlens_qs)): - casual_cu_seqlen_ke_for_each_q[cu_seqlens_qs[i]:cu_seqlens_qe[i]] = (torch.arange( - q_start_idxs[i], - q_start_idxs[i] + cu_seqlens_qe[i] - cu_seqlens_qs[i], - dtype=torch.int32, - device=cu_seqlens_qs.device) + 1) // kv_stride + cu_seqlens_ks[i] + casual_cu_seqlen_ke_for_each_q[cu_seqlens_qs[i] : cu_seqlens_qe[i]] = ( + torch.arange( + q_start_idxs[i], q_start_idxs[i] + cu_seqlens_qe[i] - cu_seqlens_qs[i], dtype=torch.int32, device=cu_seqlens_qs.device + ) + + 1 + ) // kv_stride + cu_seqlens_ks[i] cu_seqlen_ke_for_each_q = torch.minimum(casual_cu_seqlen_ke_for_each_q, cu_seqlen_ke_for_each_q) return cu_seqlen_ke_for_each_q.int() @tensor_cache -def cal_ks_ke_from_cu_seqlen_qk(cu_seqlens_q: torch.LongTensor, - cu_seqlens_k: torch.LongTensor = None, - offs_q: torch.LongTensor = None, - *, - seq_len: int, - kv_stride: int = 1, - cp_rank: int = 0, - cp_size: int = 1, - balanced_cp=False): - ''' +def cal_ks_ke_from_cu_seqlen_qk( + cu_seqlens_q: torch.LongTensor, + cu_seqlens_k: torch.LongTensor = None, + offs_q: torch.LongTensor = None, + *, + seq_len: int, + kv_stride: int = 1, + cp_rank: int = 0, + cp_size: int = 1, + balanced_cp=False, +): + """ seq_len: seq len per cp rank balanced cp slice assignment: 0 1 2 3 3 2 1 0 - ''' + """ n_seq = len(cu_seqlens_q) - 1 assert n_seq > 0 assert cu_seqlens_q.shape == (n_seq + 1,) @@ -170,10 +166,12 @@ def cal_ks_ke_from_cu_seqlen_qk(cu_seqlens_q: torch.LongTensor, def f(x: torch.Tensor): chunks = x.chunk(cp_size * 2) - return torch.cat([ - chunks[cp_rank], - chunks[cp_size - cp_rank - 1], - ]) + return torch.cat( + [ + chunks[cp_rank], + chunks[cp_size - cp_rank - 1], + ] + ) ks = f(ks) ke = f(ke) @@ -189,8 +187,7 @@ def ceil_to_ue8m0(x: torch.Tensor): return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) -def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int], - use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: +def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int], use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) sf = x_amax / 448.0 @@ -239,14 +236,18 @@ def generate_random_cu_seqlens(per_cp_seqlen, cp_size=4, cp_rank=3, kv_stride=1, total_seqlen - (cp_rank + 1) * per_chunk_seqlen, total_seqlen - cp_rank * per_chunk_seqlen, ) - ks = torch.cat([ - cu_seqlens_ks_for_each_q[slice_short], - cu_seqlens_ks_for_each_q[slice_long], - ]) - ke = torch.cat([ - cu_seqlens_ke_for_each_q[slice_short], - cu_seqlens_ke_for_each_q[slice_long], - ]) + ks = torch.cat( + [ + cu_seqlens_ks_for_each_q[slice_short], + cu_seqlens_ks_for_each_q[slice_long], + ] + ) + ke = torch.cat( + [ + cu_seqlens_ke_for_each_q[slice_short], + cu_seqlens_ke_for_each_q[slice_long], + ] + ) assert len(ks) == len(ke) == per_cp_seqlen return ks, ke @@ -302,11 +303,9 @@ def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True): raise_assert: Whether to raise assertion error on failure """ sim = calculate_tensor_similarity(x, y, name) - diff = 1. - sim + diff = 1.0 - sim if not (0 <= diff <= eps): - print( - f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m" - ) + print(f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m") if raise_assert: assert False # noqa: B011 @@ -316,11 +315,8 @@ if __name__ == "__main__": cu_seqlens = torch.randint(128, 4096, (1000,), dtype=torch.int32, device="cuda") last_idx = torch.where(cu_seqlens.cumsum(dim=0) >= seq_len)[0][0] cu_seqlens_cumsum = cu_seqlens[:last_idx].cumsum(dim=0) - cu_seqlens_qs = torch.cat( - [torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device), cu_seqlens_cumsum]) - cu_seqlens_qe = torch.cat( - [cu_seqlens_cumsum, - torch.ones(1, dtype=torch.int32, device=cu_seqlens.device) * seq_len]) + cu_seqlens_qs = torch.cat([torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device), cu_seqlens_cumsum]) + cu_seqlens_qe = torch.cat([cu_seqlens_cumsum, torch.ones(1, dtype=torch.int32, device=cu_seqlens.device) * seq_len]) from tilelang.profiler import do_bench diff --git a/examples/dequantize_gemm/dequantize_utils.py b/examples/dequantize_gemm/dequantize_utils.py index b14c0aee687636e7bb8e85b3ffbdaaeec191bb58..90a6265ffa4bf22c3d583e74e53066161c80a37a 100644 --- a/examples/dequantize_gemm/dequantize_utils.py +++ b/examples/dequantize_gemm/dequantize_utils.py @@ -39,12 +39,10 @@ def torch_convert_bit_twiddling(tensor): res0 = val_concat_expanded & mask res1 = (val_concat_expanded << 3) & mask res2 = (val_concat_expanded << 6) & mask - res3 = ((val_concat_expanded << 1) & mask1) | ((val_concat_expanded >> 3) & mask2) | ( - (val_concat_expanded >> 7) & mask3) + res3 = ((val_concat_expanded << 1) & mask1) | ((val_concat_expanded >> 3) & mask2) | ((val_concat_expanded >> 7) & mask3) # Select the correct result based on position - bf16 = torch.where(pos == 0, res0, torch.where(pos == 1, res1, - torch.where(pos == 2, res2, res3))) + bf16 = torch.where(pos == 0, res0, torch.where(pos == 1, res1, torch.where(pos == 2, res2, res3))) # Convert to uint16 for .view(torch.bfloat16) bf16_uint16 = (bf16 & 0xFFFF).to(torch.uint16) @@ -110,7 +108,7 @@ def print_bit(name, val): val (torch.Tensor): A scalar PyTorch tensor (numeric) whose 32-bit binary representation will be shown. """ val_cpu = val.cpu().item() - binary_repr = f'{val_cpu:032b}' + binary_repr = f"{val_cpu:032b}" print(name, binary_repr) @@ -122,7 +120,7 @@ def calc_sim(x, y, name="tensor"): x, y = x.data.double(), y.data.double() denominator = (x * x + y * y).sum() if denominator == 0: - print_red_warning(f'{name} all zero') + print_red_warning(f"{name} all zero") return 1 sim = 2 * (x * y).sum() / denominator return sim @@ -132,21 +130,19 @@ def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True): x_mask = torch.isfinite(x) y_mask = torch.isfinite(y) if not torch.all(x_mask == y_mask): - print_red_warning(f'{name} Error: isfinite mask mismatch') + print_red_warning(f"{name} Error: isfinite mask mismatch") if raise_assert: raise AssertionError - if not torch.isclose( - x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, - equal_nan=True).all(): - print_red_warning(f'{name} Error: nonfinite value mismatch') + if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all(): + print_red_warning(f"{name} Error: nonfinite value mismatch") if raise_assert: raise AssertionError x = x.masked_fill(~x_mask, 0) y = y.masked_fill(~y_mask, 0) sim = calc_sim(x, y, name) - diff = (1. - sim).item() - print(f'{diff=}') + diff = (1.0 - sim).item() + print(f"{diff=}") if not (0 <= diff <= eps): - print_red_warning(f'{name} Error: {diff=}') + print_red_warning(f"{name} Error: {diff=}") if raise_assert: raise AssertionError diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py index e30845b8d777194e91a4cf3af8bbe0939a7dc56d..ba3e0b4a7773aa0225c5787128c0ce185d119c96 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py @@ -24,6 +24,7 @@ def get_configs(): the parameter name to its chosen value. """ import itertools + iter_params = dict( block_M=[64, 128, 256], block_N=[64, 128, 256], @@ -32,63 +33,62 @@ def get_configs(): threads=[128, 256, 512], split=[1, 2], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] -@tilelang.autotune(configs=get_configs(),) +@tilelang.autotune( + configs=get_configs(), +) @tilelang.jit( out_idx=[-1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) -def matmul(M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - source_format='uint', - num_bits=4, - fast_dequant=True, - block_M=256, - block_N=128, - block_K=128, - num_stages=2, - threads=256, - split=1): +def matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + source_format="uint", + num_bits=4, + fast_dequant=True, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, +): + """ + Builds a parameterized TileLang/TIR matrix-multiplication kernel that dequantizes 4-bit FP inputs to BF16 on-the-fly and computes C = A @ B^T. + + This function returns a tiled, autotunable prim_func implementing a block-wise GEMM with shared-memory buffering and a pipelined K-loop. The kernel accepts: + - A: dense input of shape (M, K) with dtype `in_dtype`. + - B: packed quantized input of shape (N, QK) where QK = K / (8 / num_bits) stored as `uint8`. + - C: output of shape (M, N) with dtype `out_dtype`. + + The generated kernel supports two dequantization paths: + - fast_dequant (fast_dequant=True): calls an external mxfp dequantization intrinsic (twiddling-based) loaded from a C source returned by get_mxfp_intrin_group. + - simple dequant (fast_dequant=False): performs a pure-TIR FP4 -> BF16 conversion per element. + + Important behavior and requirements: + - num_bits (default 4) is the bit-width of the quantized elements; storage_dtype is uint8 and num_elems_per_byte = 8 // num_bits. + - QK = K // num_elems_per_byte and Block_QK = block_K // num_elems_per_byte determine B and shared-buffer shapes. + - Asserts that K % (block_K * split) == 0; K must be divisible by block_K * split for the tiling to be valid. + - When fast_dequant is True, a valid mxfp intrinsic group (C source and function name) must be available via tilelang.quantize.get_mxfp_intrin_group. + - The kernel launches a 2D grid over ceildiv(N, block_N) and ceildiv(M, block_M) and uses `threads` threads per block with `num_stages` pipeline stages. + + Parameters that alter kernel layout/behavior (brief): + - block_M, block_N, block_K: tile sizes for M, N, and K dimensions. + - num_stages: number of software pipeline stages for the K-loop. + - threads: number of threads used per kernel block. + - split: extra K-splitting factor; K must be divisible by block_K * split. + - source_format, num_bits: describe the quantized data layout passed to the mxfp intrinsics. + + Returns: + A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel. """ - Builds a parameterized TileLang/TIR matrix-multiplication kernel that dequantizes 4-bit FP inputs to BF16 on-the-fly and computes C = A @ B^T. - - This function returns a tiled, autotunable prim_func implementing a block-wise GEMM with shared-memory buffering and a pipelined K-loop. The kernel accepts: - - A: dense input of shape (M, K) with dtype `in_dtype`. - - B: packed quantized input of shape (N, QK) where QK = K / (8 / num_bits) stored as `uint8`. - - C: output of shape (M, N) with dtype `out_dtype`. - - The generated kernel supports two dequantization paths: - - fast_dequant (fast_dequant=True): calls an external mxfp dequantization intrinsic (twiddling-based) loaded from a C source returned by get_mxfp_intrin_group. - - simple dequant (fast_dequant=False): performs a pure-TIR FP4 -> BF16 conversion per element. - - Important behavior and requirements: - - num_bits (default 4) is the bit-width of the quantized elements; storage_dtype is uint8 and num_elems_per_byte = 8 // num_bits. - - QK = K // num_elems_per_byte and Block_QK = block_K // num_elems_per_byte determine B and shared-buffer shapes. - - Asserts that K % (block_K * split) == 0; K must be divisible by block_K * split for the tiling to be valid. - - When fast_dequant is True, a valid mxfp intrinsic group (C source and function name) must be available via tilelang.quantize.get_mxfp_intrin_group. - - The kernel launches a 2D grid over ceildiv(N, block_N) and ceildiv(M, block_M) and uses `threads` threads per block with `num_stages` pipeline stages. - - Parameters that alter kernel layout/behavior (brief): - - block_M, block_N, block_K: tile sizes for M, N, and K dimensions. - - num_stages: number of software pipeline stages for the K-loop. - - threads: number of threads used per kernel block. - - split: extra K-splitting factor; K must be divisible by block_K * split. - - source_format, num_bits: describe the quantized data layout passed to the mxfp intrinsics. - - Returns: - A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel. - """ num_elems_per_byte = 8 // num_bits storage_dtype = "uint8" @@ -189,8 +189,7 @@ def matmul(M, # Finally, store the dequantized data to shared memory. for v in T.vectorized(0, local_size): index = i * threads * local_size + tx * local_size + v - B_dequantize_shared[index // block_K, - index % block_K] = B_dequantize_local_thread[v] + B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v] return fast_dequant_bf16_fp4_twiddling @@ -215,30 +214,29 @@ def matmul(M, assert in_dtype in ["fp4"] assert out_dtype in ["bfloat16"] - def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, - scale: tir.PrimExpr, dtype: str): + def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str): """ - Convert a 4-bit FP4 value packed in a uint8 byte into a bfloat16 value. - - This helper extracts the 4-bit field located at the bit position `pos` within the - byte `val`, interprets it as an FP4 (sign, exponent, mantissa) value, applies an - exponent `scale` offset to align it with bfloat16 exponent bias, clamps the - resulting exponent to 8 bits, and returns the assembled bfloat16 bit pattern. - - Parameters: - nbit (int): Number of bits in the packed element; must be 4. - val (tir.PrimExpr): A uint8 value containing packed FP4 elements. - pos (tir.PrimExpr): Index (0-based) of which FP4 nibble inside `val` to extract. - scale (tir.PrimExpr): Exponent offset applied when converting FP4 exponent to bfloat16. - dtype (str): Target dtype string; must be "bfloat16". - - Returns: - tir.PrimExpr: A bfloat16-typed PrimExpr containing the converted value. - - Notes: - - The function asserts `nbit == 4`, `dtype == "bfloat16"`, and that `val.dtype` is "uint8". - - The conversion uses a fixed mapping from FP4 exponent/mantissa layout into bfloat16 - bit fields and clamps the computed exponent to fit into 8 bits. + Convert a 4-bit FP4 value packed in a uint8 byte into a bfloat16 value. + + This helper extracts the 4-bit field located at the bit position `pos` within the + byte `val`, interprets it as an FP4 (sign, exponent, mantissa) value, applies an + exponent `scale` offset to align it with bfloat16 exponent bias, clamps the + resulting exponent to 8 bits, and returns the assembled bfloat16 bit pattern. + + Parameters: + nbit (int): Number of bits in the packed element; must be 4. + val (tir.PrimExpr): A uint8 value containing packed FP4 elements. + pos (tir.PrimExpr): Index (0-based) of which FP4 nibble inside `val` to extract. + scale (tir.PrimExpr): Exponent offset applied when converting FP4 exponent to bfloat16. + dtype (str): Target dtype string; must be "bfloat16". + + Returns: + tir.PrimExpr: A bfloat16-typed PrimExpr containing the converted value. + + Notes: + - The function asserts `nbit == 4`, `dtype == "bfloat16"`, and that `val.dtype` is "uint8". + - The conversion uses a fixed mapping from FP4 exponent/mantissa layout into bfloat16 + bit fields and clamps the computed exponent to fit into 8 bits. """ assert nbit == 4 assert dtype == "bfloat16" @@ -254,8 +252,9 @@ def matmul(M, e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) m_f4 = f4 & tir.const(1, "uint16") val_bf16 = tir.reinterpret( - "bfloat16", ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) - | (m_f4 << tir.const(6, "uint16"))).astype("uint16")) + "bfloat16", + ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | (m_f4 << tir.const(6, "uint16"))).astype("uint16"), + ) return val_bf16 @T.macro @@ -292,32 +291,32 @@ def matmul(M, @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((M, N), out_dtype), ): """ - Kernel entry for the tiled, pipelined matmul used by the generated prim_func. - - This function implements a block-wise GEMM over a 2D grid (grid dims: ceildiv(N, block_N) x ceildiv(M, block_M)) with a thread block of `threads`. For each output block it: - - Allocates shared buffers for A, the packed/quantized B, and a dequantized B tile. - - Allocates a fragment accumulator (C_local) and a shared output tile (C_shared) with a swizzled layout. - - Pipelines over K in chunks of `block_K` for `num_stages` stages: - - Loads A and packed B tiles into shared memory. - - Dequantizes B into B_dequantize_shared using either the fast (twiddling/external) or the simple (pure-TIR) dequantization routine. - - Performs a GEMM accumulating into C_local with B transposed. - - Stores the accumulated block from C_local back to the global output C via C_shared. - - Parameters: - - A: input tile of shape (M, K) with dtype `in_dtype`. - - B: packed/quantized input of shape (N, QK) with storage dtype `storage_dtype` (quantized FP4 packing). - - C: output tensor of shape (M, N) with dtype `out_dtype`. - - Side effects: - - Writes the computed output block into the global tensor `C`. - - Uses and updates shared memory buffers and per-thread accumulators. - - No value is returned. + Kernel entry for the tiled, pipelined matmul used by the generated prim_func. + + This function implements a block-wise GEMM over a 2D grid (grid dims: ceildiv(N, block_N) x ceildiv(M, block_M)) with a thread block of `threads`. For each output block it: + - Allocates shared buffers for A, the packed/quantized B, and a dequantized B tile. + - Allocates a fragment accumulator (C_local) and a shared output tile (C_shared) with a swizzled layout. + - Pipelines over K in chunks of `block_K` for `num_stages` stages: + - Loads A and packed B tiles into shared memory. + - Dequantizes B into B_dequantize_shared using either the fast (twiddling/external) or the simple (pure-TIR) dequantization routine. + - Performs a GEMM accumulating into C_local with B transposed. + - Stores the accumulated block from C_local back to the global output C via C_shared. + + Parameters: + - A: input tile of shape (M, K) with dtype `in_dtype`. + - B: packed/quantized input of shape (N, QK) with storage dtype `storage_dtype` (quantized FP4 packing). + - C: output tensor of shape (M, N) with dtype `out_dtype`. + + Side effects: + - Writes the computed output block into the global tensor `C`. + - Uses and updates shared memory buffers and per-thread accumulators. + + No value is returned. """ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -327,9 +326,11 @@ def matmul(M, C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype) - T.annotate_layout({ - C_shared: tilelang.layout.make_swizzled_layout(C_shared), - }) + T.annotate_layout( + { + C_shared: tilelang.layout.make_swizzled_layout(C_shared), + } + ) T.clear(C_local) for k in T.Pipelined(K // block_K, num_stages=num_stages): @@ -344,7 +345,7 @@ def matmul(M, T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N]) + T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N]) return main @@ -409,8 +410,7 @@ def main(m=256, n=256, k=256, fast_dequant=True, tune=False): """ total_flops = 2 * m * n * k if tune: - kernel = matmul( - m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, fast_dequant=fast_dequant) + kernel = matmul(m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, fast_dequant=fast_dequant) else: kernel = matmul( m, @@ -426,7 +426,8 @@ def main(m=256, n=256, k=256, fast_dequant=True, tune=False): block_K=128, num_stages=2, threads=256, - split=1) + split=1, + ) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) if fast_dequant: profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01) diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py index ac1417aebc87c3842eec2600d04a4790677ed352..1091306c60cd3b4def0bc983812e04bfc10a3e70 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py @@ -7,29 +7,28 @@ import torch from dequantize_utils import torch_convert_bit_twiddling, torch_convert -def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, - dtype: str): +def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str): """ - Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale. + Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale. - This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its - bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by - `scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation. + This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its + bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by + `scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation. - Parameters: - nbit (int): Number of bits in the packed field (must be 4). - val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. - pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). - scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). - dtype (str): Destination dtype string (must be "bfloat16"). + Parameters: + nbit (int): Number of bits in the packed field (must be 4). + val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. + pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). + scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). + dtype (str): Destination dtype string (must be "bfloat16"). - Returns: - tir.PrimExpr: The resulting value reinterpreted as `bfloat16`. + Returns: + tir.PrimExpr: The resulting value reinterpreted as `bfloat16`. - Notes: - - Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8". - - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. - """ + Notes: + - Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8". + - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. + """ assert nbit == 4 assert dtype == "bfloat16" assert val.dtype == "uint8" @@ -43,9 +42,10 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale # To handle the overflow, we may use the min function to limit the exponential part to 8 bits # e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) m_f4 = f4 & tir.const(1, "uint16") - val_bf16 = tir.reinterpret("bfloat16", - ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) - | (m_f4 << tir.const(6, "uint16"))).astype("uint16")) + val_bf16 = tir.reinterpret( + "bfloat16", + ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | (m_f4 << tir.const(6, "uint16"))).astype("uint16"), + ) return val_bf16 @@ -65,6 +65,7 @@ def get_configs(): List[dict]: A list of configuration dictionaries covering all combinations. """ import itertools + iter_params = dict( block_M=[64, 128, 256], block_N=[64, 128, 256], @@ -73,67 +74,71 @@ def get_configs(): threads=[128, 256, 512], split=[1, 2], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] - - -@tilelang.autotune(configs=get_configs(),) -@tilelang.jit(out_idx=[-1],) -def matmul(M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - source_format='uint', - num_bits=4, - scale_size=32, - fast_dequant=True, - with_bias=False, - block_M=256, - block_N=128, - block_K=128, - num_stages=2, - threads=256, - split=1): + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] + + +@tilelang.autotune( + configs=get_configs(), +) +@tilelang.jit( + out_idx=[-1], +) +def matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + source_format="uint", + num_bits=4, + scale_size=32, + fast_dequant=True, + with_bias=False, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, +): """ - Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype. - - The generated kernel accepts: - - A: dense matrix with element type `in_dtype`. - - B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)). - - Scale: per-block scale/exponent information used to dequantize B. - The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths: - - fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization. - - fast_dequant (False): uses a simple elementwise dequantization helper. - - Parameters: - M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split). - in_dtype (str): element type of A (e.g., "fp4" in this file). - out_dtype (str): output tensor element type (e.g., "bfloat16"). - accum_dtype (str): accumulation type used for the inner GEMM. - source_format (str, optional): format string passed to intrinsic selector (default "uint"). - num_bits (int, optional): number of bits per quantized element in B (default 4). - scale_size (int, optional): number of elements grouped per scale entry (default 32). - fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True). - block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128). - num_stages (int, optional): pipelining stages for K loop (default 2). - threads (int, optional): threads per block used by the kernel (default 256). - split (int, optional): split factor along K used by the scheduler (default 1). - with_bias (bool, optional): whether to add Bias to the output (default False). - - Returns: - A T.prim_func implementing the tiled, pipelined GEMM that: - - loads tiled blocks of A and packed B to shared memory, - - dequantizes B via the chosen path into a shared dequantized tile, - - performs a tiled GEMM accumulating into local fragments, - - writes the final MxN block to the global output tensor. + Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype. - Notes: - - The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name. - - The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile. - - An assertion enforces that K % (block_K * split) == 0. + The generated kernel accepts: + - A: dense matrix with element type `in_dtype`. + - B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)). + - Scale: per-block scale/exponent information used to dequantize B. + The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths: + - fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization. + - fast_dequant (False): uses a simple elementwise dequantization helper. + + Parameters: + M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split). + in_dtype (str): element type of A (e.g., "fp4" in this file). + out_dtype (str): output tensor element type (e.g., "bfloat16"). + accum_dtype (str): accumulation type used for the inner GEMM. + source_format (str, optional): format string passed to intrinsic selector (default "uint"). + num_bits (int, optional): number of bits per quantized element in B (default 4). + scale_size (int, optional): number of elements grouped per scale entry (default 32). + fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True). + block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128). + num_stages (int, optional): pipelining stages for K loop (default 2). + threads (int, optional): threads per block used by the kernel (default 256). + split (int, optional): split factor along K used by the scheduler (default 1). + with_bias (bool, optional): whether to add Bias to the output (default False). + + Returns: + A T.prim_func implementing the tiled, pipelined GEMM that: + - loads tiled blocks of A and packed B to shared memory, + - dequantizes B via the chosen path into a shared dequantized tile, + - performs a tiled GEMM accumulating into local fragments, + - writes the final MxN block to the global output tensor. + + Notes: + - The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name. + - The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile. + - An assertion enforces that K % (block_K * split) == 0. """ num_elems_per_byte = 8 // num_bits storage_dtype = "uint8" @@ -150,6 +155,7 @@ def matmul(M, assert K % (block_K * split) == 0 from tilelang.quantize import get_mxfp_intrin_group + # fast_dequant_bf16_fp4_twiddling mxfp_intrin_info = get_mxfp_intrin_group( out_dtype=in_dtype, @@ -252,8 +258,7 @@ def matmul(M, for v in T.vectorized(0, local_size): index = i * threads * local_size + tx * local_size + v - B_dequantize_shared[index // block_K, - index % block_K] = B_dequantize_local_thread[v] + B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v] return fast_dequant_bf16_fp4_twiddling @@ -301,33 +306,32 @@ def matmul(M, B_local[i, j // num_elems_per_byte], j % num_elems_per_byte, Scale[ - bx * block_N + i, k * block_K // scale_size + j // - scale_size], # Scale is the exponential part, within the representation of uint8 + bx * block_N + i, k * block_K // scale_size + j // scale_size + ], # Scale is the exponential part, within the representation of uint8 dtype=out_dtype, - ) * T.shift_left( - 1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size])) + ) * T.shift_left(1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size])) T.copy(B_dequantize_local, B_dequantize_shared) return simple_dequant_bf16_fp4 @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - Scale: T.Tensor(Scale_shape, storage_dtype), - Bias: T.Tensor(Bias_shape, out_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Scale: T.Tensor(Scale_shape, storage_dtype), + Bias: T.Tensor(Bias_shape, out_dtype), + C: T.Tensor((M, N), out_dtype), ): """ - Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C. + Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C. - This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function. + This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function. - Parameters are self-descriptive in the signature; notable behaviors: - - B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM. - - The selected dequantization path is controlled by the outer-scope flag `fast_dequant`. - - The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization). - - The function writes results in-place into C. + Parameters are self-descriptive in the signature; notable behaviors: + - B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM. + - The selected dequantization path is controlled by the outer-scope flag `fast_dequant`. + - The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization). + - The function writes results in-place into C. """ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -337,23 +341,26 @@ def matmul(M, C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype) - T.annotate_layout({ - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - C_shared: tilelang.layout.make_swizzled_layout(C_shared), - }) + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + C_shared: tilelang.layout.make_swizzled_layout(C_shared), + } + ) if with_bias: - T.annotate_layout({ - Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared), - }) + T.annotate_layout( + { + Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared), + } + ) if threads == 512: T.disable_warp_group_reg_alloc() if with_bias: - T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N], - Bias_shared) + T.copy(Bias[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N], Bias_shared) T.copy(Bias_shared, C_local) else: T.clear(C_local) @@ -368,7 +375,7 @@ def matmul(M, T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N]) + T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N]) return main @@ -389,7 +396,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None): """ dtypeC = "bfloat16" B = torch_convert_bit_twiddling(qB) - B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) + B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) return C @@ -412,7 +419,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias): """ dtypeC = "bfloat16" B = torch_convert_bit_twiddling(qB) - B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) + B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias C = C.to(torch.__getattribute__(dtypeC)) return C @@ -436,7 +443,7 @@ def ref_program_simple(A, qB, Scale, Bias=None): """ dtypeC = "bfloat16" B = torch_convert(qB) - B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) + B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) return C @@ -464,7 +471,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias): """ dtypeC = "bfloat16" B = torch_convert(qB) - B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) + B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias C = C.to(torch.__getattribute__(dtypeC)) return C @@ -491,16 +498,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, if tune: kernel = matmul( - m, - n, - k, - "bfloat16", - "bfloat16", - "float32", - num_bits=4, - scale_size=scale_size, - fast_dequant=fast_dequant, - with_bias=with_bias) + m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias + ) else: kernel = matmul( m, @@ -518,7 +517,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, threads=256, split=1, fast_dequant=fast_dequant, - with_bias=with_bias) + with_bias=with_bias, + ) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py index 7dad795971f7a9707bd4a39c339b98d1ae0a15ac..12395df0ac9ecd9169ab0b12d5c3e0b2461334bf 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py @@ -7,29 +7,28 @@ import torch from dequantize_utils import torch_convert_bit_twiddling, torch_convert -def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, - dtype: str): +def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str): """ - Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale. + Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale. - This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its - bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by - `scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation. + This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its + bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by + `scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation. - Parameters: - nbit (int): Number of bits in the packed field (must be 4). - val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. - pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). - scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). - dtype (str): Destination dtype string (must be "bfloat16"). + Parameters: + nbit (int): Number of bits in the packed field (must be 4). + val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. + pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). + scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). + dtype (str): Destination dtype string (must be "bfloat16"). - Returns: - tir.PrimExpr: The resulting value reinterpreted as `bfloat16`. + Returns: + tir.PrimExpr: The resulting value reinterpreted as `bfloat16`. - Notes: - - Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8". - - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. - """ + Notes: + - Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8". + - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. + """ assert nbit == 4 assert dtype == "bfloat16" assert val.dtype == "uint8" @@ -43,9 +42,10 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale # To handle the overflow, we may use the min function to limit the exponential part to 8 bits # e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) m_f4 = f4 & tir.const(1, "uint16") - val_bf16 = tir.reinterpret("bfloat16", - ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) - | (m_f4 << tir.const(6, "uint16"))).astype("uint16")) + val_bf16 = tir.reinterpret( + "bfloat16", + ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | (m_f4 << tir.const(6, "uint16"))).astype("uint16"), + ) return val_bf16 @@ -65,6 +65,7 @@ def get_configs(): List[dict]: A list of configuration dictionaries covering all combinations. """ import itertools + iter_params = dict( block_M=[64, 128, 256], block_N=[64, 128, 256], @@ -73,67 +74,71 @@ def get_configs(): threads=[128, 256, 512], split=[1, 2], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] - - -@tilelang.autotune(configs=get_configs(),) -@tilelang.jit(out_idx=[-1],) -def matmul(M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - source_format='uint', - num_bits=4, - scale_size=32, - fast_dequant=True, - with_bias=False, - block_M=256, - block_N=128, - block_K=128, - num_stages=2, - threads=256, - split=1): + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] + + +@tilelang.autotune( + configs=get_configs(), +) +@tilelang.jit( + out_idx=[-1], +) +def matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + source_format="uint", + num_bits=4, + scale_size=32, + fast_dequant=True, + with_bias=False, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, +): """ - Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype. - - The generated kernel accepts: - - A: dense matrix with element type `in_dtype`. - - B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)). - - Scale: per-block scale/exponent information used to dequantize B. - The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths: - - fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization. - - fast_dequant (False): uses a simple elementwise dequantization helper. - - Parameters: - M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split). - in_dtype (str): element type of A (e.g., "fp4" in this file). - out_dtype (str): output tensor element type (e.g., "bfloat16"). - accum_dtype (str): accumulation type used for the inner GEMM. - source_format (str, optional): format string passed to intrinsic selector (default "uint"). - num_bits (int, optional): number of bits per quantized element in B (default 4). - scale_size (int, optional): number of elements grouped per scale entry (default 32). - fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True). - block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128). - num_stages (int, optional): pipelining stages for K loop (default 2). - threads (int, optional): threads per block used by the kernel (default 256). - split (int, optional): split factor along K used by the scheduler (default 1). - with_bias (bool, optional): whether to add Bias to the output (default False). - - Returns: - A T.prim_func implementing the tiled, pipelined GEMM that: - - loads tiled blocks of A and packed B to shared memory, - - dequantizes B via the chosen path into a shared dequantized tile, - - performs a tiled GEMM accumulating into local fragments, - - writes the final MxN block to the global output tensor. + Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype. - Notes: - - The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name. - - The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile. - - An assertion enforces that K % (block_K * split) == 0. + The generated kernel accepts: + - A: dense matrix with element type `in_dtype`. + - B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)). + - Scale: per-block scale/exponent information used to dequantize B. + The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths: + - fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization. + - fast_dequant (False): uses a simple elementwise dequantization helper. + + Parameters: + M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split). + in_dtype (str): element type of A (e.g., "fp4" in this file). + out_dtype (str): output tensor element type (e.g., "bfloat16"). + accum_dtype (str): accumulation type used for the inner GEMM. + source_format (str, optional): format string passed to intrinsic selector (default "uint"). + num_bits (int, optional): number of bits per quantized element in B (default 4). + scale_size (int, optional): number of elements grouped per scale entry (default 32). + fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True). + block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128). + num_stages (int, optional): pipelining stages for K loop (default 2). + threads (int, optional): threads per block used by the kernel (default 256). + split (int, optional): split factor along K used by the scheduler (default 1). + with_bias (bool, optional): whether to add Bias to the output (default False). + + Returns: + A T.prim_func implementing the tiled, pipelined GEMM that: + - loads tiled blocks of A and packed B to shared memory, + - dequantizes B via the chosen path into a shared dequantized tile, + - performs a tiled GEMM accumulating into local fragments, + - writes the final MxN block to the global output tensor. + + Notes: + - The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name. + - The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile. + - An assertion enforces that K % (block_K * split) == 0. """ num_elems_per_byte = 8 // num_bits storage_dtype = "uint8" @@ -150,6 +155,7 @@ def matmul(M, assert K % (block_K * split) == 0 from tilelang.quantize import get_mxfp_intrin_group + # fast_dequant_bf16_fp4_twiddling mxfp_intrin_info = get_mxfp_intrin_group( out_dtype=in_dtype, @@ -252,8 +258,7 @@ def matmul(M, for v in T.vectorized(0, local_size): index = i * threads * local_size + tx * local_size + v - B_dequantize_shared[index // block_K, - index % block_K] = B_dequantize_local_thread[v] + B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v] return fast_dequant_bf16_fp4_twiddling @@ -301,8 +306,8 @@ def matmul(M, B_local[i, j // num_elems_per_byte], j % num_elems_per_byte, Scale_shared[ - i, k * block_K // scale_size + j // - scale_size], # Scale is the exponential part, within the representation of uint8 + i, k * block_K // scale_size + j // scale_size + ], # Scale is the exponential part, within the representation of uint8 dtype=out_dtype, ) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size])) T.copy(B_dequantize_local, B_dequantize_shared) @@ -311,22 +316,22 @@ def matmul(M, @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - Scale: T.Tensor(Scale_shape, storage_dtype), - Bias: T.Tensor(Bias_shape, out_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Scale: T.Tensor(Scale_shape, storage_dtype), + Bias: T.Tensor(Bias_shape, out_dtype), + C: T.Tensor((M, N), out_dtype), ): """ - Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C. + Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C. - This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function. + This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function. - Parameters are self-descriptive in the signature; notable behaviors: - - B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM. - - The selected dequantization path is controlled by the outer-scope flag `fast_dequant`. - - The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization). - - The function writes results in-place into C. + Parameters are self-descriptive in the signature; notable behaviors: + - B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM. + - The selected dequantization path is controlled by the outer-scope flag `fast_dequant`. + - The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization). + - The function writes results in-place into C. """ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -339,16 +344,20 @@ def matmul(M, # May use much more shared memory than necessary Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype) - T.annotate_layout({ - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - C_shared: tilelang.layout.make_swizzled_layout(C_shared), - }) + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + C_shared: tilelang.layout.make_swizzled_layout(C_shared), + } + ) if with_bias: - T.annotate_layout({ - Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared), - }) + T.annotate_layout( + { + Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared), + } + ) if threads == 512: T.disable_warp_group_reg_alloc() @@ -357,26 +366,24 @@ def matmul(M, # T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N], # Bias_shared) # T.copy(Bias_shared, C_local) - T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N], - C_local) + T.copy(Bias[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N], C_local) else: T.clear(C_local) # Use 1D TMA to load Scale - T.copy(Scale[bx * block_N:(bx + 1) * block_N, :], Scale_shared) + T.copy(Scale[bx * block_N : (bx + 1) * block_N, :], Scale_shared) for k in T.Pipelined(K // block_K, num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_shared) T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) if fast_dequant: - get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, - k) + get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, k) else: get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k) T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N]) + T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N]) return main @@ -399,7 +406,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None): B = torch_convert_bit_twiddling(qB) for i in range(B.shape[0]): for j in range(B.shape[1]): - B[i][j] = B[i][j] * (2**(Scale[i][j // 32])) + B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32])) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) return C @@ -424,7 +431,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias): B = torch_convert_bit_twiddling(qB) for i in range(B.shape[0]): for j in range(B.shape[1]): - B[i][j] = B[i][j] * (2**(Scale[i][j // 32])) + B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32])) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias C = C.to(torch.__getattribute__(dtypeC)) return C @@ -450,7 +457,7 @@ def ref_program_simple(A, qB, Scale, Bias=None): B = torch_convert(qB) for i in range(B.shape[0]): for j in range(B.shape[1]): - B[i][j] = B[i][j] * (2**(Scale[i][j // 32])) + B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32])) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) return C @@ -480,7 +487,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias): B = torch_convert(qB) for i in range(B.shape[0]): for j in range(B.shape[1]): - B[i][j] = B[i][j] * (2**(Scale[i][j // 32])) + B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32])) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias C = C.to(torch.__getattribute__(dtypeC)) return C @@ -507,16 +514,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, if tune: kernel = matmul( - m, - n, - k, - "bfloat16", - "bfloat16", - "float32", - num_bits=4, - scale_size=scale_size, - fast_dequant=fast_dequant, - with_bias=with_bias) + m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias + ) else: kernel = matmul( m, @@ -534,7 +533,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, threads=256, split=1, fast_dequant=fast_dequant, - with_bias=with_bias) + with_bias=with_bias, + ) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) diff --git a/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py b/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py index 727d6d3b6ff5790d02bd6afe7832fd69aa76f8b3..c2b972a09350a8607876a46c9a20b1f5a54bd76e 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py +++ b/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py @@ -24,6 +24,7 @@ def matmul( num_bits=4, ): from tilelang.quantize import _tir_packed_to_unsigned_convert + num_elems_per_byte = 8 // num_bits storage_dtype = "int8" storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) @@ -39,9 +40,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -58,21 +59,19 @@ def matmul( T.copy(A[by * block_M, k * block_K], A_shared) T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) - for i in T.serial(block_N * block_K // num_elems_per_byte // - (threads * local_size_compressed)): + for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * local_size_compressed)): for v in T.vectorized(0, local_size_compressed): index = i * threads * local_size_compressed + tx * local_size_compressed + v vi = index // (block_K // num_elems_per_byte) vj = index % (block_K // num_elems_per_byte) B_local[v] = B_shared[vi, vj] for v in T.serial(0, local_size): - B_dequantize_local[v] = _tir_packed_to_unsigned_convert( - storage_type, storage_nbit)( - num_bits, - B_local[v // num_elems_per_byte], - v % num_elems_per_byte, - dtype=in_dtype, - ) + B_dequantize_local[v] = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + num_bits, + B_local[v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) for v in T.vectorized(0, local_size): index = i * threads * local_size + tx * local_size + v vi = index // block_K @@ -121,9 +120,7 @@ def run_gemm( def ref_program(A, qB): import torch - B = ( - torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, - dtype=torch.half).to(torch.half).to(A.device)) + B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device) for i in range(B.shape[0]): for j in range(B.shape[1]): B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) @@ -146,9 +143,11 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ): from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout as make_swizzle_layout from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitterWithLadderTransform,) + TensorCoreIntrinEmitterWithLadderTransform, + ) from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 + assert in_dtype in [ "float16", "int8", @@ -192,8 +191,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( pad_factor = 8 A_shape = (M, K) - B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, - micro_size_k // num_elems_per_byte) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k // num_elems_per_byte) A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) B_shared_shape = ( block_N // micro_size_y, @@ -228,7 +226,8 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( chunk=chunk, reduce_k=reduce_k, transform_kind_b=transform_b, - num_elems_per_byte=num_elems_per_byte) + num_elems_per_byte=num_elems_per_byte, + ) vec_load_qb = 16 if block_N * (block_K // reduce_k) // num_elems_per_byte // threads < vec_load_qb: @@ -236,14 +235,11 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((M, N), out_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads, - prelude=decode_i4_to_f16) as (bx, by): - + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads, prelude=decode_i4_to_f16) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -255,40 +251,36 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( thread_binding = T.get_thread_binding(0) rk = T.get_thread_binding(1) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + } + ) T.use_swizzle(panel_size=10) T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, (block_K // reduce_k)): vk = rk * (block_K // reduce_k) + k A_shared[i, vk] = A[by * block_M + i, ko * block_K + vk] # TODO(lei): Layout Inference Pass is not efficient to handle the four dims int8 load - for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte // - (threads * vec_load_qb)): + for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte // (threads * vec_load_qb)): for v in T.vectorized(0, vec_load_qb): t = thread_binding idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v vkk = idx % (micro_size_k // num_elems_per_byte) vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y - vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % ( - block_K // micro_size_k) - vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y // - (block_K // micro_size_k)) % ( - block_N // micro_size_y) - B_shared[vj, vk, vjj, - vkk] = B[bx * (block_N // micro_size_y) + vj, - ko * (block_K // micro_size_k) + vk, vjj, vkk] + vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % (block_K // micro_size_k) + vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y // (block_K // micro_size_k)) % ( + block_N // micro_size_y + ) + B_shared[vj, vk, vjj, vkk] = B[bx * (block_N // micro_size_y) + vj, ko * (block_K // micro_size_k) + vk, vjj, vkk] for ki in T.serial(0, (block_K // (micro_size_k * reduce_k))): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, @@ -307,9 +299,13 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( for j in T.serial(warp_cols): local_size_b = mma_emitter.local_size_b - T.call_extern('handle', 'decode_i4u_to_f16', - T.address_of(B_local[j * local_size_b // num_elems_per_byte]), - T.address_of(B_dequantize_local[j * local_size_b]), 8) + T.call_extern( + "handle", + "decode_i4u_to_f16", + T.address_of(B_local[j * local_size_b // num_elems_per_byte]), + T.address_of(B_dequantize_local[j * local_size_b]), + 8, + ) mma_emitter.mma(A_local, B_dequantize_local, C_local) @@ -328,7 +324,8 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( reduced_accum_res[0], rk, dtype="handle", - )) + ) + ) if rk == 0: C_local[n] = reduced_accum_res[0] @@ -340,9 +337,9 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( for i, j in T.Parallel(block_M, (block_N // reduce_k)): vj = rk * (block_N // reduce_k) + j - C[by * block_M + i, - bx * block_N + vj] = C_shared[i // micro_size_x, vj // micro_size_y, - i % micro_size_x, vj % micro_size_y] + C[by * block_M + i, bx * block_N + vj] = C_shared[ + i // micro_size_x, vj // micro_size_y, i % micro_size_x, vj % micro_size_y + ] return main @@ -357,8 +354,8 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct transform_b, ): import bitblas - matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( - M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) + + matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) kernel = tilelang.compile(matmul, out_idx=[2]) src_code = kernel.get_kernel_source() @@ -371,8 +368,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct storage_dtype = "int8" A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) - qB = torch.randint( - 0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype)) + qB = torch.randint(0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) ladder_permutate_config = bitblas.ops.LadderPermutateConfig( @@ -407,9 +403,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct # Ensure that the latency is not None assert latency is not None - B = ( - torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, - dtype=torch.half).to(torch.half).to(A.device)) + B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device) for i in range(B.shape[0]): for j in range(B.shape[1]): B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) @@ -429,8 +423,7 @@ def test_run_dequantize_gemm(): @tilelang.testing.requires_package("bitblas") def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(): - assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness( - 256, 1024, 512, "float16", "float16", "float16", 3) + assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(256, 1024, 512, "float16", "float16", "float16", 3) def main(): diff --git a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py index c5588d516cffcdd9a29f7cd804d2ab0b5cd79fec..352637de55cb655e4bd548f5c1874d8277567420 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py @@ -21,18 +21,17 @@ def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: e_f16 = e_f4 + tir.const(14, "uint16") m_f4 = f4 & tir.const(1, "uint16") m_f16 = m_f4 - val_f16 = tir.reinterpret("float16", - ((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16") - | m_f16 << tir.const(9, "uint16")).astype("uint16")) + val_f16 = tir.reinterpret( + "float16", ((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16") | m_f16 << tir.const(9, "uint16")).astype("uint16") + ) # return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16) return val_f16 def torch_convert(tensor): - def print_bit(name, val): val_cpu = val.cpu().item() - binary_repr = f'{val_cpu:032b}' + binary_repr = f"{val_cpu:032b}" print(name, binary_repr) def _convert(val, pos): @@ -68,8 +67,8 @@ def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): @T.prim_func def main( - B: T.Tensor(B_shape, storage_dtype), - C: T.Tensor((N, K), in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((N, K), in_dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): B_shared = T.alloc_shared(B_shared_shape, storage_dtype) @@ -118,19 +117,11 @@ def get_configs(): splits = [1] _configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits)) - configs = [{ - 'block_M': c[0], - 'block_N': c[1], - 'block_K': c[2], - 'num_stages': c[3], - 'threads': c[4], - 'split': c[5] - } for c in _configs] + configs = [{"block_M": c[0], "block_N": c[1], "block_K": c[2], "num_stages": c[3], "threads": c[4], "split": c[5]} for c in _configs] return configs def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): - @tilelang.jit(out_idx=[2]) def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1): num_elems_per_byte = 8 // num_bits @@ -145,17 +136,12 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): @T.prim_func def main_split( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - Ct: T.Tensor((N, M), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Ct: T.Tensor((N, M), out_dtype), ): - SplitC = T.alloc_buffer([ - split, (N + block_N - 1) // block_N * block_N, - (M + block_M - 1) // block_M * block_M - ], out_dtype) - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), split, - threads=threads) as (bx, by, bz): + SplitC = T.alloc_buffer([split, (N + block_N - 1) // block_N * block_N, (M + block_M - 1) // block_M * block_M], out_dtype) + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split, threads=threads) as (bx, by, bz): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_local = T.alloc_fragment(B_shared_shape, storage_dtype) @@ -164,10 +150,12 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) - T.annotate_layout({ - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), - }) + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), + } + ) T.clear(Ct_local) for k in T.Pipelined(K // (block_K * split), num_stages=num_stages): @@ -183,8 +171,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): ) T.copy(B_dequantize_local, B_dequantize_prev_local) T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) - T.copy(Ct_local, SplitC[bz, bx * block_N:(bx + 1) * block_N, - by * block_M:(by + 1) * block_M]) + T.copy(Ct_local, SplitC[bz, bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M]) with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M)) as (bx, by): acc = T.alloc_fragment((block_N, block_M), out_dtype) T.clear(acc) @@ -195,12 +182,11 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - Ct: T.Tensor((N, M), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Ct: T.Tensor((N, M), out_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_local = T.alloc_fragment(B_shared_shape, storage_dtype) @@ -209,10 +195,12 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) - T.annotate_layout({ - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), - }) + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), + } + ) T.clear(Ct_local) for k in T.Pipelined(K // block_K, num_stages=num_stages): @@ -229,8 +217,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): T.copy(B_dequantize_local, B_dequantize_prev_local) T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) T.copy(Ct_local, Ct_shared) - T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N, - by * block_M:(by + 1) * block_M]) + T.copy(Ct_shared, Ct[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M]) if split == 1: return main @@ -241,12 +228,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit(out_idx=[2]) - def kernel(block_M=None, - block_N=None, - block_K=None, - num_stages=None, - threads=None, - split=None): + def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None, split=None): return kernel_func(block_M, block_N, block_K, num_stages, threads, split).prim_func return kernel() @@ -269,10 +251,10 @@ def ref_program(A, qB): def main(m=256, n=256, k=256, tune=False): total_flops = 2 * m * n * k - if (not tune): - kernel = matmul( - m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)( - block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1) + if not tune: + kernel = matmul(m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)( + block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1 + ) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) print("All checks pass.") @@ -293,10 +275,10 @@ def main(m=256, n=256, k=256, tune=False): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--m', type=int, default=256, help='M') - parser.add_argument('--n', type=int, default=256, help='N') - parser.add_argument('--k', type=int, default=256, help='K') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--m", type=int, default=256, help="M") + parser.add_argument("--n", type=int, default=256, help="N") + parser.add_argument("--k", type=int, default=256, help="K") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() M, N, K = args.m, args.n, args.k main(M, N, K, args.tune) diff --git a/examples/dequantize_gemm/example_dequant_gemm_w4a8.py b/examples/dequantize_gemm/example_dequant_gemm_w4a8.py index 52ee8216f51208b98335416f20c0782b7a5c3f2d..3ff726738375851e57d9d02e667ba2f77135548a 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_w4a8.py +++ b/examples/dequantize_gemm/example_dequant_gemm_w4a8.py @@ -42,8 +42,8 @@ def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): @T.prim_func def main( - B: T.Tensor(B_shape, storage_dtype), - C: T.Tensor((N, K), in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((N, K), in_dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): B_shared = T.alloc_shared(B_shared_shape, storage_dtype) @@ -66,13 +66,12 @@ def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): def torch_convert(tensor): - def _convert(val, pos): assert val.dtype == torch.uint8 val = val.view(torch.int8) mask = (1 << 4) - 1 - i4_shifted = ((val >> (pos * 4)) & mask) - i4 = ((i4_shifted << 4) >> 4) + i4_shifted = (val >> (pos * 4)) & mask + i4 = (i4_shifted << 4) >> 4 return i4.view(torch.int8) @@ -94,7 +93,6 @@ def ref_program(A, qB): def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): - @tilelang.jit(out_idx=[2]) def kernel_func(block_M, block_N, block_K, num_stages, threads): num_elems_per_byte = 8 // num_bits @@ -109,12 +107,11 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, storage_dtype), - Ct: T.Tensor((N, M), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Ct: T.Tensor((N, M), out_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_local = T.alloc_fragment(B_shared_shape, storage_dtype) @@ -123,10 +120,12 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) - T.annotate_layout({ - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), - }) + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), + } + ) T.clear(Ct_local) for k in T.Pipelined(K // block_K, num_stages=num_stages): @@ -143,8 +142,7 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune T.copy(B_dequantize_local, B_dequantize_prev_local) T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) T.copy(Ct_local, Ct_shared) - T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N, - by * block_M:(by + 1) * block_M]) + T.copy(Ct_shared, Ct[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M]) return main @@ -167,10 +165,10 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune def main(m=128, n=256, k=256, tune=False): total_flops = 2 * m * n * k - if (not tune): - kernel = matmul_int8xint4( - m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)( - block_M=32, block_N=32, block_K=128, num_stages=1, threads=128) + if not tune: + kernel = matmul_int8xint4(m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)( + block_M=32, block_N=32, block_K=128, num_stages=1, threads=128 + ) profiler = kernel.get_profiler() profiler.assert_allclose(ref_program, rtol=1e-2, atol=1e-2) print("All checks pass.") diff --git a/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py b/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py index d3e90ec9323052ee631f693e1f8228cc1f8695b7..3f1214670c5ae6f99266983fa83fa25de8117c9e 100644 --- a/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py +++ b/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py @@ -4,7 +4,8 @@ from typing import Optional, Callable, Any import torch from tilelang import DataType from tilelang.quantize import ( - _tir_packed_int_to_int_convert,) + _tir_packed_int_to_int_convert, +) @tilelang.jit @@ -26,11 +27,10 @@ def dequantize_gemv( group_size: int = -1, with_scaling: bool = False, ) -> Callable[..., Any]: - assert n_partition is not None, "n_partition must be provided" assert reduce_thread is not None, ( - "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV" - "sch_outer_reduction_with_config is not implemented") + "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMVsch_outer_reduction_with_config is not implemented" + ) assert trans_A is False, "Dequantize only implement for trans_A=False currently" assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" @@ -81,12 +81,12 @@ def dequantize_gemv( C: T.Tensor[C_shape, out_dtype], ): with T.Kernel( - T.ceildiv(N, n_partition), - M, - threads=(reduce_thread, n_partition), + T.ceildiv(N, n_partition), + M, + threads=(reduce_thread, n_partition), ) as ( - bx, - by, + bx, + by, ): A_local = T.alloc_local((micro_size_k,), in_dtype) B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype) @@ -107,8 +107,7 @@ def dequantize_gemv( for v in T.vectorized(micro_size_k_compressed): B_quant_local[v] = B[ bx * n_partition + ni, - ko * (reduce_thread * micro_size_k_compressed) + - kr * micro_size_k_compressed + v, + ko * (reduce_thread * micro_size_k_compressed) + kr * micro_size_k_compressed + v, ] if fast_decoding: @@ -120,10 +119,9 @@ def dequantize_gemv( ) else: for ki in T.serial(micro_size_k): - B_dequantize_local[ki] = _tir_packed_int_to_int_convert( - storage_type, - storage_nbit)(num_bits, B_quant_local[ki // num_elems_per_byte], - ki % num_elems_per_byte, in_dtype) + B_dequantize_local[ki] = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + num_bits, B_quant_local[ki // num_elems_per_byte], ki % num_elems_per_byte, in_dtype + ) if use_dp4a: for ki in T.serial(micro_size_k // dp4a_size): @@ -137,9 +135,9 @@ def dequantize_gemv( accum_res[0] += A_local[ki] * B_dequantize_local[ki] with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), - "reduce_scope", - T.reinterpret(T.uint64(0), dtype="handle"), + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), ): T.evaluate( T.tvm_thread_allreduce( @@ -149,7 +147,8 @@ def dequantize_gemv( reduced_accum_res[0], kr, dtype="handle", - )) + ) + ) if kr == 0: C[by, bx * n_partition + ni] = reduced_accum_res[0] @@ -174,26 +173,39 @@ def main() -> None: group_size = -1 with_scaling = False - kernel = dequantize_gemv(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits, storage_dtype, - source_format, n_partition, reduce_thread, fast_decoding, trans_A, - trans_B, group_size, with_scaling) + kernel = dequantize_gemv( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + num_bits, + storage_dtype, + source_format, + n_partition, + reduce_thread, + fast_decoding, + trans_A, + trans_B, + group_size, + with_scaling, + ) storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) num_elems_per_byte = storage_nbit // num_bits A = torch.rand(M, K, dtype=getattr(torch, in_dtype)).cuda() - qB = torch.randint( - 0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda() + qB = torch.randint(0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda() C = torch.zeros(M, N, dtype=getattr(torch, accum_dtype)).cuda() if fast_decoding: from tilelang.quantize.utils import interleave_weight + qB = interleave_weight(qB, num_bits, in_dtype) kernel(A, qB, C) # int4 reference - B = ( - torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, - dtype=torch.half).to(torch.half).to(A.device)) + B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device) for j in range(B.shape[1]): B[:, j] = ((qB[:, j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) diff --git a/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py b/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py index c4cf5fb505ddbb31ec7a655bab3d73897a99eaa4..098f814c27646d01251e692e645996579ce94fbd 100644 --- a/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py @@ -25,6 +25,7 @@ def get_configs(): List[dict]: A list of configuration dictionaries covering all combinations. """ import itertools + iter_params = dict( block_M=[128], block_N=[64, 128, 256], @@ -33,33 +34,33 @@ def get_configs(): threads=[128, 256, 512], split=[1], ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] @tilelang.autotune(configs=get_configs()) @tilelang.jit(out_idx=[-1]) -def matmul(M, - N, - K, - topk, - E, - padding_M, - in_dtype, - out_dtype, - accum_dtype, - source_format='uint', - num_bits=4, - scale_size=32, - fast_dequant=True, - with_bias=False, - block_M=128, - block_N=256, - block_K=128, - num_stages=2, - threads=256, - split=1): +def matmul( + M, + N, + K, + topk, + E, + padding_M, + in_dtype, + out_dtype, + accum_dtype, + source_format="uint", + num_bits=4, + scale_size=32, + fast_dequant=True, + with_bias=False, + block_M=128, + block_N=256, + block_K=128, + num_stages=2, + threads=256, + split=1, +): """ Construct and return a grouped (Mixture-of-Experts) matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized, expert-grouped B (shape ExNxQK) and writes an output of shape (M, topk, N) in out_dtype. @@ -115,11 +116,12 @@ def matmul(M, Block_QK = block_K // num_elems_per_byte A_shared_shape = (block_M, block_K) B_shared_shape = (block_N, Block_QK) - Bias_shared_shape = (block_N) + Bias_shared_shape = block_N B_dequantize_shared_shape = (block_N, block_K) assert K % (block_K * split) == 0 from tilelang.quantize import get_mxfp_intrin_group + # fast_dequant_bf16_fp4_twiddling mxfp_intrin_info = get_mxfp_intrin_group( out_dtype=in_dtype, @@ -221,19 +223,16 @@ def matmul(M, for v in T.vectorized(0, local_size): index = i * threads * local_size + tx * local_size + v - B_dequantize_shared[index // block_K, - index % block_K] = B_dequantize_local_thread[v] + B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v] return fast_dequant_bf16_fp4_twiddling def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): - assert in_dtype in ["fp4"] assert out_dtype in ["bfloat16"] @T.macro def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k): - B_local = T.alloc_fragment(B_shared_shape, storage_dtype) B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype) @@ -244,8 +243,8 @@ def matmul(M, B_local[i, j // num_elems_per_byte], j % num_elems_per_byte, Scale_shared[ - i, k * block_K // scale_size + j // - scale_size], # Scale is the exponential part, within the representation of uint8 + i, k * block_K // scale_size + j // scale_size + ], # Scale is the exponential part, within the representation of uint8 dtype=out_dtype, ) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size])) T.copy(B_dequantize_local, B_dequantize_shared) @@ -254,19 +253,17 @@ def matmul(M, @T.prim_func def main( - A: T.Tensor((M, K), in_dtype), - B: T.Tensor((E, N, QK), storage_dtype), - Scale: T.Tensor((E, N, K // scale_size), storage_dtype), - Bias: T.Tensor((E, N), out_dtype), - # Add fusedmoe tensors - topk_weights: T.Tensor((M * topk), out_dtype), - sorted_token_ids: T.Tensor((padding_M), "int32"), - expert_ids: T.Tensor((padding_M // block_M), "int32"), - C: T.Tensor((M, topk, N), out_dtype), + A: T.Tensor((M, K), in_dtype), + B: T.Tensor((E, N, QK), storage_dtype), + Scale: T.Tensor((E, N, K // scale_size), storage_dtype), + Bias: T.Tensor((E, N), out_dtype), + # Add fusedmoe tensors + topk_weights: T.Tensor((M * topk), out_dtype), + sorted_token_ids: T.Tensor((padding_M), "int32"), + expert_ids: T.Tensor((padding_M // block_M), "int32"), + C: T.Tensor((M, topk, N), out_dtype), ): - - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(padding_M, block_M), threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(padding_M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) @@ -280,17 +277,19 @@ def matmul(M, # May use much more shared memory than necessary Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype) - T.annotate_layout({ - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - C_shared: tilelang.layout.make_swizzled_layout(C_shared), - }) + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + C_shared: tilelang.layout.make_swizzled_layout(C_shared), + } + ) T.use_swizzle(10) if threads == 512: T.disable_warp_group_reg_alloc() - T.copy(sorted_token_ids[by * block_M:(by + 1) * block_M], sorted_token_ids_shared) + T.copy(sorted_token_ids[by * block_M : (by + 1) * block_M], sorted_token_ids_shared) expert_id[0] = expert_ids[by] # Get the topk weights of each token in the current block @@ -300,11 +299,11 @@ def matmul(M, # Get bias and scale based on the expert id if with_bias: - T.copy(Bias[expert_id[0], bx * block_N:(bx + 1) * block_N], Bias_shared) + T.copy(Bias[expert_id[0], bx * block_N : (bx + 1) * block_N], Bias_shared) else: T.clear(Bias_shared) - T.copy(Scale[expert_id[0], bx * block_N:(bx + 1) * block_N, :], Scale_shared) + T.copy(Scale[expert_id[0], bx * block_N : (bx + 1) * block_N, :], Scale_shared) for i, j in T.Parallel(block_M, block_N): C_local[i, j] = Bias_shared[j] @@ -317,14 +316,13 @@ def matmul(M, base = copy_i * threads * 16 + tx * 16 if sorted_token_ids_shared[base // block_K] != -1: for copy_j in T.vectorized(16): - A_shared[base // block_K, base % block_K + - copy_j] = A[sorted_token_ids_shared[base // block_K] // topk, - k * block_K + base % block_K + copy_j] + A_shared[base // block_K, base % block_K + copy_j] = A[ + sorted_token_ids_shared[base // block_K] // topk, k * block_K + base % block_K + copy_j + ] T.copy(B[expert_id[0], bx * block_N, k * block_K // num_elems_per_byte], B_shared) if fast_dequant: - get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, - k) + get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, k) else: get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k) @@ -338,10 +336,11 @@ def matmul(M, base = copy_i * threads * 16 + tx * 16 if sorted_token_ids_shared[base // block_N] != -1: for copy_j in T.vectorized(16): - C[sorted_token_ids_shared[base // block_N] // topk, - sorted_token_ids_shared[base // block_N] % topk, bx * block_N + - base % block_N + copy_j] = C_shared[base // block_N, - base % block_N + copy_j] + C[ + sorted_token_ids_shared[base // block_N] // topk, + sorted_token_ids_shared[base // block_N] % topk, + bx * block_N + base % block_N + copy_j, + ] = C_shared[base // block_N, base % block_N + copy_j] return main @@ -355,7 +354,7 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc assert scale_size == 32 # MXFP4 # Initialize output tensor - C = torch.ones((M, topk, N), dtype=getattr(torch, dtypeC), device='cuda') + C = torch.ones((M, topk, N), dtype=getattr(torch, dtypeC), device="cuda") # Iterate over sorted_token_ids for idx in range(len(sorted_token_ids)): # padding_M @@ -370,14 +369,11 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc # Dequantize the expert weights B = torch_convert_bit_twiddling(qB[expert_id]) # shape: (N, K) - B *= 2**( - Scale[expert_id][:, (torch.arange(B.shape[1], device=B.device) // scale_size)].to( - torch.bfloat16)) + B *= 2 ** (Scale[expert_id][:, (torch.arange(B.shape[1], device=B.device) // scale_size)].to(torch.bfloat16)) # Compute the output for this token-expert pair # token_embedding @ B.T + bias - output = torch.matmul(token_embedding.to(torch.bfloat16), B.T.to( - torch.bfloat16)) + Bias[expert_id] + output = torch.matmul(token_embedding.to(torch.bfloat16), B.T.to(torch.bfloat16)) + Bias[expert_id] output = output.to(torch.__getattribute__(dtypeC)) # Apply the topk weight @@ -391,14 +387,12 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc def get_data(m, n, k, qk, scale_size, topk, E, block_M): - A = torch.empty(m, k, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1) - qB = torch.randint( - 0, 256, (E, n, qk), dtype=torch.uint8, - device='cuda') # Quantized weight tensor for E experts. - Scale = torch.randint(0, 8, (E, n, k // scale_size), dtype=torch.uint8, device='cuda') - Bias = torch.empty(E, n, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1) - - weights = torch.empty(m, E, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1) + A = torch.empty(m, k, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1) + qB = torch.randint(0, 256, (E, n, qk), dtype=torch.uint8, device="cuda") # Quantized weight tensor for E experts. + Scale = torch.randint(0, 8, (E, n, k // scale_size), dtype=torch.uint8, device="cuda") + Bias = torch.empty(E, n, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1) + + weights = torch.empty(m, E, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1) # topk_weights: Router weights for the top-k experts for each token. # Shape: (m, topk) # tokens_experts: A flattened tensor of expert assignments for each token. @@ -420,10 +414,7 @@ def get_data(m, n, k, qk, scale_size, topk, E, block_M): pad_len = ((cnt + block_M - 1) // block_M) * block_M - cnt if pad_len > 0: # -1 for padding (`M` instead in vLLM moe_align_block_size()) - group_token_ids = torch.cat([ - group_token_ids, - torch.full((pad_len,), -1, dtype=group_token_ids.dtype, device='cuda') - ]) + group_token_ids = torch.cat([group_token_ids, torch.full((pad_len,), -1, dtype=group_token_ids.dtype, device="cuda")]) padded_token_ids.append(group_token_ids) expert_ids.extend([eid] * ((cnt + block_M - 1) // block_M)) start = end @@ -431,21 +422,13 @@ def get_data(m, n, k, qk, scale_size, topk, E, block_M): # sorted_token_ids: The final flattened and padded tensor of token indices. sorted_token_ids = torch.cat(padded_token_ids, dim=0).to(torch.int32) # (padding_M,) # expert_ids: The final tensor of expert IDs corresponding to `sorted_token_ids`. - expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device='cuda') # (padding_M,) + expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device="cuda") # (padding_M,) padding_M = sorted_token_ids.shape[0] # padding_M: token number after padding return A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M -def main(m=256, - n=256, - k=256, - scale_size=32, - topk=4, - E=32, - fast_dequant=True, - with_bias=False, - tune=False): +def main(m=256, n=256, k=256, scale_size=32, topk=4, E=32, fast_dequant=True, with_bias=False, tune=False): # Tunable parameters block_M, block_N, block_K = 128, 256, 128 # noqa: F841 num_stages = 1 # noqa: F841 @@ -456,8 +439,7 @@ def main(m=256, num_bits = 4 num_elems_per_byte = 8 // num_bits qk = k // num_elems_per_byte - A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data( - m, n, k, qk, scale_size, topk, E, block_M) + A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data(m, n, k, qk, scale_size, topk, E, block_M) if tune: with set_autotune_inputs([A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids]): @@ -510,14 +492,11 @@ def main(m=256, expert_ids, ) - print('Tilelang kernel run finished.') + print("Tilelang kernel run finished.") - ref_output = ref_moe( - A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, - block_M=block_M) # Maybe a little bit slow... + ref_output = ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=block_M) # Maybe a little bit slow... - latency = tilelang.profiler.do_bench( - lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=100) + latency = tilelang.profiler.do_bench(lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=100) print("Tilelang: {:.2f} ms".format(latency)) print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) @@ -525,32 +504,19 @@ def main(m=256, max_val = diff.max() max_idx = diff.argmax() print(f"max abs diff: {max_val} at index: {max_idx}") - assert_similar( - output, ref_output, name="output", - eps=2e-5) # We care about the similarity rather than abs. difference + assert_similar(output, ref_output, name="output", eps=2e-5) # We care about the similarity rather than abs. difference print("All checks pass. ✅") if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - "--M", type=int, default=16384, help="M") # From gpt-oss-20b MoE's first gemm + parser.add_argument("--M", type=int, default=16384, help="M") # From gpt-oss-20b MoE's first gemm parser.add_argument("--N", type=int, default=5760, help="N") parser.add_argument("--K", type=int, default=2944, help="K") parser.add_argument("--scale_size", type=int, default=32, help="scale size") - parser.add_argument( - "--topk", type=int, default=4, help="topk") # experts activated for each token + parser.add_argument("--topk", type=int, default=4, help="topk") # experts activated for each token parser.add_argument("--E", type=int, default=32, help="E") # number of experts parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() - main( - args.M, - args.N, - args.K, - args.scale_size, - topk=args.topk, - E=args.E, - fast_dequant=True, - with_bias=True, - tune=args.tune) + main(args.M, args.N, args.K, args.scale_size, topk=args.topk, E=args.E, fast_dequant=True, with_bias=True, tune=args.tune) diff --git a/examples/dsa_sparse_finetune/dsa.py b/examples/dsa_sparse_finetune/dsa.py index 1ca282411ab8c6319b74a8ced60d4a948c7968d3..9fae8e5e3d698c9d7763b707fa2b2fd7506257c2 100644 --- a/examples/dsa_sparse_finetune/dsa.py +++ b/examples/dsa_sparse_finetune/dsa.py @@ -11,7 +11,6 @@ from utils import get_abs_err, get_err_ratio class RegsiterLossFunction(torch.autograd.Function): - @staticmethod def forward(ctx, x, loss): ctx.save_for_backward(loss) @@ -38,49 +37,43 @@ def ref_deepseek_sparse_attention_innner( index_sm_scale: Optional[float] = None, ): dtype = q.dtype - q, kv, index_q, index_k, weights = map(lambda x: x.to(torch.float32), - (q, kv, index_q, index_k, weights)) + q, kv, index_q, index_k, weights = map(lambda x: x.to(torch.float32), (q, kv, index_q, index_k, weights)) - index_sm_scale = index_q.shape[-1]**-0.5 + index_sm_scale = index_q.shape[-1] ** -0.5 b, s = index_q.shape[:2] # tl_topk_indices = tl_topk_indices.to(torch.int64) # tl_topk_indices[tl_topk_indices == -1] = s casual_mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device) - index_logits = einsum(index_q, index_k, 'b s1 h k, b s2 k -> b s1 h s2') + index_logits = einsum(index_q, index_k, "b s1 h k, b s2 k -> b s1 h s2") index_logits = F.relu(index_logits) - index_logits = (index_logits * weights.unsqueeze(-1)).sum( - dim=-2, dtype=torch.float32) * index_sm_scale - index_logits = torch.where(casual_mask, index_logits, float('-inf')) + index_logits = (index_logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) * index_sm_scale + index_logits = torch.where(casual_mask, index_logits, float("-inf")) topk_indices = torch.topk(index_logits, k=topk, dim=-1).indices - topk_logits = torch.gather( - F.pad(index_logits, (0, 1), value=float('-inf')), dim=-1, index=topk_indices) + topk_logits = torch.gather(F.pad(index_logits, (0, 1), value=float("-inf")), dim=-1, index=topk_indices) topk_score = F.log_softmax(topk_logits, dim=-1, dtype=torch.float32) index_topk_score = topk_score if sm_scale is None: - sm_scale = kv.shape[-1]**-0.5 + sm_scale = kv.shape[-1] ** -0.5 h = q.shape[-2] - index_mask = torch.zeros((b, s, s + 1), dtype=torch.bool, device="cuda")\ - .scatter_(dim=-1, index=topk_indices, src=torch.ones_like(topk_indices, dtype=torch.bool))[:, :, :-1] - mask = repeat(casual_mask & index_mask, 'b s1 s2 -> b s1 h s2', h=h) + index_mask = torch.zeros((b, s, s + 1), dtype=torch.bool, device="cuda").scatter_( + dim=-1, index=topk_indices, src=torch.ones_like(topk_indices, dtype=torch.bool) + )[:, :, :-1] + mask = repeat(casual_mask & index_mask, "b s1 s2 -> b s1 h s2", h=h) k, v = kv, kv[..., :dim_v] - logits = einsum(q, k, 'b s1 h d, b s2 d -> b s1 h s2') * sm_scale - logits = torch.where(mask, logits, float('-inf')) + logits = einsum(q, k, "b s1 h d, b s2 d -> b s1 h s2") * sm_scale + logits = torch.where(mask, logits, float("-inf")) attn_score = F.softmax(logits, dim=-1, dtype=torch.float32) - o = einsum(attn_score, v, 'b s1 h s2, b s2 d -> b s1 h d') + o = einsum(attn_score, v, "b s1 h s2, b s2 d -> b s1 h d") attn_score = attn_score.sum(dim=-2) # [b, s1, s2] attn_topk_score = torch.gather(F.pad(attn_score, (0, 1)), dim=-1, index=topk_indices) attn_topk_score = attn_topk_score / attn_topk_score.sum(dim=-1, keepdim=True) - loss = F.kl_div( - index_topk_score.clip(-100, 0), - attn_topk_score.detach().log().clip(-100, 0), - log_target=True, - reduction="sum") + loss = F.kl_div(index_topk_score.clip(-100, 0), attn_topk_score.detach().log().clip(-100, 0), log_target=True, reduction="sum") o = register_loss(o, loss) return o.to(dtype), topk_indices @@ -101,11 +94,11 @@ def ref_deepseek_sparse_attention( all_o, all_topk_indices = [], [] for i in range(offsets.shape[0] - 1): o, topk_indices = ref_deepseek_sparse_attention_innner( - q[None, offsets[i]:offsets[i + 1]], - kv[None, offsets[i]:offsets[i + 1]], - index_q[None, offsets[i]:offsets[i + 1]], - index_k[None, offsets[i]:offsets[i + 1]], - weights[None, offsets[i]:offsets[i + 1]], + q[None, offsets[i] : offsets[i + 1]], + kv[None, offsets[i] : offsets[i + 1]], + index_q[None, offsets[i] : offsets[i + 1]], + index_k[None, offsets[i] : offsets[i + 1]], + weights[None, offsets[i] : offsets[i + 1]], topk, dim_v, sm_scale, @@ -119,7 +112,6 @@ def ref_deepseek_sparse_attention( class DSAFunction(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -134,12 +126,9 @@ class DSAFunction(torch.autograd.Function): sm_scale: Optional[float] = None, ): # topk_indices, index_score = ref_index_score(index_q, weights, index_k, topk) - topk_indices, index_score = indexer_topk_reducesum_interface(index_q, weights, index_k, - topk, offsets) - o, lse = sparse_mla_fwd_interface( - q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), offsets, sm_scale=sm_scale, d_v=dim_v) - ctx.save_for_backward(q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, - offsets) + topk_indices, index_score = indexer_topk_reducesum_interface(index_q, weights, index_k, topk, offsets) + o, lse = sparse_mla_fwd_interface(q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), offsets, sm_scale=sm_scale, d_v=dim_v) + ctx.save_for_backward(q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets) ctx.topk = topk ctx.dim_v = dim_v ctx.sm_scale = sm_scale @@ -153,19 +142,10 @@ class DSAFunction(torch.autograd.Function): ): q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets = ctx.saved_tensors attn_score = sparse_mla_topk_reducesum_interface( - q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), lse, offsets, - dim_v=ctx.dim_v).squeeze(-2) - dq, dkv = sparse_mla_bwd( - q, - kv.unsqueeze(-2), - o, - do, - topk_indices.unsqueeze(-2), - lse, - offsets, - sm_scale=ctx.sm_scale) - dindex_q, dweights, dindex_k = indexer_bwd_interface(index_q, weights, index_k, attn_score, - index_score, topk_indices, offsets) + q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), lse, offsets, dim_v=ctx.dim_v + ).squeeze(-2) + dq, dkv = sparse_mla_bwd(q, kv.unsqueeze(-2), o, do, topk_indices.unsqueeze(-2), lse, offsets, sm_scale=ctx.sm_scale) + dindex_q, dweights, dindex_k = indexer_bwd_interface(index_q, weights, index_k, attn_score, index_score, topk_indices, offsets) return dq, dkv.squeeze(-2), dindex_q, dindex_k, dweights, None, None, None, None @@ -209,8 +189,7 @@ def test_kernel( index_k_grad, index_k.grad = index_k.grad, None weights_grad, weights.grad = weights.grad, None - ref_o, ref_topk_indices = ref_deepseek_sparse_attention(q, kv, index_q, index_k, weights, - offsets, topk, D) + ref_o, ref_topk_indices = ref_deepseek_sparse_attention(q, kv, index_q, index_k, weights, offsets, topk, D) ref_o.backward(do) ref_q_grad, q.grad = q.grad, None ref_kv_grad, kv.grad = kv.grad, None @@ -219,28 +198,20 @@ def test_kernel( ref_weights_grad, weights.grad = weights.grad, None print(f"o err: {get_abs_err(o, ref_o):.6f} ratio: {get_err_ratio(o, ref_o):.6f}") - print( - f"q.grad err: {get_abs_err(q_grad, ref_q_grad):.6f} ratio: {get_err_ratio(q_grad, ref_q_grad):.6f}" - ) - print( - f"kv.grad err: {get_abs_err(kv_grad, ref_kv_grad):.6f} ratio: {get_err_ratio(kv_grad, ref_kv_grad):.6f}" - ) + print(f"q.grad err: {get_abs_err(q_grad, ref_q_grad):.6f} ratio: {get_err_ratio(q_grad, ref_q_grad):.6f}") + print(f"kv.grad err: {get_abs_err(kv_grad, ref_kv_grad):.6f} ratio: {get_err_ratio(kv_grad, ref_kv_grad):.6f}") print( f"index_q.grad err: {get_abs_err(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f} ratio: {get_err_ratio(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f}" ) - print( - f"index_k.grad err: {get_abs_err(index_k_grad, ref_index_k_grad):.6f} ratio: {get_err_ratio(index_k_grad, ref_index_k_grad):.6f}" - ) - print( - f"weights.grad err: {get_abs_err(weights_grad, ref_weights_grad):.6f} ratio: {get_err_ratio(weights_grad, ref_weights_grad):.6f}" - ) + print(f"index_k.grad err: {get_abs_err(index_k_grad, ref_index_k_grad):.6f} ratio: {get_err_ratio(index_k_grad, ref_index_k_grad):.6f}") + print(f"weights.grad err: {get_abs_err(weights_grad, ref_weights_grad):.6f} ratio: {get_err_ratio(weights_grad, ref_weights_grad):.6f}") intersections = [] for j in range(S): ref_np = ref_topk_indices[j].cpu().to(torch.int32).numpy() trt_np = topk_indices[j].cpu().to(torch.int32).numpy() - mask = (trt_np != -1) + mask = trt_np != -1 set_ref = set(ref_np[mask]) set_trt = set(trt_np[mask]) diff --git a/examples/dsa_sparse_finetune/index.py b/examples/dsa_sparse_finetune/index.py index 92ce687f97d569ce0d9c522d3093399eaaa55234..5e4800411004e5890faba0578cf83f09e27f2dc9 100644 --- a/examples/dsa_sparse_finetune/index.py +++ b/examples/dsa_sparse_finetune/index.py @@ -5,7 +5,9 @@ import functools from typing import Callable, Any -def tensor_cache(fn: Callable[..., torch.Tensor],) -> Callable[..., torch.Tensor]: +def tensor_cache( + fn: Callable[..., torch.Tensor], +) -> Callable[..., torch.Tensor]: """ A decorator that caches the most recent result of a function with tensor inputs. @@ -29,10 +31,12 @@ def tensor_cache(fn: Callable[..., torch.Tensor],) -> Callable[..., torch.Tensor def wrapper(*args: Any, **kwargs: Any) -> Any: nonlocal last_args, last_kwargs, last_result - if (last_args is not None and last_kwargs is not None) and \ - (len(args) == len(last_args) and len(kwargs) == len(last_kwargs)) and \ - all(a is b for a, b in zip(args, last_args, strict=False)) and \ - all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()): + if ( + (last_args is not None and last_kwargs is not None) + and (len(args) == len(last_args) and len(kwargs) == len(last_kwargs)) + and all(a is b for a, b in zip(args, last_args, strict=False)) + and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()) + ): return last_result result = fn(*args, **kwargs) @@ -56,16 +60,15 @@ def prepare_cu_seqlens_from_lens( @tensor_cache -def prepare_lens_from_cu_seqlens(cu_seqlens: torch.LongTensor,) -> torch.LongTensor: +def prepare_lens_from_cu_seqlens( + cu_seqlens: torch.LongTensor, +) -> torch.LongTensor: return torch.diff(cu_seqlens) @tensor_cache def prepare_position_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor: - return torch.cat([ - torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) - for n in prepare_lens(cu_seqlens).unbind() - ]) + return torch.cat([torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) for n in prepare_lens(cu_seqlens).unbind()]) @tensor_cache diff --git a/examples/dsa_sparse_finetune/indexer_bwd.py b/examples/dsa_sparse_finetune/indexer_bwd.py index 5430c1c0039d52a747421613c1d04fc302cebd34..5d8132d9b83214db69ccfe2014997d9f058e56cb 100644 --- a/examples/dsa_sparse_finetune/indexer_bwd.py +++ b/examples/dsa_sparse_finetune/indexer_bwd.py @@ -49,17 +49,17 @@ def tl_indexer_bwd_impl( @T.prim_func def tl_indexer_bwd_kernel( - IndexQ: T.Tensor(index_q_shape, dtype), - Weights: T.Tensor(weights_shape, dtype), - IndexK: T.Tensor(index_k_shape, dtype), - dIndexQ: T.Tensor(index_q_shape, dtype), - dWeights: T.Tensor(weights_shape, dtype), - dIndexK: T.Tensor(index_k_shape, dtype), - AttnScore: T.Tensor(shape_p, FP32), - IndexScore: T.Tensor(shape_p, FP32), - TopkIndices: T.Tensor(topk_indices_shape, INT32), - Offsets: T.Tensor(offsets_shape, INT32), - TokenIndices: T.Tensor(token_indices_shape, INT32), + IndexQ: T.Tensor(index_q_shape, dtype), + Weights: T.Tensor(weights_shape, dtype), + IndexK: T.Tensor(index_k_shape, dtype), + dIndexQ: T.Tensor(index_q_shape, dtype), + dWeights: T.Tensor(weights_shape, dtype), + dIndexK: T.Tensor(index_k_shape, dtype), + AttnScore: T.Tensor(shape_p, FP32), + IndexScore: T.Tensor(shape_p, FP32), + TopkIndices: T.Tensor(topk_indices_shape, INT32), + Offsets: T.Tensor(offsets_shape, INT32), + TokenIndices: T.Tensor(token_indices_shape, INT32), ): with T.Kernel(seq_len, threads=num_threads) as (bx): i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1] @@ -81,7 +81,6 @@ def tl_indexer_bwd_impl( index_q_shared[i, j] = index_q_shared[i, j] * sm_scale for bi_i in T.Pipelined(num_blocks, num_stages=num_stages): - i_st = bi_i * block_I i_ed = (bi_i + 1) * block_I @@ -91,8 +90,7 @@ def tl_indexer_bwd_impl( index_k_shared = T.alloc_shared([block_I, dim], dtype=dtype) for i, j in T.Parallel(block_I, dim): pos = indices_shared[i] - index_k_shared[i, j] = T.if_then_else((pos > -1) & (pos <= i_t), - IndexK[bos + pos, j], 0) + index_k_shared[i, j] = T.if_then_else((pos > -1) & (pos <= i_t), IndexK[bos + pos, j], 0) attn_score_shared = T.alloc_shared([block_I], dtype=accum_dtype) index_score_shared = T.alloc_shared([block_I], dtype=accum_dtype) @@ -115,8 +113,7 @@ def tl_indexer_bwd_impl( # dw d_weights_i = T.alloc_fragment((block_I, heads), accum_dtype) for i, j in T.Parallel(block_I, heads): - d_weights_i[i, - j] = (index_score_shared[i] - attn_score_shared[i]) * logits[i, j] + d_weights_i[i, j] = (index_score_shared[i] - attn_score_shared[i]) * logits[i, j] T.reduce_sum(d_weights_i, d_weights_frag, dim=0, clear=False) d_logits_qk = T.alloc_shared((block_I, heads), accum_dtype) @@ -129,8 +126,7 @@ def tl_indexer_bwd_impl( d_relu = 1.0 else: d_relu = 0.0 - d_logits_qk[i, j] = (index_score_shared[i] - - attn_score_shared[i]) * d_relu * weights_shared[j] + d_logits_qk[i, j] = (index_score_shared[i] - attn_score_shared[i]) * d_relu * weights_shared[j] # dq T.copy(d_logits_qk, d_logits_qk_cast1) @@ -157,7 +153,7 @@ def tl_indexer_bwd_impl( for i, j in T.Parallel(block_I, dim): pos = indices_shared[i] - if ((pos > -1) & (pos <= i_t)): + if (pos > -1) & (pos <= i_t): T.atomic_add(dIndexK[bos + pos, j], d_index_k_frag[i, j]) for i, j in T.Parallel(heads, dim): @@ -184,40 +180,35 @@ def indexer_bwd_interface( dweights = torch.zeros_like(weights) dk = torch.zeros_like(k) kernel = tl_indexer_bwd_impl(heads, dim, topk) - kernel(q, weights, k, dq, dweights, dk, attn_score, index_score, topk_indices, offsets, - token_indices) + kernel(q, weights, k, dq, dweights, dk, attn_score, index_score, topk_indices, offsets, token_indices) return dq, dweights, dk -def ref_indexer_bwd(Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, - TopkIndices: torch.Tensor, AttnScore: torch.Tensor, - offsets: torch.Tensor) -> torch.Tensor: +def ref_indexer_bwd( + Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, AttnScore: torch.Tensor, offsets: torch.Tensor +) -> torch.Tensor: Q.requires_grad_(True) Weights.requires_grad_(True) K.requires_grad_(True) - softmax_scale = Q.shape[-1]**-0.5 + softmax_scale = Q.shape[-1] ** -0.5 all_loss = [] all_log_topk_prob = [] for i in range(offsets.shape[0] - 1): assert (offsets[i + 1] - offsets[i]).item() >= TopkIndices.shape[-1] - q = Q[offsets[i]:offsets[i + 1]] - weights = Weights[offsets[i]:offsets[i + 1]] - k = K[offsets[i]:offsets[i + 1]] - topk_indices = TopkIndices[offsets[i]:offsets[i + 1]] - attn_score = AttnScore[offsets[i]:offsets[i + 1]] + q = Q[offsets[i] : offsets[i + 1]] + weights = Weights[offsets[i] : offsets[i + 1]] + k = K[offsets[i] : offsets[i + 1]] + topk_indices = TopkIndices[offsets[i] : offsets[i + 1]] + attn_score = AttnScore[offsets[i] : offsets[i + 1]] s = q.shape[0] mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device) - logits = einsum(q, k, 's1 h k, s2 k -> s1 h s2') * softmax_scale + logits = einsum(q, k, "s1 h k, s2 k -> s1 h s2") * softmax_scale logits = F.relu(logits) score = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) - score = torch.where(mask, score, float('-inf')) + score = torch.where(mask, score, float("-inf")) topk_value = torch.gather(score, dim=-1, index=topk_indices.to(torch.int64)) log_topk_prob = F.log_softmax(topk_value, dim=-1, dtype=torch.float32) - loss = F.kl_div( - log_topk_prob.clip(-100, 0), - attn_score.log().clip(-100, 0), - log_target=True, - reduction="sum") + loss = F.kl_div(log_topk_prob.clip(-100, 0), attn_score.log().clip(-100, 0), log_target=True, reduction="sum") all_loss.append(loss) all_log_topk_prob.append(log_topk_prob) loss = torch.stack(all_loss).sum() @@ -244,15 +235,13 @@ def test_kernel( seq_len = (offsets[i + 1] - offsets[i]).item() mask = (torch.arange(seq_len)[:, None] >= torch.arange(topk)[None, :]).to(q.device) logits = torch.ones(seq_len, topk).cuda() - logits = torch.where(mask, logits, float('-inf')) + logits = torch.where(mask, logits, float("-inf")) attn_score = F.softmax(logits, dim=-1, dtype=torch.float32) all_attn_score.append(attn_score) attn_score = torch.cat(all_attn_score, dim=0) - topk_indices = repeat( - torch.arange(topk, dtype=torch.int32).cuda(), 'k -> s k', s=S).contiguous() - index_score, ref_dq, ref_dw, ref_dk = ref_indexer_bwd(q, w, k, topk_indices, attn_score, - offsets) + topk_indices = repeat(torch.arange(topk, dtype=torch.int32).cuda(), "k -> s k", s=S).contiguous() + index_score, ref_dq, ref_dw, ref_dk = ref_indexer_bwd(q, w, k, topk_indices, attn_score, offsets) dq, dw, dk = indexer_bwd_interface(q, w, k, attn_score, index_score, topk_indices, offsets) @@ -261,5 +250,5 @@ def test_kernel( print(f"dq err: {get_abs_err(dk, ref_dk):.6f} ratio: {get_err_ratio(dk, ref_dk):.6f}") -if __name__ == '__main__': +if __name__ == "__main__": test_kernel() diff --git a/examples/dsa_sparse_finetune/indexer_topk_reducesum.py b/examples/dsa_sparse_finetune/indexer_topk_reducesum.py index b7fa662763e568b139356b42a1cb3d003ee4d6bf..8e2f82ba6dd6516ace5a066493f05c9a85c4ace1 100644 --- a/examples/dsa_sparse_finetune/indexer_topk_reducesum.py +++ b/examples/dsa_sparse_finetune/indexer_topk_reducesum.py @@ -53,8 +53,8 @@ def tl_indexer_topk_reducesum_impl( @T.macro def bitonic_sort( - topk_index_shared: T.SharedBuffer([N], dtype=INT32), - topk_value_shared: T.SharedBuffer([N], dtype=FP32), + topk_index_shared: T.SharedBuffer([N], dtype=INT32), + topk_value_shared: T.SharedBuffer([N], dtype=FP32), ): T.sync_threads() for i1 in T.serial(num_iters): @@ -62,9 +62,10 @@ def tl_indexer_topk_reducesum_impl( for i in T.Parallel(N): ascending = (i & (1 << (i1 + 1))) != 0 j = i ^ (1 << (i1 - i2)) - if i < j and \ - ((ascending and topk_value_shared[i] > topk_value_shared[j]) or ( - not ascending and topk_value_shared[i] < topk_value_shared[j])): + if i < j and ( + (ascending and topk_value_shared[i] > topk_value_shared[j]) + or (not ascending and topk_value_shared[i] < topk_value_shared[j]) + ): val = topk_value_shared[i] topk_value_shared[i] = topk_value_shared[j] topk_value_shared[j] = val @@ -75,13 +76,13 @@ def tl_indexer_topk_reducesum_impl( @T.prim_func def tl_indexer_topk_reducesum_kernel( - IndexQ: T.Tensor(index_q_shape, dtype), - Weights: T.Tensor(weights_shape, dtype), - IndexK: T.Tensor(index_k_shape, dtype), - TopkIndices: T.Tensor(topk_indices_shape, INT32), - ReduceSum: T.Tensor(topk_indices_shape, FP32), - Offsets: T.Tensor(offsets_shape, INT32), - TokenIndices: T.Tensor(token_indices_shape, INT32), + IndexQ: T.Tensor(index_q_shape, dtype), + Weights: T.Tensor(weights_shape, dtype), + IndexK: T.Tensor(index_k_shape, dtype), + TopkIndices: T.Tensor(topk_indices_shape, INT32), + ReduceSum: T.Tensor(topk_indices_shape, FP32), + Offsets: T.Tensor(offsets_shape, INT32), + TokenIndices: T.Tensor(token_indices_shape, INT32), ): with T.Kernel(seq_len, threads=num_threads) as (bx): i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1] @@ -92,7 +93,7 @@ def tl_indexer_topk_reducesum_impl( topk_value_shared = T.alloc_shared([N], dtype=FP32) T.fill(topk_index_shared, -1) - T.fill(topk_value_shared, float('-inf')) + T.fill(topk_value_shared, float("-inf")) T.sync_threads() index_q_shared = T.alloc_shared([heads, dim], dtype=dtype) @@ -113,8 +114,7 @@ def tl_indexer_topk_reducesum_impl( index_k_shared = T.alloc_shared([block_K, dim], dtype=dtype) for i, j in T.Parallel(block_K, dim): - index_k_shared[i, j] = T.if_then_else(k_st + i < k_ed, IndexK[bos + k_st + i, - j], 0) + index_k_shared[i, j] = T.if_then_else(k_st + i < k_ed, IndexK[bos + k_st + i, j], 0) T.sync_threads() logits = T.alloc_fragment((block_K, heads), FP32) @@ -144,7 +144,7 @@ def tl_indexer_topk_reducesum_impl( T.sync_threads() for i in T.Parallel(block_K): if k_st + i > i_t: - logits_sum[i] = float('-inf') + logits_sum[i] = float("-inf") j = offset + i topk_index_shared[j] = k_st + i topk_value_shared[j] = logits_sum[i] @@ -209,22 +209,21 @@ def indexer_topk_reducesum_interface( return topk_indices, topk_score -def ref_index_score(Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, topk: int, - offsets: torch.Tensor) -> torch.Tensor: +def ref_index_score(Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, topk: int, offsets: torch.Tensor) -> torch.Tensor: all_topk_indices = [] all_topk_score = [] for i in range(offsets.shape[0] - 1): assert (offsets[i + 1] - offsets[i]).item() >= topk - q = Q[offsets[i]:offsets[i + 1]] - weights = Weights[offsets[i]:offsets[i + 1]] - k = K[offsets[i]:offsets[i + 1]] - softmax_scale = q.shape[-1]**-0.5 + q = Q[offsets[i] : offsets[i + 1]] + weights = Weights[offsets[i] : offsets[i + 1]] + k = K[offsets[i] : offsets[i + 1]] + softmax_scale = q.shape[-1] ** -0.5 s = q.shape[0] mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device) - logits = einsum(q, k, 's1 h k, s2 k -> s1 h s2') + logits = einsum(q, k, "s1 h k, s2 k -> s1 h s2") logits = F.relu(logits) logits = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) * softmax_scale - logits = torch.where(mask, logits, float('-inf')) + logits = torch.where(mask, logits, float("-inf")) topk_logits, topk_indices = torch.topk(logits, k=topk, dim=-1) topk_score = F.softmax(topk_logits, dim=-1, dtype=torch.float32) all_topk_indices.append(topk_indices) @@ -265,13 +264,10 @@ def test_kernel( set_trt = set(trt_np[mask]) intersection = set_ref & set_trt - print("idx:", j, "selected/all:", len(intersection), "/", len(set_ref), "=", - len(intersection) / len(set_ref)) + print("idx:", j, "selected/all:", len(intersection), "/", len(set_ref), "=", len(intersection) / len(set_ref)) - print( - f"err: {get_abs_err(ref_np_val, trt_np_val):.6f} ratio: {get_err_ratio(ref_np_val, trt_np_val):.6f}" - ) + print(f"err: {get_abs_err(ref_np_val, trt_np_val):.6f} ratio: {get_err_ratio(ref_np_val, trt_np_val):.6f}") -if __name__ == '__main__': +if __name__ == "__main__": test_kernel() diff --git a/examples/dsa_sparse_finetune/sparse_mla_bwd.py b/examples/dsa_sparse_finetune/sparse_mla_bwd.py index 33c21cb44e3257d9a4e40c0f6ced7c543f07edb9..0b085516e25c4f5c227b70142aaebad6a0a42dec 100644 --- a/examples/dsa_sparse_finetune/sparse_mla_bwd.py +++ b/examples/dsa_sparse_finetune/sparse_mla_bwd.py @@ -19,15 +19,15 @@ def preprocess( assert dtype == "bfloat16" assert accum_dtype == "float" - S = T.symbolic('S') + S = T.symbolic("S") shape = [S, H, D] @T.prim_func def preprocess_kernel( - O: T.Tensor(shape, dtype), - dO: T.Tensor(shape, dtype), - Delta: T.Tensor([S, H], accum_dtype), + O: T.Tensor(shape, dtype), + dO: T.Tensor(shape, dtype), + Delta: T.Tensor([S, H], accum_dtype), ): with T.Kernel(H, T.ceildiv(S, block_ND)) as (bx, by): o = T.alloc_fragment([block_ND, block_ND], accum_dtype) @@ -36,13 +36,12 @@ def preprocess( acc = T.alloc_fragment([block_ND, block_ND], accum_dtype) T.clear(acc) for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages): - T.copy(O[by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], o) - T.copy(dO[by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], - do) + T.copy(O[by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], o) + T.copy(dO[by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], do) for i, j in T.Parallel(block_ND, block_ND): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[by * block_ND:(by + 1) * block_ND, bx]) + T.copy(delta, Delta[by * block_ND : (by + 1) * block_ND, bx]) return preprocess_kernel @@ -59,19 +58,19 @@ def postprocess( ): assert dtype == "bfloat16" assert accum_dtype == "float" - S_kv = T.symbolic('S_kv') + S_kv = T.symbolic("S_kv") dkv_shape = [S_kv, kv_group, D + D_tail] @T.prim_func def postprocess_kernel( - dKV: T.Tensor(dkv_shape, accum_dtype), - dKV_out: T.Tensor(dkv_shape, dtype), + dKV: T.Tensor(dkv_shape, accum_dtype), + dKV_out: T.Tensor(dkv_shape, dtype), ): with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, threads=threads) as (bx, by): T.copy( - dKV[bx * block_N:(bx + 1) * block_N, by, :], - dKV_out[bx * block_N:(bx + 1) * block_N, by, :], + dKV[bx * block_N : (bx + 1) * block_N, by, :], + dKV_out[bx * block_N : (bx + 1) * block_N, by, :], ) return postprocess_kernel @@ -82,7 +81,8 @@ def postprocess( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, +) def bwd( H, D, @@ -98,17 +98,17 @@ def bwd( dtype="bfloat16", accum_dtype="float", ): - assert is_causal == True, 'non-casual is not supported now' - assert topk % block_size == 0, 'otherwise will load some index=0 thus causing wrong kv to be loaded' + assert is_causal == True, "non-casual is not supported now" + assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" assert dtype == "bfloat16" assert accum_dtype == "float" assert indices_dtype == "int32" if sm_scale is None: - sm_scale = (D + D_tail)**(-0.5) + sm_scale = (D + D_tail) ** (-0.5) - B_plus_one = T.symbolic('B_plus_one') - S = T.symbolic('S') + B_plus_one = T.symbolic("B_plus_one") + S = T.symbolic("S") H_kv = H // kv_group q_shape = [S, H, D + D_tail] @@ -132,16 +132,16 @@ def bwd( @T.prim_func def sparse_mla_bwd_kernel( - Q: T.Tensor(q_shape, dtype), - KV: T.Tensor(k_shape, dtype), - dO: T.Tensor(o_shape, dtype), - Indices: T.Tensor(indices_shape, indices_dtype), - Lse: T.Tensor(lse_shape, accum_dtype), - Delta: T.Tensor(delta_shape, accum_dtype), - Offsets: T.Tensor(offsets_shape, indices_dtype), - TokenIndices: T.Tensor(token_indices_shape, indices_dtype), - dQ: T.Tensor(q_shape, dtype), - dKV: T.Tensor(k_shape, accum_dtype), + Q: T.Tensor(q_shape, dtype), + KV: T.Tensor(k_shape, dtype), + dO: T.Tensor(o_shape, dtype), + Indices: T.Tensor(indices_shape, indices_dtype), + Lse: T.Tensor(lse_shape, accum_dtype), + Delta: T.Tensor(delta_shape, accum_dtype), + Offsets: T.Tensor(offsets_shape, indices_dtype), + TokenIndices: T.Tensor(token_indices_shape, indices_dtype), + dQ: T.Tensor(q_shape, dtype), + dKV: T.Tensor(k_shape, accum_dtype), ): with T.Kernel(S, kv_group, threads=threads) as (b_s_i, bz): Q_shared = T.alloc_shared([padded_H, D], dtype) @@ -163,32 +163,32 @@ def bwd( acc_dkv = T.alloc_fragment([BS, D], accum_dtype) acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype) acc_dkv_shared = T.view(KV_shared, shape=[BS // split_store, D], dtype=accum_dtype) - acc_dkv_tail_shared = T.view( - KV_tail_shared, shape=[BS // split_store, D_tail], dtype=accum_dtype) + acc_dkv_tail_shared = T.view(KV_tail_shared, shape=[BS // split_store, D_tail], dtype=accum_dtype) b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1] bos, eos = Offsets[b_i], Offsets[b_i + 1] max_kv_i = s_i - T.copy(Q[bos + s_i, bz * padded_H:(bz + 1) * padded_H, :D], Q_shared) - T.copy(Q[bos + s_i, bz * padded_H:(bz + 1) * padded_H, D:], Q_tail_shared) - T.copy(dO[bos + s_i, bz * padded_H:(bz + 1) * padded_H, :D], dO_shared) + T.copy(Q[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D], Q_shared) + T.copy(Q[bos + s_i, bz * padded_H : (bz + 1) * padded_H, D:], Q_tail_shared) + T.copy(dO[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D], dO_shared) T.clear(acc_dq) T.clear(acc_dq_tail) - T.annotate_layout({ - dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared), - dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared), - }) + T.annotate_layout( + { + dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared), + dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared), + } + ) # Process each block of indices for i_i in T.Pipelined(NS, num_stages=num_stages): # Check which indices are valid for bi_i in T.Parallel(BS): - mask[bi_i] = (Indices[bos + s_i, bz, i_i * BS + bi_i] <= max_kv_i) & ( - Indices[bos + s_i, bz, i_i * BS + bi_i] != -1) + mask[bi_i] = (Indices[bos + s_i, bz, i_i * BS + bi_i] <= max_kv_i) & (Indices[bos + s_i, bz, i_i * BS + bi_i] != -1) # Compute attention scores for h_i, bi_i in T.Parallel(padded_H, BS): @@ -196,65 +196,33 @@ def bwd( # Load KV, V for this block of indices for bi_i, d_i in T.Parallel(BS, D): - KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, - d_i] + KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, d_i] - T.gemm( - Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) for bi_i, d_i in T.Parallel(BS, D_tail): - KV_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], - bz, D + d_i] - T.gemm( - Q_tail_shared, - KV_tail_shared, - acc_p, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol) + KV_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, D + d_i] + T.gemm(Q_tail_shared, KV_tail_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) for h_i, bi_i in T.Parallel(padded_H, BS): - acc_p[h_i, bi_i] = T.exp(acc_p[h_i, bi_i] * sm_scale - - Lse[bos + s_i, bz * padded_H + h_i]) + acc_p[h_i, bi_i] = T.exp(acc_p[h_i, bi_i] * sm_scale - Lse[bos + s_i, bz * padded_H + h_i]) T.copy(acc_p, P_shared_cast) - T.gemm( - dO_shared, - KV_shared, - acc_dp, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol, - clear_accum=True) + T.gemm(dO_shared, KV_shared, acc_dp, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) for h_i, bi_i in T.Parallel(padded_H, BS): - acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * ( - acc_dp[h_i, bi_i] - Delta[bos + s_i, bz * padded_H + h_i]) * sm_scale + acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[bos + s_i, bz * padded_H + h_i]) * sm_scale T.copy(acc_dp, dP_shared_cast) T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol) T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol) - T.gemm( - dP_shared_cast, - Q_shared, - acc_dkv, - transpose_A=True, - policy=T.GemmWarpPolicy.FullCol, - clear_accum=True) - T.gemm( - P_shared_cast, - dO_shared, - acc_dkv, - transpose_A=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(dP_shared_cast, Q_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) + T.gemm(P_shared_cast, dO_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) T.clear(acc_dkv_tail) - T.gemm( - dP_shared_cast, - Q_tail_shared, - acc_dkv_tail, - transpose_A=True, - policy=T.GemmWarpPolicy.FullCol) + T.gemm(dP_shared_cast, Q_tail_shared, acc_dkv_tail, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) for s in range(split_store): for bi_i, d_i in T.Parallel(BS, D): @@ -263,44 +231,32 @@ def bwd( for bi_i, d_i in T.Parallel(BS, D_tail): if bi_i < BS // split_store: - acc_dkv_tail_shared[bi_i, - d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), - d_i] + acc_dkv_tail_shared[bi_i, d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), d_i] for bi_i, d_i in T.Parallel(BS // split_store, D // 4): T.atomic_addx4( - dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * - (BS // split_store)], bz, d_i * 4], - acc_dkv_shared[bi_i, d_i * 4]) + dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, d_i * 4], + acc_dkv_shared[bi_i, d_i * 4], + ) # Atomically update dKV, dKV_tail tensors for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4): T.atomic_addx4( - dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * - (BS // split_store)], bz, D + d_i * 4], - acc_dkv_tail_shared[bi_i, d_i * 4]) + dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, D + d_i * 4], + acc_dkv_tail_shared[bi_i, d_i * 4], + ) # Store the accumulated dQ T.copy(acc_dq, dQ_shared) T.copy(acc_dq_tail, dQ_tail_shared) - T.copy(dQ_shared, dQ[bos + s_i, bz * padded_H:(bz + 1) * padded_H, :D]) - T.copy(dQ_tail_shared, dQ[bos + s_i, bz * padded_H:(bz + 1) * padded_H, D:]) + T.copy(dQ_shared, dQ[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D]) + T.copy(dQ_tail_shared, dQ[bos + s_i, bz * padded_H : (bz + 1) * padded_H, D:]) return sparse_mla_bwd_kernel -def sparse_mla_bwd(q, - kv, - o, - do, - indices, - lse, - offsets, - sm_scale=None, - is_casual=True, - return_kernel=False, - delta=None): +def sparse_mla_bwd(q, kv, o, do, indices, lse, offsets, sm_scale=None, is_casual=True, return_kernel=False, delta=None): assert q.is_contiguous() assert kv.is_contiguous() assert indices.is_contiguous() @@ -333,16 +289,9 @@ def sparse_mla_bwd(q, return dq, dkv -def ref_sparse_mla_bwd_interface(q, - kv, - o, - do, - indices, - lse, - offsets, - sm_scale=None, - is_casual=True): +def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, offsets, sm_scale=None, is_casual=True): from sparse_mla_fwd import ref_sparse_mla_fwd_interface + q = q.detach().clone() kv = kv.detach().clone() q.requires_grad = True @@ -352,32 +301,25 @@ def ref_sparse_mla_bwd_interface(q, return q.grad, kv.grad -def test_sparse_mla_bwd(B=1, - S=2048, - H=64, - HKV=1, - DQKV=576, - DV=512, - topk=512, - dtype=torch.bfloat16, - check_correctness=True): +def test_sparse_mla_bwd(B=1, S=2048, H=64, HKV=1, DQKV=576, DV=512, topk=512, dtype=torch.bfloat16, check_correctness=True): # Prepare data - q = torch.randn((S, H, DQKV), dtype=dtype, device='cuda').requires_grad_(True) - kv = torch.randn((S, HKV, DQKV), dtype=dtype, device='cuda').requires_grad_(True) - do = torch.randn((S, H, DV), dtype=dtype, device='cuda') + q = torch.randn((S, H, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((S, HKV, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((S, H, DV), dtype=dtype, device="cuda") offsets = torch.tensor([0, S], dtype=torch.int32, device="cuda") - indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device='cuda') + indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device="cuda") for i in range(offsets.shape[0] - 1): seq_len = (offsets[i + 1] - offsets[i]).item() assert seq_len >= topk for t in range(seq_len): for h in range(HKV): i_i = torch.randperm(max(1, t))[:topk] - indices[offsets[i] + t, h, :len(i_i)] = i_i + indices[offsets[i] + t, h, : len(i_i)] = i_i # Forward from sparse_mla_fwd import sparse_mla_fwd_interface + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, offsets) tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse, offsets) @@ -388,13 +330,15 @@ def test_sparse_mla_bwd(B=1, assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv") print("assert_tensors_similar passed") - per_token_flop = 2 * sum([ - H * DV * topk, - H * DQKV * topk, - H * DQKV * topk, - H * DQKV * topk, - H * DV * topk, - ]) + per_token_flop = 2 * sum( + [ + H * DV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DV * topk, + ] + ) from tilelang.profiler import do_bench def fn(): @@ -402,19 +346,9 @@ def test_sparse_mla_bwd(B=1, ms = do_bench(fn, rep=100, warmup=250) print(f"Average time: {ms:.3f} ms") - print(f'bwd io bandwidth = ', - (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12) - print(f'bwd tflops = ', per_token_flop * S / (ms * 1e-3) / 1e12) + print(f"bwd io bandwidth = ", (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12) + print(f"bwd tflops = ", per_token_flop * S / (ms * 1e-3) / 1e12) if __name__ == "__main__": - test_sparse_mla_bwd( - B=1, - S=2048, - H=64, - HKV=1, - DQKV=576, - DV=512, - topk=512, - dtype=torch.bfloat16, - check_correctness=True) + test_sparse_mla_bwd(B=1, S=2048, H=64, HKV=1, DQKV=576, DV=512, topk=512, dtype=torch.bfloat16, check_correctness=True) diff --git a/examples/dsa_sparse_finetune/sparse_mla_fwd.py b/examples/dsa_sparse_finetune/sparse_mla_fwd.py index 5f03dfbb68321b95c4ebf23e029c28fcb600f99d..6ec3caa7b628db36829d45540feea3b0b5c1c7b2 100644 --- a/examples/dsa_sparse_finetune/sparse_mla_fwd.py +++ b/examples/dsa_sparse_finetune/sparse_mla_fwd.py @@ -27,15 +27,12 @@ def sparse_mla_fwd( num_stages=2, threads=128, ): - assert dim == tilelang.math.next_power_of_2( - dim), f"haven't check padding correctness yet, dim={dim}" - assert tail_dim == tilelang.math.next_power_of_2( - tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" assert is_causal == True, "non-casual is not supported" - assert (topk % - block_I == 0), "otherwise will load some index=0 thus causing wrong kv to be loaded" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" if sm_scale is None: - sm_scale = (1.0 / (dim + tail_dim))**0.5 + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 else: sm_scale = sm_scale @@ -58,9 +55,9 @@ def sparse_mla_fwd( H = head_kv padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) if padded_H != H: - assert ( - kv_group == 1 - ), "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + ) BI = block_I NI = tilelang.cdiv(topk, block_I) D = dim @@ -76,19 +73,18 @@ def sparse_mla_fwd( @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), # type: ignore - KV: T.Tensor(kv_shape, dtype), # type: ignore - Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore - Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore - TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore - Output: T.Tensor(o_shape, dtype), # type: ignore - Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore + TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore ): - with T.Kernel( - seq_len * REPLICATE_H, kv_group, threads=threads) as ( - bx, - by, - ): + with T.Kernel(seq_len * REPLICATE_H, kv_group, threads=threads) as ( + bx, + by, + ): Q_shared = T.alloc_shared([H_per_block, D], dtype) Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) KV_shared = T.alloc_shared([BI, D], dtype) @@ -122,17 +118,13 @@ def sparse_mla_fwd( T.copy(Q[bos + s_i, H0:H1, D:], Q_tail_shared) for i_i in T.Pipelined(NI, num_stages=num_stages): - for bi_i in T.Parallel(BI): - mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & ( - Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1) + mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1) for bi_i, d_i in T.Parallel(BI, D): - KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, - d_i] + KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, d_i] for bi_i, d_i in T.Parallel(BI, D_tail): - K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], - g_i, D + d_i] + K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, D + d_i] for h_i, bi_i in T.Parallel(H_per_block, BI): acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) @@ -177,16 +169,9 @@ def sparse_mla_fwd( return main -def sparse_mla_fwd_interface(q, - kv, - indices, - offsets, - sm_scale=None, - return_p_sum: bool = False, - d_v=512, - block_I=32, - num_stages=2, - threads=128): +def sparse_mla_fwd_interface( + q, kv, indices, offsets, sm_scale=None, return_p_sum: bool = False, d_v=512, block_I=32, num_stages=2, threads=128 +): is_casual = True assert return_p_sum == False, "This kernel file is for fwd only" assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() @@ -205,16 +190,8 @@ def sparse_mla_fwd_interface(q, token_indices = prepare_token_indices(offsets) kernel = sparse_mla_fwd( - heads, - dim, - tail_dim, - topk, - kv_group, - sm_scale, - is_casual, - block_I=block_I, - num_stages=num_stages, - threads=threads) + heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual, block_I=block_I, num_stages=num_stages, threads=threads + ) out, lse = kernel(q, kv, indices, offsets, token_indices) return out, lse @@ -224,9 +201,9 @@ def ref_sparse_mla_fwd_interface(Q, KV, Indices, offsets, sm_scale=None, is_casu KV = KV.float() all_o = [] for i in range(offsets.shape[0] - 1): - q = Q[None, offsets[i]:offsets[i + 1]] - kv = KV[None, offsets[i]:offsets[i + 1]] - indices = Indices[None, offsets[i]:offsets[i + 1]].clone() + q = Q[None, offsets[i] : offsets[i + 1]] + kv = KV[None, offsets[i] : offsets[i + 1]] + indices = Indices[None, offsets[i] : offsets[i + 1]].clone() indices = indices.transpose(1, 2) b, sq, h, dim_q = q.shape @@ -240,15 +217,15 @@ def ref_sparse_mla_fwd_interface(Q, KV, Indices, offsets, sm_scale=None, is_casu b, _, _, dim_v = v.shape g_index = g h_index = h // g - compressed_casual_mask = torch.arange( - 0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange( - 1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda").view(1, -1) + compressed_casual_mask = torch.arange(0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange( + 1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda" + ).view(1, -1) indices[indices > sk] = sk mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) mask = mask[..., :-1] mask = mask & compressed_casual_mask.view(1, 1, sq, sk) - mask[:, :, :1 - 1, 0] = True + mask[:, :, : 1 - 1, 0] = True mask = mask.view(b, g_index, 1, sq, sk) q = q.view(b, sq, g, -1, dim_q) @@ -265,18 +242,20 @@ def ref_sparse_mla_fwd_interface(Q, KV, Indices, offsets, sm_scale=None, is_casu return o.to(torch.bfloat16) -def test_sparse_mla_fwd(B=1, - S=4096, - H=128, - HKV=1, - DQK=576, - DV=512, - topk=2048, - dtype=torch.bfloat16, - check_correctness=True, - block_I=64, - num_stages=2, - threads=256): +def test_sparse_mla_fwd( + B=1, + S=4096, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=2048, + dtype=torch.bfloat16, + check_correctness=True, + block_I=64, + num_stages=2, + threads=256, +): torch.random.manual_seed(0) q = torch.randn((S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) kv = torch.randn((S, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) @@ -289,10 +268,9 @@ def test_sparse_mla_fwd(B=1, for t in range(seq_len): for h in range(HKV): i_i = torch.randperm(max(1, t))[:topk] - indices[offsets[i] + t, h, :len(i_i)] = i_i + indices[offsets[i] + t, h, : len(i_i)] = i_i - tl_out, tl_lse = sparse_mla_fwd_interface( - q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads) + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads) if check_correctness: # otherwise may cause out of memory @@ -301,8 +279,7 @@ def test_sparse_mla_fwd(B=1, print("assert_tensors_similar passed") def fn(): - return sparse_mla_fwd_interface( - q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads) + return sparse_mla_fwd_interface(q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads) from tilelang.profiler import do_bench @@ -329,4 +306,5 @@ if __name__ == "__main__": check_correctness=True, block_I=64, num_stages=2, - threads=256) + threads=256, + ) diff --git a/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py b/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py index 94bdb8fbe041f03b0de0e5e9f811c379a16f9216..6675215c7bf5136414fe588f95065cef18e8609f 100644 --- a/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py +++ b/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py @@ -30,14 +30,11 @@ def tl_sparse_mla_topk_reducesum_impl( num_stages=2, threads=128, ): - assert dim == tilelang.math.next_power_of_2( - dim), f"haven't check padding correctness yet, dim={dim}" - assert tail_dim == tilelang.math.next_power_of_2( - tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" - assert (topk % - block_I == 0), "otherwise will load some index=0 thus causing wrong kv to be loaded" + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" if sm_scale is None: - sm_scale = (1.0 / (dim + tail_dim))**0.5 + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 batch_plus_one = T.symbolic("batch_plus_one") seq_len = T.symbolic("seq_len") @@ -52,9 +49,9 @@ def tl_sparse_mla_topk_reducesum_impl( H = head_kv padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) if padded_H != H: - assert ( - kv_group == 1 - ), "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + ) BI = block_I NI = tilelang.cdiv(topk, block_I) D = dim @@ -78,19 +75,18 @@ def tl_sparse_mla_topk_reducesum_impl( @T.prim_func def tl_sparse_mla_topk_reducesum_kernel( - Q: T.Tensor(q_shape, dtype), # type: ignore - KV: T.Tensor(kv_shape, dtype), # type: ignore - Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore - Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore - Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore - TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore - ReduceSum: T.Tensor(reducesum_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore + TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore + ReduceSum: T.Tensor(reducesum_shape, accum_dtype), # type: ignore ): - with T.Kernel( - seq_len * REPLICATE_H, kv_group, threads=threads) as ( - bx, - by, - ): + with T.Kernel(seq_len * REPLICATE_H, kv_group, threads=threads) as ( + bx, + by, + ): Q_shared = T.alloc_shared([H_per_block, D], dtype) Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) KV_shared = T.alloc_shared([BI, D], dtype) @@ -119,17 +115,13 @@ def tl_sparse_mla_topk_reducesum_impl( T.copy(Lse[bos + s_i, H0:H1], lse) for i_i in T.Pipelined(NI, num_stages=num_stages): - for bi_i in T.Parallel(BI): - mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & ( - Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1) + mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1) for bi_i, d_i in T.Parallel(BI, D): - KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, - d_i] + KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, d_i] for bi_i, d_i in T.Parallel(BI, D_tail): - K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], - g_i, D + d_i] + K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, D + d_i] for h_i, bi_i in T.Parallel(H_per_block, BI): acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) @@ -150,7 +142,7 @@ def tl_sparse_mla_topk_reducesum_impl( for h_i, bi_i in T.Parallel(H_per_block, BI): acc_s[h_i, bi_i] = T.exp(acc_s[h_i, bi_i] * sm_scale - lse[h_i]) T.reduce_sum(acc_s, reducesum, dim=0) - T.copy(reducesum, ReduceSum[bos + s_i, g_i, r_i, i_i * BI:i_i * BI + BI]) + T.copy(reducesum, ReduceSum[bos + s_i, g_i, r_i, i_i * BI : i_i * BI + BI]) return tl_sparse_mla_topk_reducesum_kernel @@ -178,29 +170,26 @@ def sparse_mla_topk_reducesum_interface( return attn_score -def ref_mla_topk_softmax(Q: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, - offsets: torch.Tensor): +def ref_mla_topk_softmax(Q: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, offsets: torch.Tensor): # q: [batch, seq_len, heads, dim] # k: [batch, seq_len, dim] - sm_scale = Q.shape[-1]**-0.5 + sm_scale = Q.shape[-1] ** -0.5 all_lse = [] all_topk_score = [] for i in range(offsets.shape[0] - 1): - q = Q[offsets[i]:offsets[i + 1]] - k = K[offsets[i]:offsets[i + 1]] - topk_indices = TopkIndices[offsets[i]:offsets[i + 1]] + q = Q[offsets[i] : offsets[i + 1]] + k = K[offsets[i] : offsets[i + 1]] + topk_indices = TopkIndices[offsets[i] : offsets[i + 1]] seq_len = q.shape[0] - mask = (torch.arange(seq_len)[:, None] - >= torch.arange(seq_len)[None, :]).unsqueeze(-2).cuda() - logits = einsum(q, k, 's1 h d, s2 d -> s1 h s2') * sm_scale - logits = torch.where(mask, logits, float('-inf')) + mask = (torch.arange(seq_len)[:, None] >= torch.arange(seq_len)[None, :]).unsqueeze(-2).cuda() + logits = einsum(q, k, "s1 h d, s2 d -> s1 h s2") * sm_scale + logits = torch.where(mask, logits, float("-inf")) score = F.softmax(logits, dim=-1, dtype=torch.float32) score_sum = score.sum(dim=-2) topk_score = torch.gather(score_sum, dim=-1, index=topk_indices.to(torch.int64)) topk_score = topk_score / topk_score.sum(dim=-1, keepdim=True) max_logits = logits.amax(dim=-1).to(torch.float32) - lse = torch.log( - (logits - max_logits.unsqueeze(-1).to(torch.float32)).exp().sum(dim=-1)) + max_logits + lse = torch.log((logits - max_logits.unsqueeze(-1).to(torch.float32)).exp().sum(dim=-1)) + max_logits all_lse.append(lse) all_topk_score.append(topk_score) lse = torch.cat(all_lse, dim=0) @@ -222,20 +211,16 @@ def test_kernel( kv = torch.randn((S, D + tail_D)).cuda().bfloat16() offsets = torch.tensor([0, 1023, S], dtype=torch.int32).cuda() - topk_indices = repeat( - torch.arange(topk, dtype=torch.int32).cuda(), 'k -> s k', s=S).contiguous() + topk_indices = repeat(torch.arange(topk, dtype=torch.int32).cuda(), "k -> s k", s=S).contiguous() lse, ref_attn_score = ref_mla_topk_softmax(q, kv, topk_indices, offsets) kv = kv.unsqueeze(-2) topk_indices = topk_indices.unsqueeze(-2) - attn_score = sparse_mla_topk_reducesum_interface( - q, kv, topk_indices, lse, offsets, dim_v=D).squeeze(-2) - print( - f"attn_score err: {get_abs_err(attn_score, ref_attn_score):.6f} ratio: {get_err_ratio(attn_score, ref_attn_score):.6f}" - ) + attn_score = sparse_mla_topk_reducesum_interface(q, kv, topk_indices, lse, offsets, dim_v=D).squeeze(-2) + print(f"attn_score err: {get_abs_err(attn_score, ref_attn_score):.6f} ratio: {get_err_ratio(attn_score, ref_attn_score):.6f}") -if __name__ == '__main__': +if __name__ == "__main__": test_kernel() diff --git a/examples/dsa_sparse_finetune/utils.py b/examples/dsa_sparse_finetune/utils.py index 691af64dc3041f6048138bac95d39d0bfba326da..96afd064dc0f83f0e813fa4093f10d2fd309dfce 100644 --- a/examples/dsa_sparse_finetune/utils.py +++ b/examples/dsa_sparse_finetune/utils.py @@ -66,10 +66,8 @@ def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True): raise_assert: Whether to raise assertion error on failure """ sim = calculate_tensor_similarity(x, y, name) - diff = 1. - sim + diff = 1.0 - sim if not (0 <= diff <= eps): - print( - f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m" - ) + print(f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m") if raise_assert: assert False # noqa: B011 diff --git a/examples/dynamic_shape/example_dynamic.py b/examples/dynamic_shape/example_dynamic.py index be018c8b70c76e8ac22d4c6154aca5acc0324a2c..97ce7d9b33d24d4c5e3fd3666666651caca242da 100644 --- a/examples/dynamic_shape/example_dynamic.py +++ b/examples/dynamic_shape/example_dynamic.py @@ -29,9 +29,9 @@ def matmul_dynamic_mnk( @T.prim_func def dynamic_matmul( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -53,15 +53,14 @@ def matmul_dynamic_mnk( return dynamic_matmul -def matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, - accum_dtype, num_stages, threads): +def matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads): print( f"M: {M}, N: {N}, K: {K}, block_M: {block_M}, block_N: {block_N}, block_K: {block_K}, trans_A: {trans_A}, trans_B: {trans_B}, in_dtype: {in_dtype}, out_dtype: {out_dtype}, accum_dtype: {accum_dtype}, num_stages: {num_stages}, threads: {threads}" ) - kernel = matmul_dynamic_mnk(block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, - accum_dtype, num_stages, threads) + kernel = matmul_dynamic_mnk(block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads) import torch + if trans_A: A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype)) else: @@ -103,8 +102,7 @@ def main(M=16384, N=16384, K=16384): accum_dtype = "float32" num_stages = 3 threads = 128 - matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, - accum_dtype, num_stages, threads) + matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads) if __name__ == "__main__": diff --git a/examples/elementwise/example_elementwise_add.py b/examples/elementwise/example_elementwise_add.py index bc9bb4df5bbc9daf13a963d5714478c8d5b54f25..464312ced0974aa6b347ddd8912fb038042f89ce 100644 --- a/examples/elementwise/example_elementwise_add.py +++ b/examples/elementwise/example_elementwise_add.py @@ -12,10 +12,8 @@ def ref_program(x, y): @tilelang.jit(out_idx=[-1]) def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): - @T.prim_func - def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor( - (M, N), out_dtype)): + def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N), out_dtype)): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared((block_M, block_N), in_dtype) B_shared = T.alloc_shared((block_M, block_N), in_dtype) @@ -24,7 +22,7 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): T.copy(A[by * block_M, bx * block_N], A_shared) T.copy(B[by * block_M, bx * block_N], B_shared) - for (local_y, local_x) in T.Parallel(block_M, block_N): + for local_y, local_x in T.Parallel(block_M, block_N): C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x] T.copy(C_local, C_shared) T.copy(C_shared, C[by * block_M, bx * block_N]) @@ -41,19 +39,21 @@ def get_configs(M, N): def get_best_config(M, N): - def kernel(block_M=None, block_N=None, threads=None): return elementwise_add(M, N, block_M, block_N, "float32", "float32", threads) - autotuner = AutoTuner.from_kernel( - kernel=kernel, configs=get_configs(M, N)).set_compile_args( + autotuner = ( + AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N)) + .set_compile_args( out_idx=[-1], target="cuda", - ).set_profile_args( + ) + .set_profile_args( supply_type=tilelang.TensorSupplyType.Auto, ref_prog=ref_program, skip_check=False, ) + ) return autotuner.run(warmup=3, rep=20) diff --git a/examples/flash_attention/bert_padding.py b/examples/flash_attention/bert_padding.py index 7058fd773d6a1250eac24083a124e8c98543028c..15c4097ce77a21ebcd2060b53c629e7a89972b88 100644 --- a/examples/flash_attention/bert_padding.py +++ b/examples/flash_attention/bert_padding.py @@ -6,7 +6,6 @@ from einops import rearrange, repeat class IndexFirstAxis(torch.autograd.Function): - @staticmethod def forward(ctx, input, indices): ctx.save_for_backward(indices) @@ -15,9 +14,7 @@ class IndexFirstAxis(torch.autograd.Function): second_dim = other_shape.numel() # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. # return input[indices] - return torch.gather( - rearrange(input, "b ... -> b (...)"), 0, - repeat(indices, "z -> z d", d=second_dim)).reshape(-1, *other_shape) + return torch.gather(rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)).reshape(-1, *other_shape) @staticmethod def backward(ctx, grad_output): @@ -40,14 +37,12 @@ index_first_axis = IndexFirstAxis.apply class IndexPutFirstAxis(torch.autograd.Function): - @staticmethod def forward(ctx, values, indices, first_axis_dim): ctx.save_for_backward(indices) assert indices.ndim == 1 assert values.ndim >= 2 - output = torch.zeros( - first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) + output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. output[indices] = values # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) @@ -66,7 +61,6 @@ index_put_first_axis = IndexPutFirstAxis.apply class IndexFirstAxisResidual(torch.autograd.Function): - @staticmethod def forward(ctx, input, indices): ctx.save_for_backward(indices) @@ -128,7 +122,7 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng """ Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model). The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286). - + For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is: ``` [ @@ -177,9 +171,7 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng """ length = attention_mask_in_length.sum(dim=-1) seqlen = attention_mask_in_length.size(-1) - attention_mask_2d = torch.arange( - seqlen, device=length.device, dtype=length.dtype).expand(len(length), - seqlen) < length.unsqueeze(1) + attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1) real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten() seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx] indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten() diff --git a/examples/flash_attention/example_gqa_bwd.py b/examples/flash_attention/example_gqa_bwd.py index 968d1de334fe1e561fc589651fd6d8abe8851552..d1f5843e3c34dc0e80f0749266b13979e01a90e8 100644 --- a/examples/flash_attention/example_gqa_bwd.py +++ b/examples/flash_attention/example_gqa_bwd.py @@ -6,11 +6,13 @@ import argparse @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] @@ -20,11 +22,11 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc @T.prim_func def flash_fwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -40,25 +42,21 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc logsum = T.alloc_fragment([block_M], accum_dtype) T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = ( - T.ceildiv( - (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, - -T.infinity(acc_s.dtype), 0) + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_M): @@ -76,18 +74,20 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim_v): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): dtype = "float16" accum_dtype = "float" @@ -96,9 +96,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -107,26 +107,27 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim_v, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep def make_dq_layout(dQ): # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment - return T.Layout(dQ.shape, - lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + return T.Layout(dQ.shape, lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) @tilelang.jit( - out_idx=[1], pass_configs={ + out_idx=[1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk): dtype = "float16" accum_dtype = "float" @@ -135,35 +136,27 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk): @T.prim_func def flash_bwd_post( - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(shape, dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.annotate_layout({dQ: make_dq_layout(dQ)}) T.copy( - dQ[bz, bx * blk:(bx + 1) * blk, by, :], - dQ_out[bz, bx * blk:(bx + 1) * blk, by, :], + dQ[bz, bx * blk : (bx + 1) * blk, by, :], + dQ_out[bz, bx * blk : (bx + 1) * blk, by, :], ) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd_atomic_add(batch, - heads, - seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] @@ -173,15 +166,15 @@ def flashattn_bwd_atomic_add(batch, @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(k_shape, accum_dtype), # type: ignore - dV: T.Tensor(v_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -201,35 +194,36 @@ def flashattn_bwd_atomic_add(batch, dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype) dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - }) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + } + ) - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -241,29 +235,21 @@ def flashattn_bwd_atomic_add(batch, for i, j in T.Parallel(block_N, dim_qk): T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.copy(dv, dv_shared) - T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared) + T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared) T.copy(dk, dk_shared) - T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared) + T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared) return flash_bwd -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd_split(batch, - heads, - seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_split(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] @@ -275,15 +261,15 @@ def flashattn_bwd_split(batch, @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(dk_shape, dtype), # type: ignore - dV: T.Tensor(dv_shape, dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(dk_shape, dtype), # type: ignore + dV: T.Tensor(dv_shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -303,37 +289,38 @@ def flashattn_bwd_split(batch, dv_shared = T.alloc_shared([block_M, dim_v], dtype) dk_shared = T.alloc_shared([block_M, dim_qk], dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -346,16 +333,15 @@ def flashattn_bwd_split(batch, T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.copy(dv, dv_shared) - T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) + T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) T.copy(dk, dk_shared) - T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) + T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) return flash_bwd @torch.compile class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): BATCH, N_CTX, H, D_HEAD_QK = q.shape @@ -373,7 +359,10 @@ class _attention(torch.autograd.Function): def backward(ctx, do): q, k, v, o, lse = ctx.saved_tensors BATCH, N_CTX, H, D_HEAD_QK = q.shape - HEAD_KV, D_HEAD_V, = v.shape[-2], v.shape[-1] + ( + HEAD_KV, + D_HEAD_V, + ) = v.shape[-2], v.shape[-1] groups = H // HEAD_KV def maybe_contiguous(x): @@ -390,17 +379,8 @@ class _attention(torch.autograd.Function): if ctx.use_atomic: kernel = flashattn_bwd_atomic_add( - BATCH, - H, - N_CTX, - D_HEAD_QK, - D_HEAD_V, - ctx.causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=groups) + BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups + ) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK] shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V] @@ -413,17 +393,8 @@ class _attention(torch.autograd.Function): dv = dv.to(torch.float16) else: kernel = flashattn_bwd_split( - BATCH, - H, - N_CTX, - D_HEAD_QK, - D_HEAD_V, - ctx.causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=groups) + BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups + ) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel @@ -445,53 +416,45 @@ def ref_program(Q, K, V, is_causal, groups=1): # K: [B, T, HK, D_QK] # V: [B, T, HV, D_V] # HQ = HKV * groups - assert Q.size(2) == K.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" - assert Q.size(2) == V.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" dim_qk = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output -def main(BATCH: int = 1, - H: int = 32, - N_CTX: int = 256, - D_HEAD_QK: int = 192, - D_HEAD_V: int = 128, - groups: int = 16, - causal: bool = False, - use_atomic: bool = True): +def main( + BATCH: int = 1, + H: int = 32, + N_CTX: int = 256, + D_HEAD_QK: int = 192, + D_HEAD_V: int = 128, + groups: int = 16, + causal: bool = False, + use_atomic: bool = True, +): flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V total_flops = 3 * flops_per_qk + 2 * flops_per_v if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() head_kv = H // groups - K = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - V = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - dO = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() O = attention(Q, K, V, causal, groups, use_atomic) O.backward(dO, retain_graph=True) dQ, Q.grad = Q.grad.clone(), None @@ -508,7 +471,7 @@ def main(BATCH: int = 1, torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) - print('All checks passed.✅') + print("All checks passed.✅") def run(): O_ref.backward(dO, retain_graph=True) @@ -528,17 +491,15 @@ def main(BATCH: int = 1, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') - parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') - parser.add_argument('--causal', action='store_true', help='Causal flag') - parser.add_argument('--groups', type=int, default=16, help='groups') - parser.add_argument( - '--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') - parser.add_argument( - '--use_split', action='store_true', default=False, help='Use split for dK/dV') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K") + parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V") + parser.add_argument("--causal", action="store_true", help="Causal flag") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV") + parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV") args = parser.parse_args() # Handle backward compatibility and logic @@ -550,5 +511,4 @@ if __name__ == "__main__": # Default: use atomic use_atomic = True - main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, - use_atomic) + main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic) diff --git a/examples/flash_attention/example_gqa_bwd_tma_reduce.py b/examples/flash_attention/example_gqa_bwd_tma_reduce.py index c427908a6366f4e438b1dfee97c95f2bfa757837..c6cf336dffaadb951717c07420757cb5b84651ee 100644 --- a/examples/flash_attention/example_gqa_bwd_tma_reduce.py +++ b/examples/flash_attention/example_gqa_bwd_tma_reduce.py @@ -9,11 +9,13 @@ tilelang.disable_cache() @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] @@ -23,11 +25,11 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc @T.prim_func def flash_fwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -43,27 +45,23 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc logsum = T.alloc_fragment([block_M], accum_dtype) T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) # Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops # We should set it to negative large number instead T.fill(scores_max, T.Cast(accum_dtype, -1e30)) - loop_range = ( - T.ceildiv( - (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - T.Cast(accum_dtype, -1e30)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, T.Cast(accum_dtype, -1e30)) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, - -T.infinity(acc_s.dtype), 0) + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_M): @@ -81,18 +79,20 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim_v): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): dtype = "float16" accum_dtype = "float" @@ -101,9 +101,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -112,12 +112,12 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim_v, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep @@ -128,9 +128,11 @@ def make_dq_layout(dQ): @tilelang.jit( - out_idx=[3, 4, 5], pass_configs={ + out_idx=[3, 4, 5], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_postprocess(batch, heads, head_kv, seq_len, dim_qk, dim_v): dtype = "float16" accum_dtype = "float" @@ -141,46 +143,37 @@ def flashattn_bwd_postprocess(batch, heads, head_kv, seq_len, dim_qk, dim_v): @T.prim_func def flash_bwd_post( - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(k_shape, accum_dtype), # type: ignore - dV: T.Tensor(v_shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(q_shape, dtype), # type: ignore - dK_out: T.Tensor(k_shape, dtype), # type: ignore - dV_out: T.Tensor(v_shape, dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(q_shape, dtype), # type: ignore + dK_out: T.Tensor(k_shape, dtype), # type: ignore + dV_out: T.Tensor(v_shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.annotate_layout({dQ: make_dq_layout(dQ)}) - T.copy(dQ[bz, bx * blk:(bx + 1) * blk, by, :], dQ_out[bz, bx * blk:(bx + 1) * blk, - by, :]) + T.copy(dQ[bz, bx * blk : (bx + 1) * blk, by, :], dQ_out[bz, bx * blk : (bx + 1) * blk, by, :]) with T.Kernel(T.ceildiv(seq_len, blk), head_kv, batch, threads=128) as (bx, by, bz): - T.annotate_layout({ - dK: make_dq_layout(dK), - dV: make_dq_layout(dV), - }) - T.copy(dK[bz, bx * blk:(bx + 1) * blk, by, :], dK_out[bz, bx * blk:(bx + 1) * blk, - by, :]) - T.copy(dV[bz, bx * blk:(bx + 1) * blk, by, :], dV_out[bz, bx * blk:(bx + 1) * blk, - by, :]) + T.annotate_layout( + { + dK: make_dq_layout(dK), + dV: make_dq_layout(dV), + } + ) + T.copy(dK[bz, bx * blk : (bx + 1) * blk, by, :], dK_out[bz, bx * blk : (bx + 1) * blk, by, :]) + T.copy(dV[bz, bx * blk : (bx + 1) * blk, by, :], dV_out[bz, bx * blk : (bx + 1) * blk, by, :]) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd_atomic_add(batch, - heads, - seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] @@ -190,15 +183,15 @@ def flashattn_bwd_atomic_add(batch, @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(k_shape, accum_dtype), # type: ignore - dV: T.Tensor(v_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -219,37 +212,38 @@ def flashattn_bwd_atomic_add(batch, dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - dK: make_dq_layout(dK), - dV: make_dq_layout(dV), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - }) - - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + dK: make_dq_layout(dK), + dV: make_dq_layout(dV), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + } + ) + + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -259,33 +253,23 @@ def flashattn_bwd_atomic_add(batch, T.clear(dq) T.gemm(dsT_shared, K_shared, dq, transpose_A=True) T.copy(dq, dq_shared) - T.atomic_add(dQ[bz, k * block_N:(k + 1) * block_N, bx, :], dq_shared, use_tma=True) + T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared, use_tma=True) T.copy(dv, dv_shared) - T.atomic_add( - dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared, use_tma=True) + T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared, use_tma=True) T.copy(dk, dk_shared) - T.atomic_add( - dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared, use_tma=True) + T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared, use_tma=True) return flash_bwd -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd_split_novarlen(batch, - heads, - seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_split_novarlen(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] @@ -297,15 +281,15 @@ def flashattn_bwd_split_novarlen(batch, @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(dk_shape, dtype), # type: ignore - dV: T.Tensor(dv_shape, dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(dk_shape, dtype), # type: ignore + dV: T.Tensor(dv_shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -325,37 +309,38 @@ def flashattn_bwd_split_novarlen(batch, dv_shared = T.alloc_shared([block_M, dim_v], dtype) dk_shared = T.alloc_shared([block_M, dim_qk], dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) - - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) + + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -368,16 +353,15 @@ def flashattn_bwd_split_novarlen(batch, T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.copy(dv, dv_shared) - T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) + T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) T.copy(dk, dk_shared) - T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) + T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) return flash_bwd @torch.compile class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): BATCH, N_CTX, H, D_HEAD_QK = q.shape @@ -395,7 +379,10 @@ class _attention(torch.autograd.Function): def backward(ctx, do): q, k, v, o, lse = ctx.saved_tensors BATCH, N_CTX, H, D_HEAD_QK = q.shape - HEAD_KV, D_HEAD_V, = v.shape[-2], v.shape[-1] + ( + HEAD_KV, + D_HEAD_V, + ) = v.shape[-2], v.shape[-1] groups = H // HEAD_KV def maybe_contiguous(x): @@ -412,17 +399,8 @@ class _attention(torch.autograd.Function): if ctx.use_atomic: kernel = flashattn_bwd_atomic_add( - BATCH, - H, - N_CTX, - D_HEAD_QK, - D_HEAD_V, - ctx.causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=groups) + BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups + ) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK] shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V] @@ -433,17 +411,8 @@ class _attention(torch.autograd.Function): dq, dk, dv = mod_post(dq, dk, dv) else: kernel = flashattn_bwd_split_novarlen( - BATCH, - H, - N_CTX, - D_HEAD_QK, - D_HEAD_V, - ctx.causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=groups) + BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups + ) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel @@ -451,8 +420,7 @@ class _attention(torch.autograd.Function): dk = torch.empty(shape_k, dtype=torch.float16, device=q.device) dv = torch.empty(shape_v, dtype=torch.float16, device=q.device) kernel(q, k, v, do, lse, delta, dq, dk, dv) - dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), - torch.zeros_like(v, dtype=torch.float32)) + dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32)) dk, dv = dk.sum(0), dv.sum(0) return dq, dk, dv, None, None, None @@ -466,53 +434,45 @@ def ref_program(Q, K, V, is_causal, groups=1): # K: [B, T, HK, D_QK] # V: [B, T, HV, D_V] # HQ = HKV * groups - assert Q.size(2) == K.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" - assert Q.size(2) == V.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" dim_qk = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output -def main(BATCH: int = 1, - H: int = 32, - N_CTX: int = 256, - D_HEAD_QK: int = 192, - D_HEAD_V: int = 128, - groups: int = 16, - causal: bool = False, - use_atomic: bool = True): +def main( + BATCH: int = 1, + H: int = 32, + N_CTX: int = 256, + D_HEAD_QK: int = 192, + D_HEAD_V: int = 128, + groups: int = 16, + causal: bool = False, + use_atomic: bool = True, +): flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V total_flops = 3 * flops_per_qk + 2 * flops_per_v if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() head_kv = H // groups - K = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - V = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - dO = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() O = attention(Q, K, V, causal, groups, use_atomic) O.backward(dO, retain_graph=True) dQ, Q.grad = Q.grad.clone(), None @@ -529,7 +489,7 @@ def main(BATCH: int = 1, torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) - print('All checks passed.✅') + print("All checks passed.✅") def run(): O_ref.backward(dO, retain_graph=True) @@ -552,17 +512,15 @@ if __name__ == "__main__": print(f"Detected GPU compute capability: {arch}") assert float(arch) >= 9.0, "This example only supports GPU with compute capability >= 9.0" parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') - parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') - parser.add_argument('--causal', action='store_true', help='Causal flag') - parser.add_argument('--groups', type=int, default=16, help='groups') - parser.add_argument( - '--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') - parser.add_argument( - '--use_split', action='store_true', default=False, help='Use split for dK/dV') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K") + parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V") + parser.add_argument("--causal", action="store_true", help="Causal flag") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV") + parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV") args = parser.parse_args() # Handle backward compatibility and logic @@ -574,5 +532,4 @@ if __name__ == "__main__": # Default: use atomic use_atomic = True - main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, - use_atomic) + main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic) diff --git a/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py index a9604f4de4ce9335046ae3a2024e47266bbd8973..112438f767395c7fa720635e9fa96e592085451a 100644 --- a/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py +++ b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py @@ -15,32 +15,21 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): if mode == "full": lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) elif mode == "random": - lengths = torch.randint( - max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device) + lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device) elif mode == "third": lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) - padding_mask = ( - repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths) + padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths return padding_mask @tilelang.jit( - out_idx=[5, 6], pass_configs={ + out_idx=[5, 6], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn_fwd(batch, - total_q, - total_kv, - N_CTX, - heads, - max_seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - groups=1): - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn_fwd(batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [total_q, heads, dim_qk] k_shape = [total_kv, head_kv, dim_qk] @@ -51,13 +40,13 @@ def flashattn_fwd(batch, @T.prim_func def flash_fwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore - cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore - Output: T.Tensor(o_shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore + cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(max_seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -102,15 +91,17 @@ def flashattn_fwd(batch, if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and - (bx * block_M + i < q_current_seqlen and - k * block_N + j < k_current_seqlen), 0, - T.Cast(accum_dtype, -1e30)) + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= k * block_N + j) + and (bx * block_M + i < q_current_seqlen and k * block_N + j < k_current_seqlen), + 0, + T.Cast(accum_dtype, -1e30), + ) else: for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.if_then_else( - bx * block_M + i < q_current_seqlen and - k * block_N + j < k_current_seqlen, 0, T.Cast(accum_dtype, -1e30)) + bx * block_M + i < q_current_seqlen and k * block_N + j < k_current_seqlen, 0, T.Cast(accum_dtype, -1e30) + ) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) for i, d in T.Parallel(block_N, dim_v): if k * block_N + i < k_current_seqlen: @@ -148,9 +139,11 @@ def flashattn_fwd(batch, @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, total_q, N_CTX, max_seq_len, dim_v): dtype = "float16" accum_dtype = "float" @@ -159,10 +152,10 @@ def flashattn_bwd_preprocess(batch, heads, total_q, N_CTX, max_seq_len, dim_v): @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore - Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore + Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(max_seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -201,9 +194,11 @@ def make_dq_layout(dQ): @tilelang.jit( - out_idx=[3, 4, 5], pass_configs={ + out_idx=[3, 4, 5], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_postprocess(total_q, total_kv, heads, head_kv, dim_qk, dim_v): dtype = "float16" accum_dtype = "float" @@ -214,46 +209,39 @@ def flashattn_bwd_postprocess(total_q, total_kv, heads, head_kv, dim_qk, dim_v): @T.prim_func def flash_bwd_post( - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(k_shape, accum_dtype), # type: ignore - dV: T.Tensor(v_shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(q_shape, dtype), # type: ignore - dK_out: T.Tensor(k_shape, dtype), # type: ignore - dV_out: T.Tensor(v_shape, dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(q_shape, dtype), # type: ignore + dK_out: T.Tensor(k_shape, dtype), # type: ignore + dV_out: T.Tensor(v_shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(total_q, blk), heads, threads=128) as (bx, by): T.annotate_layout({dQ: make_dq_layout(dQ)}) - T.copy(dQ[bx * blk:(bx + 1) * blk, by, :], dQ_out[bx * blk:(bx + 1) * blk, by, :]) + T.copy(dQ[bx * blk : (bx + 1) * blk, by, :], dQ_out[bx * blk : (bx + 1) * blk, by, :]) with T.Kernel(T.ceildiv(total_kv, blk), head_kv, threads=128) as (bx, by): - T.annotate_layout({ - dK: make_dq_layout(dK), - dV: make_dq_layout(dV), - }) - T.copy(dK[bx * blk:(bx + 1) * blk, by, :], dK_out[bx * blk:(bx + 1) * blk, by, :]) - T.copy(dV[bx * blk:(bx + 1) * blk, by, :], dV_out[bx * blk:(bx + 1) * blk, by, :]) + T.annotate_layout( + { + dK: make_dq_layout(dK), + dV: make_dq_layout(dV), + } + ) + T.copy(dK[bx * blk : (bx + 1) * blk, by, :], dK_out[bx * blk : (bx + 1) * blk, by, :]) + T.copy(dV[bx * blk : (bx + 1) * blk, by, :], dV_out[bx * blk : (bx + 1) * blk, by, :]) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd_atomic_add(batch, - total_q, - total_kv, - N_CTX, - heads, - max_seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_atomic_add( + batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1 +): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [total_q, heads, dim_qk] k_shape = [total_kv, head_kv, dim_qk] @@ -264,20 +252,19 @@ def flashattn_bwd_atomic_add(batch, @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor(do_shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore - cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore - cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(k_shape, accum_dtype), # type: ignore - dV: T.Tensor(v_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor(do_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore + cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore ): - with T.Kernel( - heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz): + with T.Kernel(heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) dsT_shared = T.alloc_shared([block_M, block_N], dtype) q = T.alloc_shared([block_N, dim_qk], dtype) @@ -303,58 +290,54 @@ def flashattn_bwd_atomic_add(batch, q_current_seqlen = q_end_idx - q_start_idx k_current_seqlen = k_end_idx - k_start_idx - T.annotate_layout({ - dQ: make_dq_layout(dQ), - dK: make_dq_layout(dK), - dV: make_dq_layout(dV), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - }) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + dK: make_dq_layout(dK), + dV: make_dq_layout(dV), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + } + ) - T.copy(K[k_start_idx + by * block_M:k_start_idx + (by + 1) * block_M, bx // groups, :], - K_shared) - T.copy(V[k_start_idx + by * block_M:k_start_idx + (by + 1) * block_M, bx // groups, :], - V_shared) + T.copy(K[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) - loop_st = T.min( - T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen, - block_N)) if is_causal else 0 + loop_st = T.min(T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen, block_N)) if is_causal else 0 loop_ed = T.ceildiv(q_current_seqlen, block_N) for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy( - Q[q_start_idx + k_base * block_N:q_start_idx + (k_base + 1) * block_N, bx, :], - q) + T.copy(Q[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k_base * block_N:(k_base + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k_base * block_N : (k_base + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else((by * block_M + i <= k_base * block_N + j) and - (by * block_M + i < k_current_seqlen and - k_base * block_N + j < q_current_seqlen), - qkT[i, j], 0) + qkT[i, j] = T.if_then_else( + (by * block_M + i <= k_base * block_N + j) + and (by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen), + qkT[i, j], + 0, + ) else: for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.if_then_else( - by * block_M + i < k_current_seqlen and - k_base * block_N + j < q_current_seqlen, qkT[i, j], 0) + by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen, qkT[i, j], 0 + ) - T.copy( - dO[q_start_idx + k_base * block_N:q_start_idx + (k_base + 1) * block_N, bx, :], - do) + T.copy(dO[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], do) T.clear(dsT) # dsT: (block_kv, block_q) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k_base * block_N:(k_base + 1) * block_N], delta) + T.copy(Delta[bz, bx, k_base * block_N : (k_base + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) @@ -364,49 +347,40 @@ def flashattn_bwd_atomic_add(batch, T.gemm(dsT_shared, K_shared, dq, transpose_A=True) T.copy(dq, dq_shared) T.atomic_add( - dQ[q_start_idx + k_base * block_N:q_start_idx + k_base * block_N + block_N, - bx, :], + dQ[q_start_idx + k_base * block_N : q_start_idx + k_base * block_N + block_N, bx, :], dq_shared, memory_order="relaxed", - use_tma=True) + use_tma=True, + ) T.copy(dv, dv_shared) T.atomic_add( - dV[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M, - bx // groups, :], + dV[k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :], dv_shared, memory_order="relaxed", - use_tma=True) + use_tma=True, + ) T.copy(dk, dk_shared) T.atomic_add( - dK[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M, - bx // groups, :], + dK[k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :], dk_shared, memory_order="relaxed", - use_tma=True) + use_tma=True, + ) return flash_bwd -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd_split(batch, - total_q, - total_kv, - N_CTX, - heads, - max_seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_split( + batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1 +): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [total_q, heads, dim_qk] k_shape = [total_kv, head_kv, dim_qk] @@ -419,20 +393,19 @@ def flashattn_bwd_split(batch, @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor(do_shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore - cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore - cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(dk_shape, dtype), # type: ignore - dV: T.Tensor(dv_shape, dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor(do_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore + cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(dk_shape, dtype), # type: ignore + dV: T.Tensor(dv_shape, dtype), # type: ignore ): - with T.Kernel( - heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz): + with T.Kernel(heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) dsT_shared = T.alloc_shared([block_M, block_N], dtype) q = T.alloc_shared([block_N, dim_qk], dtype) @@ -457,59 +430,55 @@ def flashattn_bwd_split(batch, q_current_seqlen = q_end_idx - q_start_idx k_current_seqlen = k_end_idx - k_start_idx - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) - T.copy(K[k_start_idx + by * block_M:k_start_idx + (by + 1) * block_M, bx // groups, :], - K_shared) - T.copy(V[k_start_idx + by * block_M:k_start_idx + (by + 1) * block_M, bx // groups, :], - V_shared) + T.copy(K[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) - loop_st = T.min( - T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen, - block_N)) if is_causal else 0 + loop_st = T.min(T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen, block_N)) if is_causal else 0 loop_ed = T.ceildiv(q_current_seqlen, block_N) for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): # Note: The padding zero of varlen should be considered in T.copy - T.copy( - Q[q_start_idx + k_base * block_N:q_start_idx + (k_base + 1) * block_N, bx, :], - q) + T.copy(Q[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy( - dO[q_start_idx + k_base * block_N:q_start_idx + (k_base + 1) * block_N, bx, :], - do) + T.copy(dO[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k_base * block_N:(k_base + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k_base * block_N : (k_base + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else((by * block_M + i <= k_base * block_N + j) and - (by * block_M + i < k_current_seqlen and - k_base * block_N + j < q_current_seqlen), - qkT[i, j], 0) + qkT[i, j] = T.if_then_else( + (by * block_M + i <= k_base * block_N + j) + and (by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen), + qkT[i, j], + 0, + ) else: for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.if_then_else( - by * block_M + i < k_current_seqlen and - k_base * block_N + j < q_current_seqlen, qkT[i, j], 0) + by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen, qkT[i, j], 0 + ) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k_base * block_N:(k_base + 1) * block_N], delta) + T.copy(Delta[bz, bx, k_base * block_N : (k_base + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -520,62 +489,37 @@ def flashattn_bwd_split(batch, T.gemm(dsT_shared, K_shared, dq, transpose_A=True) for i, j in T.Parallel(block_N, dim_qk): if k_base * block_N + i < q_current_seqlen: - T.atomic_add( - dQ[q_start_idx + k_base * block_N + i, bx, j], - dq[i, j], - memory_order="relaxed") + T.atomic_add(dQ[q_start_idx + k_base * block_N + i, bx, j], dq[i, j], memory_order="relaxed") T.copy(dv, dv_shared) - T.copy( - dv_shared, - dV[bx % groups, k_start_idx + by * block_M:k_start_idx + by * block_M + block_M, - bx // groups, :]) + T.copy(dv_shared, dV[bx % groups, k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :]) T.copy(dk, dk_shared) - T.copy( - dk_shared, - dK[bx % groups, k_start_idx + by * block_M:k_start_idx + by * block_M + block_M, - bx // groups, :]) + T.copy(dk_shared, dK[bx % groups, k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :]) return flash_bwd @torch.compile class _attention(torch.autograd.Function): - @staticmethod - def forward(ctx, - q, - k, - v, - seqlens_q, - seqlens_k, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - causal, - groups=1, - use_atomic=True): + def forward( + ctx, q, k, v, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal, groups=1, use_atomic=True + ): BATCH, N_CTX, H, D_HEAD_QK = q.shape D_HEAD_V = v.shape[-1] block_M = 128 block_N = 64 - q_unpad, indices_q, _, _ = unpad_input( - q, (torch.arange(N_CTX, device=q.device).unsqueeze(0) < seqlens_q.unsqueeze(1))) - k_unpad, indices_k, _, _ = unpad_input( - k, (torch.arange(N_CTX, device=k.device).unsqueeze(0) < seqlens_k.unsqueeze(1))) - v_unpad, _, _, _ = unpad_input( - v, (torch.arange(N_CTX, device=v.device).unsqueeze(0) < seqlens_k.unsqueeze(1))) + q_unpad, indices_q, _, _ = unpad_input(q, (torch.arange(N_CTX, device=q.device).unsqueeze(0) < seqlens_q.unsqueeze(1))) + k_unpad, indices_k, _, _ = unpad_input(k, (torch.arange(N_CTX, device=k.device).unsqueeze(0) < seqlens_k.unsqueeze(1))) + v_unpad, _, _, _ = unpad_input(v, (torch.arange(N_CTX, device=v.device).unsqueeze(0) < seqlens_k.unsqueeze(1))) total_q = q_unpad.shape[0] total_kv = k_unpad.shape[0] - mod = flashattn_fwd(BATCH, total_q, total_kv, N_CTX, H, max_seqlen_q, D_HEAD_QK, D_HEAD_V, - causal, block_M, block_N, groups) + mod = flashattn_fwd(BATCH, total_q, total_kv, N_CTX, H, max_seqlen_q, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups) o_unpad, lse = mod(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k) o = pad_input(o_unpad, indices_q, BATCH, N_CTX) - ctx.save_for_backward(q_unpad, k_unpad, v_unpad, o_unpad, lse, seqlens_q, seqlens_k, - cu_seqlens_q, cu_seqlens_k) + ctx.save_for_backward(q_unpad, k_unpad, v_unpad, o_unpad, lse, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k) ctx.batch = BATCH ctx.causal = causal ctx.use_atomic = use_atomic @@ -590,8 +534,7 @@ class _attention(torch.autograd.Function): N_CTX = do.shape[1] q, k, v, o, lse_clone, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors # lse_clone = lse.clone() - do_unpad, _, _, _ = unpad_input( - do, (torch.arange(N_CTX, device=do.device).unsqueeze(0) < seqlens_q.unsqueeze(1))) + do_unpad, _, _, _ = unpad_input(do, (torch.arange(N_CTX, device=do.device).unsqueeze(0) < seqlens_q.unsqueeze(1))) total_q, H, D_HEAD_QK = q.shape total_kv, HEAD_KV, D_HEAD_V = v.shape groups = H // HEAD_KV @@ -624,7 +567,8 @@ class _attention(torch.autograd.Function): block_N, threads=256, num_stages=2, - groups=groups) + groups=groups, + ) dq = torch.zeros_like(q, dtype=torch.float32) dk = torch.zeros_like(k, dtype=torch.float32) dv = torch.zeros_like(v, dtype=torch.float32) @@ -645,13 +589,13 @@ class _attention(torch.autograd.Function): block_N, threads=256, num_stages=2, - groups=groups) + groups=groups, + ) dq = torch.zeros_like(q, dtype=torch.float32) dk = torch.empty(groups, *k.shape, dtype=torch.float16, device=q.device) dv = torch.empty(groups, *v.shape, dtype=torch.float16, device=q.device) kernel(q, k, v, do, lse_clone, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv) - dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), - torch.zeros_like(v, dtype=torch.float32)) + dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32)) dk, dv = dk.sum(0), dv.sum(0) dq = pad_input(dq, ctx.indices_q, BATCH, N_CTX) @@ -670,15 +614,13 @@ def ref_program(Q, K, V, padding_mask, is_causal, groups=1): # HQ = HKV * groups # To handle precision issue Q, K, V = Q.float(), K.float(), V.float() - assert Q.size(2) == K.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" - assert Q.size(2) == V.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" dim_qk = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) if padding_mask is not None: scores.masked_fill_(rearrange(~padding_mask, "b s -> b 1 1 s"), float("-inf")) @@ -686,41 +628,35 @@ def ref_program(Q, K, V, padding_mask, is_causal, groups=1): seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) if padding_mask is not None: output.masked_fill_(rearrange(~padding_mask, "b s -> b s 1 1"), 0.0) return output -def main(BATCH: int = 1, - H: int = 32, - N_CTX: int = 256, - D_HEAD_QK: int = 192, - D_HEAD_V: int = 128, - groups: int = 16, - causal: bool = False, - use_atomic: bool = True): +def main( + BATCH: int = 1, + H: int = 32, + N_CTX: int = 256, + D_HEAD_QK: int = 192, + D_HEAD_V: int = 128, + groups: int = 16, + causal: bool = False, + use_atomic: bool = True, +): flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V total_flops = 3 * flops_per_qk + 2 * flops_per_v if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() head_kv = H // groups - K = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - V = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - dO = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() padding_mask = generate_random_padding_mask(N_CTX, BATCH, "cuda", mode="random") seqlens_q = padding_mask.sum(dim=-1, dtype=torch.int32) cu_seqlens_q = F.pad(torch.cumsum(seqlens_q, dim=0, dtype=torch.int32), (1, 0)) @@ -729,8 +665,7 @@ def main(BATCH: int = 1, # In training backward pass, seqlens_k should be the same as seqlens_q seqlens_k, cu_seqlens_k, max_seqlen_k = seqlens_q, cu_seqlens_q, max_seqlen_q - O = attention(Q, K, V, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, - max_seqlen_k, causal, groups, use_atomic) + O = attention(Q, K, V, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal, groups, use_atomic) O.backward(dO, retain_graph=True) dQ, Q.grad = Q.grad.clone(), None dK, K.grad = K.grad.clone(), None @@ -772,17 +707,15 @@ if __name__ == "__main__": print(f"Detected GPU compute capability: {arch}") assert float(arch) >= 9.0, "This example only supports GPU with compute capability >= 9.0" parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') - parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') - parser.add_argument('--causal', action='store_true', help='Causal flag') - parser.add_argument('--groups', type=int, default=16, help='groups') - parser.add_argument( - '--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') - parser.add_argument( - '--use_split', action='store_true', default=False, help='Use split for dK/dV') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K") + parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V") + parser.add_argument("--causal", action="store_true", help="Causal flag") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV") + parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV") args = parser.parse_args() # Can be set to True/False for testing args.causal = True @@ -796,5 +729,4 @@ if __name__ == "__main__": # Default: use atomic use_atomic = True - main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, - use_atomic) + main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic) diff --git a/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py index e916812f522edeb203962e9d2c16cf9654aec257..adb7e06a8b1ce59250dee39187c9bc4d76328ebd 100644 --- a/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py +++ b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py @@ -6,11 +6,13 @@ import argparse @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] @@ -20,11 +22,11 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc @T.prim_func def flash_fwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -40,25 +42,21 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc logsum = T.alloc_fragment([block_M], accum_dtype) T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = ( - T.ceildiv( - (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, - -T.infinity(acc_s.dtype), 0) + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_M): @@ -76,18 +74,20 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim_v): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): dtype = "float16" accum_dtype = "float" @@ -96,9 +96,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -107,32 +107,24 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim_v, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) -def flashattn_bwd(batch, - heads, - seq_len, - dim_qk, - dim_v, - is_causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=1): - sm_scale = (1.0 / dim_qk)**0.5 - scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] @@ -142,15 +134,15 @@ def flashattn_bwd(batch, @T.prim_func def flash_bwd( - Q: T.Tensor(q_shape, dtype), # type: ignore - K: T.Tensor(k_shape, dtype), # type: ignore - V: T.Tensor(v_shape, dtype), # type: ignore - dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(q_shape, accum_dtype), # type: ignore - dK: T.Tensor(k_shape, accum_dtype), # type: ignore - dV: T.Tensor(v_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) @@ -171,45 +163,39 @@ def flashattn_bwd(batch, dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype) - T.annotate_layout({ - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - }) - - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + } + ) + + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) - T.gemm( - K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) - T.gemm( - V_shared, - do, - dsT, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow, - wg_wait=-1) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) T.wait_wgmma(1) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -221,18 +207,17 @@ def flashattn_bwd(batch, T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1) T.wait_wgmma(0) T.copy(dq, dq_shared) - T.atomic_add(dQ[bz, k * block_N:(k + 1) * block_N, bx, :], dq_shared) + T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared) T.copy(dv, dv_shared) - T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared) + T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared) T.copy(dk, dk_shared) - T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared) + T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared) return flash_bwd @torch.compile class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): BATCH, N_CTX, H, D_HEAD_QK = q.shape @@ -250,7 +235,10 @@ class _attention(torch.autograd.Function): def backward(ctx, do): q, k, v, o, lse = ctx.saved_tensors BATCH, N_CTX, H, D_HEAD_QK = q.shape - HEAD_KV, D_HEAD_V, = v.shape[-2], v.shape[-1] + ( + HEAD_KV, + D_HEAD_V, + ) = v.shape[-2], v.shape[-1] groups = H // HEAD_KV def maybe_contiguous(x): @@ -264,18 +252,7 @@ class _attention(torch.autograd.Function): mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V) delta = mod_prep(o, do) - kernel = flashattn_bwd( - BATCH, - H, - N_CTX, - D_HEAD_QK, - D_HEAD_V, - ctx.causal, - block_M, - block_N, - threads=256, - num_stages=2, - groups=groups) + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK] shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V] @@ -298,52 +275,36 @@ def ref_program(Q, K, V, is_causal, groups=1): # K: [B, T, HK, D_QK] # V: [B, T, HV, D_V] # HQ = HKV * groups - assert Q.size(2) == K.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" - assert Q.size(2) == V.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" dim_qk = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output -def main(BATCH: int = 1, - H: int = 32, - N_CTX: int = 256, - D_HEAD_QK: int = 192, - D_HEAD_V: int = 128, - groups: int = 16, - causal: bool = False): +def main(BATCH: int = 1, H: int = 32, N_CTX: int = 256, D_HEAD_QK: int = 192, D_HEAD_V: int = 128, groups: int = 16, causal: bool = False): flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V total_flops = 3 * flops_per_qk + 2 * flops_per_v if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() head_kv = H // groups - K = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - V = ( - torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) - dO = ( - torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() O = attention(Q, K, V, causal, groups) O.backward(dO, retain_graph=True) dQ, Q.grad = Q.grad.clone(), None @@ -360,7 +321,7 @@ def main(BATCH: int = 1, torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) - print('All checks passed.✅') + print("All checks passed.✅") def run(): O_ref.backward(dO, retain_graph=True) @@ -380,13 +341,13 @@ def main(BATCH: int = 1, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') - parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') - parser.add_argument('--causal', action='store_true', help='Causal flag') - parser.add_argument('--groups', type=int, default=16, help='groups') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K") + parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V") + parser.add_argument("--causal", action="store_true", help="Causal flag") + parser.add_argument("--groups", type=int, default=16, help="groups") args = parser.parse_args() main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal) diff --git a/examples/flash_attention/example_gqa_fwd_bshd.py b/examples/flash_attention/example_gqa_fwd_bshd.py index a6d3b5f203b0a1510eb5d8c0cc48184cffa84175..408d6e50796dca7b5cc27897520e89a7f9390372 100644 --- a/examples/flash_attention/example_gqa_fwd_bshd.py +++ b/examples/flash_attention/example_gqa_fwd_bshd.py @@ -9,7 +9,6 @@ from functools import partial class FlashAttentionTuneSpace: - def __init__( self, block_sizes=(64, 128, 256), @@ -40,7 +39,7 @@ def get_configs(user_config=None): warp_M = block_M // warp_count warp_N = block_N // warp_count - if (warp_M % config.warp_alignment != 0 or warp_N % config.warp_alignment != 0): + if warp_M % config.warp_alignment != 0 or warp_N % config.warp_alignment != 0: continue shared_mem = 2 * config.dtype_bytes * config.dim * (block_M + block_N) @@ -48,31 +47,26 @@ def get_configs(user_config=None): continue for num_stages in config.num_stages_range: - valid_configs.append({ - "block_M": block_M, - "block_N": block_N, - "num_stages": num_stages, - "threads": threads, - }) + valid_configs.append( + { + "block_M": block_M, + "block_N": block_N, + "num_stages": num_stages, + "threads": threads, + } + ) return valid_configs @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, - heads, - seq_len, - dim, - is_causal, - groups=1, - block_M=64, - block_N=64, - num_stages=0, - threads=128): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch, heads, seq_len, dim, is_causal, groups=1, block_M=64, block_N=64, num_stages=0, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] @@ -90,15 +84,13 @@ def flashattn(batch, by: T.int32, bz: T.int32, ): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), - 0) + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro @@ -111,18 +103,18 @@ def flashattn(batch, by: T.int32, bz: T.int32, ): - T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -148,18 +140,18 @@ def flashattn(batch, @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -175,25 +167,24 @@ def flashattn(batch, scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) return main @@ -203,50 +194,34 @@ def ref_program(Q, K, V, is_causal, groups=1): # K: [B, T, HK, D] # V: [B, T, HV, D] # HQ = HKV * groups - assert Q.size(2) == K.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" - assert Q.size(2) == V.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" dim = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output -def main(batch: int = 1, - heads: int = 64, - seq_len: int = 4096, - dim: int = 128, - is_causal: bool = False, - groups: int = 16, - tune: bool = False): +def main( + batch: int = 1, heads: int = 64, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 16, tune: bool = False +): flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim total_flops = 2 * flops_per_matmul if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_len, - dim, - is_causal, - groups=groups, - block_M=64, - block_N=64, - num_stages=2, - threads=128) + if not tune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal, groups=groups, block_M=64, block_N=64, num_stages=2, threads=128) ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) @@ -270,12 +245,12 @@ def main(batch: int = 1, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=64, help='heads') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--tune', action='store_true', help='tune configs') - parser.add_argument('--groups', type=int, default=16, help='groups') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") + parser.add_argument("--groups", type=int, default=16, help="groups") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups, args.tune) diff --git a/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py index 03ad15e9411bc26a6681754616e763c9cba6f511..3492be7646cf29e4c7b089913f934fe5f0d442b6 100644 --- a/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py @@ -24,9 +24,11 @@ def get_configs(): rep=10, ) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn( batch, heads, @@ -39,7 +41,7 @@ def flashattn( num_stages=0, threads=128, ): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] @@ -57,15 +59,13 @@ def flashattn( by: T.int32, bz: T.int32, ): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), - 0) + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro @@ -78,18 +78,18 @@ def flashattn( by: T.int32, bz: T.int32, ): - T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -115,18 +115,18 @@ def flashattn( @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -142,30 +142,30 @@ def flashattn( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined( - loop_range, - num_stages=num_stages, - order=[-1, 0, 3, 1, -1, 2], - stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]): + loop_range, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) return main @@ -175,23 +175,21 @@ def ref_program(Q, K, V, is_causal, groups=1): # K: [B, T, HK, D] # V: [B, T, HV, D] # HQ = HKV * groups - assert Q.size(2) == K.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" - assert Q.size(2) == V.size( - 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" dim = Q.size(-1) K = K.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -209,18 +207,8 @@ def main( if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_len, - dim, - is_causal, - groups=groups, - block_M=128, - block_N=128, - num_stages=2, - threads=256) + if not tune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal, groups=groups, block_M=128, block_N=128, num_stages=2, threads=256) ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) @@ -244,12 +232,12 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=64, help='heads') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--tune', action='store_true', help='tune configs') - parser.add_argument('--groups', type=int, default=16, help='groups') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") + parser.add_argument("--groups", type=int, default=16, help="groups") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups, args.tune) diff --git a/examples/flash_attention/example_gqa_fwd_varlen.py b/examples/flash_attention/example_gqa_fwd_varlen.py index ccc50e413f6acd92f91f85f6f7d0b4474d755a44..87b11f71bebd119a0971ce68a88637ccd904d792 100644 --- a/examples/flash_attention/example_gqa_fwd_varlen.py +++ b/examples/flash_attention/example_gqa_fwd_varlen.py @@ -10,14 +10,14 @@ from varlen_utils import generate_random_padding_mask, generate_qkv def attention_ref( - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - causal=False, - window_size=(-1, -1), - upcast=True, + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + causal=False, + window_size=(-1, -1), + upcast=True, ): if causal: window_size = (window_size[0], 0) @@ -26,7 +26,7 @@ def attention_ref( q, k, v = q.float(), k.float(), v.float() b, T, Hq, D = q.shape S = k.shape[1] - scale = (1.0 / D)**0.5 + scale = (1.0 / D) ** 0.5 k = repeat(k, "b s h d -> b s (h g) d", g=Hq // k.shape[2]) v = repeat(v, "b s h d -> b s (h g) d", g=Hq // v.shape[2]) scores = torch.einsum("bthd,bshd->bhts", q, k) @@ -54,21 +54,13 @@ def attention_ref( @tilelang.jit( - out_idx=[6], pass_configs={ + out_idx=[6], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch_size, - groups, - UQ, - UKV, - heads, - dim, - is_causal, - block_M=64, - block_N=64, - num_stages=1, - threads=128): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch_size, groups, UQ, UKV, heads, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [UQ, heads, dim] kv_shape = [UKV, head_kv, dim] @@ -78,17 +70,15 @@ def flashattn(batch_size, @T.prim_func def main( - Q_unpad: T.Tensor(q_shape, dtype), - K_unpad: T.Tensor(kv_shape, dtype), - V_unpad: T.Tensor(kv_shape, dtype), - cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), - cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), - max_seqlen_q: T.int32, - Output_unpad: T.Tensor(o_shape, dtype), + Q_unpad: T.Tensor(q_shape, dtype), + K_unpad: T.Tensor(kv_shape, dtype), + V_unpad: T.Tensor(kv_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), + cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), + max_seqlen_q: T.int32, + Output_unpad: T.Tensor(o_shape, dtype), ): - with T.Kernel( - T.ceildiv(max_seqlen_q, block_M), heads, batch_size, - threads=threads) as (bx, by, bz): + with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -102,10 +92,12 @@ def flashattn(batch_size, scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), - }) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + } + ) batch_idx = bz head_idx = by @@ -119,36 +111,34 @@ def flashattn(batch_size, q_current_seqlen = q_end_idx - q_start_idx kv_current_seqlen = k_end_idx - kv_start_idx - T.copy( - Q_unpad[q_start_idx + bx * block_M:q_start_idx + (bx + 1) * block_M, head_idx, :], - Q_shared) + T.copy(Q_unpad[q_start_idx + bx * block_M : q_start_idx + (bx + 1) * block_M, head_idx, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min( - T.ceildiv(q_current_seqlen + - (bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N)) - if is_causal else T.ceildiv(kv_current_seqlen, block_N)) + T.min(T.ceildiv(q_current_seqlen + (bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N)) + if is_causal + else T.ceildiv(kv_current_seqlen, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): - T.copy( - K_unpad[kv_start_idx + k * block_N:kv_start_idx + (k + 1) * block_N, - kv_head_idx, :], K_shared) + T.copy(K_unpad[kv_start_idx + k * block_N : kv_start_idx + (k + 1) * block_N, kv_head_idx, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, - j] = T.if_then_else((bx * block_M + i < k * block_N + j) or - (bx * block_M + i >= q_current_seqlen or - k * block_N + j >= kv_current_seqlen), -1e9, 0) + acc_s[i, j] = T.if_then_else( + (bx * block_M + i < k * block_N + j) + or (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), + -1e9, + 0, + ) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or - k * block_N + j >= kv_current_seqlen), -1e9, - 0) + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), -1e9, 0 + ) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -170,9 +160,7 @@ def flashattn(batch_size, for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] - T.copy( - V_unpad[kv_start_idx + k * block_N:kv_start_idx + (k + 1) * block_N, - kv_head_idx, :], V_shared) + T.copy(V_unpad[kv_start_idx + k * block_N : kv_start_idx + (k + 1) * block_N, kv_head_idx, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @@ -187,13 +175,9 @@ def flashattn(batch_size, return main -def main(batch: int = 1, - heads: int = 64, - q_seqlen: int = 2048, - k_seqlen: int = 2048, - dim: int = 128, - groups: int = 16, - is_causal: bool = False): +def main( + batch: int = 1, heads: int = 64, q_seqlen: int = 2048, k_seqlen: int = 2048, dim: int = 128, groups: int = 16, is_causal: bool = False +): assert heads % groups == 0, "heads must be divisible by groups" flops_per_matmul = 2.0 * batch * heads * q_seqlen * k_seqlen * dim @@ -231,24 +215,12 @@ def main(batch: int = 1, output_pad_fn, _, _, - ) = generate_qkv( - q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) UQ = q_unpad.shape[0] UKV = k_unpad.shape[0] - kernel = flashattn( - batch, - groups, - UQ, - UKV, - heads, - dim, - is_causal, - block_M=128, - block_N=128, - num_stages=2, - threads=256) + kernel = flashattn(batch, groups, UQ, UKV, heads, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) out = output_pad_fn(out_unpad) @@ -263,23 +235,19 @@ def main(batch: int = 1, ) torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2) print("All checks passed.✅") - latency = do_bench( - lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q), - _n_warmup=5, - _n_repeat=5) + latency = do_bench(lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q), _n_warmup=5, _n_repeat=5) print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=64, help='query heads') - parser.add_argument('--groups', type=int, default=16, help='groups') - parser.add_argument('--q_seqlen', type=int, default=2048, help='query sequence length') - parser.add_argument('--k_seqlen', type=int, default=2048, help='key/value sequence length') - parser.add_argument('--dim', type=int, default=128, help='head dim') - parser.add_argument('--is_causal', action='store_true', help='causal attention') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="query heads") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--q_seqlen", type=int, default=2048, help="query sequence length") + parser.add_argument("--k_seqlen", type=int, default=2048, help="key/value sequence length") + parser.add_argument("--dim", type=int, default=128, help="head dim") + parser.add_argument("--is_causal", action="store_true", help="causal attention") args = parser.parse_args() - main(args.batch, args.heads, args.q_seqlen, args.k_seqlen, args.dim, args.groups, - args.is_causal) + main(args.batch, args.heads, args.q_seqlen, args.k_seqlen, args.dim, args.groups, args.is_causal) diff --git a/examples/flash_attention/example_mha_bwd_bhsd.py b/examples/flash_attention/example_mha_bwd_bhsd.py index d91d1770fda5ecd48ae930e8cd9f17ba52d6fd91..81eb6d1e5cb2551c90200aedba2d2c7b0bf5b72d 100644 --- a/examples/flash_attention/example_mha_bwd_bhsd.py +++ b/examples/flash_attention/example_mha_bwd_bhsd.py @@ -7,22 +7,24 @@ import argparse @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] dtype = "float16" accum_dtype = "float" @T.prim_func def flash_fwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - Output: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + Output: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -39,28 +41,24 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): logsum = T.alloc_fragment([block_M], accum_dtype) T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) # T.copy(Q_shared, Q_local) # for i, j in T.Parallel(block_M, dim): # Q_local[i, j] *= scale - loop_range = ( - T.ceildiv( - (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, - -T.infinity(acc_s.dtype), 0) + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_M): @@ -78,18 +76,20 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim): dtype = "float16" accum_dtype = "float" @@ -98,9 +98,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -109,26 +109,27 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], do) + T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep def make_dq_layout(dQ): # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment - return T.Layout(dQ.shape, - lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) @tilelang.jit( - out_idx=[1], pass_configs={ + out_idx=[1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_postprocess(batch, heads, seq_len, dim): dtype = "float16" accum_dtype = "float" @@ -137,40 +138,42 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim): @T.prim_func def flash_bwd_post( - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(shape, dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.annotate_layout({dQ: make_dq_layout(dQ)}) T.copy( - dQ[bz, by, bx * blk:(bx + 1) * blk, :], - dQ_out[bz, by, bx * blk:(bx + 1) * blk, :], + dQ[bz, by, bx * blk : (bx + 1) * blk, :], + dQ_out[bz, by, bx * blk : (bx + 1) * blk, :], ) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - sm_scale = (1.0 / dim)**0.5 - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + sm_scale = (1.0 / dim) ** 0.5 + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] dtype = "float16" accum_dtype = "float" @T.prim_func def flash_bwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dK: T.Tensor(shape, dtype), # type: ignore - dV: T.Tensor(shape, dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, dtype), # type: ignore + dV: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim], dtype) @@ -194,38 +197,39 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): dv_shared = T.alloc_shared([block_M, dim], dtype) dk_shared = T.alloc_shared([block_M, dim], dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - }) - T.copy(K[bz, bx, by * block_M:(by + 1) * block_M, :], K_shared) - T.copy(V[bz, bx, by * block_M:(by + 1) * block_M, :], V_shared) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) + T.copy(K[bz, bx, by * block_M : (by + 1) * block_M, :], K_shared) + T.copy(V[bz, bx, by * block_M : (by + 1) * block_M, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=2): - T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q) + T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) # We don't need to handle OOB positions for non-causal cases, # since OOB values won't affect other positions here. - T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], do) + T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -238,14 +242,13 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): T.atomic_add(dQ[bz, bx, k * block_N + i, j], dq[i, j]) T.copy(dv, dv_shared) T.copy(dk, dk_shared) - T.copy(dv_shared, dV[bz, bx, by * block_M:(by + 1) * block_M, :]) - T.copy(dk_shared, dK[bz, bx, by * block_M:(by + 1) * block_M, :]) + T.copy(dv_shared, dV[bz, bx, by * block_M : (by + 1) * block_M, :]) + T.copy(dk_shared, dK[bz, bx, by * block_M : (by + 1) * block_M, :]) return flash_bwd class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal): BATCH, H, N_CTX, D_HEAD = q.shape @@ -287,15 +290,15 @@ attention = _attention.apply def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bhqd,bhkd->bhqk', Q, K) + scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(2) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V) + output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V) return output @@ -310,9 +313,7 @@ def main( total_flops = 5 * flops_per_matmul if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, H, N_CTX, D_HEAD, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, H, N_CTX, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_() K = torch.empty_like(Q).normal_().requires_grad_() V = torch.empty_like(Q).normal_().requires_grad_() dO = torch.randn_like(Q) @@ -353,10 +354,10 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head', type=int, default=64, help='Head dimension') - parser.add_argument('--causal', type=bool, default=False, help='Causal flag') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head", type=int, default=64, help="Head dimension") + parser.add_argument("--causal", type=bool, default=False, help="Causal flag") args = parser.parse_args() main(args.batch, args.h, args.n_ctx, args.d_head, args.causal) diff --git a/examples/flash_attention/example_mha_bwd_bshd.py b/examples/flash_attention/example_mha_bwd_bshd.py index 7c85f982e4e0e68a9909bd843096b83b999e7611..427a0f694a91bbe3148668fe988f530f88731f5c 100644 --- a/examples/flash_attention/example_mha_bwd_bshd.py +++ b/examples/flash_attention/example_mha_bwd_bshd.py @@ -7,22 +7,24 @@ import argparse @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] dtype = "float16" accum_dtype = "float" @T.prim_func def flash_fwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - Output: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + Output: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -38,25 +40,21 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = ( - T.ceildiv( - (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, - -T.infinity(acc_s.dtype), 0) + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_M): @@ -74,18 +72,20 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim): dtype = "float16" accum_dtype = "float" @@ -94,9 +94,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -105,26 +105,27 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep def make_dq_layout(dQ): # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment - return T.Layout(dQ.shape, - lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + return T.Layout(dQ.shape, lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) @tilelang.jit( - out_idx=[1], pass_configs={ + out_idx=[1], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_postprocess(batch, heads, seq_len, dim): dtype = "float16" accum_dtype = "float" @@ -133,40 +134,42 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim): @T.prim_func def flash_bwd_post( - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dQ_out: T.Tensor(shape, dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): T.annotate_layout({dQ: make_dq_layout(dQ)}) T.copy( - dQ[bz, bx * blk:(bx + 1) * blk, by, :], - dQ_out[bz, bx * blk:(bx + 1) * blk, by, :], + dQ[bz, bx * blk : (bx + 1) * blk, by, :], + dQ_out[bz, bx * blk : (bx + 1) * blk, by, :], ) return flash_bwd_post -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - sm_scale = (1.0 / dim)**0.5 - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + sm_scale = (1.0 / dim) ** 0.5 + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] dtype = "float16" accum_dtype = "float" @T.prim_func def flash_bwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dK: T.Tensor(shape, dtype), # type: ignore - dV: T.Tensor(shape, dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, dtype), # type: ignore + dV: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim], dtype) @@ -190,35 +193,36 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): dv_shared = T.alloc_shared([block_M, dim], dtype) dk_shared = T.alloc_shared([block_M, dim], dtype) - T.annotate_layout({ - dQ: make_dq_layout(dQ), - }) - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared) + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + } + ) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=2): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) # We don't need to handle OOB positions for non-causal cases, # since OOB values won't affect other positions here. - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -231,14 +235,13 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.copy(dv, dv_shared) T.copy(dk, dk_shared) - T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :]) - T.copy(dk_shared, dK[bz, by * block_M:(by + 1) * block_M, bx, :]) + T.copy(dv_shared, dV[bz, by * block_M : (by + 1) * block_M, bx, :]) + T.copy(dk_shared, dK[bz, by * block_M : (by + 1) * block_M, bx, :]) return flash_bwd class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal): BATCH, N_CTX, H, D_HEAD = q.shape @@ -280,15 +283,15 @@ attention = _attention.apply def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -303,9 +306,7 @@ def main( total_flops = 5 * flops_per_matmul if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_() K = torch.empty_like(Q).normal_().requires_grad_() V = torch.empty_like(Q).normal_().requires_grad_() dO = torch.randn_like(Q) @@ -344,10 +345,10 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head', type=int, default=64, help='Head dimension') - parser.add_argument('--causal', type=bool, default=False, help='Causal flag') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head", type=int, default=64, help="Head dimension") + parser.add_argument("--causal", type=bool, default=False, help="Causal flag") args = parser.parse_args() main(args.batch, args.h, args.n_ctx, args.d_head, args.causal) diff --git a/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py index e8ee5d97393b00465a6d81c43fb6de2b0bcedfd1..813f379ca0afa3b349c2f91a5cc03d356e19d657 100644 --- a/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py @@ -7,22 +7,24 @@ import argparse @tilelang.jit( - out_idx=[3, 4], pass_configs={ + out_idx=[3, 4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] dtype = "float16" accum_dtype = "float" @T.prim_func def flash_fwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - Output: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + Output: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -38,26 +40,22 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): logsum = T.alloc_fragment([block_M], accum_dtype) T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = ( - T.ceildiv( - (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, - -T.infinity(acc_s.dtype), 0) + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_M): @@ -75,18 +73,20 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) return flash_fwd @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim): dtype = "float16" accum_dtype = "float" @@ -95,9 +95,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): @T.prim_func def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): o = T.alloc_fragment([blk, blk], dtype) @@ -106,37 +106,39 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): delta = T.alloc_fragment([blk], accum_dtype) T.clear(acc) for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) - T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) for i, j in T.Parallel(blk, blk): acc[i, j] += o[i, j] * do[i, j] T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) return flash_bwd_prep -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, -}) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - sm_scale = (1.0 / dim)**0.5 - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + sm_scale = (1.0 / dim) ** 0.5 + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] dtype = "float16" accum_dtype = "float" @T.prim_func def flash_bwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dK: T.Tensor(shape, dtype), # type: ignore - dV: T.Tensor(shape, dtype), # type: ignore + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, dtype), # type: ignore + dV: T.Tensor(shape, dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim], dtype) @@ -161,49 +163,43 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): dk_shared = T.alloc_shared([block_M, dim], dtype) dq_shared = T.alloc_shared([block_N, dim], accum_dtype) - T.annotate_layout({ - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), - }) - - T.copy(K[bz, by * block_M:(by + 1) * block_M, bx, :], K_shared) - T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared) + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), + } + ) + + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx, :], V_shared) T.clear(dv) T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_st, loop_ed, num_stages=2): - T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) T.clear(qkT) - T.gemm( - K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) T.clear(dsT) - T.gemm( - V_shared, - do, - dsT, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow, - wg_wait=-1) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) T.wait_wgmma(1) - T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) if is_causal: for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], - 0) + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) # We don't need to handle OOB positions for non-causal cases, # since OOB values won't affect other positions here. T.wait_wgmma(0) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) - T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) for i, j in T.Parallel(block_M, block_N): dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale @@ -214,17 +210,16 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1) T.wait_wgmma(0) T.copy(dq, dq_shared) - T.atomic_add(dQ[bz, k * block_N:(k + 1) * block_N, bx, :], dq_shared) + T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared) T.copy(dv, dv_shared) T.copy(dk, dk_shared) - T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :]) - T.copy(dk_shared, dK[bz, by * block_M:(by + 1) * block_M, bx, :]) + T.copy(dv_shared, dV[bz, by * block_M : (by + 1) * block_M, bx, :]) + T.copy(dk_shared, dK[bz, by * block_M : (by + 1) * block_M, bx, :]) return flash_bwd class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal): BATCH, N_CTX, H, D_HEAD = q.shape @@ -266,15 +261,15 @@ attention = _attention.apply def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -289,9 +284,7 @@ def main( total_flops = 5 * flops_per_matmul if causal: total_flops *= 0.5 - Q = ( - torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, - device="cuda").normal_().requires_grad_()) + Q = torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_() K = torch.empty_like(Q).normal_().requires_grad_() V = torch.empty_like(Q).normal_().requires_grad_() dO = torch.randn_like(Q) @@ -311,7 +304,7 @@ def main( assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2) - print('All checks passed.✅') + print("All checks passed.✅") def run(): O_ref.backward(dO, retain_graph=True) @@ -329,10 +322,10 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='Batch size') - parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') - parser.add_argument('--d_head', type=int, default=64, help='Head dimension') - parser.add_argument('--causal', type=bool, default=False, help='Causal flag') + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head", type=int, default=64, help="Head dimension") + parser.add_argument("--causal", type=bool, default=False, help="Causal flag") args = parser.parse_args() main(args.batch, args.h, args.n_ctx, args.d_head, args.causal) diff --git a/examples/flash_attention/example_mha_fwd_bhsd.py b/examples/flash_attention/example_mha_fwd_bhsd.py index e0e0bca22e2f20517bedf637d4e229ebdef6b6aa..7fa5549d0e96f8335e20014bdbd1ffb590aff9f8 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd.py +++ b/examples/flash_attention/example_mha_fwd_bhsd.py @@ -15,20 +15,13 @@ def get_configs(): @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, - heads, - seq_q, - seq_kv, - dim, - is_causal, - block_M=64, - block_N=64, - num_stages=1, - threads=128): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] dtype = "float16" @@ -48,7 +41,7 @@ def flashattn(batch, by: T.int32, bz: T.int32, ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): q_idx = bx * block_M + i + past_len @@ -70,18 +63,18 @@ def flashattn(batch, by: T.int32, bz: T.int32, ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -110,18 +103,18 @@ def flashattn(batch, @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -137,43 +130,42 @@ def flashattn(batch, scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min( - T.ceildiv(seq_kv, block_N), T.ceildiv( - (bx + 1) * block_M + - past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) + T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + if is_causal + else T.ceildiv(seq_kv, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bhqd,bhkd->bhqk', Q, K) + scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_q = Q.size(2) seq_kv = K.size(2) mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V) + output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V) return output @@ -191,18 +183,8 @@ def main( if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_q, - seq_kv, - dim, - is_causal, - block_M=64, - block_N=64, - num_stages=1, - threads=128) + if not tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128) ref_program_processed = partial(ref_program, is_causal=is_causal) profiler = kernel.get_profiler() @@ -227,12 +209,12 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=1, help='heads') - parser.add_argument('--seq_q', type=int, default=256, help='query sequence length') - parser.add_argument('--seq_kv', type=int, default=256, help='key/value sequence length') - parser.add_argument('--dim', type=int, default=64, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal', default=False) - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=1, help="heads") + parser.add_argument("--seq_q", type=int, default=256, help="query sequence length") + parser.add_argument("--seq_kv", type=int, default=256, help="key/value sequence length") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal", default=False) + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune) diff --git a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py index b797bbcc6401f992889566cd1e840dd61ac4908d..440a2cd74d5da17dc0c109e9954bb640b10e7529 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py @@ -15,20 +15,13 @@ def get_configs(): @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, - heads, - seq_q, - seq_kv, - dim, - is_causal, - block_M=128, - block_N=128, - num_stages=2, - threads=256): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] dtype = "float16" @@ -48,7 +41,7 @@ def flashattn(batch, by: T.int32, bz: T.int32, ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): q_idx = bx * block_M + i + past_len @@ -70,18 +63,18 @@ def flashattn(batch, by: T.int32, bz: T.int32, ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -108,18 +101,18 @@ def flashattn(batch, @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -135,48 +128,48 @@ def flashattn(batch, scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min( - T.ceildiv(seq_kv, block_N), T.ceildiv( - (bx + 1) * block_M + - past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) + T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + if is_causal + else T.ceildiv(seq_kv, block_N) + ) for k in T.Pipelined( - loop_range, - num_stages=num_stages, - order=[-1, 0, 3, 1, -1, 2], - stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]): + loop_range, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bhqd,bhkd->bhqk', Q, K) + scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_q = Q.size(2) seq_kv = K.size(2) mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V) + output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V) return output @@ -194,18 +187,8 @@ def main( if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_q, - seq_kv, - dim, - is_causal, - block_M=128, - block_N=128, - num_stages=2, - threads=256) + if not tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) ref_program_processed = partial(ref_program, is_causal=is_causal) profiler = kernel.get_profiler() @@ -230,12 +213,12 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--seq_q', type=int, default=4096, help='query sequence length') - parser.add_argument('--seq_kv', type=int, default=4096, help='key/value sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_q", type=int, default=4096, help="query sequence length") + parser.add_argument("--seq_kv", type=int, default=4096, help="key/value sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune) diff --git a/examples/flash_attention/example_mha_fwd_bshd.py b/examples/flash_attention/example_mha_fwd_bshd.py index b5b7282876478b428933ed798342b14b050e4265..888914c9bb86310c8471861c6f73ae24b32d9d37 100644 --- a/examples/flash_attention/example_mha_fwd_bshd.py +++ b/examples/flash_attention/example_mha_fwd_bshd.py @@ -15,19 +15,13 @@ def get_configs(): @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, - heads, - seq_len, - dim, - is_causal, - block_M=64, - block_N=64, - num_stages=1, - threads=128): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch, heads, seq_len, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] dtype = "float16" accum_dtype = "float" @@ -43,16 +37,14 @@ def flashattn(batch, by: T.int32, bz: T.int32, ): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: # We shall fill -inf for OOB positions for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), - 0) + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro @@ -65,18 +57,18 @@ def flashattn(batch, by: T.int32, bz: T.int32, ): - T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -102,18 +94,18 @@ def flashattn(batch, @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(shape, dtype), - K: T.Tensor(shape, dtype), - V: T.Tensor(shape, dtype), - Output: T.Tensor(shape, dtype), + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -129,40 +121,39 @@ def flashattn(batch, scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) return main def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -179,17 +170,8 @@ def main( if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_len, - dim, - is_causal, - block_M=128, - block_N=128, - num_stages=1, - threads=128) + if not tune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=1, threads=128) ref_program_processed = partial(ref_program, is_causal=is_causal) profiler = kernel.get_profiler() profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) @@ -213,11 +195,11 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.tune) diff --git a/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py index 02d8baef28589ea6fc3904c9d629296b1c77895d..b54d3e626fc2699193f6439231e6d9c8fc174ee4 100644 --- a/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py @@ -15,19 +15,13 @@ def get_configs(): @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, - heads, - seq_len, - dim, - is_causal, - block_M=128, - block_N=128, - num_stages=2, - threads=256): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] dtype = "float16" accum_dtype = "float" @@ -43,16 +37,14 @@ def flashattn(batch, by: T.int32, bz: T.int32, ): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: # We shall fill -inf for OOB positions for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), - 0) + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro @@ -65,18 +57,18 @@ def flashattn(batch, by: T.int32, bz: T.int32, ): - T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -102,18 +94,18 @@ def flashattn(batch, @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(shape, dtype), - K: T.Tensor(shape, dtype), - V: T.Tensor(shape, dtype), - Output: T.Tensor(shape, dtype), + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -129,45 +121,45 @@ def flashattn(batch, scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined( - loop_range, - num_stages=num_stages, - order=[-1, 0, 3, 1, -1, 2], - stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]): + loop_range, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) return main def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_len = Q.size(1) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -184,17 +176,8 @@ def main( if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_len, - dim, - is_causal, - block_M=128, - block_N=128, - num_stages=2, - threads=256) + if not tune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) ref_program_processed = partial(ref_program, is_causal=is_causal) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) @@ -218,11 +201,11 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.tune) diff --git a/examples/flash_attention/example_mha_fwd_varlen.py b/examples/flash_attention/example_mha_fwd_varlen.py index bbb4546ca94b89184b630733cdb65e40aca3968c..f7bb36f71d1d3d800b948686bb696d09e62fa8f6 100644 --- a/examples/flash_attention/example_mha_fwd_varlen.py +++ b/examples/flash_attention/example_mha_fwd_varlen.py @@ -11,14 +11,14 @@ from varlen_utils import generate_random_padding_mask, generate_qkv def attention_ref( - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - causal=False, - window_size=(-1, -1), # -1 means infinite window size - upcast=True, + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + upcast=True, ): """ Arguments: @@ -47,7 +47,7 @@ def attention_ref( if upcast: q, k, v = q.float(), k.float(), v.float() dim = q.shape[-1] - scale = (1.0 / dim)**0.5 # log2(e) + scale = (1.0 / dim) ** 0.5 # log2(e) k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) scores = torch.einsum("bthd,bshd->bhts", q, k) @@ -68,20 +68,13 @@ def attention_ref( @tilelang.jit( - out_idx=[6], pass_configs={ + out_idx=[6], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch_size, - UQ, - UKV, - heads, - dim, - is_causal, - block_M=64, - block_N=64, - num_stages=0, - threads=32): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch_size, UQ, UKV, heads, dim, is_causal, block_M=64, block_N=64, num_stages=0, threads=32): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) q_shape = [UQ, heads, dim] k_shape = [UKV, heads, dim] v_shape = [UKV, heads, dim] @@ -92,17 +85,15 @@ def flashattn(batch_size, @T.prim_func def main( - Q_unpad: T.Tensor(q_shape, dtype), - K_unpad: T.Tensor(k_shape, dtype), - V_unpad: T.Tensor(v_shape, dtype), - cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), - cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), - max_seqlen_q: T.int32, - Output_unpad: T.Tensor(o_shape, dtype), + Q_unpad: T.Tensor(q_shape, dtype), + K_unpad: T.Tensor(k_shape, dtype), + V_unpad: T.Tensor(v_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), + cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), + max_seqlen_q: T.int32, + Output_unpad: T.Tensor(o_shape, dtype), ): - with T.Kernel( - T.ceildiv(max_seqlen_q, block_M), heads, batch_size, - threads=threads) as (bx, by, bz): + with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype, "shared") K_shared = T.alloc_shared([block_N, dim], dtype, "shared") V_shared = T.alloc_shared([block_N, dim], dtype, "shared") @@ -151,15 +142,17 @@ def flashattn(batch_size, K_shared[i, d] = 0 if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and - (bx * block_M + i >= q_current_seqlen or - k * block_N + j >= k_current_seqlen), - -T.infinity(acc_s.dtype), 0) + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= k * block_N + j) + and (bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen), + -T.infinity(acc_s.dtype), + 0, + ) else: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or - k * block_N + j >= k_current_seqlen), - -T.infinity(acc_s.dtype), 0) + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen), -T.infinity(acc_s.dtype), 0 + ) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -244,8 +237,7 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128): output_pad_fn, dq_pad_fn, dk_pad_fn, - ) = generate_qkv( - q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) UQ = q_unpad.shape[0] # unpadded query length UK = k_unpad.shape[0] # unpadded key length @@ -287,10 +279,10 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=64, help='heads') - parser.add_argument('--seq_len', type=int, default=2048, help='sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_len", type=int, default=2048, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") args = parser.parse_args() main(args.batch, args.heads, args.seq_len, args.dim) diff --git a/examples/flash_attention/test_example_flash_attention.py b/examples/flash_attention/test_example_flash_attention.py index b184fc601bdfa33d3685c51a2cf8fc9816ceae9a..da172bb62a4dee0ada293fc1249812905ded8de1 100644 --- a/examples/flash_attention/test_example_flash_attention.py +++ b/examples/flash_attention/test_example_flash_attention.py @@ -62,14 +62,12 @@ def test_example_mha_bwd_wgmma_pipelined(): @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_gqa_fwd_bshd_wgmma_pipelined(): - example_gqa_fwd_bshd_wgmma_pipelined.main( - batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False) + example_gqa_fwd_bshd_wgmma_pipelined.main(batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False) @tilelang.testing.requires_cuda def test_example_gqa_fwd_bshd(): - example_gqa_fwd_bshd.main( - batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False) + example_gqa_fwd_bshd.main(batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False) @tilelang.testing.requires_cuda diff --git a/examples/flash_attention/varlen_utils.py b/examples/flash_attention/varlen_utils.py index 4301215d554b23f18c241391a58886e452d7531a..43e21cc3b80ce72eaa582407024ec2c42015731e 100644 --- a/examples/flash_attention/varlen_utils.py +++ b/examples/flash_attention/varlen_utils.py @@ -9,22 +9,14 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): if mode == "full": lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) elif mode == "random": - lengths = torch.randint( - max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device) + lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device) elif mode == "third": lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) - padding_mask = ( - repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths) + padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths return padding_mask -def generate_qkv(q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - kvpacked=False, - qkvpacked=False): +def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False): """ Arguments: q: (batch_size, seqlen_q, nheads, d) @@ -39,15 +31,12 @@ def generate_qkv(q, if query_padding_mask is not None: q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) - output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q - ) + output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q) else: q_unpad = rearrange(q, "b s h d -> (b s) h d") - cu_seqlens_q = torch.arange( - 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device) + cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device) max_seqlen_q = seqlen_q - output_pad_fn = lambda output_unpad: rearrange( - output_unpad, "(b s) h d -> b s h d", b=batch_size) + output_pad_fn = lambda output_unpad: rearrange(output_unpad, "(b s) h d -> b s h d", b=batch_size) if key_padding_mask is not None: k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) @@ -55,8 +44,7 @@ def generate_qkv(q, else: k_unpad = rearrange(k, "b s h d -> (b s) h d") v_unpad = rearrange(v, "b s h d -> (b s) h d") - cu_seqlens_k = torch.arange( - 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device) + cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device) max_seqlen_k = seqlen_k if qkvpacked: @@ -67,8 +55,7 @@ def generate_qkv(q, if query_padding_mask is not None: dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) else: - dqkv_pad_fn = lambda dqkv_unpad: rearrange( - dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) + dqkv_pad_fn = lambda dqkv_unpad: rearrange(dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) return ( qkv_unpad.detach().requires_grad_(), cu_seqlens_q, @@ -84,8 +71,7 @@ def generate_qkv(q, if key_padding_mask is not None: dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) else: - dkv_pad_fn = lambda dkv_unpad: rearrange( - dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) + dkv_pad_fn = lambda dkv_unpad: rearrange(dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) return ( q_unpad.detach().requires_grad_(), kv_unpad.detach().requires_grad_(), diff --git a/examples/flash_decoding/example_gqa_decode.py b/examples/flash_decoding/example_gqa_decode.py index 7ccd98397f5192f6ba13490cf3c038b284f26d22..136a5129215cd9be70d660a3ed632ccf2749785c 100644 --- a/examples/flash_decoding/example_gqa_decode.py +++ b/examples/flash_decoding/example_gqa_decode.py @@ -20,13 +20,7 @@ def get_configs(): threads = [128] _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) - configs = [{ - 'block_N': c[0], - 'block_H': c[1], - 'num_split': c[2], - 'num_stages': c[3], - 'threads': c[4] - } for c in _configs] + configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs] return configs @@ -48,17 +42,13 @@ def get_heuristic_config() -> Tuple[Dict, int]: # TODO(lei): fix warp specialized and tma lower pass def get_pass_configs(): - return { - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True - } + return {tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit(out_idx=[6], pass_configs=get_pass_configs()) -def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, num_stages, - threads): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) +def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, num_stages, threads): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape_q = [batch, heads, dim] shape_k = [batch, seqlen_kv, groups, dim] shape_v = [batch, seqlen_kv, groups, dim] @@ -73,11 +63,11 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, @T.macro def flash_attn( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), + Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) @@ -98,20 +88,19 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, hid = by cur_kv_head = hid // (kv_group_num // valid_block_H) - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = T.ceildiv((seqlen_kv // num_split), block_N) for k in T.Pipelined(loop_range, num_stages=num_stages): - T.copy(K[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_shared) - T.copy(mask[bid, k * block_N:(k + 1) * block_N, cur_kv_head], mask_local) + T.copy(K[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_shared) + T.copy(mask[bid, k * block_N : (k + 1) * block_N, cur_kv_head], mask_local) T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j], - -T.infinity(accum_dtype)) + acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j], -T.infinity(accum_dtype)) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) @@ -127,23 +116,23 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, T.copy(acc_s, acc_s_cast) for i, j in T.Parallel(block_H, dim): acc_o[i, j] *= scores_scale[i] - T.copy(V[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], V_shared) + T.copy(V[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, dim): acc_o[i, j] /= logsum[i] for i in T.Parallel(block_H): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale T.copy(acc_o[:valid_block_H, :], O_shared) - T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :]) + T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) @T.macro def flash_attn_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor(part_shape, dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor(part_shape, dtype), ): with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) @@ -165,7 +154,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, sid = bz cur_kv_head = hid // (kv_group_num // valid_block_H) - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -174,19 +163,26 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, for k in T.Pipelined(loop_range, num_stages=num_stages): T.copy( - K[bid, (seqlen_kv // num_split) * sid + - k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, - cur_kv_head, :], K_shared) + K[ + bid, + (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, + cur_kv_head, + :, + ], + K_shared, + ) T.copy( - mask[bid, (seqlen_kv // num_split) * sid + - k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, - cur_kv_head], mask_local) + mask[ + bid, + (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, + cur_kv_head, + ], + mask_local, + ) T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, block_N): - acc_s[i, - j] = T.if_then_else((mask_local[j] != 0) & (j < seqlen_kv // num_split), - acc_s[i, j], -T.infinity(accum_dtype)) + acc_s[i, j] = T.if_then_else((mask_local[j] != 0) & (j < seqlen_kv // num_split), acc_s[i, j], -T.infinity(accum_dtype)) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) @@ -203,9 +199,14 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, for i, j in T.Parallel(block_H, dim): acc_o[i, j] *= scores_scale[i] T.copy( - V[bid, (seqlen_kv // num_split) * sid + - k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, - cur_kv_head, :], V_shared) + V[ + bid, + (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, + cur_kv_head, + :, + ], + V_shared, + ) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, dim): acc_o[i, j] /= logsum[i] @@ -216,14 +217,13 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, if i < valid_block_H: glse[bid, hid * valid_block_H + i, sid] = logsum[i] T.copy(acc_o[:valid_block_H, :], O_shared) - T.copy(O_shared, Output_partial[bid, hid * valid_block_H:(hid + 1) * valid_block_H, - sid, :]) + T.copy(O_shared, Output_partial[bid, hid * valid_block_H : (hid + 1) * valid_block_H, sid, :]) @T.macro def combine( - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor(part_shape, dtype), - Output: T.Tensor(shape_o, dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor(part_shape, dtype), + Output: T.Tensor(shape_o, dtype), ): with T.Kernel(heads, batch, threads=128) as (by, bz): po_local = T.alloc_fragment([dim], dtype) @@ -233,12 +233,14 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, lse_max_local = T.alloc_fragment([128], accum_dtype) scale_local = T.alloc_fragment([128], accum_dtype) - T.annotate_layout({ - lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - lse_max_local: T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i), - # lse_local: (local_id, thread_id) - lse_local: T.Fragment(lse_local.shape, forward_fn=lambda i, j: (j, i)), - }) + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + lse_max_local: T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i), + # lse_local: (local_id, thread_id) + lse_local: T.Fragment(lse_local.shape, forward_fn=lambda i, j: (j, i)), + } + ) T.clear(lse_logsum_local) T.clear(o_accum_local) @@ -263,26 +265,26 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, @T.prim_func def flashattn_gqa_decode_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor(part_shape, dtype), - Output: T.Tensor(shape_o, dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor(part_shape, dtype), + Output: T.Tensor(shape_o, dtype), ): flash_attn_split(Q, K, V, mask, glse, Output_partial) combine(glse, Output_partial, Output) @T.prim_func def flashattn_gqa_decode_no_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor(part_shape, dtype), - Output: T.Tensor(shape_o, dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor(part_shape, dtype), + Output: T.Tensor(shape_o, dtype), ): flash_attn(Q, K, V, mask, Output) @@ -305,27 +307,21 @@ def ref_program(query, key, value, mask, glse, Output_partial): dim = query.shape[-1] num_head_groups = query.shape[1] // key.shape[2] scale = dim**0.5 - key = rearrange(key, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] - value = rearrange(value, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + key = rearrange(key, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - query = rearrange( - query, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] if mask is not None: - mask = rearrange(mask, 'b s h -> b h s') + mask = rearrange(mask, "b s h -> b h s") mask = mask.unsqueeze(1) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, value, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out @@ -339,16 +335,12 @@ def flash_split_ref(Q, K, V, mask): seqlen_kv = K.size(1) num_head_groups = nheads // groups - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) acc_s = torch.empty((batch, num_head_groups, groups, block_N), device="cuda", dtype=torch.float) - acc_s_cast = torch.empty((batch, num_head_groups, groups, block_N), - device="cuda", - dtype=torch.float16) + acc_s_cast = torch.empty((batch, num_head_groups, groups, block_N), device="cuda", dtype=torch.float16) acc_o = torch.empty((batch, num_head_groups, groups, dim), device="cuda", dtype=torch.float) scores_max = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) - scores_max_prev = torch.empty((batch, num_head_groups, groups), - device="cuda", - dtype=torch.float) + scores_max_prev = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) scores_scale = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) scores_sum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) logsum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) @@ -356,25 +348,25 @@ def flash_split_ref(Q, K, V, mask): glogsum = torch.empty((num_split, batch, nheads), device="cuda", dtype=torch.float) Q_ = Q * scale - Q_ = rearrange(Q_, 'b (h g) d -> b g h d', g=num_head_groups) + Q_ = rearrange(Q_, "b (h g) d -> b g h d", g=num_head_groups) for ks in range(num_split): acc_o.fill_(0) logsum.fill_(0) - scores_max.fill_(float('-inf')) - scores_max_prev.fill_(float('-inf')) + scores_max.fill_(float("-inf")) + scores_max_prev.fill_(float("-inf")) for i in range(int((seqlen_kv // num_split) / block_N)): acc_s.fill_(0) - acc_s = torch.einsum('bghd,bkhd->bghk', Q_, - K[:, (seqlen_kv // num_split) * ks + - i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) # [batch, nheads, block_N] + acc_s = torch.einsum( + "bghd,bkhd->bghk", + Q_, + K[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) # [batch, nheads, block_N] if mask is not None: - mask_local = mask[:, (seqlen_kv // num_split) * ks + - i * block_N:(seqlen_kv // num_split) * ks + (i + 1) * block_N, :] - mask_local = rearrange(mask_local, 'b s h -> b h s') + mask_local = mask[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :] + mask_local = rearrange(mask_local, "b s h -> b h s") mask_local = mask_local.unsqueeze(1) - acc_s = acc_s.masked_fill(mask_local == 0, float('-inf')) + acc_s = acc_s.masked_fill(mask_local == 0, float("-inf")) scores_max_prev = scores_max scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads] scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads] @@ -382,15 +374,16 @@ def flash_split_ref(Q, K, V, mask): acc_s = torch.exp2(acc_s - scores_max[:, :, :, None]) acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N] acc_o += torch.einsum( - 'bghk,bkhd->bghd', acc_s_cast, - V[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) + "bghk,bkhd->bghd", + acc_s_cast, + V[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) scores_sum = acc_s.sum(dim=-1, keepdim=False) logsum = logsum * scores_scale + scores_sum - acc_o_out = rearrange(acc_o, 'b g h d->b (h g) d') - logsum_out = rearrange(logsum, 'b g h->b (h g)') + acc_o_out = rearrange(acc_o, "b g h d->b (h g) d") + logsum_out = rearrange(logsum, "b g h->b (h g)") acc_o_out /= logsum_out[:, :, None] - logsum_out = torch.log2(logsum_out) + rearrange(scores_max, 'b g h->b (h g)') + logsum_out = torch.log2(logsum_out) + rearrange(scores_max, "b g h->b (h g)") gacc_o[ks, :, :, :] = acc_o_out glogsum[ks, :, :] = logsum_out @@ -426,7 +419,7 @@ def calc_sim(x, y, name="tensor"): x, y = x.data.double(), y.data.double() denominator = (x * x + y * y).sum() if denominator == 0: - print_red_warning(f'{name} all zero') + print_red_warning(f"{name} all zero") return 1 sim = 2 * (x * y).sum() / denominator return sim @@ -434,28 +427,23 @@ def calc_sim(x, y, name="tensor"): def assert_similar(x, y, eps=1e-2, name="tensor", assert_=False, print_=True): sim = calc_sim(x, y, name) - diff = 1. - sim + diff = 1.0 - sim if not (0 <= diff <= eps): - print_red_warning(f'{name} Error: {diff}') + print_red_warning(f"{name} Error: {diff}") if assert_: - raise AssertionError(f'{name} Error: {diff}') + raise AssertionError(f"{name} Error: {diff}") else: if print_: - print(f'passed: {name} diff={diff}') + print(f"passed: {name} diff={diff}") -def main(batch: int = 1, - heads: int = 32, - groups: int = 8, - kv_seqlen: int = 8192, - dim: int = 128, - tune: bool = False): +def main(batch: int = 1, heads: int = 32, groups: int = 8, kv_seqlen: int = 8192, dim: int = 128, tune: bool = False): batch, heads, groups, kv_seqlen, dim = batch, heads, groups, kv_seqlen, dim qk_flops = 2 * batch * heads * kv_seqlen * dim pv_flops = 2 * batch * heads * kv_seqlen * dim total_flops = qk_flops + pv_flops - if (not tune): + if not tune: config, sm_version = get_heuristic_config() kernel = flashattn(batch, heads, groups, kv_seqlen, dim, **config) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) @@ -497,11 +485,11 @@ def main(batch: int = 1, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--groups', type=int, default=8, help='groups') - parser.add_argument('--kv_seqlen', type=int, default=8192, help='kv sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--groups", type=int, default=8, help="groups") + parser.add_argument("--kv_seqlen", type=int, default=8192, help="kv sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() main(args.batch, args.heads, args.groups, args.kv_seqlen, args.dim, args.tune) diff --git a/examples/flash_decoding/example_gqa_decode_varlen_logits.py b/examples/flash_decoding/example_gqa_decode_varlen_logits.py index 16924ebe89bc1bfe24552997cb9053eae93e6fcd..0fdd52919616d827cfced39d65415a2c6f108a79 100644 --- a/examples/flash_decoding/example_gqa_decode_varlen_logits.py +++ b/examples/flash_decoding/example_gqa_decode_varlen_logits.py @@ -19,8 +19,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, - head_dim) + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) @@ -74,14 +73,9 @@ def _fwd_inner( return m_i, l_i, acc - @triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [4, 8]\ - for num_stages in [2, 4]\ - ], - key=['gqa_group_size', 'BLOCK_N', 'BLOCK_D', 'BLOCK_H'], + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [4, 8] for num_stages in [2, 4]], + key=["gqa_group_size", "BLOCK_N", "BLOCK_D", "BLOCK_H"], ) @triton.jit def _fwd_kernel_varlen( @@ -107,13 +101,12 @@ def _fwd_kernel_varlen( stride_od, stride_sb, stride_sh, - stride_sn, #bmask shape [b, q_h, seq/BLOCK_N] + stride_sn, # bmask shape [b, q_h, seq/BLOCK_N] gqa_group_size: tl.constexpr, BLOCK_H: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_D: tl.constexpr, ): - off_z = tl.program_id(0) off_h_for_kv = tl.program_id(1) off_h_q = off_h_for_kv * gqa_group_size @@ -134,8 +127,7 @@ def _fwd_kernel_varlen( S_ptrs = S + off_z * stride_sb + off_h_q * stride_sh mask_h = offs_h < gqa_group_size - q = tl.load( - Q_ptrs + offs_d[None, :] * stride_qd + offs_h[:, None] * stride_qh, mask=mask_h[:, None]) + q = tl.load(Q_ptrs + offs_d[None, :] * stride_qd + offs_h[:, None] * stride_qh, mask=mask_h[:, None]) if s_aux is not None: sink = tl.load(s_aux + off_h_q + offs_h, mask=mask_h).to(tl.float32) @@ -189,14 +181,12 @@ def _fwd_kernel_varlen( acc = acc.to(O.dtype.element_ty) - tl.store( - O_ptrs + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od, - acc, - mask=mask_h[:, None]) + tl.store(O_ptrs + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od, acc, mask=mask_h[:, None]) def get_configs(): import itertools + block_N = [64, 128] block_H = [64] num_split = [1] @@ -204,31 +194,16 @@ def get_configs(): threads = [128] _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) - configs = [{ - 'block_N': c[0], - 'block_H': c[1], - 'num_split': c[2], - 'num_stages': c[3], - 'threads': c[4] - } for c in _configs] + configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs] return configs @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit(out_idx=[-2, -1], debug_root_path="./examples/flash_decoding") -def flashattn(batch, - heads, - k_heads, - max_seqlen_kv, - total_seqlen_k, - dim, - has_sink, - block_N=128, - block_H=64, - num_split=1, - num_stages=1, - threads=128): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) +def flashattn( + batch, heads, k_heads, max_seqlen_kv, total_seqlen_k, dim, has_sink, block_N=128, block_H=64, num_split=1, num_stages=1, threads=128 +): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape_q = [batch, heads, dim] shape_k = [total_seqlen_k, k_heads, dim] shape_v = [total_seqlen_k, k_heads, dim] @@ -243,13 +218,13 @@ def flashattn(batch, @T.macro def flash_attn( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - cu_seqlens_k: T.Tensor([batch + 1], "int32"), - s_aux: T.Tensor([heads], "float32"), - Output: T.Tensor([batch, heads, dim], dtype), - S: T.Tensor(shape_s, dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + cu_seqlens_k: T.Tensor([batch + 1], "int32"), + s_aux: T.Tensor([heads], "float32"), + Output: T.Tensor([batch, heads, dim], dtype), + S: T.Tensor(shape_s, dtype), ): with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) @@ -268,13 +243,15 @@ def flashattn(batch, # S_fragment = T.alloc_fragment([block_H, math.ceil(max_seqlen_kv / block_N)], accum_dtype) s_aux_shared = T.alloc_shared([block_H], "float32") - T.annotate_layout({ - # Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), - # K_shared: tilelang.layout.make_swizzled_layout(K_shared), - # V_shared: tilelang.layout.make_swizzled_layout(V_shared), - # O_shared: tilelang.layout.make_swizzled_layout(O_shared), - # S_shared: tilelang.layout.make_swizzled_layout(S_shared), - }) + T.annotate_layout( + { + # Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + # K_shared: tilelang.layout.make_swizzled_layout(K_shared), + # V_shared: tilelang.layout.make_swizzled_layout(V_shared), + # O_shared: tilelang.layout.make_swizzled_layout(O_shared), + # S_shared: tilelang.layout.make_swizzled_layout(S_shared), + } + ) bid = bx hid = by @@ -284,7 +261,7 @@ def flashattn(batch, cur_end_k = cu_seqlens_k[bid + 1] cur_seqlen_k = cur_end_k - cur_start_k - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -292,15 +269,13 @@ def flashattn(batch, # loop_range = T.ceildiv((seqlen_kv // num_split), block_N) loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N) for k in T.Pipelined(loop_range, num_stages=num_stages): - T.copy(K[cur_start_k + k * block_N:cur_start_k + (k + 1) * block_N, cur_kv_head, :], - K_shared) + T.copy(K[cur_start_k + k * block_N : cur_start_k + (k + 1) * block_N, cur_kv_head, :], K_shared) T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, block_N): # acc_s[i, j] = T.if_then_else(mask_local[j] != 0 and k * block_N + j < cur_seqlen_k, acc_s[i, j], # -T.infinity(accum_dtype)) - acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], - -T.infinity(accum_dtype)) + acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], -T.infinity(accum_dtype)) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) @@ -320,12 +295,11 @@ def flashattn(batch, T.copy(acc_s, acc_s_cast) for i, j in T.Parallel(block_H, dim): acc_o[i, j] *= scores_scale[i] - T.copy(V[cur_start_k + k * block_N:cur_start_k + (k + 1) * block_N, cur_kv_head, :], - V_shared) + T.copy(V[cur_start_k + k * block_N : cur_start_k + (k + 1) * block_N, cur_kv_head, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) if has_sink: - T.copy(s_aux[hid * valid_block_H:hid * valid_block_H + block_H], s_aux_shared) + T.copy(s_aux[hid * valid_block_H : hid * valid_block_H + block_H], s_aux_shared) for i in T.Parallel(block_H): logsum[i] += s_aux_shared[i] for i, j in T.Parallel(block_H, dim): @@ -338,20 +312,19 @@ def flashattn(batch, for i in T.Parallel(block_H): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale T.copy(acc_o[:valid_block_H, :], O_shared) - T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :]) + T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) # T.copy(S_fragment, S_shared) - T.copy(S_shared[:valid_block_H, :], S[bid, - hid * valid_block_H:(hid + 1) * valid_block_H, :]) + T.copy(S_shared[:valid_block_H, :], S[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) @T.prim_func def flashattn_gqa_decode_no_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - cu_seqlens_k: T.Tensor([batch + 1], "int32"), - s_aux: T.Tensor([heads], "float32"), - Output: T.Tensor(shape_o, dtype), - S: T.Tensor(shape_s, dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + cu_seqlens_k: T.Tensor([batch + 1], "int32"), + s_aux: T.Tensor([heads], "float32"), + Output: T.Tensor(shape_o, dtype), + S: T.Tensor(shape_s, dtype), ): flash_attn(Q, K, V, cu_seqlens_k, s_aux, Output, S) @@ -388,9 +361,7 @@ def flash_attn_with_attn_pool_decode_tilelang( gqa_group_size = q_h // k_h O_tl = torch.zeros_like(Q) - S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), - dtype=Q.dtype, - device=Q.device) + S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), dtype=Q.dtype, device=Q.device) O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux) if use_per_kv_head_sparse_index: @@ -433,9 +404,7 @@ def flash_attn_with_attn_pool_decode( BLOCK_H = 64 O = torch.zeros_like(Q) - S = torch.zeros((batch, q_h, math.ceil(max_seqlen_k / block_size)), - dtype=Q.dtype, - device=Q.device) + S = torch.zeros((batch, q_h, math.ceil(max_seqlen_k / block_size)), dtype=Q.dtype, device=Q.device) def grid(META): return (batch, k_h) @@ -483,15 +452,15 @@ def test_equal_seqlen_decode_main(args): dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 # For decode, query is just 1 token per batch - q = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) - k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device='cuda', dtype=dtype) - v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device='cuda', dtype=dtype) + q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype) + v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype) softmax_scale = 1.0 / math.sqrt(head_size) # Generate sink values if needed sink = None if args.test_sink: - sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values print(f"Using sink attention with sink values: {sink}") # Convert to varlen format for K, V @@ -499,8 +468,7 @@ def test_equal_seqlen_decode_main(args): v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size) # Generate cumulative sequence lengths - cu_seqlens_k = torch.arange( - 0, (batch_size + 1) * k_seqlen, k_seqlen, device='cuda', dtype=torch.int32) + cu_seqlens_k = torch.arange(0, (batch_size + 1) * k_seqlen, k_seqlen, device="cuda", dtype=torch.int32) max_seqlen_k = k_seqlen print(f"q shape: {q.shape}") @@ -510,8 +478,7 @@ def test_equal_seqlen_decode_main(args): num_tokens, q_h, head_size = q.shape batch = cu_seqlens_k.size(0) - 1 k_h = k_varlen.size(1) - tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, - args.test_sink) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink) # Test our decode kernel O_triton, S_triton = flash_attn_with_attn_pool_decode( @@ -524,7 +491,8 @@ def test_equal_seqlen_decode_main(args): args.num_split, softmax_scale, s_aux=sink, - block_size=block_size) + block_size=block_size, + ) O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( q, k_varlen, @@ -539,9 +507,7 @@ def test_equal_seqlen_decode_main(args): tl_kernel=tl_kernel, ) for i in range(batch_size): - S_tilelang[i, :, - math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / - block_size):] = 0 + S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0 # Compute torch reference q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size] @@ -550,14 +516,12 @@ def test_equal_seqlen_decode_main(args): if sink is None: # Standard scaled dot-product attention - logits = torch.matmul(q_expanded, k_repeat.transpose( - -2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] attn_weights = torch.softmax(logits, dim=-1) O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size] else: # s_aux attention - logits = torch.matmul(q_expanded, k_repeat.transpose( - -2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] logits_max = torch.max(logits, dim=-1, keepdim=True).values @@ -566,15 +530,15 @@ def test_equal_seqlen_decode_main(args): unnormalized_scores = torch.exp(logits - logits_or_sinks_max) normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks attn_weights = unnormalized_scores / normalizer - O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), - v_repeat).squeeze(2) # [batch, q_heads, head_size] + O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat).squeeze(2) # [batch, q_heads, head_size] # Compute attention score pooling attn_score_pooled = torch.max_pool2d( attn_weights.squeeze(2), # [b, q_heads, k_seqlen] kernel_size=(q_heads, block_size), stride=(q_heads, block_size), - ceil_mode=True).to(torch.float16) + ceil_mode=True, + ).to(torch.float16) print("S_tilelang", S_tilelang) print("attn_score_pooled", attn_score_pooled) @@ -588,15 +552,10 @@ def test_equal_seqlen_decode_main(args): print(f"Max difference in S: {max_diff_s.item()}") print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}") print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}") - assert torch.allclose( - O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" - assert torch.allclose( - S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" - assert torch.allclose( - O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}" - assert torch.allclose( - S_tilelang, attn_score_pooled, atol=1e-2, - rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}" + assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" + assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" + assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}" + assert torch.allclose(S_tilelang, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}" print("✅ All tests passed!") @@ -616,7 +575,7 @@ def test_varlen_decode_main(args): # Generate sink values if needed sink = None if args.test_sink: - sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values print(f"Using sink attention with sink values: {sink}") # Generate variable length k sequences @@ -624,7 +583,7 @@ def test_varlen_decode_main(args): print(f"k_seqlens: {k_seqlens}") # Generate cumulative sequence lengths for k - cu_seqlens_k = torch.zeros(batch_size + 1, device='cuda', dtype=torch.int32) + cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) total_k_tokens = 0 for i in range(batch_size): cu_seqlens_k[i] = total_k_tokens @@ -634,9 +593,9 @@ def test_varlen_decode_main(args): print(f"cu_seqlens_k: {cu_seqlens_k}") # Generate tensors - Q is [batch_size, q_heads, head_size] for decode - q_decode = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) - k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) - v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) + q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) softmax_scale = 1.0 / math.sqrt(head_size) max_seqlen_k = int(k_seqlens.max()) @@ -649,8 +608,7 @@ def test_varlen_decode_main(args): num_tokens, q_h, head_size = q_decode.shape batch = cu_seqlens_k.size(0) - 1 k_h = k_varlen.size(1) - tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, - args.test_sink) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink) # Test our decode kernel O_triton, S_triton = flash_attn_with_attn_pool_decode( @@ -663,7 +621,8 @@ def test_varlen_decode_main(args): args.num_split, softmax_scale, s_aux=sink, - block_size=block_size) + block_size=block_size, + ) O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( q_decode, k_varlen, @@ -678,9 +637,7 @@ def test_varlen_decode_main(args): tl_kernel=tl_kernel, ) for i in range(batch_size): - S_tilelang[i, :, - math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / - block_size):] = 0 + S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0 # Create torch reference - pad tensors for comparison k_padded_list = [] @@ -694,8 +651,8 @@ def test_varlen_decode_main(args): k_end = cu_seqlens_k[i + 1] # Pad to max_seqlen_k - k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device='cuda', dtype=dtype) - v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device='cuda', dtype=dtype) + k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) + v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) k_padded[:actual_k_len] = k_varlen[k_start:k_end] v_padded[:actual_k_len] = v_varlen[k_start:k_end] @@ -704,10 +661,8 @@ def test_varlen_decode_main(args): v_padded_list.append(v_padded) # Stack to create batched tensors [b, max_seqlen, kv_heads, head_size] - k_padded_batched = torch.stack( - k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] - v_padded_batched = torch.stack( - v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] + k_padded_batched = torch.stack(k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] + v_padded_batched = torch.stack(v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] # Expand q to match kv heads: [b, q_heads, 1, head_size] q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size] @@ -717,20 +672,17 @@ def test_varlen_decode_main(args): print(f"v_padded_batched shape: {v_padded_batched.shape}") # Compute torch reference - k_repeat = repeat_kv(k_padded_batched, - q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] - v_repeat = repeat_kv(v_padded_batched, - q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] + k_repeat = repeat_kv(k_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] + v_repeat = repeat_kv(v_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] if sink is None: # Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen] - attn_score = torch.matmul(q_expanded, k_repeat.transpose( - -2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] + attn_score = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] # Apply sequence length masking for i in range(batch_size): actual_k_len = k_seqlens[i] - attn_score[i, :, :, actual_k_len:] = float('-inf') + attn_score[i, :, :, actual_k_len:] = float("-inf") attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen] @@ -743,13 +695,12 @@ def test_varlen_decode_main(args): O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size] else: # s_aux attention - logits = torch.matmul(q_expanded, k_repeat.transpose( - -2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] # Apply sequence length masking for i in range(batch_size): actual_k_len = k_seqlens[i] - logits[i, :, :, actual_k_len:] = float('-inf') + logits[i, :, :, actual_k_len:] = float("-inf") sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] logits_max = torch.max(logits, dim=-1, keepdim=True).values @@ -765,8 +716,7 @@ def test_varlen_decode_main(args): attn_weights[i, :, :, actual_k_len:] = 0.0 # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] - O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), - v_repeat) # [b, q_heads, 1, head_size] + O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat) # [b, q_heads, 1, head_size] O_torch = O_torch.squeeze(2) # [b, q_heads, head_size] @@ -775,7 +725,8 @@ def test_varlen_decode_main(args): attn_weights.squeeze(2), # [b, q_heads, max_seqlen] kernel_size=(q_heads, block_size), stride=(q_heads, block_size), - ceil_mode=True).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)] + ceil_mode=True, + ).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)] print(f"O_triton shape: {O_triton.shape}") print(f"O_tilelang shape: {O_tilelang.shape}") @@ -791,22 +742,16 @@ def test_varlen_decode_main(args): print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}") max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) - max_diff_s_tl = torch.max( - torch.abs(S_tilelang[:, :, :math.ceil(max_seqlen_k / block_size)] - attn_score_pooled)) + max_diff_s_tl = torch.max(torch.abs(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)] - attn_score_pooled)) print(f"Max difference in S: {max_diff_s.item()}") print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}") - assert torch.allclose( - O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" - assert torch.allclose( - S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" - assert torch.allclose( - O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}" - assert torch.allclose( - S_tilelang[:, :, :math.ceil(max_seqlen_k / block_size)], - attn_score_pooled, - atol=1e-2, - rtol=1e-2), f"Score mismatch: {max_diff_s_tl.item()}" + assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" + assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" + assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}" + assert torch.allclose(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)], attn_score_pooled, atol=1e-2, rtol=1e-2), ( + f"Score mismatch: {max_diff_s_tl.item()}" + ) print("✅ All tests passed!") @@ -865,7 +810,7 @@ def speed_benchmark_decode_comparison(args): k_seqlens = torch.full((batch_size,), max_k_seqlen, dtype=int) # Generate cumulative sequence lengths for k - cu_seqlens_k = torch.zeros(batch_size + 1, device='cuda', dtype=torch.int32) + cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) total_k_tokens = 0 for i in range(batch_size): cu_seqlens_k[i] = total_k_tokens @@ -873,9 +818,9 @@ def speed_benchmark_decode_comparison(args): cu_seqlens_k[batch_size] = total_k_tokens # Generate tensors - q_decode = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) - k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) - v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) + q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) softmax_scale = 1.0 / math.sqrt(head_size) max_seqlen_k = int(k_seqlens.max()) @@ -883,7 +828,7 @@ def speed_benchmark_decode_comparison(args): # Generate sink values if needed sink = None if args.test_sink: - sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values print(" Using sink attention with sink values") print("Setup complete:") @@ -896,8 +841,7 @@ def speed_benchmark_decode_comparison(args): num_tokens, q_h, head_size = q_decode.shape batch = cu_seqlens_k.size(0) - 1 k_h = k_varlen.size(1) - tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, - args.test_sink) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink) # Benchmark print("⚡ Benchmarking Tilelang kernel (100 iterations)...") @@ -920,36 +864,41 @@ def speed_benchmark_decode_comparison(args): # Benchmark print("⚡ Benchmarking Triton kernel (100 iterations)...") - triton_time = do_bench(flash_attn_with_attn_pool_decode, q_decode, k_varlen, v_varlen, - cu_seqlens_k, max_seqlen_k, args.k_seqlen, 1, softmax_scale, sink, - block_size) + triton_time = do_bench( + flash_attn_with_attn_pool_decode, + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + args.k_seqlen, + 1, + softmax_scale, + sink, + block_size, + ) print(f"Average decode kernel time Triton: {triton_time:.3f} ms") print(f"Speedup: {(triton_time / tilelang_time):.3f}") if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Flash Attention Decode with Attention Pooling') - parser.add_argument('--batch_size', type=int, default=1, help='Batch size') - parser.add_argument('--q_heads', type=int, default=32, help='Number of query heads') - parser.add_argument('--kv_heads', type=int, default=8, help='Number of key-value heads') - parser.add_argument('--k_seqlen', type=int, default=8192, help='Key sequence length') - parser.add_argument( - '--head_size', type=int, default=128, choices=[64, 128, 256], help='Head dimension') - parser.add_argument('--block_size', type=int, default=64, help='Block size for computation') - parser.add_argument( - '--dtype', type=str, default='bfloat16', choices=['float16', 'bfloat16'], help='Data type') - parser.add_argument( - '--test_varlen', action='store_true', help='Test with truly variable sequence lengths') - parser.add_argument( - '--test_sink', action='store_true', help='Test with sink attention mechanism') - parser.add_argument('--benchmark', action='store_true', help='Run speed benchmark') - parser.add_argument( - '--num_split', type=int, default=1, choices=[1, 16], help='Number of splits') + parser = argparse.ArgumentParser(description="Flash Attention Decode with Attention Pooling") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("--q_heads", type=int, default=32, help="Number of query heads") + parser.add_argument("--kv_heads", type=int, default=8, help="Number of key-value heads") + parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length") + parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension") + parser.add_argument("--block_size", type=int, default=64, help="Block size for computation") + parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float16", "bfloat16"], help="Data type") + parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths") + parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism") + parser.add_argument("--benchmark", action="store_true", help="Run speed benchmark") + parser.add_argument("--num_split", type=int, default=1, choices=[1, 16], help="Number of splits") args = parser.parse_args() args.test_sink = True args.test_varlen = False - args.dtype = 'float16' + args.dtype = "float16" args.num_split = 1 if args.benchmark: diff --git a/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py b/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py index e565cbeb5f6f6c937ef5f524ca7f9d86bbbd93a0..3537e5af049d33f30d49fc9d59c55dd541d76411 100644 --- a/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py +++ b/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py @@ -10,6 +10,7 @@ torch.manual_seed(0) def get_configs(): import itertools + block_N = [64, 128] block_H = [64] num_split = [1] @@ -17,32 +18,28 @@ def get_configs(): threads = [128] _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) - configs = [{ - 'block_N': c[0], - 'block_H': c[1], - 'num_split': c[2], - 'num_stages': c[3], - 'threads': c[4] - } for c in _configs] + configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs] return configs # @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit(out_idx=[-2, -1], debug_root_path="./examples/flash_decoding") -def flashattn(batch, - heads, - k_heads, - max_seqlen_kv, - total_seqlen_k, - dim, - has_sink, - page_block_size, - block_N=128, - block_H=64, - num_split=1, - num_stages=1, - threads=128): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) +def flashattn( + batch, + heads, + k_heads, + max_seqlen_kv, + total_seqlen_k, + dim, + has_sink, + page_block_size, + block_N=128, + block_H=64, + num_split=1, + num_stages=1, + threads=128, +): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape_q = [batch, heads, dim] shape_k = [total_seqlen_k, k_heads, dim] shape_v = [total_seqlen_k, k_heads, dim] @@ -51,21 +48,23 @@ def flashattn(batch, dtype = "float16" accum_dtype = "float" kv_group_num = heads // k_heads - assert page_block_size >= block_N and page_block_size % block_N == 0, "page_block_size must be larger than block_N and a multiple of block_N" + assert page_block_size >= block_N and page_block_size % block_N == 0, ( + "page_block_size must be larger than block_N and a multiple of block_N" + ) valid_block_H = min(block_H, kv_group_num) # TODO: check if max_seqlen_kv is correct for varlen case @T.macro def flash_attn( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - cu_seqlens_k: T.Tensor([batch + 1], "int32"), - s_aux: T.Tensor([heads], "float32"), - BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / block_N)], "int32"), - Output: T.Tensor([batch, heads, dim], dtype), - S: T.Tensor(shape_s, dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + cu_seqlens_k: T.Tensor([batch + 1], "int32"), + s_aux: T.Tensor([heads], "float32"), + BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / block_N)], "int32"), + Output: T.Tensor([batch, heads, dim], dtype), + S: T.Tensor(shape_s, dtype), ): with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) @@ -91,7 +90,7 @@ def flashattn(batch, cur_end_k = cu_seqlens_k[bid + 1] cur_seqlen_k = cur_end_k - cur_start_k - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -99,15 +98,12 @@ def flashattn(batch, # loop_range = T.ceildiv((seqlen_kv // num_split), block_N) loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N) for k in T.Pipelined(loop_range, num_stages=num_stages): - k_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + ( - k * block_N) % page_block_size - T.copy(K[cur_start_k + k_start:cur_start_k + k_start + block_N, cur_kv_head, :], - K_shared) + k_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + (k * block_N) % page_block_size + T.copy(K[cur_start_k + k_start : cur_start_k + k_start + block_N, cur_kv_head, :], K_shared) T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], - -T.infinity(accum_dtype)) + acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], -T.infinity(accum_dtype)) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) @@ -127,14 +123,12 @@ def flashattn(batch, T.copy(acc_s, acc_s_cast) for i, j in T.Parallel(block_H, dim): acc_o[i, j] *= scores_scale[i] - v_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + ( - k * block_N) % page_block_size - T.copy(V[cur_start_k + v_start:cur_start_k + v_start + block_N, cur_kv_head, :], - V_shared) + v_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + (k * block_N) % page_block_size + T.copy(V[cur_start_k + v_start : cur_start_k + v_start + block_N, cur_kv_head, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) if has_sink: - T.copy(s_aux[hid * valid_block_H:hid * valid_block_H + block_H], s_aux_shared) + T.copy(s_aux[hid * valid_block_H : hid * valid_block_H + block_H], s_aux_shared) for i in T.Parallel(block_H): logsum[i] += s_aux_shared[i] for i, j in T.Parallel(block_H, dim): @@ -144,20 +138,19 @@ def flashattn(batch, for i in T.Parallel(block_H): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale T.copy(acc_o[:valid_block_H, :], O_shared) - T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :]) - T.copy(S_shared[:valid_block_H, :], S[bid, - hid * valid_block_H:(hid + 1) * valid_block_H, :]) + T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) + T.copy(S_shared[:valid_block_H, :], S[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) @T.prim_func def flashattn_gqa_decode_no_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - cu_seqlens_k: T.Tensor([batch + 1], "int32"), - s_aux: T.Tensor([heads], "float32"), - BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / page_block_size)], "int32"), - Output: T.Tensor(shape_o, dtype), - S: T.Tensor(shape_s, dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + cu_seqlens_k: T.Tensor([batch + 1], "int32"), + s_aux: T.Tensor([heads], "float32"), + BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / page_block_size)], "int32"), + Output: T.Tensor(shape_o, dtype), + S: T.Tensor(shape_s, dtype), ): flash_attn(Q, K, V, cu_seqlens_k, s_aux, BLOCK_TABLE, Output, S) @@ -195,9 +188,7 @@ def flash_attn_with_attn_pool_decode_tilelang( gqa_group_size = q_h // k_h O_tl = torch.zeros_like(Q) - S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), - dtype=Q.dtype, - device=Q.device) + S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), dtype=Q.dtype, device=Q.device) O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux, block_table) if use_per_kv_head_sparse_index: @@ -223,15 +214,15 @@ def test_equal_seqlen_decode_main(args): dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 # For decode, query is just 1 token per batch - q = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) - k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device='cuda', dtype=dtype) - v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device='cuda', dtype=dtype) + q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype) + v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype) softmax_scale = 1.0 / math.sqrt(head_size) # Generate sink values if needed sink = None if args.test_sink: - sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values print(f"Using sink attention with sink values: {sink}") # Convert to varlen format for K, V @@ -239,8 +230,7 @@ def test_equal_seqlen_decode_main(args): v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size).contiguous() # Generate cumulative sequence lengths - cu_seqlens_k = torch.arange( - 0, (batch_size + 1) * k_seqlen, k_seqlen, device='cuda', dtype=torch.int32) + cu_seqlens_k = torch.arange(0, (batch_size + 1) * k_seqlen, k_seqlen, device="cuda", dtype=torch.int32) max_seqlen_k = k_seqlen print(f"q shape: {q.shape}") @@ -250,11 +240,9 @@ def test_equal_seqlen_decode_main(args): num_tokens, q_h, head_size = q.shape batch = cu_seqlens_k.size(0) - 1 k_h = k_varlen.size(1) - tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, - args.test_sink, page_block_size) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size) - block_table = torch.zeros( - batch, math.ceil(real_max_k_seqlen / page_block_size), device='cuda', dtype=torch.int32) + block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32) block_cnt = 0 for i in range(batch): cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item() @@ -274,7 +262,8 @@ def test_equal_seqlen_decode_main(args): args.num_split, softmax_scale, s_aux=sink, - block_size=block_size) + block_size=block_size, + ) O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( q, k_varlen, @@ -290,9 +279,7 @@ def test_equal_seqlen_decode_main(args): block_table=block_table, ) for i in range(batch_size): - S_tilelang[i, :, - math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / - block_size):] = 0 + S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0 # Compute torch reference q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size] @@ -301,14 +288,12 @@ def test_equal_seqlen_decode_main(args): if sink is None: # Standard scaled dot-product attention - logits = torch.matmul(q_expanded, k_repeat.transpose( - -2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] attn_weights = torch.softmax(logits, dim=-1) O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size] else: # s_aux attention - logits = torch.matmul(q_expanded, k_repeat.transpose( - -2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] logits_max = torch.max(logits, dim=-1, keepdim=True).values @@ -317,15 +302,15 @@ def test_equal_seqlen_decode_main(args): unnormalized_scores = torch.exp(logits - logits_or_sinks_max) normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks attn_weights = unnormalized_scores / normalizer - O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), - v_repeat).squeeze(2) # [batch, q_heads, head_size] + O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat).squeeze(2) # [batch, q_heads, head_size] # Compute attention score pooling attn_score_pooled = torch.max_pool2d( attn_weights.squeeze(2), # [b, q_heads, k_seqlen] kernel_size=(q_heads, block_size), stride=(q_heads, block_size), - ceil_mode=True).to(torch.float16) + ceil_mode=True, + ).to(torch.float16) print("S_tilelang", S_tilelang) print("attn_score_pooled", attn_score_pooled) @@ -339,15 +324,10 @@ def test_equal_seqlen_decode_main(args): print(f"Max difference in S: {max_diff_s.item()}") print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}") print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}") - assert torch.allclose( - O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" - assert torch.allclose( - S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" - assert torch.allclose( - O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}" - assert torch.allclose( - S_tilelang, attn_score_pooled, atol=1e-2, - rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}" + assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" + assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" + assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}" + assert torch.allclose(S_tilelang, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}" print("✅ All tests passed!") @@ -368,7 +348,7 @@ def test_varlen_decode_main(args): # Generate sink values if needed sink = None if args.test_sink: - sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values print(f"Using sink attention with sink values: {sink}") # Generate variable length k sequences @@ -376,7 +356,7 @@ def test_varlen_decode_main(args): print(f"k_seqlens: {k_seqlens}") # Generate cumulative sequence lengths for k - cu_seqlens_k = torch.zeros(batch_size + 1, device='cuda', dtype=torch.int32) + cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) total_k_tokens = 0 for i in range(batch_size): cu_seqlens_k[i] = total_k_tokens @@ -386,9 +366,9 @@ def test_varlen_decode_main(args): print(f"cu_seqlens_k: {cu_seqlens_k}") # Generate tensors - Q is [batch_size, q_heads, head_size] for decode - q_decode = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) - k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) - v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) + q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) softmax_scale = 1.0 / math.sqrt(head_size) max_seqlen_k = int(k_seqlens.max()) @@ -401,11 +381,9 @@ def test_varlen_decode_main(args): num_tokens, q_h, head_size = q_decode.shape batch = cu_seqlens_k.size(0) - 1 k_h = k_varlen.size(1) - tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, - args.test_sink, page_block_size) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size) - block_table = torch.zeros( - batch, math.ceil(real_max_k_seqlen / page_block_size), device='cuda', dtype=torch.int32) + block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32) block_cnt = 0 for i in range(batch): cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item() @@ -425,7 +403,8 @@ def test_varlen_decode_main(args): args.num_split, softmax_scale, s_aux=sink, - block_size=block_size) + block_size=block_size, + ) O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( q_decode, k_varlen, @@ -441,9 +420,7 @@ def test_varlen_decode_main(args): block_table=block_table, ) for i in range(batch_size): - S_tilelang[i, :, - math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / - block_size):] = 0 + S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0 # Create torch reference - pad tensors for comparison k_padded_list = [] @@ -457,8 +434,8 @@ def test_varlen_decode_main(args): k_end = cu_seqlens_k[i + 1] # Pad to max_seqlen_k - k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device='cuda', dtype=dtype) - v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device='cuda', dtype=dtype) + k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) + v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) k_padded[:actual_k_len] = k_varlen[k_start:k_end] v_padded[:actual_k_len] = v_varlen[k_start:k_end] @@ -467,10 +444,8 @@ def test_varlen_decode_main(args): v_padded_list.append(v_padded) # Stack to create batched tensors [b, max_seqlen, kv_heads, head_size] - k_padded_batched = torch.stack( - k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] - v_padded_batched = torch.stack( - v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] + k_padded_batched = torch.stack(k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] + v_padded_batched = torch.stack(v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] # Expand q to match kv heads: [b, q_heads, 1, head_size] q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size] @@ -480,20 +455,17 @@ def test_varlen_decode_main(args): print(f"v_padded_batched shape: {v_padded_batched.shape}") # Compute torch reference - k_repeat = repeat_kv(k_padded_batched, - q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] - v_repeat = repeat_kv(v_padded_batched, - q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] + k_repeat = repeat_kv(k_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] + v_repeat = repeat_kv(v_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] if sink is None: # Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen] - attn_score = torch.matmul(q_expanded, k_repeat.transpose( - -2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] + attn_score = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] # Apply sequence length masking for i in range(batch_size): actual_k_len = k_seqlens[i] - attn_score[i, :, :, actual_k_len:] = float('-inf') + attn_score[i, :, :, actual_k_len:] = float("-inf") attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen] @@ -506,13 +478,12 @@ def test_varlen_decode_main(args): O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size] else: # s_aux attention - logits = torch.matmul(q_expanded, k_repeat.transpose( - -2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] # Apply sequence length masking for i in range(batch_size): actual_k_len = k_seqlens[i] - logits[i, :, :, actual_k_len:] = float('-inf') + logits[i, :, :, actual_k_len:] = float("-inf") sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] logits_max = torch.max(logits, dim=-1, keepdim=True).values @@ -528,8 +499,7 @@ def test_varlen_decode_main(args): attn_weights[i, :, :, actual_k_len:] = 0.0 # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] - O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), - v_repeat) # [b, q_heads, 1, head_size] + O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat) # [b, q_heads, 1, head_size] O_torch = O_torch.squeeze(2) # [b, q_heads, head_size] @@ -538,7 +508,8 @@ def test_varlen_decode_main(args): attn_weights.squeeze(2), # [b, q_heads, max_seqlen] kernel_size=(q_heads, block_size), stride=(q_heads, block_size), - ceil_mode=True).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)] + ceil_mode=True, + ).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)] print(f"O_triton shape: {O_triton.shape}") print(f"O_tilelang shape: {O_tilelang.shape}") @@ -554,22 +525,16 @@ def test_varlen_decode_main(args): print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}") max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) - max_diff_s_tl = torch.max( - torch.abs(S_tilelang[:, :, :math.ceil(max_seqlen_k / block_size)] - attn_score_pooled)) + max_diff_s_tl = torch.max(torch.abs(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)] - attn_score_pooled)) print(f"Max difference in S: {max_diff_s.item()}") print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}") - assert torch.allclose( - O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" - assert torch.allclose( - S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" - assert torch.allclose( - O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}" - assert torch.allclose( - S_tilelang[:, :, :math.ceil(max_seqlen_k / block_size)], - attn_score_pooled, - atol=1e-2, - rtol=1e-2), f"Score mismatch: {max_diff_s_tl.item()}" + assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" + assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" + assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}" + assert torch.allclose(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)], attn_score_pooled, atol=1e-2, rtol=1e-2), ( + f"Score mismatch: {max_diff_s_tl.item()}" + ) print("✅ All tests passed!") @@ -605,7 +570,7 @@ def speed_benchmark_decode_comparison(args): k_seqlens = torch.full((batch_size,), max_k_seqlen, dtype=int) # Generate cumulative sequence lengths for k - cu_seqlens_k = torch.zeros(batch_size + 1, device='cuda', dtype=torch.int32) + cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) total_k_tokens = 0 for i in range(batch_size): cu_seqlens_k[i] = total_k_tokens @@ -613,9 +578,9 @@ def speed_benchmark_decode_comparison(args): cu_seqlens_k[batch_size] = total_k_tokens # Generate tensors - q_decode = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) - k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) - v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) + q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) softmax_scale = 1.0 / math.sqrt(head_size) max_seqlen_k = int(k_seqlens.max()) @@ -623,7 +588,7 @@ def speed_benchmark_decode_comparison(args): # Generate sink values if needed sink = None if args.test_sink: - sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values print(" Using sink attention with sink values") print("Setup complete:") @@ -636,11 +601,9 @@ def speed_benchmark_decode_comparison(args): num_tokens, q_h, head_size = q_decode.shape batch = cu_seqlens_k.size(0) - 1 k_h = k_varlen.size(1) - tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, - args.test_sink, page_block_size) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size) - block_table = torch.zeros( - batch, math.ceil(real_max_k_seqlen / page_block_size), device='cuda', dtype=torch.int32) + block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32) block_cnt = 0 for i in range(batch): cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item() @@ -671,36 +634,41 @@ def speed_benchmark_decode_comparison(args): # Benchmark print("⚡ Benchmarking Triton kernel (100 iterations)...") - triton_time = do_bench(flash_attn_with_attn_pool_decode, q_decode, k_varlen, v_varlen, - cu_seqlens_k, max_seqlen_k, args.k_seqlen, 1, softmax_scale, sink, - block_size) + triton_time = do_bench( + flash_attn_with_attn_pool_decode, + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + args.k_seqlen, + 1, + softmax_scale, + sink, + block_size, + ) print(f"Average decode kernel time Triton: {triton_time:.3f} ms") print(f"Speedup: {(triton_time / tilelang_time):.3f}") if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Flash Attention Decode with Attention Pooling') - parser.add_argument('--batch_size', type=int, default=1, help='Batch size') - parser.add_argument('--q_heads', type=int, default=32, help='Number of query heads') - parser.add_argument('--kv_heads', type=int, default=8, help='Number of key-value heads') - parser.add_argument('--k_seqlen', type=int, default=8192, help='Key sequence length') - parser.add_argument( - '--head_size', type=int, default=128, choices=[64, 128, 256], help='Head dimension') - parser.add_argument('--block_size', type=int, default=128, help='Block size for computation') - parser.add_argument( - '--dtype', type=str, default='bfloat16', choices=['float16', 'bfloat16'], help='Data type') - parser.add_argument( - '--test_varlen', action='store_true', help='Test with truly variable sequence lengths') - parser.add_argument( - '--test_sink', action='store_true', help='Test with sink attention mechanism') - parser.add_argument('--benchmark', action='store_true', help='Run speed benchmark') - parser.add_argument( - '--num_split', type=int, default=1, choices=[1, 16], help='Number of splits') - parser.add_argument('--page_block_size', type=int, default=128, help='Page block size') + parser = argparse.ArgumentParser(description="Flash Attention Decode with Attention Pooling") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("--q_heads", type=int, default=32, help="Number of query heads") + parser.add_argument("--kv_heads", type=int, default=8, help="Number of key-value heads") + parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length") + parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension") + parser.add_argument("--block_size", type=int, default=128, help="Block size for computation") + parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float16", "bfloat16"], help="Data type") + parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths") + parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism") + parser.add_argument("--benchmark", action="store_true", help="Run speed benchmark") + parser.add_argument("--num_split", type=int, default=1, choices=[1, 16], help="Number of splits") + parser.add_argument("--page_block_size", type=int, default=128, help="Page block size") args = parser.parse_args() args.test_sink = True args.test_varlen = True - args.dtype = 'float16' + args.dtype = "float16" args.num_split = 1 if args.benchmark: diff --git a/examples/flash_decoding/example_mha_inference.py b/examples/flash_decoding/example_mha_inference.py index 0360b3e2b82dae777c77290f1bfb8e1bf84c72ec..d0381bc4ac94734fda6269e0c4d6903860a580f8 100644 --- a/examples/flash_decoding/example_mha_inference.py +++ b/examples/flash_decoding/example_mha_inference.py @@ -10,7 +10,7 @@ num_split = 4 @tilelang.jit(out_idx=[5]) def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape_q = [batch, seqlen_q, heads, dim] shape_kv = [batch, seqlen_kv, heads, dim] part_shape = [batch, seqlen_q, heads, num_split, dim] @@ -29,14 +29,11 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ bid: T.int32, sid: T.int32, ): - T.copy( - K[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid + - (k + 1) * block_N, hid, :], K_shared) + T.copy(K[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :], K_shared) # TODO: Handle causal split case if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -52,20 +49,18 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ bid: T.int32, sid: T.int32, ): - T.copy( - V[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid + - (k + 1) * block_N, hid, :], V_shared) + T.copy(V[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -91,23 +86,21 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.macro def flash_attn_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_kv, dtype), - V: T.Tensor(shape_kv, dtype), - glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), - Output_partial: T.Tensor(part_shape, dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_kv, dtype), + V: T.Tensor(shape_kv, dtype), + glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), + Output_partial: T.Tensor(part_shape, dtype), ): - with T.Kernel( - T.ceildiv(seqlen_q, block_M), heads * batch, num_split, - threads=128) as (bx, by, bz): + with T.Kernel(T.ceildiv(seqlen_q, block_M), heads * batch, num_split, threads=128) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -128,39 +121,36 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ # NOTE(wt): tma barrier has some problems with padded dimensions (seq_q here) currently # disable relevant tma copy and use SIMT as fallback for now - T.copy(Q[bid, mid * block_M:(mid + 1) * block_M, hid, :], Q_shared, disable_tma=True) + T.copy(Q[bid, mid * block_M : (mid + 1) * block_M, hid, :], Q_shared, disable_tma=True) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) # TODO: Handle causal split case loop_range = ( - T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv( - (mid + 1) * block_M, block_N)) if is_causal else T.ceildiv( - (seqlen_kv // num_split), block_N)) + T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv((mid + 1) * block_M, block_N)) + if is_causal + else T.ceildiv((seqlen_kv // num_split), block_N) + ) for k in T.Pipelined(loop_range, num_stages=2): MMA0(K, Q_shared, K_shared, acc_s, k, mid, hid, bid, sid) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, hid, bid, sid) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] for i in T.Parallel(block_M): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, glse[bid, hid, sid, mid * block_M:(mid + 1) * block_M]) + T.copy(logsum, glse[bid, hid, sid, mid * block_M : (mid + 1) * block_M]) T.copy(acc_o, O_shared) - T.copy( - O_shared, - Output_partial[bid, mid * block_M:(mid + 1) * block_M, hid, sid, :], - disable_tma=True) + T.copy(O_shared, Output_partial[bid, mid * block_M : (mid + 1) * block_M, hid, sid, :], disable_tma=True) @T.macro def combine( - glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), - Output_partial: T.Tensor(part_shape, dtype), - Output: T.Tensor(shape_q, dtype), + glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), + Output_partial: T.Tensor(part_shape, dtype), + Output: T.Tensor(shape_q, dtype), ): with T.Kernel(T.ceildiv(seqlen_q, block_M), heads, batch, threads=128) as (bx, by, bz): po_local = T.alloc_fragment([block_M, dim], dtype) @@ -173,20 +163,25 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ lse_max_local = T.alloc_fragment([block_M], accum_dtype) scale_local = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout({ - o_accum_local: T.Fragment(o_accum_local.shape, forward_thread_fn=lambda i, j: i), - o_shared: tilelang.layout.make_swizzled_layout(o_shared), - po_shared: tilelang.layout.make_swizzled_layout(po_shared), - }) + T.annotate_layout( + { + o_accum_local: T.Fragment(o_accum_local.shape, forward_thread_fn=lambda i, j: i), + o_shared: tilelang.layout.make_swizzled_layout(o_shared), + po_shared: tilelang.layout.make_swizzled_layout(po_shared), + } + ) T.clear(lse_logsum_local) T.clear(o_accum_local) - T.copy(glse[ - bz, - by, - :, - bx * block_M:(bx + 1) * block_M, - ], lse_local) + T.copy( + glse[ + bz, + by, + :, + bx * block_M : (bx + 1) * block_M, + ], + lse_local, + ) T.reduce_max(lse_local, lse_max_local, dim=0, clear=False) for k in T.Pipelined(num_split): T.copy(lse_local[k, :], lse_local_split) @@ -195,10 +190,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ for i in T.Parallel(block_M): lse_logsum_local[i] = T.log2(lse_logsum_local[i]) + lse_max_local[i] for k in T.Pipelined(num_split, num_stages=2): - T.copy( - Output_partial[bz, bx * block_M:(bx + 1) * block_M, by, k, :], - po_shared, - disable_tma=True) + T.copy(Output_partial[bz, bx * block_M : (bx + 1) * block_M, by, k, :], po_shared, disable_tma=True) T.copy(po_shared, po_local) for i in T.Parallel(block_M): lse_local_split[i] = lse_local[k, i] @@ -207,16 +199,16 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ for i, j in T.Parallel(block_M, dim): o_accum_local[i, j] += po_local[i, j] * scale_local[i] T.copy(o_accum_local, o_shared) - T.copy(o_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :], disable_tma=True) + T.copy(o_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :], disable_tma=True) @T.prim_func def flashattn_mha_inference( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_kv, dtype), - V: T.Tensor(shape_kv, dtype), - glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), - Output_partial: T.Tensor(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim] - Output: T.Tensor(shape_q, dtype), + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_kv, dtype), + V: T.Tensor(shape_kv, dtype), + glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), + Output_partial: T.Tensor(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim] + Output: T.Tensor(shape_q, dtype), ): flash_attn_split(Q, K, V, glse, Output_partial) combine(glse, Output_partial, Output) @@ -227,10 +219,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ def ref_program(Q, K, V, glse, Output_partial, causal): assert causal is False dim = Q.size(-1) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) return output @@ -258,7 +250,7 @@ def flash_split_ref(Q, K, V, causal): block_N = 128 seqlen_kv = K.size(1) - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) acc_s = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float) acc_s_cast = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float16) acc_o = torch.empty((batch, block_M, nheads, dim), device="cuda", dtype=torch.float) @@ -275,14 +267,15 @@ def flash_split_ref(Q, K, V, causal): for ks in range(num_split): acc_o.fill_(0) logsum.fill_(0) - scores_max.fill_(float('-inf')) - scores_max_prev.fill_(float('-inf')) + scores_max.fill_(float("-inf")) + scores_max_prev.fill_(float("-inf")) for i in range(int((seqlen_kv // num_split) / block_N)): acc_s.fill_(0) - acc_s = torch.einsum('bqhd,bkhd->bhqk', Q_, - K[:, (seqlen_kv // num_split) * ks + - i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) # [batch, seqlen, nheads, block_N] + acc_s = torch.einsum( + "bqhd,bkhd->bhqk", + Q_, + K[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) # [batch, seqlen, nheads, block_N] scores_max_prev = scores_max scores_max = acc_s.max(dim=-1, keepdim=False).values # [blockM] scores_scale = torch.exp2(scores_max_prev - scores_max) @@ -290,9 +283,10 @@ def flash_split_ref(Q, K, V, causal): acc_s = torch.exp2(acc_s - scores_max[:, :, :, None]) acc_s_cast = acc_s.to(torch.float16) acc_o += torch.einsum( - 'bhqk,bkhd->bqhd', acc_s_cast, - V[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks + - (i + 1) * block_N, :, :]) + "bhqk,bkhd->bqhd", + acc_s_cast, + V[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) scores_sum = acc_s.sum(dim=-1, keepdim=False) logsum = logsum * scores_scale + scores_sum acc_o /= logsum[:, :, :, None].transpose(1, 2) @@ -300,8 +294,7 @@ def flash_split_ref(Q, K, V, causal): gacc_o[ks, :, :, :, :] = acc_o glogsum[ks, :, :, :] = logsum - return glogsum.to(torch.float16).permute(1, 2, 0, - 3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4) + return glogsum.to(torch.float16).permute(1, 2, 0, 3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4) def main(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False): diff --git a/examples/fusedmoe/example_fusedmoe_tilelang.py b/examples/fusedmoe/example_fusedmoe_tilelang.py index a8d6849659118cb630ab0d97dad7b0233abec0b7..b737f30aae25274cf33df9d3a0215b320e619286 100644 --- a/examples/fusedmoe/example_fusedmoe_tilelang.py +++ b/examples/fusedmoe/example_fusedmoe_tilelang.py @@ -9,17 +9,18 @@ from example_fusedmoe_torch import * @tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) -def moe_forward_tilelang_shared(d_hidden, - d_expert, - n_shared_experts, - dtype, - num_tokens, - block_token=128, - block_dhidden=128, - block_dexpert=128, - threads=256, - num_stages=1): - +def moe_forward_tilelang_shared( + d_hidden, + d_expert, + n_shared_experts, + dtype, + num_tokens, + block_token=128, + block_dhidden=128, + block_dexpert=128, + threads=256, + num_stages=1, +): scale = 1.44269504 # log2(e) # Parameters @@ -36,17 +37,15 @@ def moe_forward_tilelang_shared(d_hidden, @T.prim_func def kernel_shared( - input: T.Tensor(input_shape, dtype), # type: ignore - shared_W_gate: T.Tensor(shared_W_gate_shape, dtype), # type: ignore - shared_W_up: T.Tensor(shared_W_up_shape, dtype), # type: ignore - shared_W_down: T.Tensor(shared_W_down_shape, dtype), # type: ignore - up_logits: T.Tensor((num_tokens, dexpert), dtype), # type: ignore - output: T.Tensor(input_shape, dtype), # type: ignore + input: T.Tensor(input_shape, dtype), # type: ignore + shared_W_gate: T.Tensor(shared_W_gate_shape, dtype), # type: ignore + shared_W_up: T.Tensor(shared_W_up_shape, dtype), # type: ignore + shared_W_down: T.Tensor(shared_W_down_shape, dtype), # type: ignore + up_logits: T.Tensor((num_tokens, dexpert), dtype), # type: ignore + output: T.Tensor(input_shape, dtype), # type: ignore ): # Step 1: Compute gate and up logits - with T.Kernel( - T.ceildiv(num_tokens, block_token), T.ceildiv(dexpert, block_dexpert), - threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(num_tokens, block_token), T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by): # Split the block to shared experts and routed experts input_shared = T.alloc_fragment((block_token, block_dhidden), dtype=dtype) W_gate_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype) @@ -70,16 +69,13 @@ def moe_forward_tilelang_shared(d_hidden, # Fuse with SiLU and element-wise product for i, j in T.Parallel(block_token, block_dexpert): - gate_logits_local[i, j] = gate_logits_local[i, j] * ( - 1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale))) + gate_logits_local[i, j] = gate_logits_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale))) up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j] T.copy(up_logits_local, up_logits[bx * block_token, by * block_dexpert]) # Step 2: Compute down logits - with T.Kernel( - T.ceildiv(num_tokens, block_token), T.ceildiv(dhidden, block_dhidden), - threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(num_tokens, block_token), T.ceildiv(dhidden, block_dhidden), threads=threads) as (bx, by): up_logits_shared = T.alloc_fragment((block_token, block_dexpert), dtype=dtype) W_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype) output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_type) @@ -98,20 +94,21 @@ def moe_forward_tilelang_shared(d_hidden, @tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) -def moe_forward_tilelang_routed(d_hidden, - d_expert, - n_routed_experts, - dtype, - group_sum, - group_count, - block_token=128, - block_dhidden=128, - block_dexpert=128, - threads=256, - num_stages=1, - k_pack=1, - coalesced_width=None): - +def moe_forward_tilelang_routed( + d_hidden, + d_expert, + n_routed_experts, + dtype, + group_sum, + group_count, + block_token=128, + block_dhidden=128, + block_dexpert=128, + threads=256, + num_stages=1, + k_pack=1, + coalesced_width=None, +): scale = 1.44269504 # log2(e) # Parameters @@ -132,22 +129,22 @@ def moe_forward_tilelang_routed(d_hidden, routed_expert_gate_shape = (n_routed_experts, dexpert, dhidden) routed_expert_up_shape = (n_routed_experts, dexpert, dhidden) routed_expert_down_shape = (n_routed_experts, dhidden, dexpert) - routed_expert_weights_shape = (group_sum) - group_sizes_shape = (n_routed_experts) + routed_expert_weights_shape = group_sum + group_sizes_shape = n_routed_experts @T.prim_func def kernel( - input: T.Tensor(input_shape, dtype), # type: ignore - routed_expert_gate: T.Tensor(routed_expert_gate_shape, dtype), # type: ignore - routed_expert_up: T.Tensor(routed_expert_up_shape, dtype), # type: ignore - routed_expert_down: T.Tensor(routed_expert_down_shape, dtype), # type: ignore - routed_expert_weights: T.Tensor(routed_expert_weights_shape, dtype), # type: ignore - group_sizes: T.Tensor(group_sizes_shape, "int32"), # type: ignore - group_offsets: T.Tensor(group_sizes_shape, "int32"), # type: ignore - group_padded_offsets: T.Tensor(group_sizes_shape, "int32"), # type: ignore - group_idx_for_bx: T.Tensor((M,), "int32"), # type: ignore - up_logits: T.Tensor(intermediate_shape, dtype), # type: ignore - output: T.Tensor(input_shape, dtype), # type: ignore + input: T.Tensor(input_shape, dtype), # type: ignore + routed_expert_gate: T.Tensor(routed_expert_gate_shape, dtype), # type: ignore + routed_expert_up: T.Tensor(routed_expert_up_shape, dtype), # type: ignore + routed_expert_down: T.Tensor(routed_expert_down_shape, dtype), # type: ignore + routed_expert_weights: T.Tensor(routed_expert_weights_shape, dtype), # type: ignore + group_sizes: T.Tensor(group_sizes_shape, "int32"), # type: ignore + group_offsets: T.Tensor(group_sizes_shape, "int32"), # type: ignore + group_padded_offsets: T.Tensor(group_sizes_shape, "int32"), # type: ignore + group_idx_for_bx: T.Tensor((M,), "int32"), # type: ignore + up_logits: T.Tensor(intermediate_shape, dtype), # type: ignore + output: T.Tensor(input_shape, dtype), # type: ignore ): # Step 1: Compute gate and up logits with T.Kernel(M, T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by): @@ -168,48 +165,37 @@ def moe_forward_tilelang_routed(d_hidden, cur_group_idx[0] = group_idx_for_bx[bx] cur_group_size[0] = group_sizes[cur_group_idx[0]] - m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[ - cur_group_idx[0]] - actual_rows = T.max( - 0, - T.min(block_token, cur_group_size[0] - - (m_start_padded - group_padded_offsets[cur_group_idx[0]]))) + m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[cur_group_idx[0]] + actual_rows = T.max(0, T.min(block_token, cur_group_size[0] - (m_start_padded - group_padded_offsets[cur_group_idx[0]]))) T.clear(gate_logits_local) T.clear(up_logits_local) for k in T.Pipelined(T.ceildiv(dhidden, block_dhidden), num_stages=num_stages): T.copy( - input[m_start:m_start + block_token, k * block_dhidden:(k + 1) * block_dhidden], + input[m_start : m_start + block_token, k * block_dhidden : (k + 1) * block_dhidden], input_shared, - coalesced_width=coalesced_width) + coalesced_width=coalesced_width, + ) T.copy( - routed_expert_gate[cur_group_idx[0], - by * block_dexpert:(by + 1) * block_dexpert, - k * block_dhidden:(k + 1) * block_dhidden], - routed_expert_gate_shared, - coalesced_width=coalesced_width) - T.gemm( - input_shared, + routed_expert_gate[ + cur_group_idx[0], by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden + ], routed_expert_gate_shared, - gate_logits_local, - k_pack=k_pack, - transpose_B=True) + coalesced_width=coalesced_width, + ) + T.gemm(input_shared, routed_expert_gate_shared, gate_logits_local, k_pack=k_pack, transpose_B=True) T.copy( - routed_expert_up[cur_group_idx[0], by * block_dexpert:(by + 1) * block_dexpert, - k * block_dhidden:(k + 1) * block_dhidden], + routed_expert_up[ + cur_group_idx[0], by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden + ], routed_expert_up_shared, - coalesced_width=coalesced_width) - T.gemm( - input_shared, - routed_expert_up_shared, - up_logits_local, - k_pack=k_pack, - transpose_B=True) + coalesced_width=coalesced_width, + ) + T.gemm(input_shared, routed_expert_up_shared, up_logits_local, k_pack=k_pack, transpose_B=True) for i, j in T.Parallel(block_token, block_dexpert): - gate_logits_local[i, j] = gate_logits_local[i, j] * ( - 1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale))) + gate_logits_local[i, j] = gate_logits_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale))) up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j] for i, j in T.Parallel(block_token, block_dexpert): @@ -232,50 +218,35 @@ def moe_forward_tilelang_routed(d_hidden, cur_group_idx[0] = group_idx_for_bx[bx] cur_group_size[0] = group_sizes[cur_group_idx[0]] - m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[ - cur_group_idx[0]] - actual_rows = T.max( - 0, - T.min(block_token, cur_group_size[0] - - (m_start_padded - group_padded_offsets[cur_group_idx[0]]))) + m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[cur_group_idx[0]] + actual_rows = T.max(0, T.min(block_token, cur_group_size[0] - (m_start_padded - group_padded_offsets[cur_group_idx[0]]))) T.clear(output_local) for k in T.Pipelined(T.ceildiv(dexpert, block_dexpert), num_stages=num_stages): T.copy( - up_logits[m_start:m_start + block_token, - k * block_dexpert:(k + 1) * block_dexpert], + up_logits[m_start : m_start + block_token, k * block_dexpert : (k + 1) * block_dexpert], up_logits_shared, - coalesced_width=coalesced_width) + coalesced_width=coalesced_width, + ) T.copy( - routed_expert_down[cur_group_idx[0], - by * block_dhidden:(by + 1) * block_dhidden, - k * block_dexpert:(k + 1) * block_dexpert], - routed_expert_down_shared, - coalesced_width=coalesced_width) - T.gemm( - up_logits_shared, + routed_expert_down[ + cur_group_idx[0], by * block_dhidden : (by + 1) * block_dhidden, k * block_dexpert : (k + 1) * block_dexpert + ], routed_expert_down_shared, - output_local, - k_pack=k_pack, - transpose_B=True) + coalesced_width=coalesced_width, + ) + T.gemm(up_logits_shared, routed_expert_down_shared, output_local, k_pack=k_pack, transpose_B=True) for i, j in T.Parallel(block_token, block_dhidden): if i < actual_rows: - output[m_start + i, by * block_dhidden + - j] = output_local[i, j] * routed_expert_weights[m_start + i] + output[m_start + i, by * block_dhidden + j] = output_local[i, j] * routed_expert_weights[m_start + i] return kernel class Expert(nn.Module): - - def __init__(self, - config: Dict, - gate: torch.Tensor, - up: torch.Tensor, - down: torch.Tensor, - d_expert: Optional[int] = None): + def __init__(self, config: Dict, gate: torch.Tensor, up: torch.Tensor, down: torch.Tensor, d_expert: Optional[int] = None): super().__init__() self.config = config self.act_fn = nn.SiLU() @@ -294,14 +265,13 @@ class Expert(nn.Module): class MoEGate(nn.Module): - def __init__(self, config: Dict, weights: Dict): super().__init__() self.top_k: int = config["n_experts_per_token"] self.num_experts: int = config["n_routed_experts"] self.d_hidden: int = config["d_hidden"] - self.W_g_weight = weights['router.weight'].t() + self.W_g_weight = weights["router.weight"].t() def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: logits = x @ self.W_g_weight @@ -312,76 +282,69 @@ class MoEGate(nn.Module): class MoE(nn.Module): - - def __init__(self, - config: Dict, - shared_kernel: tilelang.JITKernel, - routed_kernel: tilelang.JITKernel, - weights: Dict, - padding_M: int = 128): + def __init__( + self, config: Dict, shared_kernel: tilelang.JITKernel, routed_kernel: tilelang.JITKernel, weights: Dict, padding_M: int = 128 + ): super().__init__() self.config = config self.shared_kernel = shared_kernel self.routed_kernel = routed_kernel self.padding_M = padding_M - self.experts = nn.ModuleList([ - Expert( - config, - gate=weights[f'experts.{i}.0.weight'], - up=weights[f'experts.{i}.1.weight'], - down=weights[f'experts.{i}.2.weight']) for i in range(config["n_routed_experts"]) - ]) + self.experts = nn.ModuleList( + [ + Expert( + config, + gate=weights[f"experts.{i}.0.weight"], + up=weights[f"experts.{i}.1.weight"], + down=weights[f"experts.{i}.2.weight"], + ) + for i in range(config["n_routed_experts"]) + ] + ) self.device = torch.device("cuda") self.gating_network = MoEGate(config, weights).to(self.device) shared_expert_dim = config["d_expert"] * config["n_shared_experts"] self.shared_expert = Expert( config=config, - gate=weights['shared_experts.0.weight'], - up=weights['shared_experts.1.weight'], - down=weights['shared_experts.2.weight'], - d_expert=shared_expert_dim).to(self.device) + gate=weights["shared_experts.0.weight"], + up=weights["shared_experts.1.weight"], + down=weights["shared_experts.2.weight"], + d_expert=shared_expert_dim, + ).to(self.device) self.expert_cache = torch.zeros( - (config["batch_size"] * config["seq_len"], config["d_hidden"]), - dtype=torch.float16, - device=self.device) - self.stacked_expert_w_gate = torch.stack([expert.W_gate_weight for expert in self.experts], - dim=0) - self.stacked_expert_w_up = torch.stack([expert.W_up_weight for expert in self.experts], - dim=0) - self.stacked_expert_w_down = torch.stack([expert.W_down_weight for expert in self.experts], - dim=0) + (config["batch_size"] * config["seq_len"], config["d_hidden"]), dtype=torch.float16, device=self.device + ) + self.stacked_expert_w_gate = torch.stack([expert.W_gate_weight for expert in self.experts], dim=0) + self.stacked_expert_w_up = torch.stack([expert.W_up_weight for expert in self.experts], dim=0) + self.stacked_expert_w_down = torch.stack([expert.W_down_weight for expert in self.experts], dim=0) self.stacked_expert_tokens = torch.empty( - (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], - self.config["d_hidden"]), + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]), dtype=torch.float16, - device=self.device) + device=self.device, + ) self.stacked_expert_weights = torch.empty( - (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), - dtype=torch.float16, - device=self.device) + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.float16, device=self.device + ) self.stacked_expert_tokens_idxs = torch.empty( - (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), - dtype=torch.int64, - device=self.device) + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.int64, device=self.device + ) self.up_logits_shared = torch.empty( - (config["batch_size"] * config["seq_len"], self.config["d_expert"]), - dtype=torch.float16, - device=self.device) + (config["batch_size"] * config["seq_len"], self.config["d_expert"]), dtype=torch.float16, device=self.device + ) self.expert_output_shared = torch.empty( - (config["batch_size"] * config["seq_len"], self.config["d_hidden"]), - dtype=torch.float16, - device=self.device) + (config["batch_size"] * config["seq_len"], self.config["d_hidden"]), dtype=torch.float16, device=self.device + ) self.up_logits_routed = torch.empty( - (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], - self.config["d_expert"]), + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_expert"]), dtype=torch.float16, - device=self.device) + device=self.device, + ) self.expert_output_routed = torch.empty( - (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], - self.config["d_hidden"]), + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]), dtype=torch.float16, - device=self.device) + device=self.device, + ) @torch.no_grad() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -413,22 +376,20 @@ class MoE(nn.Module): self.stacked_expert_tokens[start_idx:end_idx] = expert_tokens self.stacked_expert_tokens_idxs[start_idx:end_idx] = exp_token_idxs - self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[ - idxs[start_idx:end_idx]] + self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[idxs[start_idx:end_idx]] group_sizes = torch.tensor(counts, dtype=torch.int32, device=self.device) - group_offset = torch.tensor( - tokens_per_expert - counts, dtype=torch.int32, device=self.device) + group_offset = torch.tensor(tokens_per_expert - counts, dtype=torch.int32, device=self.device) group_padded_offsets = [0 for _ in range(len(group_sizes))] for i in range(1, len(group_sizes)): - group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil( - (counts[i - 1] + 1) / self.padding_M) * self.padding_M + group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil((counts[i - 1] + 1) / self.padding_M) * self.padding_M block_token = 128 - M = math.ceil( - self.config["batch_size"] * self.config["seq_len"] * - self.config["n_experts_per_token"] / block_token) + self.config["n_routed_experts"] + M = ( + math.ceil(self.config["batch_size"] * self.config["seq_len"] * self.config["n_experts_per_token"] / block_token) + + self.config["n_routed_experts"] + ) group_idx_for_bx = [0 for _ in range(M)] for bx in range(M): @@ -437,8 +398,7 @@ class MoE(nn.Module): if m_start_padded >= group_padded_offsets[i]: group_idx_for_bx[bx] = i - group_padded_offsets = torch.tensor( - group_padded_offsets, dtype=torch.int32, device=self.device) + group_padded_offsets = torch.tensor(group_padded_offsets, dtype=torch.int32, device=self.device) group_idx_for_bx = torch.tensor(group_idx_for_bx, dtype=torch.int32, device=self.device) # Multi-stream execution @@ -448,11 +408,19 @@ class MoE(nn.Module): with torch.cuda.stream(routed_stream): # Tilelang version: Grouped GEMM - self.routed_kernel(self.stacked_expert_tokens, self.stacked_expert_w_gate, - self.stacked_expert_w_up, self.stacked_expert_w_down, - self.stacked_expert_weights, group_sizes, group_offset, - group_padded_offsets, group_idx_for_bx, self.up_logits_routed, - self.expert_output_routed) + self.routed_kernel( + self.stacked_expert_tokens, + self.stacked_expert_w_gate, + self.stacked_expert_w_up, + self.stacked_expert_w_down, + self.stacked_expert_weights, + group_sizes, + group_offset, + group_padded_offsets, + group_idx_for_bx, + self.up_logits_routed, + self.expert_output_routed, + ) # Scatter reduce self.expert_cache = torch.scatter_reduce( @@ -460,14 +428,19 @@ class MoE(nn.Module): 0, self.stacked_expert_tokens_idxs.view(-1, 1).repeat(1, x_flat.shape[-1]), self.expert_output_routed, - reduce='sum') + reduce="sum", + ) routed_output = self.expert_cache.view(*orig_shape) with torch.cuda.stream(shared_stream): - - self.shared_kernel(x_flat, self.shared_expert.W_gate_weight, - self.shared_expert.W_up_weight, self.shared_expert.W_down_weight, - self.up_logits_shared, self.expert_output_shared) + self.shared_kernel( + x_flat, + self.shared_expert.W_gate_weight, + self.shared_expert.W_up_weight, + self.shared_expert.W_down_weight, + self.up_logits_shared, + self.expert_output_shared, + ) shared_output = self.expert_output_shared.view(*orig_shape) torch.cuda.synchronize() @@ -498,7 +471,8 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: config["d_expert"], config["n_shared_experts"], dtype=dtype_str, - num_tokens=config["batch_size"] * config["seq_len"]) + num_tokens=config["batch_size"] * config["seq_len"], + ) routed_kernel = moe_forward_tilelang_routed( config["d_hidden"], config["d_expert"], @@ -512,7 +486,8 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: threads=256, num_stages=1, k_pack=1, - coalesced_width=2) + coalesced_width=2, + ) moe = MoE(config, shared_kernel, routed_kernel, weights, padding_M=128) @@ -521,13 +496,7 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: return output -def main(d_hidden=7168, - d_expert=2048, - n_routed_experts=8, - n_shared_experts=1, - n_experts_per_token=4, - batch_size=1, - seq_len=8192): +def main(d_hidden=7168, d_expert=2048, n_routed_experts=8, n_shared_experts=1, n_experts_per_token=4, batch_size=1, seq_len=8192): config = { "dhidden": d_hidden, "dexpert": d_expert, @@ -536,7 +505,7 @@ def main(d_hidden=7168, "nexpertspertoken": n_experts_per_token, "bs": batch_size, "seqlen": seq_len, - "seed": 81394 + "seed": 81394, } data = generate_input(**config) diff --git a/examples/fusedmoe/example_fusedmoe_torch.py b/examples/fusedmoe/example_fusedmoe_torch.py index 00219c6e94b54070d14b8e98715ff6173b04b00e..6b6322aff7dce196ce12f371d83f47e5c1fa82e4 100644 --- a/examples/fusedmoe/example_fusedmoe_torch.py +++ b/examples/fusedmoe/example_fusedmoe_torch.py @@ -6,7 +6,6 @@ from typing import Dict, Tuple, Optional # Reference code in PyTorch class ExpertTorch(nn.Module): - def __init__(self, config: Dict, d_expert: Optional[int] = None): super().__init__() self.config = config @@ -25,7 +24,6 @@ class ExpertTorch(nn.Module): class MoEGateTorch(nn.Module): - def __init__(self, config: Dict): super().__init__() self.top_k: int = config["n_experts_per_token"] @@ -43,12 +41,10 @@ class MoEGateTorch(nn.Module): class MoETorch(nn.Module): - def __init__(self, config: Dict): super().__init__() self.config = config - self.experts = nn.ModuleList( - [ExpertTorch(config) for _ in range(config["n_routed_experts"])]) + self.experts = nn.ModuleList([ExpertTorch(config) for _ in range(config["n_routed_experts"])]) self.gating_network = MoEGateTorch(config) shared_expert_dim = config["d_expert"] * config["n_shared_experts"] self.shared_expert = ExpertTorch(config=config, d_expert=shared_expert_dim) @@ -67,8 +63,7 @@ class MoETorch(nn.Module): return routed_output + shared_output @torch.no_grad() - def moe_infer(self, x: torch.Tensor, flat_expert_indices: torch.Tensor, - flat_expert_weights: torch.Tensor) -> torch.Tensor: + def moe_infer(self, x: torch.Tensor, flat_expert_indices: torch.Tensor, flat_expert_weights: torch.Tensor) -> torch.Tensor: expert_cache = torch.zeros_like(x) # test_expert_cache = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"])) # test_expert_tokens = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"])) @@ -91,8 +86,7 @@ class MoETorch(nn.Module): expert_out = expert(expert_tokens) expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) - expert_cache.scatter_reduce_( - 0, exp_token_idxs.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum') + expert_cache.scatter_reduce_(0, exp_token_idxs.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce="sum") return expert_cache @@ -116,21 +110,21 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: moe = MoETorch(config) # Fill in the given weights of the model - moe.gating_network.W_g.weight = nn.Parameter(weights['router.weight']) + moe.gating_network.W_g.weight = nn.Parameter(weights["router.weight"]) for i in range(num_experts): - gate_proj_weight = weights[f'experts.{i}.0.weight'] - up_proj_weight = weights[f'experts.{i}.1.weight'] - down_proj_weight = weights[f'experts.{i}.2.weight'] + gate_proj_weight = weights[f"experts.{i}.0.weight"] + up_proj_weight = weights[f"experts.{i}.1.weight"] + down_proj_weight = weights[f"experts.{i}.2.weight"] # Transpose weights to match expected shape for nn.Linear moe.experts[i].W_gate.weight = nn.Parameter(gate_proj_weight.t()) moe.experts[i].W_up.weight = nn.Parameter(up_proj_weight.t()) moe.experts[i].W_down.weight = nn.Parameter(down_proj_weight.t()) - moe.shared_expert.W_gate.weight = nn.Parameter(weights['shared_experts.0.weight'].t()) - moe.shared_expert.W_up.weight = nn.Parameter(weights['shared_experts.1.weight'].t()) - moe.shared_expert.W_down.weight = nn.Parameter(weights['shared_experts.2.weight'].t()) + moe.shared_expert.W_gate.weight = nn.Parameter(weights["shared_experts.0.weight"].t()) + moe.shared_expert.W_up.weight = nn.Parameter(weights["shared_experts.1.weight"].t()) + moe.shared_expert.W_down.weight = nn.Parameter(weights["shared_experts.2.weight"].t()) output = moe(input_tensor) @@ -140,10 +134,9 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: # Input generation for the reference code -def generate_input(dhidden: int, dexpert: int, nroutedexperts: int, nsharedexperts: int, - nexpertspertoken: int, bs: int, seqlen: int, - seed: int) -> Tuple[torch.Tensor, Dict, Dict]: - +def generate_input( + dhidden: int, dexpert: int, nroutedexperts: int, nsharedexperts: int, nexpertspertoken: int, bs: int, seqlen: int, seed: int +) -> Tuple[torch.Tensor, Dict, Dict]: # Really dumb but for now _ isn't parsing correctly. d_hidden = dhidden d_expert = dexpert @@ -163,50 +156,40 @@ def generate_input(dhidden: int, dexpert: int, nroutedexperts: int, nsharedexper "seq_len": seq_len, } - gen = torch.Generator(device='cuda') + gen = torch.Generator(device="cuda") gen.manual_seed(seed) num_experts = n_routed_experts expert_dim = d_expert weights = {} - input_tensor = torch.randn((batch_size, seq_len, d_hidden), - device='cuda', - dtype=torch.float16, - generator=gen).contiguous() + input_tensor = torch.randn((batch_size, seq_len, d_hidden), device="cuda", dtype=torch.float16, generator=gen).contiguous() # Initialize router weights - weights['router.weight'] = torch.randn( - (num_experts, d_hidden), device="cuda", dtype=torch.float16, - generator=gen) / math.sqrt(d_hidden) + weights["router.weight"] = torch.randn((num_experts, d_hidden), device="cuda", dtype=torch.float16, generator=gen) / math.sqrt(d_hidden) for i in range(num_experts): - weights[f'experts.{i}.0.weight'] = torch.randn( - (d_hidden, expert_dim), device='cuda', dtype=torch.float16, - generator=gen) / math.sqrt(expert_dim) - - weights[f'experts.{i}.1.weight'] = torch.randn( - (d_hidden, expert_dim), device='cuda', dtype=torch.float16, - generator=gen) / math.sqrt(expert_dim) - - weights[f'experts.{i}.2.weight'] = torch.randn( - (expert_dim, d_hidden), device='cuda', dtype=torch.float16, - generator=gen) / math.sqrt(d_hidden) - - weights['shared_experts.0.weight'] = torch.randn( - (d_hidden, expert_dim * n_shared_experts), - device='cuda', - dtype=torch.float16, - generator=gen) / math.sqrt(expert_dim * n_shared_experts) - weights['shared_experts.1.weight'] = torch.randn( - (d_hidden, expert_dim * n_shared_experts), - device='cuda', - dtype=torch.float16, - generator=gen) / math.sqrt(expert_dim * n_shared_experts) - weights['shared_experts.2.weight'] = torch.randn((expert_dim * n_shared_experts, d_hidden), - device='cuda', - dtype=torch.float16, - generator=gen) / math.sqrt(d_hidden) + weights[f"experts.{i}.0.weight"] = torch.randn( + (d_hidden, expert_dim), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(expert_dim) + + weights[f"experts.{i}.1.weight"] = torch.randn( + (d_hidden, expert_dim), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(expert_dim) + + weights[f"experts.{i}.2.weight"] = torch.randn( + (expert_dim, d_hidden), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(d_hidden) + + weights["shared_experts.0.weight"] = torch.randn( + (d_hidden, expert_dim * n_shared_experts), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(expert_dim * n_shared_experts) + weights["shared_experts.1.weight"] = torch.randn( + (d_hidden, expert_dim * n_shared_experts), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(expert_dim * n_shared_experts) + weights["shared_experts.2.weight"] = torch.randn( + (expert_dim * n_shared_experts, d_hidden), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(d_hidden) return (input_tensor, weights, config) diff --git a/examples/fusedmoe/test_example_fusedmoe.py b/examples/fusedmoe/test_example_fusedmoe.py index 806aff49ee3741816d9c687420553173829eccee..ba8415895d52f75dc8bf029b7a97e1cabc983b03 100644 --- a/examples/fusedmoe/test_example_fusedmoe.py +++ b/examples/fusedmoe/test_example_fusedmoe.py @@ -4,13 +4,8 @@ import example_fusedmoe_tilelang def test_example_fusedmoe_tilelang(): example_fusedmoe_tilelang.main( - d_hidden=1024, - d_expert=256, - n_routed_experts=8, - n_shared_experts=1, - n_experts_per_token=4, - batch_size=1, - seq_len=1024) + d_hidden=1024, d_expert=256, n_routed_experts=8, n_shared_experts=1, n_experts_per_token=4, batch_size=1, seq_len=1024 + ) if __name__ == "__main__": diff --git a/examples/gdn/example_chunk_delta_bwd.py b/examples/gdn/example_chunk_delta_bwd.py index d9ccc256543c77da92a4717725e93865114c51bc..ecda7e41b205047614f8096feffa7b20a34ac31a 100644 --- a/examples/gdn/example_chunk_delta_bwd.py +++ b/examples/gdn/example_chunk_delta_bwd.py @@ -12,6 +12,7 @@ print(tilelang.__file__, flush=True) # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__, flush=True) from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu except ImportError: @@ -49,6 +50,7 @@ def prepare_input( G = F.logsigmoid(G) try: from fla.ops.utils.cumsum import chunk_local_cumsum + G = chunk_local_cumsum(G, chunk_size) except ImportError: print("fla not found, skip cumsum") @@ -125,8 +127,11 @@ def torch_chunk_gated_delta_rule_bwd_dhu( DV = dv.shape[-1] block_S = 64 BS = S // block_S - dh, dh0, dv2 = torch.empty((B, BS, H, DK, DV), dtype=output_dtype), torch.empty( - (B, H, DK, DV), dtype=state_dtype), torch.empty((B, S, H, DV), dtype=output_dtype) + dh, dh0, dv2 = ( + torch.empty((B, BS, H, DK, DV), dtype=output_dtype), + torch.empty((B, H, DK, DV), dtype=state_dtype), + torch.empty((B, S, H, DV), dtype=output_dtype), + ) dh_tmp = torch.empty((B, H, DK, DV), dtype=accum_dtype) dv_tmp = torch.empty((B, S, H, DV), dtype=accum_dtype) Q_tmp = torch.empty((B, S, H, DK), dtype=accum_dtype) @@ -138,34 +143,30 @@ def torch_chunk_gated_delta_rule_bwd_dhu( for i_s in range(BS - 1, -1, -1): dh[:, i_s, :, :, :] = dh_tmp - dv_tmp = torch.matmul(K[:, i_s * block_S:(i_s + 1) * block_S, :, :].permute(0, 2, 1, 3), - dh_tmp.to(K.dtype)).permute(0, 2, 1, 3) + dv_tmp = torch.matmul(K[:, i_s * block_S : (i_s + 1) * block_S, :, :].permute(0, 2, 1, 3), dh_tmp.to(K.dtype)).permute(0, 2, 1, 3) if use_g: for i_bh in range(B * H): i_b, i_h = i_bh // H, i_bh % H for i_s2 in range(block_S): - if G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, - i_h] <= 0: - dv_tmp[i_b, i_s2, - i_h, :] *= torch.exp(G[i_b, i_s * block_S + block_S - 1, i_h] - - G[i_b, i_s * block_S + i_s2, i_h]) + if G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, i_h] <= 0: + dv_tmp[i_b, i_s2, i_h, :] *= torch.exp(G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, i_h]) else: dv_tmp[i_b, i_s2, i_h, :] = 0 - dv_tmp += dv[:, i_s * block_S:(i_s + 1) * block_S, :, :] - dv2[:, i_s * block_S:(i_s + 1) * block_S, :, :] = dv_tmp + dv_tmp += dv[:, i_s * block_S : (i_s + 1) * block_S, :, :] + dv2[:, i_s * block_S : (i_s + 1) * block_S, :, :] = dv_tmp if use_g: G_last = G[:, i_s * block_S + block_S - 1, :] for i_bh in range(B * H): i_b, i_h = i_bh // H, i_bh % H dh_tmp[i_b, i_h, :, :] *= torch.exp(G_last[i_b, i_h]) - Q_tmp = Q[:, i_s * block_S:(i_s + 1) * block_S, :, :] + Q_tmp = Q[:, i_s * block_S : (i_s + 1) * block_S, :, :] for i_s2 in range(block_S): for i_k in range(DK): Q_tmp[:, i_s2, :, i_k] *= torch.exp(G[:, i_s * block_S + i_s2, :]) Q_tmp *= scale - W_tmp = W[:, i_s * block_S:(i_s + 1) * block_S, :, :] - dO_tmp = dO[:, i_s * block_S:(i_s + 1) * block_S, :, :] + W_tmp = W[:, i_s * block_S : (i_s + 1) * block_S, :, :] + dO_tmp = dO[:, i_s * block_S : (i_s + 1) * block_S, :, :] torch.backends.cuda.matmul.allow_tf32 = True dh_tmp += torch.matmul(Q_tmp.permute(0, 2, 3, 1), dO_tmp.permute(0, 2, 1, 3)) @@ -223,19 +224,19 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( @T.prim_func def kernel( - # Input - Q: T.Tensor(Q_shape, dtype=input_dtype), - K: T.Tensor(K_shape, dtype=input_dtype), - W: T.Tensor(W_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - h0: T.Tensor(h0_shape, dtype=input_dtype), - dht: T.Tensor(dht_shape, dtype=input_dtype), - dO: T.Tensor(dO_shape, dtype=input_dtype), - dv: T.Tensor(dv_shape, dtype=input_dtype), - # Output - dh: T.Tensor(dh_shape, dtype=output_dtype), - dh0: T.Tensor(dh0_shape, dtype=state_dtype), - dv2: T.Tensor(dv2_shape, dtype=output_dtype), + # Input + Q: T.Tensor(Q_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + W: T.Tensor(W_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + h0: T.Tensor(h0_shape, dtype=input_dtype), + dht: T.Tensor(dht_shape, dtype=input_dtype), + dO: T.Tensor(dO_shape, dtype=input_dtype), + dv: T.Tensor(dv_shape, dtype=input_dtype), + # Output + dh: T.Tensor(dh_shape, dtype=output_dtype), + dh0: T.Tensor(dh0_shape, dtype=state_dtype), + dv2: T.Tensor(dv2_shape, dtype=output_dtype), ): with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh): bb, bh = bbh // H, bbh % H @@ -269,20 +270,22 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( T.use_swizzle(10) - T.annotate_layout({ - b_dh_shared: tilelang.layout.make_swizzled_layout(b_dh_shared), - b_dh_shared_fp32: tilelang.layout.make_swizzled_layout(b_dh_shared_fp32), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - dO_shared: tilelang.layout.make_swizzled_layout(dO_shared), - dO_shared_t: tilelang.layout.make_swizzled_layout(dO_shared_t), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), - Q_shared_fp32: tilelang.layout.make_swizzled_layout(Q_shared_fp32), - W_shared: tilelang.layout.make_swizzled_layout(W_shared), - }) + T.annotate_layout( + { + b_dh_shared: tilelang.layout.make_swizzled_layout(b_dh_shared), + b_dh_shared_fp32: tilelang.layout.make_swizzled_layout(b_dh_shared_fp32), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dO_shared: tilelang.layout.make_swizzled_layout(dO_shared), + dO_shared_t: tilelang.layout.make_swizzled_layout(dO_shared_t), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + Q_shared_fp32: tilelang.layout.make_swizzled_layout(Q_shared_fp32), + W_shared: tilelang.layout.make_swizzled_layout(W_shared), + } + ) if use_final_state_gradient: - T.copy(dht[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV], b_dh_shared) + T.copy(dht[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_dh_shared) T.copy(b_dh_shared, b_dh_fragment) else: T.clear(b_dh_fragment) @@ -293,17 +296,14 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( # Store the updated dh T.copy(b_dh_fragment, b_dh_shared) - T.copy(b_dh_shared, dh[bb, i_s_inv, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV]) + T.copy(b_dh_shared, dh[bb, i_s_inv, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) # Update dv - T.copy(K[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], K_shared) + T.copy(K[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], K_shared) T.gemm(K_shared, b_dh_shared, dv_fragment, clear_accum=True) if use_g: - T.copy( - G[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh], - G_shared, - disable_tma=True) + T.copy(G[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh], G_shared, disable_tma=True) T.copy(G_shared, G_fragment) G_last_local[0] = G_shared[block_S - 1] G_last_local_exp[0] = T.exp(G_last_local[0]) @@ -313,27 +313,22 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( # with T.If(G_last_local[0] - G_shared[i_s2] <= 0): with T.If(G_last_local[0] - G_fragment[i_s2] <= 0): with T.Then(): - dv_fragment[i_s2, - i_v] = dv_fragment[i_s2, i_v] * G_fragment_post[i_s2] + dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] * G_fragment_post[i_s2] with T.Else(): dv_fragment[i_s2, i_v] = 0 - T.copy( - dv[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, - bv * block_DV:(bv + 1) * block_DV], dv_shared) + T.copy(dv[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dv_shared) T.copy(dv_shared, dv_fragment_2) for i_s2, i_v in T.Parallel(block_S, block_DV): dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] + dv_fragment_2[i_s2, i_v] # Store the updated dv T.copy(dv_fragment, dv_shared) - T.copy( - dv_shared, dv2[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, - bv * block_DV:(bv + 1) * block_DV]) + T.copy(dv_shared, dv2[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV]) # Update dh - T.copy(Q[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], Q_shared) - T.copy(W[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], W_shared) + T.copy(Q[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], Q_shared) + T.copy(W[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], W_shared) T.clear(Q_fragment) if use_g: @@ -353,9 +348,7 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( for i_s2, i_k in T.Parallel(block_S, DK): Q_fragment_t[i_k, i_s2] = Q_fragment[i_s2, i_k] - T.copy( - dO[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, - bv * block_DV:(bv + 1) * block_DV], dO_shared) + T.copy(dO[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dO_shared) T.copy(dO_shared, dO_fragment) for i_s2, i_v in T.Parallel(block_S, block_DV): dO_fragment_t[i_v, i_s2] = dO_fragment[i_s2, i_v] @@ -369,7 +362,7 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( b_dh_fragment[i_k, i_v] += b_dh_fragment_1[i_k, i_v] - b_dh_fragment_2[i_k, i_v] if use_initial_state: - T.copy(b_dh_fragment, dh0[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV]) + T.copy(b_dh_fragment, dh0[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) return kernel @@ -444,44 +437,61 @@ def run_test( num_stages=0, use_torch=False, ): - Q, K, W, G, h0, dht, dO, dv = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, accum_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) - dh_ref, dh0_ref, dv2_ref = prepare_output(B, S, H, DK, DV, chunk_size, - getattr(torch, output_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) - dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, - getattr(torch, output_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) + Q, K, W, G, h0, dht, dO, dv = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + dh_ref, dh0_ref, dv2_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) + dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) # fla ref print("fla running...", flush=True) if use_g: - dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, - scale) + dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale) else: G = G.fill_(0) - dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, - scale) + dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale) # tilelang print("tilelang running...", flush=True) - kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(B, S, H, DK, DV, input_dtype, output_dtype, - accum_dtype, gate_dtype, state_dtype, - chunk_size, scale, use_g, use_initial_state, - use_final_state_gradient, block_DV, threads, - num_stages) + kernel = tilelang_chunk_gated_delta_rule_bwd_dhu( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_g, + use_initial_state, + use_final_state_gradient, + block_DV, + threads, + num_stages, + ) # kernel = tilelang.compile(program) print(kernel.get_kernel_source()) dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) - fla_time = do_bench( - chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size) + fla_time = do_bench(chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size) tilelang_time = do_bench(kernel, Q, K, W, G, h0, dht, dO, dv) print(f"fla time: {fla_time} ms") @@ -496,19 +506,47 @@ def run_test( print("torch running...", flush=True) if use_g: dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu( - Q, K, W, G, h0, dht, dO, dv, scale, use_g, use_initial_state, - use_final_state_gradient, getattr(torch, input_dtype), getattr(torch, output_dtype), - getattr(torch, accum_dtype), getattr(torch, - gate_dtype), getattr(torch, state_dtype)) + Q, + K, + W, + G, + h0, + dht, + dO, + dv, + scale, + use_g, + use_initial_state, + use_final_state_gradient, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) dh_ref_torch = dh_ref_torch.cuda() dh0_ref_torch = dh0_ref_torch.cuda() dv2_ref_torch = dv2_ref_torch.cuda() else: dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu( - Q, K, W, None, h0, dht, dO, dv, scale, use_g, use_initial_state, - use_final_state_gradient, getattr(torch, input_dtype), getattr(torch, output_dtype), - getattr(torch, accum_dtype), getattr(torch, - gate_dtype), getattr(torch, state_dtype)) + Q, + K, + W, + None, + h0, + dht, + dO, + dv, + scale, + use_g, + use_initial_state, + use_final_state_gradient, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) dh_ref_torch = dh_ref_torch.cuda() dh0_ref_torch = dh0_ref_torch.cuda() dv2_ref_torch = dv2_ref_torch.cuda() diff --git a/examples/gdn/example_chunk_delta_h.py b/examples/gdn/example_chunk_delta_h.py index cc384aded9afaf3ced5dba636b42023368814b9f..43f1e972b6a4ecfae665acc6b5aafb7aae7889c4 100644 --- a/examples/gdn/example_chunk_delta_h.py +++ b/examples/gdn/example_chunk_delta_h.py @@ -10,6 +10,7 @@ from tilelang.autotuner import autotune # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h except ImportError: @@ -56,6 +57,7 @@ def prepare_input( G = F.logsigmoid(G) try: from fla.ops.utils.cumsum import chunk_local_cumsum + G = chunk_local_cumsum(G, chunk_size) except ImportError: print("fla not found, skip cumsum") @@ -83,18 +85,14 @@ def prepare_output( def get_configs(): import itertools + block_DK = [32, 64, 128] block_DV = [32, 64, 128] threads = [128, 256] num_stages = [1, 2, 3] _configs = list(itertools.product(block_DK, block_DV, threads, num_stages)) - configs = [{ - 'block_DK': c[0], - 'block_DV': c[1], - 'threads': c[2], - 'num_stages': c[3] - } for c in _configs] + configs = [{"block_DK": c[0], "block_DV": c[1], "threads": c[2], "num_stages": c[3]} for c in _configs] return configs @@ -137,14 +135,14 @@ def tilelang_chunk_gated_delta_rule_fwd_h( @T.prim_func def kernel( - K: T.Tensor(K_shape, dtype=input_dtype), - W: T.Tensor(W_shape, dtype=input_dtype), - U: T.Tensor(U_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - initial_state: T.Tensor(initial_state_shape, dtype=input_dtype), - h: T.Tensor(h_shape, dtype=output_dtype), - final_state: T.Tensor(final_state_shape, dtype=state_dtype), - V_new: T.Tensor(V_shape, dtype=output_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + W: T.Tensor(W_shape, dtype=input_dtype), + U: T.Tensor(U_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + initial_state: T.Tensor(initial_state_shape, dtype=input_dtype), + h: T.Tensor(h_shape, dtype=output_dtype), + final_state: T.Tensor(final_state_shape, dtype=state_dtype), + V_new: T.Tensor(V_shape, dtype=output_dtype), ): with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh): bb, bh = bbh // H, bbh % H @@ -162,35 +160,35 @@ def tilelang_chunk_gated_delta_rule_fwd_h( G_shared = T.alloc_shared((block_S, block_DV), dtype=gate_dtype) G_fragment = T.alloc_fragment((block_S, block_DV), dtype=gate_dtype) - T.annotate_layout({ - b_h_shared: tilelang.layout.make_swizzled_layout(b_h_shared), - U_shared: tilelang.layout.make_swizzled_layout(U_shared), - W_shared: tilelang.layout.make_swizzled_layout(W_shared), - V_new_shared: tilelang.layout.make_swizzled_layout(V_new_shared), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - G_shared: tilelang.layout.make_swizzled_layout(G_shared), - }) + T.annotate_layout( + { + b_h_shared: tilelang.layout.make_swizzled_layout(b_h_shared), + U_shared: tilelang.layout.make_swizzled_layout(U_shared), + W_shared: tilelang.layout.make_swizzled_layout(W_shared), + V_new_shared: tilelang.layout.make_swizzled_layout(V_new_shared), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + G_shared: tilelang.layout.make_swizzled_layout(G_shared), + } + ) T.use_swizzle(10) if use_initial_state: - T.copy(initial_state[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV], b_h_shared) + T.copy(initial_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_h_shared) T.copy(b_h_shared, b_h_fragment) else: T.clear(b_h_fragment) for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages): # Store previous result to the hidden tensor, like the epilogue - T.copy(b_h_shared, h[bb, i_s, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV]) + T.copy(b_h_shared, h[bb, i_s, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) # Recurrence - T.copy(W[bb, i_s * block_S:(i_s + 1) * block_S, bh, 0:DK], W_shared) + T.copy(W[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], W_shared) T.gemm(W_shared, b_h_shared, V_new_fragment, clear_accum=True) # U - W * S - T.copy( - U[bb, i_s * block_S:(i_s + 1) * block_S, bh, bv * block_DV:(bv + 1) * block_DV], - U_shared) + T.copy(U[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], U_shared) T.copy(U_shared, U_fragment) for i_s2, i_v in T.Parallel(block_S, block_DV): V_new_fragment[i_s2, i_v] = -V_new_fragment[i_s2, i_v] + U_fragment[i_s2, i_v] @@ -198,11 +196,9 @@ def tilelang_chunk_gated_delta_rule_fwd_h( # Save V_new if save_new_value: T.copy(V_new_fragment, dst=V_new_shared) - T.copy( - V_new_shared, V_new[bb, i_s * block_S:(i_s + 1) * block_S, bh, - bv * block_DV:(bv + 1) * block_DV]) + T.copy(V_new_shared, V_new[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV]) - T.copy(K[bb, i_s * block_S:(i_s + 1) * block_S, bh, 0:DK], K_shared) + T.copy(K[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], K_shared) # use_g if use_g: G_last_local[0] = G[bb, (i_s + 1) * block_S - 1, bh] @@ -213,7 +209,8 @@ def tilelang_chunk_gated_delta_rule_fwd_h( with T.If(G_last_local[0] - G_fragment[i_s2, i_v] <= 0): with T.Then(): V_new_fragment[i_s2, i_v] = V_new_fragment[i_s2, i_v] * T.exp2( - (G_last_local[0] - G_fragment[i_s2, i_v]) * 1.442695) + (G_last_local[0] - G_fragment[i_s2, i_v]) * 1.442695 + ) with T.Else(): V_new_fragment[i_s2, i_v] = 0 G_last_local[0] = T.exp2(G_last_local[0] * 1.442695) @@ -228,7 +225,7 @@ def tilelang_chunk_gated_delta_rule_fwd_h( # Save final state if store_final_state: - T.copy(b_h_fragment, final_state[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV]) + T.copy(b_h_fragment, final_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) return kernel @@ -279,17 +276,24 @@ def run_test( threads=128, num_stages=0, ): - K, W, U, G, initial_state = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, accum_dtype), - getattr(torch, gate_dtype)) - h_ref, final_state_ref, V_new_ref = prepare_output(B, S, H, DK, DV, chunk_size, - getattr(torch, output_dtype), - getattr(torch, state_dtype)) - h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, - getattr(torch, output_dtype), - getattr(torch, state_dtype)) + K, W, U, G, initial_state = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + ) + h_ref, final_state_ref, V_new_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype) + ) + h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype) + ) # fla ref h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h( @@ -300,13 +304,27 @@ def run_test( initial_state=initial_state, output_final_state=store_final_state, chunk_size=chunk_size, - save_new_value=save_new_value) + save_new_value=save_new_value, + ) # tilelang - kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype, - accum_dtype, gate_dtype, state_dtype, chunk_size, - use_g, use_initial_state, store_final_state, - save_new_value) + kernel = tilelang_chunk_gated_delta_rule_fwd_h( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + use_g, + use_initial_state, + store_final_state, + save_new_value, + ) h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state) # (zhengju) If you want to print the generated cuda code, you can uncomment the following line # print("CUDA Code:\n", kernel.get_kernel_source()) @@ -320,19 +338,15 @@ def run_test( initial_state=initial_state, output_final_state=store_final_state, chunk_size=chunk_size, - save_new_value=save_new_value) + save_new_value=save_new_value, + ) tilelang_time = do_bench(kernel, K, W, U, G, initial_state) # check correctness try: h_ref_fp32 = h_ref.to(torch.float32) h_tilelang_fp32 = h_tilelang.to(torch.float32) - assert_similar( - h_ref_fp32, - h_tilelang_fp32, - eps=1e-5, - name="tilelang chunk gated delta rule fwd h", - raise_assert=False) + assert_similar(h_ref_fp32, h_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd h", raise_assert=False) print("tilelang chunk gated delta rule fwd h passed √") except Exception as e: print("tilelang chunk gated delta rule fwd h failed ✗") @@ -346,7 +360,8 @@ def run_test( final_state_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd final_state", - raise_assert=False) + raise_assert=False, + ) print("tilelang chunk gated delta rule fwd final_state passed √") except Exception as e: print("tilelang chunk gated delta rule fwd final_state failed ✗") @@ -355,12 +370,7 @@ def run_test( try: V_new_ref_fp32 = V_new_ref.to(torch.float32) V_new_tilelang_fp32 = V_new_tilelang.to(torch.float32) - assert_similar( - V_new_ref_fp32, - V_new_tilelang_fp32, - eps=1e-5, - name="tilelang chunk gated delta rule fwd V_new", - raise_assert=False) + assert_similar(V_new_ref_fp32, V_new_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd V_new", raise_assert=False) print("tilelang chunk gated delta rule fwd V_new passed √") except Exception as e: print("tilelang chunk gated delta rule fwd V_new failed ✗") diff --git a/examples/gdn/example_chunk_o.py b/examples/gdn/example_chunk_o.py index 1c084be70549aff839d87718d3e1a3bab688a76c..bd1e9aa2394d7030df23c99150278e78a44f9da8 100644 --- a/examples/gdn/example_chunk_o.py +++ b/examples/gdn/example_chunk_o.py @@ -9,6 +9,7 @@ import sys # noqa: F401 # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.common.chunk_o import chunk_fwd_o except ImportError: @@ -87,16 +88,14 @@ def tilelang_chunk_fwd_o( @T.prim_func def kernel( - Q: T.Tensor(Q_shape, dtype=input_dtype), - K: T.Tensor(K_shape, dtype=input_dtype), - V: T.Tensor(V_shape, dtype=input_dtype), - HIDDEN: T.Tensor(H_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - O: T.Tensor(O_shape, dtype=output_dtype), + Q: T.Tensor(Q_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + HIDDEN: T.Tensor(H_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + O: T.Tensor(O_shape, dtype=output_dtype), ): - with T.Kernel( - T.ceildiv(DV, block_DV), T.ceildiv(S, block_S), B * H, - threads=threads) as (bv, bs, bbh): + with T.Kernel(T.ceildiv(DV, block_DV), T.ceildiv(S, block_S), B * H, threads=threads) as (bv, bs, bbh): bb, bh = bbh // H, bbh % H Q_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) @@ -109,28 +108,24 @@ def tilelang_chunk_fwd_o( G_shared = T.alloc_shared((block_S,), dtype=gate_dtype, scope="shared") G_diff_local = T.alloc_fragment((block_S, block_S), dtype=gate_dtype) - T.annotate_layout({ - Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - V_shared: tilelang.layout.make_swizzled_layout(V_shared), - H_shared: tilelang.layout.make_swizzled_layout(H_shared), - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - }) + T.annotate_layout( + { + Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + V_shared: tilelang.layout.make_swizzled_layout(V_shared), + H_shared: tilelang.layout.make_swizzled_layout(H_shared), + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + } + ) T.clear(A_fragment) T.clear(O_fragment) T.disable_warp_group_reg_alloc() for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): - T.copy( - Q[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], - Q_shared) - T.copy( - K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], - K_shared) - T.copy( - HIDDEN[bb, bs, bh, i_k * block_DK:(i_k + 1) * block_DK, - bv * block_DV:(bv + 1) * block_DV], H_shared) + T.copy(Q[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], Q_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) + T.copy(HIDDEN[bb, bs, bh, i_k * block_DK : (i_k + 1) * block_DK, bv * block_DV : (bv + 1) * block_DV], H_shared) T.gemm(Q_shared, H_shared, O_fragment) T.gemm(Q_shared, K_shared, A_fragment, transpose_B=True) @@ -145,8 +140,7 @@ def tilelang_chunk_fwd_o( for i_s1, i_s2 in T.Parallel(block_S, block_S): with T.If(G_diff_local[i_s1, i_s2] <= 0): with T.Then(): - A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp( - G_diff_local[i_s1, i_s2]) + A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2]) with T.Else(): A_fragment[i_s1, i_s2] = 0 @@ -155,8 +149,7 @@ def tilelang_chunk_fwd_o( with T.Then(): A_fragment[i_s1, i_s2] = 0 - T.copy(V[bb, bs * block_S:(bs + 1) * block_S, bh, bv * block_DV:(bv + 1) * block_DV], - V_shared) + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], V_shared) T.copy(A_fragment, A_shared) T.gemm(A_shared, V_shared, O_fragment) @@ -164,8 +157,7 @@ def tilelang_chunk_fwd_o( O_fragment[i_s, i_v] = O_fragment[i_s, i_v] * scale T.copy(O_fragment, O_shared) - T.copy(O_shared, O[bb, bs * block_S:(bs + 1) * block_S, bh, - bv * block_DV:(bv + 1) * block_DV]) + T.copy(O_shared, O[bb, bs * block_S : (bs + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV]) return kernel @@ -191,8 +183,9 @@ def run_test( output_dtype_torch = getattr(torch, output_dtype) accum_dtype_torch = getattr(torch, accum_dtype) gate_dtype_torch = getattr(torch, gate_dtype) - Q, K, V, HIDDEN, G = prepare_input(B, S, H, DK, DV, chunk_size, input_dtype_torch, - output_dtype_torch, accum_dtype_torch, gate_dtype_torch) + Q, K, V, HIDDEN, G = prepare_input( + B, S, H, DK, DV, chunk_size, input_dtype_torch, output_dtype_torch, accum_dtype_torch, gate_dtype_torch + ) scale = 1.0 / DK**0.5 O_ref = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch) @@ -200,9 +193,25 @@ def run_test( block_S = chunk_size O_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch) - kernel = tilelang_chunk_fwd_o(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, chunk_size, scale, use_g, block_S, block_DK, block_DV, - threads, num_stages) + kernel = tilelang_chunk_fwd_o( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + chunk_size, + scale, + use_g, + block_S, + block_DK, + block_DV, + threads, + num_stages, + ) O_tilelang = kernel(Q, K, V, HIDDEN, G) try: diff --git a/examples/gdn/example_chunk_o_bwd.py b/examples/gdn/example_chunk_o_bwd.py index 20aa8414df4047326ec58ead3bf67dad148a9b92..66cb6942e8a20ac38aabbe8723230e1b1328c532 100644 --- a/examples/gdn/example_chunk_o_bwd.py +++ b/examples/gdn/example_chunk_o_bwd.py @@ -12,6 +12,7 @@ from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F4 # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.common.chunk_o import chunk_bwd_dqkwg except ImportError: @@ -108,10 +109,8 @@ def prepare_output( @tilelang.jit( out_idx=[-4, -3, -2, -1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True - }) + pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, +) def tilelang_chunk_o_bwd_dqkwg( # task config B, @@ -155,25 +154,23 @@ def tilelang_chunk_o_bwd_dqkwg( @T.prim_func def kernel( - # input - Q: T.Tensor(Q_shape, dtype=input_dtype), - K: T.Tensor(K_shape, dtype=input_dtype), - V: T.Tensor(V_shape, dtype=input_dtype), - h: T.Tensor(h_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - dO: T.Tensor(dO_shape, dtype=input_dtype), - dh: T.Tensor(dh_shape, dtype=input_dtype), - dv: T.Tensor(dv_shape, dtype=input_dtype), - W: T.Tensor(W_shape, dtype=input_dtype), - # output - dq: T.Tensor(dq_shape, dtype=output_dtype), - dk: T.Tensor(dk_shape, dtype=output_dtype), - dw: T.Tensor(dw_shape, dtype=output_dtype), - dg: T.Tensor(dg_shape, dtype=gate_dtype), + # input + Q: T.Tensor(Q_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + h: T.Tensor(h_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + dO: T.Tensor(dO_shape, dtype=input_dtype), + dh: T.Tensor(dh_shape, dtype=input_dtype), + dv: T.Tensor(dv_shape, dtype=input_dtype), + W: T.Tensor(W_shape, dtype=input_dtype), + # output + dq: T.Tensor(dq_shape, dtype=output_dtype), + dk: T.Tensor(dk_shape, dtype=output_dtype), + dw: T.Tensor(dw_shape, dtype=output_dtype), + dg: T.Tensor(dg_shape, dtype=gate_dtype), ): - with T.Kernel( - T.ceildiv(DK, block_DK), T.ceildiv(S, block_S), B * H, - threads=threads) as (bk, bs, bbh): + with T.Kernel(T.ceildiv(DK, block_DK), T.ceildiv(S, block_S), B * H, threads=threads) as (bk, bs, bbh): bb, bh = bbh // H, bbh % H V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) @@ -212,15 +209,17 @@ def tilelang_chunk_o_bwd_dqkwg( T.use_swizzle(10) - T.annotate_layout({ - V_shared: tilelang.layout.make_swizzled_layout(V_shared), - dO_shared: tilelang.layout.make_swizzled_layout(dO_shared), - h_shared: tilelang.layout.make_swizzled_layout(h_shared), - dh_shared: tilelang.layout.make_swizzled_layout(dh_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), - q_shared: tilelang.layout.make_swizzled_layout(q_shared), - k_shared: tilelang.layout.make_swizzled_layout(k_shared), - }) + T.annotate_layout( + { + V_shared: tilelang.layout.make_swizzled_layout(V_shared), + dO_shared: tilelang.layout.make_swizzled_layout(dO_shared), + h_shared: tilelang.layout.make_swizzled_layout(h_shared), + dh_shared: tilelang.layout.make_swizzled_layout(dh_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + q_shared: tilelang.layout.make_swizzled_layout(q_shared), + k_shared: tilelang.layout.make_swizzled_layout(k_shared), + } + ) T.clear(dg_last_local) T.clear(G_last_local) @@ -235,18 +234,10 @@ def tilelang_chunk_o_bwd_dqkwg( T.clear(dw_fragment) for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): - T.copy( - V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV], - V_shared) - T.copy( - dO[bb, bs * block_S:(bs + 1) * block_S, bh, - i_v * block_DV:(i_v + 1) * block_DV], dO_shared) - T.copy( - h[bb, bs, bh, bk * block_DK:(bk + 1) * block_DK, - i_v * block_DV:(i_v + 1) * block_DV], h_shared) - T.copy( - dh[bb, bs, bh, bk * block_DK:(bk + 1) * block_DK, - i_v * block_DV:(i_v + 1) * block_DV], dh_shared) + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared) + T.copy(dO[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], dO_shared) + T.copy(h[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], h_shared) + T.copy(dh[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], dh_shared) if use_g: T.clear(dg_last_fragment_scalar) @@ -254,9 +245,7 @@ def tilelang_chunk_o_bwd_dqkwg( # for i_kv in T.Parallel(block_DK * block_DV): # dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV] for i_kv in T.Parallel(block_DK * block_DV): - dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % - block_DV] * dh_shared[i_kv // block_DV, - i_kv % block_DV] + dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV] T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False) dg_last_local[0] += dg_last_fragment_scalar[0] @@ -265,22 +254,16 @@ def tilelang_chunk_o_bwd_dqkwg( T.gemm(V_shared, dh_shared, dk_fragment, transpose_B=True) if use_dw: - T.copy( - dv[bb, bs * block_S:(bs + 1) * block_S, bh, - i_v * block_DV:(i_v + 1) * block_DV], dv_shared) + T.copy(dv[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], dv_shared) T.gemm(dv_shared, h_shared, dw_fragment, transpose_B=True) if use_dw: for i_s, i_k in T.Parallel(block_S, block_DK): dw_fragment[i_s, i_k] = -dw_fragment[i_s, i_k] - T.copy( - dw_fragment, dw[bb, bs * block_S:(bs + 1) * block_S, bh, - bk * block_DK:(bk + 1) * block_DK]) - - T.copy(Q[bb, bs * block_S:(bs + 1) * block_S, bh, bk * block_DK:(bk + 1) * block_DK], - q_shared) - T.copy(K[bb, bs * block_S:(bs + 1) * block_S, bh, bk * block_DK:(bk + 1) * block_DK], - k_shared) + T.copy(dw_fragment, dw[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + + T.copy(Q[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], q_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], k_shared) T.copy(q_shared, q_fragment) T.copy(k_shared, k_fragment) @@ -294,8 +277,7 @@ def tilelang_chunk_o_bwd_dqkwg( dg_last_local[0] = dg_last_local[0] * T.exp(G[bb, bs * block_S + block_S - 1, bh]) for i_s, i_k in T.Parallel(block_S, block_DK): - dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * T.exp(G[bb, bs * block_S + i_s, - bh]) * scale + dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * T.exp(G[bb, bs * block_S + i_s, bh]) * scale T.clear(dg_fragment_reduce_tmp) for i_s, i_k in T.Parallel(block_S, block_DK): dg_fragment_reduce_tmp[i_s, i_k] = dq_fragment[i_s, i_k] * q_shared[i_s, i_k] @@ -305,8 +287,7 @@ def tilelang_chunk_o_bwd_dqkwg( for i_s, i_k in T.Parallel(block_S, block_DK): with T.If(G_last_local[0] - G[bb, bs * block_S + i_s, bh] <= 0): with T.Then(): - dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] * T.exp( - G_last_local[0] - G[bb, bs * block_S + i_s, bh]) + dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] * T.exp(G_last_local[0] - G[bb, bs * block_S + i_s, bh]) with T.Else(): dk_fragment[i_s, i_k] = 0 T.clear(dg_fragment_reduce_tmp) @@ -325,12 +306,11 @@ def tilelang_chunk_o_bwd_dqkwg( dg_last_local[1] = dg_last_fragment_scalar_2[0] for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(i_s1 >= i_s2 and - G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0): + with T.If(i_s1 >= i_s2 and G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0): with T.Then(): - ds_fragment[i_s1, i_s2] = ds_fragment[ - i_s1, i_s2] * T.exp(G[bb, bs * block_S + i_s1, bh] - - G[bb, bs * block_S + i_s2, bh]) * scale + ds_fragment[i_s1, i_s2] = ( + ds_fragment[i_s1, i_s2] * T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh]) * scale + ) with T.Else(): ds_fragment[i_s1, i_s2] = 0 @@ -338,8 +318,7 @@ def tilelang_chunk_o_bwd_dqkwg( T.clear(ds_fragment_positive_transpose) T.gemm(q_shared, k_shared, ds_fragment_positive, transpose_B=True) for i_s1, i_s2 in T.Parallel(block_S, block_S): - ds_fragment_positive[ - i_s1, i_s2] = ds_fragment[i_s1, i_s2] * ds_fragment_positive[i_s1, i_s2] + ds_fragment_positive[i_s1, i_s2] = ds_fragment[i_s1, i_s2] * ds_fragment_positive[i_s1, i_s2] # FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass T.reduce_sum(ds_fragment_positive, dg_fragment, dim=1, clear=False) @@ -363,15 +342,10 @@ def tilelang_chunk_o_bwd_dqkwg( for i_s in T.Parallel(block_S): with T.If(i_s >= block_S - 1): # noqa: SIM117 with T.Then(): - dg_fragment_final[ - i_s] = dg_fragment_final[i_s] + dg_last_local[0] + dg_last_local[1] - - T.copy( - dq_fragment, dq[bb, bs * block_S:(bs + 1) * block_S, bh, - bk * block_DK:(bk + 1) * block_DK]) - T.copy( - dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh, - bk * block_DK:(bk + 1) * block_DK]) + dg_fragment_final[i_s] = dg_fragment_final[i_s] + dg_last_local[0] + dg_last_local[1] + + T.copy(dq_fragment, dq[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) for i_s in T.Parallel(block_S): dg[bk, bb, bs * block_S + i_s, bh] = dg_fragment_final[i_s] @@ -387,12 +361,8 @@ def tilelang_chunk_o_bwd_dqkwg( for i_s, i_k in T.Parallel(block_S, block_DK): dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * scale dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] + dk_fragment_2[i_s, i_k] * scale - T.copy( - dq_fragment, dq[bb, bs * block_S:(bs + 1) * block_S, bh, - bk * block_DK:(bk + 1) * block_DK]) - T.copy( - dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh, - bk * block_DK:(bk + 1) * block_DK]) + T.copy(dq_fragment, dq[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) return kernel @@ -442,32 +412,53 @@ def run_test( threads=256, num_stages=0, ): - Q, K, V, h, G, dO, dh, dv, W = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, accum_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) - dq_ref, dk_ref, dw_ref, dg_ref = prepare_output(B, S, H, DK, DV, chunk_size, - getattr(torch, output_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype), block_DK) + Q, K, V, h, G, dO, dh, dv, W = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + dq_ref, dk_ref, dw_ref, dg_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype), block_DK + ) dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = prepare_output( - B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), - getattr(torch, state_dtype), block_DK) + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype), block_DK + ) # ref if use_g: - dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg( - Q, K, V, G, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale) + dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(Q, K, V, G, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale) else: - dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg( - Q, K, V, None, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale) + dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(Q, K, V, None, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale) # tilelang - kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, state_dtype, chunk_size, scale, use_g, use_dw, - block_DK, block_DV, threads, num_stages) + kernel = tilelang_chunk_o_bwd_dqkwg( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_g, + use_dw, + block_DK, + block_DV, + threads, + num_stages, + ) dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W) if use_g: diff --git a/examples/gdn/example_chunk_scaled_dot_kkt.py b/examples/gdn/example_chunk_scaled_dot_kkt.py index d07a4776a23f4808547f99ce4b6616cf44b93bc5..af2b08e57d2dee0f06558d5a72d645f4f99be633 100644 --- a/examples/gdn/example_chunk_scaled_dot_kkt.py +++ b/examples/gdn/example_chunk_scaled_dot_kkt.py @@ -9,6 +9,7 @@ import sys # noqa: F401 # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd except ImportError: @@ -75,10 +76,10 @@ def tilelang_chunk_scaled_dot_kkt_fwd( @T.prim_func def kernel( - K: T.Tensor(K_shape, dtype=input_dtype), - Beta: T.Tensor(Beta_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=accum_dtype), - A: T.Tensor(output_shape, dtype=output_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=accum_dtype), + A: T.Tensor(output_shape, dtype=output_dtype), ): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): bb, bh = bbh // H, bbh % H @@ -93,10 +94,12 @@ def tilelang_chunk_scaled_dot_kkt_fwd( G_shared = T.alloc_shared((block_S,), dtype=accum_dtype, scope="shared") G_diff_local = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) - T.annotate_layout({ - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - }) + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + } + ) T.fill(A_fragment, 0) T.disable_warp_group_reg_alloc() @@ -104,9 +107,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd( Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): - T.copy( - K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], - K_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) for i_s, i_k2 in T.Parallel(block_S, block_DK): Beta_K_fragment[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] T.gemm(Beta_K_fragment, K_shared, A_fragment, transpose_B=True) @@ -119,8 +120,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd( for i_s1, i_s2 in T.Parallel(block_S, block_S): with T.If(G_diff_local[i_s1, i_s2] <= 0 and i_s1 > i_s2): with T.Then(): - A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp( - G_diff_local[i_s1, i_s2]) + A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2]) with T.Else(): A_fragment[i_s1, i_s2] = 0 else: @@ -130,7 +130,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd( A_fragment[i_s1, i_s2] = 0 T.copy(A_fragment, A_shared) - T.copy(A_shared, A[bb, bs * block_S:(bs + 1) * block_S, bh, :]) + T.copy(A_shared, A[bb, bs * block_S : (bs + 1) * block_S, bh, :]) return kernel @@ -149,24 +149,21 @@ def run_test( threads, num_stages, ): - K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), - getattr(torch, output_dtype), getattr(torch, accum_dtype)) + K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), getattr(torch, output_dtype), getattr(torch, accum_dtype)) A_ref = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype)) A_tilelang = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype)) # reference if use_g: - A_ref = chunk_scaled_dot_kkt_fwd( - K, Beta, G, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype)) + A_ref = chunk_scaled_dot_kkt_fwd(K, Beta, G, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype)) else: - A_ref = chunk_scaled_dot_kkt_fwd( - K, Beta, None, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype)) + A_ref = chunk_scaled_dot_kkt_fwd(K, Beta, None, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype)) # tilelang block_S = chunk_size - kernel = tilelang_chunk_scaled_dot_kkt_fwd(B, S, H, DK, chunk_size, input_dtype, output_dtype, - accum_dtype, use_g, block_S, block_DK, threads, - num_stages) + kernel = tilelang_chunk_scaled_dot_kkt_fwd( + B, S, H, DK, chunk_size, input_dtype, output_dtype, accum_dtype, use_g, block_S, block_DK, threads, num_stages + ) A_tilelang = kernel(K, Beta, G) try: @@ -192,7 +189,8 @@ def main(): use_g=True, block_DK=64, threads=128, - num_stages=2) + num_stages=2, + ) if __name__ == "__main__": diff --git a/examples/gdn/example_cumsum.py b/examples/gdn/example_cumsum.py index 9896c7ecf7689bf379e371298f06822502eb34d0..13547cd60a59781706335c7adb0ee0c5d05ccc1c 100644 --- a/examples/gdn/example_cumsum.py +++ b/examples/gdn/example_cumsum.py @@ -10,6 +10,7 @@ import sys # noqa: F401 # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.utils.cumsum import chunk_local_cumsum_scalar except ImportError: @@ -20,11 +21,8 @@ import torch @tilelang.jit( - out_idx=[-1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True - }) + out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} +) def tilelang_chunk_local_cumsum_scalar( # task config B, @@ -42,35 +40,35 @@ def tilelang_chunk_local_cumsum_scalar( use_fragment=False, ): G_shape = (B, H, S) if head_first else (B, S, H) - assert chunk_size == 2**(chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" assert chunk_size == block_S, "chunk_size must be equal to block_S" @T.prim_func def kernel( - G: T.Tensor(G_shape, dtype=input_dtype), - G_new: T.Tensor(G_shape, dtype=output_dtype), + G: T.Tensor(G_shape, dtype=input_dtype), + G_new: T.Tensor(G_shape, dtype=output_dtype), ): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): bb, bh = bbh // H, bbh % H G_shared = T.alloc_shared((1, block_S), dtype=output_dtype, scope="shared") if head_first: - T.copy(G[bb, bh, bs * block_S:(bs + 1) * block_S], G_shared) + T.copy(G[bb, bh, bs * block_S : (bs + 1) * block_S], G_shared) else: - T.copy(G[bb, bs * block_S:(bs + 1) * block_S, bh], G_shared) + T.copy(G[bb, bs * block_S : (bs + 1) * block_S, bh], G_shared) if use_fragment: G_fragment = T.alloc_fragment((1, block_S), dtype=output_dtype, scope="shared") T.copy(G_shared, G_fragment) T.cumsum(G_fragment, dim=1, reverse=reverse) if head_first: - T.copy(G_fragment, G_new[bb, bh, bs * block_S:(bs + 1) * block_S]) + T.copy(G_fragment, G_new[bb, bh, bs * block_S : (bs + 1) * block_S]) else: - T.copy(G_fragment, G_new[bb, bs * block_S:(bs + 1) * block_S, bh]) + T.copy(G_fragment, G_new[bb, bs * block_S : (bs + 1) * block_S, bh]) else: T.cumsum(G_shared, dim=1, reverse=reverse) if head_first: - T.copy(G_shared, G_new[bb, bh, bs * block_S:(bs + 1) * block_S]) + T.copy(G_shared, G_new[bb, bh, bs * block_S : (bs + 1) * block_S]) else: - T.copy(G_shared, G_new[bb, bs * block_S:(bs + 1) * block_S, bh]) + T.copy(G_shared, G_new[bb, bs * block_S : (bs + 1) * block_S, bh]) return kernel @@ -113,11 +111,8 @@ def run_test( # reference cumsum G_new_ref = chunk_local_cumsum_scalar( - g=G, - chunk_size=chunk_size, - reverse=reverse, - head_first=head_first, - output_dtype=getattr(torch, output_dtype)) + g=G, chunk_size=chunk_size, reverse=reverse, head_first=head_first, output_dtype=getattr(torch, output_dtype) + ) # tilelang cumsum block_S = chunk_size @@ -162,7 +157,8 @@ def main(): input_dtype="float32", output_dtype="float32", threads=256, - use_fragment=False) + use_fragment=False, + ) if __name__ == "__main__": diff --git a/examples/gdn/example_wy_fast.py b/examples/gdn/example_wy_fast.py index 0a0983a82f70f10994e1a150b484e271e5c7915a..874e25c3b053fb5d09229076e666fbc467f9d936 100644 --- a/examples/gdn/example_wy_fast.py +++ b/examples/gdn/example_wy_fast.py @@ -9,6 +9,7 @@ import sys # noqa: F401 # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd except ImportError: @@ -73,13 +74,13 @@ def tilelang_recompute_w_u_fwd( @T.prim_func def kernel( - K: T.Tensor(K_shape, dtype=input_dtype), - V: T.Tensor(V_shape, dtype=input_dtype), - Beta: T.Tensor(Beta_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - A: T.Tensor(A_shape, dtype=output_dtype), - W: T.Tensor(K_shape, dtype=output_dtype), - U: T.Tensor(V_shape, dtype=output_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + A: T.Tensor(A_shape, dtype=output_dtype), + W: T.Tensor(K_shape, dtype=output_dtype), + U: T.Tensor(V_shape, dtype=output_dtype), ): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): bb, bh = bbh // H, bbh % H @@ -95,49 +96,42 @@ def tilelang_recompute_w_u_fwd( W_Beta_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) U_Beta_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) - T.annotate_layout({ - K_shared: tilelang.layout.make_swizzled_layout(K_shared), - V_shared: tilelang.layout.make_swizzled_layout(V_shared), - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - W_shared: tilelang.layout.make_swizzled_layout(W_shared), - U_shared: tilelang.layout.make_swizzled_layout(U_shared), - W_Beta_shared: tilelang.layout.make_swizzled_layout(W_Beta_shared), - U_Beta_shared: tilelang.layout.make_swizzled_layout(U_Beta_shared), - }) + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + V_shared: tilelang.layout.make_swizzled_layout(V_shared), + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + W_shared: tilelang.layout.make_swizzled_layout(W_shared), + U_shared: tilelang.layout.make_swizzled_layout(U_shared), + W_Beta_shared: tilelang.layout.make_swizzled_layout(W_Beta_shared), + U_Beta_shared: tilelang.layout.make_swizzled_layout(U_Beta_shared), + } + ) T.disable_warp_group_reg_alloc() for i_s in T.Parallel(block_S): Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] G_shared[i_s] = T.exp(G[bb, bs * block_S + i_s, bh]) - T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared) + T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared) for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): - T.copy( - V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV], - V_shared) + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared) for i_s, i_v2 in T.Parallel(block_S, block_DV): U_Beta_shared[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s] T.gemm(A_shared, U_Beta_shared, U_fragment, clear_accum=True) # First copy to smem, then copy to gmem to reduce U2RU instructions T.copy(U_fragment, U_shared) - T.copy( - U_shared, U[bb, bs * block_S:(bs + 1) * block_S, bh, - i_v * block_DV:(i_v + 1) * block_DV]) + T.copy(U_shared, U[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV]) for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): - T.copy( - K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], - K_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) for i_s, i_k2 in T.Parallel(block_S, block_DK): - W_Beta_shared[i_s, - i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared[i_s] + W_Beta_shared[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared[i_s] T.gemm(A_shared, W_Beta_shared, W_fragment, clear_accum=True) # First copy to smem, then copy to gmem to reduce U2RU instructions T.copy(W_fragment, W_shared) - T.copy( - W_shared, W[bb, bs * block_S:(bs + 1) * block_S, bh, - i_k * block_DK:(i_k + 1) * block_DK]) + T.copy(W_shared, W[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK]) return kernel @@ -159,15 +153,8 @@ def run_test( num_stages, ): K, V, Beta, G, A = prepare_input( - B, - S, - H, - DK, - DV, - chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - gate_dtype=getattr(torch, gate_dtype)) + B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, output_dtype), gate_dtype=getattr(torch, gate_dtype) + ) W_ref, U_ref = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype)) W_tilelang, U_tilelang = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype)) @@ -191,7 +178,8 @@ def run_test( block_DK=block_DK, block_DV=block_DV, threads=threads, - num_stages=num_stages) + num_stages=num_stages, + ) print(kernel.get_kernel_source()) W_tilelang, U_tilelang = kernel(K, V, Beta, G, A) @@ -224,7 +212,8 @@ def main(): block_DK=64, block_DV=32, threads=128, - num_stages=3) + num_stages=3, + ) if __name__ == "__main__": diff --git a/examples/gdn/example_wy_fast_bwd_split.py b/examples/gdn/example_wy_fast_bwd_split.py index 42a0040dd13d7e5a6e6e92ca45cb3bb43b41db92..5b0230e5c3e0ba2be6fa40c6f8fa988371aa717b 100644 --- a/examples/gdn/example_wy_fast_bwd_split.py +++ b/examples/gdn/example_wy_fast_bwd_split.py @@ -10,6 +10,7 @@ import tilelang.language as T # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla + print(fla.__file__) from fla.ops.gated_delta_rule.wy_fast import bwd_prepare_wy_repr except ImportError: @@ -93,10 +94,8 @@ def prepare_output( @tilelang.jit( out_idx=[-5, -4, -3, -2, -1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True - }) + pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, +) def tilelang_wy_fast_bwd( # task config B, @@ -135,20 +134,20 @@ def tilelang_wy_fast_bwd( @T.prim_func def kernel( - # input - K: T.Tensor(K_shape, dtype=input_dtype), - V: T.Tensor(V_shape, dtype=input_dtype), - Beta: T.Tensor(Beta_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - A: T.Tensor(A_shape, dtype=input_dtype), - dw: T.Tensor(dw_shape, dtype=input_dtype), - du: T.Tensor(du_shape, dtype=input_dtype), - # output - dA: T.Tensor(dA_shape, dtype=input_dtype), - dk: T.Tensor(dk_shape, dtype=output_dtype), - dv: T.Tensor(dv_shape, dtype=output_dtype), - dbeta: T.Tensor(dbeta_shape, dtype=output_dtype), - dg: T.Tensor(dg_shape, dtype=gate_dtype), + # input + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + A: T.Tensor(A_shape, dtype=input_dtype), + dw: T.Tensor(dw_shape, dtype=input_dtype), + du: T.Tensor(du_shape, dtype=input_dtype), + # output + dA: T.Tensor(dA_shape, dtype=input_dtype), + dk: T.Tensor(dk_shape, dtype=output_dtype), + dv: T.Tensor(dv_shape, dtype=output_dtype), + dbeta: T.Tensor(dbeta_shape, dtype=output_dtype), + dg: T.Tensor(dg_shape, dtype=gate_dtype), ): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): bb, bh = bbh // H, bbh % H @@ -187,7 +186,7 @@ def tilelang_wy_fast_bwd( T.clear(dbeta_fragment_v) T.clear(dg_fragment) - T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared) + T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared) for i_s in T.Parallel(block_S): Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] G_shared[i_s] = G[bb, bs * block_S + i_s, bh] @@ -195,51 +194,37 @@ def tilelang_wy_fast_bwd( # Update dk for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): - T.copy( - K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], - K_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) for i_s, i_k2 in T.Parallel(block_S, block_DK): - K_shared_beta_g[i_s, - i_k2] = K_shared[i_s, - i_k2] * Beta_shared[i_s] * G_shared_exp[i_s] - T.copy( - dw[bb, bs * block_S:(bs + 1) * block_S, bh, - i_k * block_DK:(i_k + 1) * block_DK], dw_shared) + K_shared_beta_g[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s] + T.copy(dw[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], dw_shared) T.gemm(dw_shared, K_shared_beta_g, dA_fragment, transpose_B=True) T.gemm(A_shared, dw_shared, dk_fragment_beta_g, clear_accum=True, transpose_A=True) for i_s, i_k2 in T.Parallel(block_S, block_DK): - dk_fragment[ - i_s, - i_k2] = dk_fragment_beta_g[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s] + dk_fragment[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s] # for i_s, i_k2 in T.Parallel(block_S, block_DK): # dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] for i_s, i_k2 in T.Parallel(block_S, block_DK): - dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta_g[ - i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] + dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False) # for i_s, i_k2 in T.Parallel(block_S, block_DK): # dg_fragment[i_s] = dg_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s] for i_s, i_k2 in T.Parallel(block_S, block_DK): - dg_fragment_reduce_tmp[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * K_shared[ - i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s] + dg_fragment_reduce_tmp[i_s, i_k2] = ( + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s] + ) T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=1, clear=False) # correct dk - T.copy( - dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh, - i_k * block_DK:(i_k + 1) * block_DK]) + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK]) # Update dv for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): - T.copy( - V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV], - V_shared) + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared) for i_s, i_v2 in T.Parallel(block_S, block_DV): V_shared_beta[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s] - T.copy( - du[bb, bs * block_S:(bs + 1) * block_S, bh, - i_v * block_DV:(i_v + 1) * block_DV], du_shared) + T.copy(du[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], du_shared) T.gemm(du_shared, V_shared_beta, dA_fragment, transpose_B=True) T.gemm(A_shared, du_shared, dv_fragment_beta, clear_accum=True, transpose_A=True) for i_s, i_v2 in T.Parallel(block_S, block_DV): @@ -247,30 +232,22 @@ def tilelang_wy_fast_bwd( # for i_s, i_v2 in T.Parallel(block_S, block_DV): # dbeta_fragment[i_s] = dbeta_fragment[i_s] + dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2] for i_s, i_v2 in T.Parallel(block_S, block_DV): - dbeta_fragment_reduce_tmpv[i_s, - i_v2] = dv_fragment_beta[i_s, i_v2] * V_shared[i_s, - i_v2] + dbeta_fragment_reduce_tmpv[i_s, i_v2] = dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2] T.reduce_sum(dbeta_fragment_reduce_tmpv, dbeta_fragment_v, dim=1, clear=False) - T.copy( - dv_fragment, dv[bb, bs * block_S:(bs + 1) * block_S, bh, - i_v * block_DV:(i_v + 1) * block_DV]) + T.copy(dv_fragment, dv[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV]) # Temporary store dbeta, dg and dA for i_s in T.Parallel(block_S): dbeta[bb, bs * block_S + i_s, bh] = dbeta_fragment_k[i_s] + dbeta_fragment_v[i_s] dg[bb, bs * block_S + i_s, bh] = dg_fragment[i_s] # correct dA - T.copy(dA_fragment, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :]) + T.copy(dA_fragment, dA[bb, bs * block_S : (bs + 1) * block_S, bh, :]) return kernel -@tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True - }) +@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}) def tilelang_wy_fast_bwd_split( # task config B, @@ -308,20 +285,20 @@ def tilelang_wy_fast_bwd_split( @T.prim_func def kernel( - # input - K: T.Tensor(K_shape, dtype=input_dtype), - V: T.Tensor(V_shape, dtype=input_dtype), - Beta: T.Tensor(Beta_shape, dtype=input_dtype), - G: T.Tensor(G_shape, dtype=gate_dtype), - A: T.Tensor(A_shape, dtype=input_dtype), - dw: T.Tensor(dw_shape, dtype=input_dtype), - du: T.Tensor(du_shape, dtype=input_dtype), - dA: T.Tensor(dA_shape, dtype=input_dtype), - dk: T.Tensor(dk_shape, dtype=output_dtype), - dv: T.Tensor(dv_shape, dtype=output_dtype), - dbeta_k: T.Tensor(dbeta_shape, dtype=output_dtype), - dg_A_positive: T.Tensor(dA_shape, dtype=gate_dtype), - dg_A_negative: T.Tensor(dA_shape, dtype=gate_dtype), + # input + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + A: T.Tensor(A_shape, dtype=input_dtype), + dw: T.Tensor(dw_shape, dtype=input_dtype), + du: T.Tensor(du_shape, dtype=input_dtype), + dA: T.Tensor(dA_shape, dtype=input_dtype), + dk: T.Tensor(dk_shape, dtype=output_dtype), + dv: T.Tensor(dv_shape, dtype=output_dtype), + dbeta_k: T.Tensor(dbeta_shape, dtype=output_dtype), + dg_A_positive: T.Tensor(dA_shape, dtype=gate_dtype), + dg_A_negative: T.Tensor(dA_shape, dtype=gate_dtype), ): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): bb, bh = bbh // H, bbh % H @@ -350,7 +327,7 @@ def tilelang_wy_fast_bwd_split( T.clear(dA_A_fragment_1) T.clear(dA_A_fragment_2) - T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared) + T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared) for i_s in T.Parallel(block_S): Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] G_shared[i_s] = G[bb, bs * block_S + i_s, bh] @@ -361,7 +338,7 @@ def tilelang_wy_fast_bwd_split( # for i_s in T.Parallel(block_S): # dbeta_fragment[i_s] = dbeta[bb, bs * block_S + i_s, bh] # dg_fragment[i_s] = dg[bb, bs * block_S + i_s, bh] - T.copy(dA[bb, bs * block_S:(bs + 1) * block_S, bh, :], dA_shared) + T.copy(dA[bb, bs * block_S : (bs + 1) * block_S, bh, :], dA_shared) # T.copy(dA_shared, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :]) # Update dA @@ -385,8 +362,7 @@ def tilelang_wy_fast_bwd_split( for i_s1, i_s2 in T.Parallel(block_S, block_S): with T.If(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0): with T.Then(): - dA_fragment[i_s1, i_s2] *= T.exp(G[bb, bs * block_S + i_s1, bh] - - G[bb, bs * block_S + i_s2, bh]) + dA_fragment[i_s1, i_s2] *= T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh]) with T.Else(): dA_fragment[i_s1, i_s2] = 0 T.copy(dA_fragment, dA_shared) @@ -397,12 +373,8 @@ def tilelang_wy_fast_bwd_split( # Update dk using previous dk T.clear(A_fragment) for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): - T.copy( - K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], - K_shared) - T.copy( - dk[bb, bs * block_S:(bs + 1) * block_S, bh, - i_k * block_DK:(i_k + 1) * block_DK], dk_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) + T.copy(dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], dk_shared) T.copy(dk_shared, dk_fragment) for i_s, i_k2 in T.Parallel(block_S, block_DK): K_shared_beta[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] @@ -411,18 +383,14 @@ def tilelang_wy_fast_bwd_split( # for i_s, i_k2 in T.Parallel(block_S, block_DK): # dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2] for i_s, i_k2 in T.Parallel(block_S, block_DK): - dbeta_fragment_reduce_tmpk[i_s, - i_k2] = dk_fragment_beta[i_s, i_k2] * K_shared[i_s, - i_k2] + dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2] T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False) T.gemm(dA_shared, K_shared_beta, dk_fragment, transpose_A=True) for i_s, i_k2 in T.Parallel(block_S, block_DK): dk_shared_beta[i_s, i_k2] = dk_fragment_beta[i_s, i_k2] * Beta_shared[i_s] for i_s, i_k2 in T.Parallel(block_S, block_DK): dk_fragment[i_s, i_k2] = dk_fragment[i_s, i_k2] + dk_shared_beta[i_s, i_k2] - T.copy( - dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh, - i_k * block_DK:(i_k + 1) * block_DK]) + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK]) # Update dg and dbeta T.copy(A_fragment, A_shared) @@ -460,19 +428,25 @@ def run_test( threads=128, num_stages=0, ): - K, V, Beta, G, A, dw, du = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, - accum_dtype), getattr(torch, gate_dtype), - getattr(torch, state_dtype)) - dk_ref, dv_ref, dbeta_ref, dg_ref = prepare_output(B, S, H, DK, DV, chunk_size, - getattr(torch, output_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) + K, V, Beta, G, A, dw, du = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + dk_ref, dv_ref, dbeta_ref, dg_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output( - B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), - getattr(torch, state_dtype)) + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) BS = chunk_size dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda() dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda() @@ -480,28 +454,55 @@ def run_test( dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda() # ref - dk_ref, dv_ref, dbeta_ref, dg_ref = bwd_prepare_wy_repr( - K, V, G, Beta, A, dw, du, cu_seqlens=None) + dk_ref, dv_ref, dbeta_ref, dg_ref = bwd_prepare_wy_repr(K, V, G, Beta, A, dw, du, cu_seqlens=None) # tilelang - kernel = tilelang_wy_fast_bwd(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, state_dtype, chunk_size, block_DK, block_DV, threads, - num_stages) - dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel( - K, V, Beta, G, A, dw, du) + kernel = tilelang_wy_fast_bwd( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK, + block_DV, + threads, + num_stages, + ) + dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(K, V, Beta, G, A, dw, du) torch.cuda.synchronize() - kernel_split = tilelang_wy_fast_bwd_split(B, S, H, DK, DV, input_dtype, output_dtype, - accum_dtype, gate_dtype, state_dtype, chunk_size, - block_DK, block_DV, threads, num_stages) - kernel_split(K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, - dg_tilelang_A_positive, dg_tilelang_A_negative) + kernel_split = tilelang_wy_fast_bwd_split( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK, + block_DV, + threads, + num_stages, + ) + kernel_split( + K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, dg_tilelang_A_positive, dg_tilelang_A_negative + ) torch.cuda.synchronize() dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang - dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum( - dim=-1) + dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(dim=-1) from test_utils import assert_similar + assert_similar(dk_ref, dk_tilelang, eps=1e-5, name="dk", raise_assert=False) assert_similar(dv_ref, dv_tilelang, eps=1e-5, name="dv", raise_assert=False) assert_similar(dbeta_ref, dbeta_tilelang, eps=1e-5, name="dbeta", raise_assert=False) diff --git a/examples/gdn/test_example_gdn_compilation.py b/examples/gdn/test_example_gdn_compilation.py index 75a62171f7240f24869a30f9eddd5c9e87759f8a..a51936ef8991c50fb70205852f63d889cbce4133 100644 --- a/examples/gdn/test_example_gdn_compilation.py +++ b/examples/gdn/test_example_gdn_compilation.py @@ -25,16 +25,10 @@ num_stages = 1 def test_example_wy_fast_compilation(): from example_wy_fast import tilelang_recompute_w_u_fwd, prepare_input + K, V, Beta, G, A = prepare_input( - B, - S, - H, - DK, - DV, - chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - gate_dtype=getattr(torch, gate_dtype)) + B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, output_dtype), gate_dtype=getattr(torch, gate_dtype) + ) # tilelang block_S = chunk_size kernel = tilelang_recompute_w_u_fwd( @@ -52,22 +46,31 @@ def test_example_wy_fast_compilation(): block_DK=block_DK, block_DV=block_DV, threads=threads, - num_stages=num_stages) + num_stages=num_stages, + ) print(kernel.get_kernel_source()) W_tilelang, U_tilelang = kernel(K, V, Beta, G, A) def test_example_wy_fast_bwd_split_compilation(): from example_wy_fast_bwd_split import tilelang_wy_fast_bwd, tilelang_wy_fast_bwd_split, prepare_input, prepare_output - K, V, Beta, G, A, dw, du = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, - accum_dtype), getattr(torch, gate_dtype), - getattr(torch, state_dtype)) + + K, V, Beta, G, A, dw, du = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output( - B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), - getattr(torch, state_dtype)) + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) BS = chunk_size dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda() dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda() @@ -75,68 +78,146 @@ def test_example_wy_fast_bwd_split_compilation(): dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda() # tilelang - kernel = tilelang_wy_fast_bwd(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, state_dtype, chunk_size, block_DK, block_DV, threads, - num_stages) - dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel( - K, V, Beta, G, A, dw, du) + kernel = tilelang_wy_fast_bwd( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK, + block_DV, + threads, + num_stages, + ) + dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(K, V, Beta, G, A, dw, du) torch.cuda.synchronize() - kernel_split = tilelang_wy_fast_bwd_split(B, S, H, DK, DV, input_dtype, output_dtype, - accum_dtype, gate_dtype, state_dtype, chunk_size, - block_DK, block_DV, threads, num_stages) - kernel_split(K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, - dg_tilelang_A_positive, dg_tilelang_A_negative) + kernel_split = tilelang_wy_fast_bwd_split( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK, + block_DV, + threads, + num_stages, + ) + kernel_split( + K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, dg_tilelang_A_positive, dg_tilelang_A_negative + ) torch.cuda.synchronize() dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang - dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum( - dim=-1) + dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(dim=-1) def test_example_chunk_o_compilation(): from example_chunk_o import tilelang_chunk_fwd_o, prepare_input - Q, K, V, HIDDEN, G = prepare_input(B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), - getattr(torch, output_dtype), getattr(torch, accum_dtype), - getattr(torch, gate_dtype)) + + Q, K, V, HIDDEN, G = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + ) scale = 1.0 / DK**0.5 block_S = chunk_size - kernel = tilelang_chunk_fwd_o(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, chunk_size, scale, use_g, block_S, block_DK, block_DV, - threads, num_stages) + kernel = tilelang_chunk_fwd_o( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + chunk_size, + scale, + use_g, + block_S, + block_DK, + block_DV, + threads, + num_stages, + ) O_tilelang = kernel(Q, K, V, HIDDEN, G) # noqa: F841 def test_example_chunk_o_bwd_compilation(): from example_chunk_o_bwd import tilelang_chunk_o_bwd_dqkwg, prepare_input - Q, K, V, h, G, dO, dh, dv, W = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, accum_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) - kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, state_dtype, chunk_size, 1.0, use_g, True, - block_DK, block_DV, threads, num_stages) - - dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, - W) # noqa: F841 + + Q, K, V, h, G, dO, dh, dv, W = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + kernel = tilelang_chunk_o_bwd_dqkwg( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + 1.0, + use_g, + True, + block_DK, + block_DV, + threads, + num_stages, + ) + + dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W) # noqa: F841 if use_g: dg_tilelang = dg_tilelang.sum(dim=0) def test_example_chunk_scaled_dot_kkt_compilation(): from example_chunk_scaled_dot_kkt import tilelang_chunk_scaled_dot_kkt_fwd, prepare_input - K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), - getattr(torch, output_dtype), getattr(torch, accum_dtype)) + + K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), getattr(torch, output_dtype), getattr(torch, accum_dtype)) block_S = chunk_size - kernel = tilelang_chunk_scaled_dot_kkt_fwd(B, S, H, DK, chunk_size, input_dtype, output_dtype, - accum_dtype, use_g, block_S, block_DK, threads, - num_stages) + kernel = tilelang_chunk_scaled_dot_kkt_fwd( + B, S, H, DK, chunk_size, input_dtype, output_dtype, accum_dtype, use_g, block_S, block_DK, threads, num_stages + ) A_tilelang = kernel(K, Beta, G) # noqa: F841 def test_example_cumsum_compilation(): from example_cumsum import tilelang_chunk_local_cumsum_scalar, prepare_cumsum_input, prepare_cumsum_output + G = prepare_cumsum_input(B, S, H, getattr(torch, gate_dtype)) G_new_tilelang = prepare_cumsum_output(B, S, H, getattr(torch, gate_dtype)) block_S = chunk_size @@ -158,33 +239,79 @@ def test_example_cumsum_compilation(): def test_example_chunk_delta_h_compilation(): from example_chunk_delta_h import tilelang_chunk_gated_delta_rule_fwd_h, prepare_input - K, W, U, G, initial_state = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, accum_dtype), - getattr(torch, gate_dtype)) - kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype, - accum_dtype, gate_dtype, state_dtype, chunk_size, - use_g, use_initial_state, store_final_state, - save_new_value, block_DK, block_DV, threads, - num_stages) - h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, - initial_state) # noqa: F841 + + K, W, U, G, initial_state = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + ) + kernel = tilelang_chunk_gated_delta_rule_fwd_h( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + use_g, + use_initial_state, + store_final_state, + save_new_value, + block_DK, + block_DV, + threads, + num_stages, + ) + h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state) # noqa: F841 def test_example_chunk_delta_bwd_compilation(): from example_chunk_delta_bwd import tilelang_chunk_gated_delta_rule_bwd_dhu, prepare_input - Q, K, W, G, h0, dht, dO, dv = prepare_input(B, S, H, DK, DV, chunk_size, - getattr(torch, input_dtype), - getattr(torch, output_dtype), - getattr(torch, accum_dtype), - getattr(torch, gate_dtype), - getattr(torch, state_dtype)) - kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(B, S, H, DK, DV, input_dtype, output_dtype, - accum_dtype, gate_dtype, state_dtype, - chunk_size, 1.0, use_g, use_initial_state, - use_final_state_gradient, block_DV, threads, - num_stages) + + Q, K, W, G, h0, dht, dO, dv = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + kernel = tilelang_chunk_gated_delta_rule_bwd_dhu( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + 1.0, + use_g, + use_initial_state, + use_final_state_gradient, + block_DV, + threads, + num_stages, + ) dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) # noqa: F841 diff --git a/examples/gdn/test_utils.py b/examples/gdn/test_utils.py index 37f8d8e69f2bf388bc9b95fc571817489375d32c..3588551ce39ad1bee4267b727979336d73341561 100644 --- a/examples/gdn/test_utils.py +++ b/examples/gdn/test_utils.py @@ -9,7 +9,7 @@ def calc_sim(x, y, name="tensor"): x, y = x.data.double(), y.data.double() denominator = (x * x + y * y).sum() if denominator == 0: - print_red_warning(f'{name} all zero') + print_red_warning(f"{name} all zero") return 1 sim = 2 * (x * y).sum() / denominator return sim @@ -19,21 +19,19 @@ def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True): x_mask = torch.isfinite(x) y_mask = torch.isfinite(y) if not torch.all(x_mask == y_mask): - print_red_warning(f'{name} Error: isfinite mask mismatch') + print_red_warning(f"{name} Error: isfinite mask mismatch") if raise_assert: raise AssertionError - if not torch.isclose( - x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, - equal_nan=True).all(): - print_red_warning(f'{name} Error: nonfinite value mismatch') + if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all(): + print_red_warning(f"{name} Error: nonfinite value mismatch") if raise_assert: raise AssertionError x = x.masked_fill(~x_mask, 0) y = y.masked_fill(~y_mask, 0) sim = calc_sim(x, y, name) - diff = 1. - sim + diff = 1.0 - sim if not (0 <= diff <= eps): - print_red_warning(f'{name} Error: {diff}') + print_red_warning(f"{name} Error: {diff}") if raise_assert: raise AssertionError else: diff --git a/examples/gemm/example_gemm.py b/examples/gemm/example_gemm.py index f18cd388a79e0972b12cdecd89bac6fdd780d64e..2c234d122d07ff9d5faf4428e1ccc9a716272223 100644 --- a/examples/gemm/example_gemm.py +++ b/examples/gemm/example_gemm.py @@ -4,12 +4,11 @@ import tilelang.language as T @tilelang.jit(out_idx=[-1]) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def gemm( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) diff --git a/examples/gemm/example_gemm_autotune.py b/examples/gemm/example_gemm_autotune.py index 661ef1276d496dbcc38fa0a77036a611fecceea8..badc334025110b0078943968e0508e7ca3eaa0fc 100644 --- a/examples/gemm/example_gemm_autotune.py +++ b/examples/gemm/example_gemm_autotune.py @@ -90,7 +90,8 @@ def get_configs(M, N, K, with_roller=False, topk=20): num_stages, thread_num, enable_rasterization, - )) + ) + ) configs = [ { @@ -100,13 +101,13 @@ def get_configs(M, N, K, with_roller=False, topk=20): "num_stages": c[3], "thread_num": c[4], "enable_rasteration": c[5], # keep param name for backward-compat - } for c in _configs + } + for c in _configs ] return configs def get_best_config(M, N, K, with_roller=False): - def kernel( block_M=None, block_N=None, @@ -120,12 +121,11 @@ def get_best_config(M, N, K, with_roller=False): @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_N, block_K), dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -146,15 +146,18 @@ def get_best_config(M, N, K, with_roller=False): return main - autotuner = AutoTuner.from_kernel( - kernel=kernel, configs=get_configs(M, N, K, with_roller)).set_compile_args( + autotuner = ( + AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N, K, with_roller)) + .set_compile_args( out_idx=[-1], target="auto", - ).set_profile_args( + ) + .set_profile_args( supply_type=tl.TensorSupplyType.Integer, ref_prog=ref_program, skip_check=False, ) + ) return autotuner.run(warmup=3, rep=20) @@ -167,52 +170,20 @@ def get_heuristic_config() -> dict: sm_version = sm_major * 10 + sm_minor print(f"CUDA device capability: {sm_version}") if sm_version in {80}: - return { - "block_M": 128, - "block_N": 256, - "block_K": 32, - "num_stages": 2, - "thread_num": 128, - "enable_rasteration": True - } + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 2, "thread_num": 128, "enable_rasteration": True} elif sm_version in {90}: - return { - "block_M": 128, - "block_N": 256, - "block_K": 64, - "num_stages": 3, - "thread_num": 256, - "enable_rasteration": True - } + return {"block_M": 128, "block_N": 256, "block_K": 64, "num_stages": 3, "thread_num": 256, "enable_rasteration": True} else: - return { - "block_M": 128, - "block_N": 256, - "block_K": 32, - "num_stages": 0, - "thread_num": 128, - "enable_rasteration": True - } + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 0, "thread_num": 128, "enable_rasteration": True} @tl.jit(out_idx=[-1]) -def matmul(M, - N, - K, - block_M, - block_N, - block_K, - num_stages, - thread_num, - enable_rasteration, - dtype="float16", - accum_dtype="float"): - +def matmul(M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype="float16", accum_dtype="float"): @T.prim_func def gemm_autotune( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -236,11 +207,7 @@ def matmul(M, return gemm_autotune -def main(M: int = 4096, - N: int = 4096, - K: int = 4096, - use_autotune: bool = False, - with_roller: bool = False): +def main(M: int = 4096, N: int = 4096, K: int = 4096, use_autotune: bool = False, with_roller: bool = False): use_autotune = True if use_autotune: result = get_best_config(M, N, K, with_roller) @@ -266,15 +233,7 @@ if __name__ == "__main__": parser.add_argument("--m", type=int, default=4096, help="Matrix dimension M") parser.add_argument("--n", type=int, default=4096, help="Matrix dimension N") parser.add_argument("--k", type=int, default=4096, help="Matrix dimension K") - parser.add_argument( - "--use_autotune", - action="store_true", - default=False, - help="Whether to use autotune for matmul configs") - parser.add_argument( - "--with_roller", - action="store_true", - default=False, - help="Whether to enable BitBLAS roller for search space") + parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs") + parser.add_argument("--with_roller", action="store_true", default=False, help="Whether to enable BitBLAS roller for search space") args = parser.parse_args() main(args.m, args.n, args.k, args.use_autotune, args.with_roller) diff --git a/examples/gemm/example_gemm_intrinsics.py b/examples/gemm/example_gemm_intrinsics.py index 5c014ce3a4b51d9b131382fa5af4773cd1b0f582..488e5bf6bc37ed9833192fd3a5fb4006b09647e7 100644 --- a/examples/gemm/example_gemm_intrinsics.py +++ b/examples/gemm/example_gemm_intrinsics.py @@ -4,7 +4,8 @@ import tilelang import tilelang.language as T from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from tilelang.transform import simplify_prim_func @@ -99,12 +100,11 @@ def tl_matmul( @T.prim_func def gemm_intrinsics( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -112,10 +112,12 @@ def tl_matmul( B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -123,7 +125,6 @@ def tl_matmul( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -133,7 +134,6 @@ def tl_matmul( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a(A_local, A_shared, ki) diff --git a/examples/gemm/example_gemm_persistent.py b/examples/gemm/example_gemm_persistent.py index a2a7122d39193b274b015c038e8615e3e21d17af..6fc0e5aac6365ec10e0d251a15d62c64a5045b76 100644 --- a/examples/gemm/example_gemm_persistent.py +++ b/examples/gemm/example_gemm_persistent.py @@ -5,22 +5,12 @@ import argparse @tilelang.jit(out_idx=[-1]) -def matmul_non_persistent(M, - N, - K, - block_M, - block_N, - block_K, - threads, - num_stages, - dtype="float16", - accum_dtype="float"): - +def matmul_non_persistent(M, N, K, block_M, block_N, block_K, threads, num_stages, dtype="float16", accum_dtype="float"): @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=threads) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -43,18 +33,9 @@ def matmul_non_persistent(M, @tilelang.jit(out_idx=[-1]) -def matmul_persistent(M, - N, - K, - block_M, - block_N, - block_K, - threads, - num_stages, - dtype="float16", - accum_dtype="float", - use_persistent_primitive=True): - +def matmul_persistent( + M, N, K, block_M, block_N, block_K, threads, num_stages, dtype="float16", accum_dtype="float", use_persistent_primitive=True +): sm_num = driver.get_num_sms() m_blocks = T.ceildiv(M, block_M) n_blocks = T.ceildiv(N, block_N) @@ -63,9 +44,9 @@ def matmul_persistent(M, @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(sm_num, threads=threads) as (block_id): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -90,9 +71,9 @@ def matmul_persistent(M, @T.prim_func def main_persistent_primitive( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(sm_num, threads=threads) as (block_id): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -100,8 +81,7 @@ def matmul_persistent(M, C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_shared = T.alloc_shared((block_M, block_N), dtype) - for bx, by in T.Persistent( - [T.ceildiv(M, block_M), T.ceildiv(N, block_N)], sm_num, block_id): + for bx, by in T.Persistent([T.ceildiv(M, block_M), T.ceildiv(N, block_N)], sm_num, block_id): T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[bx * block_M, k * block_K], A_shared) @@ -128,18 +108,15 @@ def main(M=4096, N=4096, K=4096): num_stages = 3 persistent_kernel = matmul_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages) - persistent_profiler = persistent_kernel.get_profiler( - tensor_supply_type=tilelang.TensorSupplyType.Randn) + persistent_profiler = persistent_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) print("Persistent GEMM: All check passed.") persistent_latency = persistent_profiler.do_bench(warmup=500) print(f"Persistent GEMM Latency: {persistent_latency} ms") print(f"Persistent GEMM TFlops: {total_flops / persistent_latency * 1e-9} TFlops") - non_persistent_kernel = matmul_non_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, - num_stages) - non_persistent_profiler = non_persistent_kernel.get_profiler( - tensor_supply_type=tilelang.TensorSupplyType.Randn) + non_persistent_kernel = matmul_non_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages) + non_persistent_profiler = non_persistent_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) non_persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) print("Non-Persistent GEMM: All check passed.") non_persistent_latency = non_persistent_profiler.do_bench(warmup=500) @@ -151,9 +128,9 @@ def main(M=4096, N=4096, K=4096): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--M', type=int, default=8192, help='M dimension') - parser.add_argument('--N', type=int, default=8192, help='N dimension') - parser.add_argument('--K', type=int, default=8192, help='K dimension') + parser.add_argument("--M", type=int, default=8192, help="M dimension") + parser.add_argument("--N", type=int, default=8192, help="N dimension") + parser.add_argument("--K", type=int, default=8192, help="K dimension") args = parser.parse_args() M, N, K = args.M, args.N, args.K main(M, N, K) diff --git a/examples/gemm/example_gemm_schedule.py b/examples/gemm/example_gemm_schedule.py index f4727412b74be1f7ef4a59bec24f40109e55e50a..d1eb11df56671bbdf4f162851789086236133b9f 100644 --- a/examples/gemm/example_gemm_schedule.py +++ b/examples/gemm/example_gemm_schedule.py @@ -4,12 +4,11 @@ import tilelang.language as T @tilelang.jit(out_idx=[-1]) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def gemm_schedule( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) diff --git a/examples/gemm_fp8/example_tilelang_gemm_amd.py b/examples/gemm_fp8/example_tilelang_gemm_amd.py index 0e6ace7571f55eba02c13930b87808eb9dd3c9db..4c58144e4385100655f0a67ab636ee553ee416e6 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_amd.py +++ b/examples/gemm_fp8/example_tilelang_gemm_amd.py @@ -17,10 +17,8 @@ def supply_prog(args): a_param, b_param = args M, K = a_param.shape N, _ = b_param.shape - a = (torch.randn(M, K, dtype=torch.float16, device='cuda') * - 0.01).to(dtype=torch.float8_e4m3fnuz) - b = (torch.randn(N, K, dtype=torch.float16, device='cuda') * - 0.01).to(dtype=torch.float8_e4m3fnuz) + a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) + b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) return [a, b] @@ -35,27 +33,24 @@ def get_configs(): valid_configs = [] - for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks, - num_stages, num_threads, k_packs, - gemm_types): - valid_configs.append({ - "block_M": m, - "block_N": n, - "block_K": k, - "num_stages": stages, - "num_threads": t, - "k_pack": kp, - "gemm_type": gemm_type, - }) + for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks, num_stages, num_threads, k_packs, gemm_types): + valid_configs.append( + { + "block_M": m, + "block_N": n, + "block_K": k, + "num_stages": stages, + "num_threads": t, + "k_pack": kp, + "gemm_type": gemm_type, + } + ) return valid_configs @tilelang.autotune( - configs=get_configs(), - cache_input_tensors=True, - ref_prog=ref_program, - manual_check_prog=manual_check_prog, - supply_prog=supply_prog) + configs=get_configs(), cache_input_tensors=True, ref_prog=ref_program, manual_check_prog=manual_check_prog, supply_prog=supply_prog +) @tilelang.jit(out_idx=[-1]) def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pack, gemm_type): dtype = "float8_e4m3fnuz" @@ -63,12 +58,11 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa @T.prim_func def gemm_fp8_rs( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), accum_dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), accum_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): A_local = T.alloc_fragment((block_M, block_K), dtype) B_shared = T.alloc_shared((block_N, block_K), dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -77,24 +71,17 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_local) T.copy(B[bx * block_N, k * block_K], B_shared) - T.gemm( - A_local, - B_shared, - C_local, - transpose_B=True, - k_pack=k_pack, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(A_local, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow) T.copy(C_local, C[by * block_M, bx * block_N]) @T.prim_func def gemm_fp8_ss( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), accum_dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), accum_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_N, block_K), dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -103,13 +90,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_shared) T.copy(B[bx * block_N, k * block_K], B_shared) - T.gemm( - A_shared, - B_shared, - C_local, - transpose_B=True, - k_pack=k_pack, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(A_shared, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow) T.copy(C_local, C[by * block_M, bx * block_N]) @@ -123,10 +104,8 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa def test_gemm_fp8(M, N, K): kernel = fp8_matmul(M, N, K) - a = (torch.randn(M, K, dtype=torch.float16, device='cuda') * - 0.01).to(dtype=torch.float8_e4m3fnuz) - b = (torch.randn(N, K, dtype=torch.float16, device='cuda') * - 0.01).to(dtype=torch.float8_e4m3fnuz) + a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) + b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) c = kernel(a, b) ref_c = ref_program(a, b) torch_assert_close(c, ref_c, rtol=1e-2, atol=1e-2) diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8.py b/examples/gemm_fp8/example_tilelang_gemm_fp8.py index a403ed068a5c2bea9b5d234fe1dc7c1e04c826a5..1ecd344bc35c701545113dd096f14555aaca769f 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8.py @@ -13,12 +13,11 @@ def calc_diff(x, y): @tilelang.jit(out_idx=[-1]) def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"): - @T.prim_func def gemm_fp8( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -41,8 +40,8 @@ def test_gemm_fp8(M, N, K, dtype): kernel = matmul(M, N, K, 128, 128, 64, dtype) - a = torch.randn(M, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype) - b = torch.randn(N, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype) + a = torch.randn(M, K, dtype=torch.float16, device="cuda").to(dtype=torch_dtype) + b = torch.randn(N, K, dtype=torch.float16, device="cuda").to(dtype=torch_dtype) c = kernel(a, b) @@ -57,8 +56,8 @@ def test_gemm_fp8(M, N, K, dtype): def main(): - test_gemm_fp8(1024, 1024, 1024, 'float8_e4m3') - test_gemm_fp8(1024, 1024, 1024, 'float8_e5m2') + test_gemm_fp8(1024, 1024, 1024, "float8_e4m3") + test_gemm_fp8(1024, 1024, 1024, "float8_e5m2") if __name__ == "__main__": diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py index 1d9207aff2d3d33b83dccb1c106de027afbbd058..3af4c3d6da8b9483afd7fcaf7a974b362adbadbe 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py @@ -13,9 +13,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"): @T.prim_func def gemm_fp8_2xAcc( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), accum_dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), accum_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -59,14 +59,14 @@ def test_gemm_fp8(M, N, K, dtype): kernel = matmul(M, N, K, 128, 128, 64, dtype) - a = torch.rand(M, K, dtype=torch.float16, device='cuda') + a = torch.rand(M, K, dtype=torch.float16, device="cuda") a = (100 * (2 * a - 1)).to(dtype=torch_dtype) - b = torch.rand(N, K, dtype=torch.float16, device='cuda') + b = torch.rand(N, K, dtype=torch.float16, device="cuda") b = (100 * (2 * b - 1)).to(dtype=torch_dtype) c = kernel(a, b) - ref_c = (a.float() @ b.float().T) + ref_c = a.float() @ b.float().T diff = calc_diff(c, ref_c) print(f"diff: {diff}") @@ -74,8 +74,8 @@ def test_gemm_fp8(M, N, K, dtype): def main(): - test_gemm_fp8(1024, 1024, 8192, 'float8_e4m3') - test_gemm_fp8(1024, 1024, 8192, 'float8_e5m2') + test_gemm_fp8(1024, 1024, 8192, "float8_e4m3") + test_gemm_fp8(1024, 1024, 8192, "float8_e5m2") if __name__ == "__main__": diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py index 0e2c437e3de34e9e9b6437fecda902b006d0da62..6e2d41be83cc687db93244c092faffef5a756f17 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py @@ -5,7 +5,8 @@ from tvm import DataType import tilelang.language as T from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from tilelang.transform import simplify_prim_func from tilelang.utils.tensor import map_torch_type @@ -110,12 +111,11 @@ def tl_matmul( @T.prim_func def gemm_fp8_intrinsic( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -123,10 +123,12 @@ def tl_matmul( B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -134,7 +136,6 @@ def tl_matmul( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -144,7 +145,6 @@ def tl_matmul( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py index 4628a9975284f7592016d0b49af8e79d8c5279ff..5cb42e328dc40743ff1d62f50eb5340242faaab6 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py @@ -26,9 +26,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -121,6 +121,4 @@ for tvm_fp8_dtype in ["float8_e4m3", "float8_e5m2"]: profiler = jit_kernel.get_profiler() latency = profiler.do_bench() print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Latency: {latency} ms") - print( - f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS" - ) + print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS") diff --git a/examples/gemm_sm100/gemm_mma.py b/examples/gemm_sm100/gemm_mma.py index a58e5a7c005a0d8ff6d7d2316d10243860ede981..be43f4ec40eefeb43cdafb439977473e5679e24c 100644 --- a/examples/gemm_sm100/gemm_mma.py +++ b/examples/gemm_sm100/gemm_mma.py @@ -5,12 +5,11 @@ import tilelang.language as T # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): @@ -62,7 +61,8 @@ jit_kernel = tilelang.compile( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, +) print(jit_kernel.get_kernel_source()) # 3. Test the kernel in Python with PyTorch data import torch diff --git a/examples/gemm_sm100/gemm_tcgen5mma.py b/examples/gemm_sm100/gemm_tcgen5mma.py index 9008c7ef5208aa6ea9d0145ceac5e43b4a534cbb..88614f561044b565bdc1a44dae0acfef01f52419 100644 --- a/examples/gemm_sm100/gemm_tcgen5mma.py +++ b/examples/gemm_sm100/gemm_tcgen5mma.py @@ -25,9 +25,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -40,15 +40,7 @@ def matmul( for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_shared) T.copy(B[bx * block_N, k * block_K], B_shared) - T.gemm( - A_shared, - B_shared, - C_tmem, - trans_A, - trans_B, - mbar=mbar, - wg_wait=-1, - clear_accum=k == 0) + T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k == 0) T.mbarrier_wait_parity(mbar, k % 2) T.copy(C_tmem, C_local) @@ -66,8 +58,7 @@ in_dtype, out_dtype, accum_dtype = "bfloat16", "bfloat16", "float" num_stages = 2 threads = 256 -func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, - accum_dtype, num_stages, threads) +func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads) jit_kernel = tilelang.compile( func, out_idx=[2], @@ -75,7 +66,8 @@ jit_kernel = tilelang.compile( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, +) print(jit_kernel.get_kernel_source()) @@ -88,4 +80,4 @@ torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) profiler = jit_kernel.get_profiler() latency = profiler.do_bench() print(f"Latency: {latency} ms") -print(f"Flops: {2 * M * N * K / (latency/1e3) / 1e12} TFLOPS") +print(f"Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS") diff --git a/examples/gemm_sp/example_custom_compress.py b/examples/gemm_sp/example_custom_compress.py index 5125aed073a3b74f0d150eac81a8fc053c6af686..fe3b1523344fac338ab872f56f10c9e3f47e6e58 100644 --- a/examples/gemm_sp/example_custom_compress.py +++ b/examples/gemm_sp/example_custom_compress.py @@ -17,77 +17,76 @@ torch.manual_seed(42) DEFAULT_CONFIG = { # take best config from autotune script "4090": { - 'float': { - 'block_M': 128, - 'block_N': 64, - 'block_K': 64, - 'num_stages': 1, - 'thread_num': 128, - 'policy': T.GemmWarpPolicy.Square, - 'enable_rasterization': True + "float": { + "block_M": 128, + "block_N": 64, + "block_K": 64, + "num_stages": 1, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + "float16": { + "block_M": 256, + "block_N": 128, + "block_K": 64, + "num_stages": 2, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, }, - 'float16': { - 'block_M': 256, - 'block_N': 128, - 'block_K': 64, - 'num_stages': 2, - 'thread_num': 128, - 'policy': T.GemmWarpPolicy.Square, - 'enable_rasterization': True - } }, "h20": { - 'float': { - 'block_M': 128, - 'block_N': 64, - 'block_K': 128, - 'num_stages': 3, - 'thread_num': 128, - 'policy': T.GemmWarpPolicy.Square, - 'enable_rasterization': True + "float": { + "block_M": 128, + "block_N": 64, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + "float16": { + "block_M": 128, + "block_N": 64, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, }, - 'float16': { - 'block_M': 128, - 'block_N': 64, - 'block_K': 128, - 'num_stages': 3, - 'thread_num': 128, - 'policy': T.GemmWarpPolicy.Square, - 'enable_rasterization': True - } - } + }, } ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} @tilelang.jit(out_idx=[-1]) -def matmul_sp_fp16_custom_compress(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, - thread_num, policy, enable_rasterization, use_cutlass_layout): +def matmul_sp_fp16_custom_compress( + M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization, use_cutlass_layout +): e_factor, e_dtype = (16, "int16") @T.prim_func def gemm_sp_fp16_custom_compress( - A_sparse: T.Tensor((M, K // 2), 'float16'), - E: T.Tensor((M, K // e_factor), e_dtype), - B: T.Tensor((K, N), 'float16'), - C: T.Tensor((M, N), accum_dtype), + A_sparse: T.Tensor((M, K // 2), "float16"), + E: T.Tensor((M, K // e_factor), e_dtype), + B: T.Tensor((K, N), "float16"), + C: T.Tensor((M, N), accum_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K // 2), 'float16') + A_shared = T.alloc_shared((block_M, block_K // 2), "float16") E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) - B_shared = T.alloc_shared((block_K, block_N), 'float16') + B_shared = T.alloc_shared((block_K, block_N), "float16") C_shared = T.alloc_shared((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) if use_cutlass_layout: - T.annotate_layout({ - E: - make_cutlass_metadata_layout( - E, mma_dtype="float16", arch="8.0", block_k=block_K), - E_shared: - make_cutlass_metadata_layout( - E_shared, mma_dtype="float16", arch="8.0", block_k=block_K), - }) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype="float16", arch="8.0", block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", arch="8.0", block_k=block_K), + } + ) T.clear(C_local) T.disable_warp_group_reg_alloc() T.use_swizzle(panel_size=10, enable=enable_rasterization) @@ -108,8 +107,7 @@ def torch_compress(dense): A naive compression function, where each 4-bit meta matches 4 elements in original matrix in row major layout. """ if dense.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor") + raise RuntimeError(f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor") m, k = dense.shape @@ -131,9 +129,7 @@ def torch_compress(dense): if m % 32 != 0: raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 32") if k % (4 * quadbits_per_meta_elem) != 0: - raise RuntimeError( - f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" - ) + raise RuntimeError(f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}") if dense.dtype != torch.float: ksparse = 4 @@ -194,19 +190,13 @@ def torch_compress(dense): sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) else: - sparse = dense_2.gather(-1, - idxs0.unsqueeze(-1) // 2).view( - m, k // 2) # type: ignore[possibly-undefined] + sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined] meta_4 = idxs0 | (idxs1 << 2) meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) if quadbits_per_meta_elem == 4: - meta = ( - meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12)) + meta = meta_n[:, :, 0] | (meta_n[:, :, 1] << 4) | (meta_n[:, :, 2] << 8) | (meta_n[:, :, 3] << 12) elif quadbits_per_meta_elem == 8: meta = ( meta_n[:, :, 0] @@ -216,7 +206,8 @@ def torch_compress(dense): | (meta_n[:, :, 4] << 16) | (meta_n[:, :, 5] << 20) | (meta_n[:, :, 6] << 24) - | (meta_n[:, :, 7] << 28)) + | (meta_n[:, :, 7] << 28) + ) return (sparse, meta) @@ -234,9 +225,11 @@ def decode_metadata(meta: torch.Tensor) -> torch.Tensor: @tilelang.jit( - out_idx=[1, 2], pass_configs={ + out_idx=[1, 2], + pass_configs={ tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True, - }) + }, +) def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): e_factor, e_dtype = ARCH_INFO["8.0"] e_K = K // e_factor @@ -249,23 +242,21 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): @T.prim_func def kernel( - A: T.Tensor((M, K), dtype), - A_sp: T.Tensor((M, K // 2), dtype), - E: T.Tensor((M, e_K), e_dtype), + A: T.Tensor((M, K), dtype), + A_sp: T.Tensor((M, K // 2), dtype), + E: T.Tensor((M, e_K), e_dtype), ): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype) E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) if use_cutlass_layout: - T.annotate_layout({ - E: - make_cutlass_metadata_layout( - E, mma_dtype="float16", arch="8.0", block_k=block_K), - E_shared: - make_cutlass_metadata_layout( - E_shared, mma_dtype="float16", arch="8.0", block_k=block_K), - }) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype="float16", arch="8.0", block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", arch="8.0", block_k=block_K), + } + ) T.clear(A_sp_shared) T.clear(E_shared) # TODO: alloc_var seems buggy here @@ -295,8 +286,7 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): non_zero_elt_log_idx[1] = 3 for i in T.serial(elem): val = non_zero_elt_log_idx[i] - E_shared[tm, a_k // e_factor] |= T.shift_left( - val, 4 * (g_i % (e_factor // group)) + 2 * i) + E_shared[tm, a_k // e_factor] |= T.shift_left(val, 4 * (g_i % (e_factor // group)) + 2 * i) T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2]) T.copy(E_shared, E[bx * block_M, by * block_K // e_factor]) @@ -304,41 +294,27 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): def main(): - parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") - parser.add_argument( - "--use_cutlass_layout", action='store_true', help="Use cutlass layout for E tensor") - parser.add_argument( - "--use_torch_compressor", action='store_true', help="Use torch sparse for reference") - parser.add_argument( - "--accum_dtype", - type=str, - default="float", - choices=["float", "float16"], - help="Accumulation datatype") + parser.add_argument("--use_cutlass_layout", action="store_true", help="Use cutlass layout for E tensor") + parser.add_argument("--use_torch_compressor", action="store_true", help="Use torch sparse for reference") + parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype") parser.add_argument("--cfg", type=str, choices=["4090"], default="4090") args = parser.parse_args() kernel = matmul_sp_fp16_custom_compress( - args.m, - args.n, - args.k, - args.accum_dtype, - **DEFAULT_CONFIG[args.cfg][args.accum_dtype], - use_cutlass_layout=args.use_cutlass_layout) + args.m, args.n, args.k, args.accum_dtype, **DEFAULT_CONFIG[args.cfg][args.accum_dtype], use_cutlass_layout=args.use_cutlass_layout + ) - a = randn_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half) - b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half) + a = randn_semi_sparse(args.m, args.k, device="cuda", dtype=torch.half) + b = torch.randn(args.k, args.n, device="cuda", dtype=torch.half) if args.use_torch_compressor: assert not args.use_cutlass_layout, "torch sparse must be used with naive layout" a_sparse, e = torch_compress(a) else: - a_sparse, e = compress_kernel( - args.m, args.k, 32, 32, "float16", use_cutlass_layout=args.use_cutlass_layout)( - a) + a_sparse, e = compress_kernel(args.m, args.k, 32, 32, "float16", use_cutlass_layout=args.use_cutlass_layout)(a) c = kernel(a_sparse, e, b) @@ -346,9 +322,7 @@ def main(): assert not c.isnan().any(), "Reference result contains NaNs, please report an issue" torch_assert_close(c, ref_c.to(c.dtype), rtol=1e-3, atol=1e-3) - print( - f"Precision check passed. Max diff: {(c - ref_c).abs().max()}, Mean diff: {(c - ref_c).abs().mean()}" - ) + print(f"Precision check passed. Max diff: {(c - ref_c).abs().max()}, Mean diff: {(c - ref_c).abs().mean()}") latency = do_bench(lambda: kernel(a_sparse, e, b)) ref_latency = do_bench(lambda: a @ b) @@ -356,8 +330,8 @@ def main(): total_flops = 2 * args.m * args.n * args.k tflops = total_flops / latency / 1e9 ref_tflops = total_flops / ref_latency / 1e9 - print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency/1e3} s") - print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency/1e3:} s") + print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s") + print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency / 1e3:} s") if __name__ == "__main__": diff --git a/examples/gemm_sp/example_gemm_sp.py b/examples/gemm_sp/example_gemm_sp.py index 91682a9e459064088156c20d9cf77ee54bac4a68..828ca43a28c35c137e7f9f382bcc83c61845c051 100644 --- a/examples/gemm_sp/example_gemm_sp.py +++ b/examples/gemm_sp/example_gemm_sp.py @@ -16,80 +16,77 @@ arch = nvcc.get_target_compute_version() DEFAULT_CONFIG = { # take best config from autotune script "4090": { - 'float': { - 'block_M': 128, - 'block_N': 64, - 'block_K': 64, - 'num_stages': 1, - 'thread_num': 128, - 'policy': T.GemmWarpPolicy.Square, - 'enable_rasterization': True + "float": { + "block_M": 128, + "block_N": 64, + "block_K": 64, + "num_stages": 1, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + "float16": { + "block_M": 256, + "block_N": 128, + "block_K": 64, + "num_stages": 2, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, }, - 'float16': { - 'block_M': 256, - 'block_N': 128, - 'block_K': 64, - 'num_stages': 2, - 'thread_num': 128, - 'policy': T.GemmWarpPolicy.Square, - 'enable_rasterization': True - } }, "h20": { - 'float': { - 'block_M': 128, - 'block_N': 64, - 'block_K': 128, - 'num_stages': 3, - 'thread_num': 128, - 'policy': T.GemmWarpPolicy.Square, - 'enable_rasterization': True + "float": { + "block_M": 128, + "block_N": 64, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + "float16": { + "block_M": 128, + "block_N": 64, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, }, - 'float16': { - 'block_M': 128, - 'block_N': 64, - 'block_K': 128, - 'num_stages': 3, - 'thread_num': 128, - 'policy': T.GemmWarpPolicy.Square, - 'enable_rasterization': True - } - } + }, } ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} @tilelang.jit(out_idx=[-1]) -def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, - enable_rasterization): +def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization): e_factor, e_dtype = ARCH_INFO[arch] @T.prim_func def gemm_sp_fp16( - A_sparse: T.Tensor((M, K // 2), 'float16'), - E: T.Tensor((M, K // e_factor), e_dtype), - B: T.Tensor((K, N), 'float16'), - C: T.Tensor((M, N), accum_dtype), + A_sparse: T.Tensor((M, K // 2), "float16"), + E: T.Tensor((M, K // e_factor), e_dtype), + B: T.Tensor((K, N), "float16"), + C: T.Tensor((M, N), accum_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K // 2), 'float16') + A_shared = T.alloc_shared((block_M, block_K // 2), "float16") E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) - B_shared = T.alloc_shared((block_K, block_N), 'float16') + B_shared = T.alloc_shared((block_K, block_N), "float16") C_shared = T.alloc_shared((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) T.disable_warp_group_reg_alloc() T.use_swizzle(panel_size=10, enable=enable_rasterization) - T.annotate_layout({ - E: - make_cutlass_metadata_layout( - E, mma_dtype="float16", block_k=block_K, arch=arch), - E_shared: - make_cutlass_metadata_layout( - E_shared, mma_dtype="float16", block_k=block_K, arch=arch), - }) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype="float16", block_k=block_K, arch=arch), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", block_k=block_K, arch=arch), + } + ) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) T.copy(E[by * block_M, k * block_K // e_factor], E_shared) @@ -107,25 +104,15 @@ def main(): parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") - parser.add_argument( - "--accum_dtype", - type=str, - default="float", - choices=["float", "float16"], - help="Accumulation datatype") + parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype") parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090") args = parser.parse_args() - kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype, - **DEFAULT_CONFIG[args.cfg][args.accum_dtype]) + kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype, **DEFAULT_CONFIG[args.cfg][args.accum_dtype]) - a = randn_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half) - b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half) + a = randn_semi_sparse(args.m, args.k, device="cuda", dtype=torch.half) + b = torch.randn(args.k, args.n, device="cuda", dtype=torch.half) - a_sparse, e = compress( - a, - transposed=False, - block_k=DEFAULT_CONFIG[args.cfg][args.accum_dtype]['block_K'], - arch=arch) + a_sparse, e = compress(a, transposed=False, block_k=DEFAULT_CONFIG[args.cfg][args.accum_dtype]["block_K"], arch=arch) c = kernel(a_sparse, e, b) ref_c = a @ b @@ -140,8 +127,8 @@ def main(): total_flops = 2 * args.m * args.n * args.k tflops = total_flops / latency / 1e9 ref_tflops = total_flops / ref_latency / 1e9 - print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency/1e3} s") - print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency/1e3:} s") + print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s") + print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency / 1e3:} s") if __name__ == "__main__": diff --git a/examples/gemm_splitk/example_tilelang_gemm_splitk.py b/examples/gemm_splitk/example_tilelang_gemm_splitk.py index c9666971180615813ae3ab36513e07126b32184b..320a699c5ffdd7d6997a08f09de2c7ab88c5d1f7 100644 --- a/examples/gemm_splitk/example_tilelang_gemm_splitk.py +++ b/examples/gemm_splitk/example_tilelang_gemm_splitk.py @@ -3,27 +3,16 @@ import tilelang.language as T @tilelang.jit -def matmul(M, - N, - K, - block_M, - block_N, - block_K, - split_k, - dtype="float16", - accum_dtype="float", - out_dtype="float32"): - +def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_dtype="float", out_dtype="float32"): splitK = K // split_k @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), out_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype) diff --git a/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py b/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py index 145d622edf1e497ad1664157dc3e638ecea5aa86..dfd8471018b4bee6812a52ba1c22a8629009c007 100644 --- a/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py +++ b/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py @@ -3,27 +3,16 @@ import tilelang.language as T @tilelang.jit -def matmul(M, - N, - K, - block_M, - block_N, - block_K, - split_k, - dtype="float16", - accum_dtype="float", - out_dtype="float32"): - +def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_dtype="float", out_dtype="float32"): splitK = K // split_k @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), out_dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype) diff --git a/examples/gemm_streamk/example_tilelang_gemm_streamk.py b/examples/gemm_streamk/example_tilelang_gemm_streamk.py index 31cf40647c28872be65f588f5cd532f7f676e001..2d83586e51d658774818ff69823a0ddbf570721a 100644 --- a/examples/gemm_streamk/example_tilelang_gemm_streamk.py +++ b/examples/gemm_streamk/example_tilelang_gemm_streamk.py @@ -39,7 +39,7 @@ total_tiles = num_block_m * num_block_n # Two-tile SK + DP streamk_tiles = total_tiles % streamk_programs -if (total_tiles - streamk_tiles > streamk_programs): # (total_tiles // total_programs > 1) +if total_tiles - streamk_tiles > streamk_programs: # (total_tiles // total_programs > 1) streamk_tiles += streamk_programs blocking_tiles = total_tiles - streamk_tiles @@ -135,7 +135,6 @@ def tl_matmul_streamk( C: T.Tensor, C_local: T.LocalBuffer, ): - for p in T.serial(sm_patition_factor): tile_id = pid + streamk_tiles + p * total_sm pid_m = tile_id // T.ceildiv(N, block_N) @@ -150,12 +149,11 @@ def tl_matmul_streamk( @T.prim_func def main( - A: T.Tensor(A_shape, dtypeAB), - B: T.Tensor(B_shape, dtypeAB), - C: T.Tensor((M, N), dtypeC), + A: T.Tensor(A_shape, dtypeAB), + B: T.Tensor(B_shape, dtypeAB), + C: T.Tensor((M, N), dtypeC), ): with T.Kernel(streamk_programs, threads=threads) as pid: - A_shared = T.alloc_shared(A_shared_shape, dtypeAB) B_shared = T.alloc_shared(B_shared_shape, dtypeAB) A_shared_full_tiles = T.alloc_shared(A_shared_shape, dtypeAB) diff --git a/examples/gemv/example_gemv.py b/examples/gemv/example_gemv.py index 3772dc6bd5841ac3c455206c1d48f5ce4a74bd5f..00cbac06704f3e0a6ce72cd7f19294f30357d70e 100644 --- a/examples/gemv/example_gemv.py +++ b/examples/gemv/example_gemv.py @@ -20,12 +20,11 @@ def naive_gemv( dtype: str = "float16", accum_dtype: str = "float", ): - @T.prim_func def main( - A: T.Tensor((K,), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, BLOCK_N)) as bn: tn = T.get_thread_binding(0) # tn = threadIdx.x @@ -38,8 +37,7 @@ def naive_gemv( A_shared[tk] = A[bk * BLOCK_K + tk] B_shared[tn, tk] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk] for tk in T.serial(BLOCK_K): - C_reg[0] += A_shared[tk].astype(accum_dtype) * B_shared[tn, - tk].astype(accum_dtype) + C_reg[0] += A_shared[tk].astype(accum_dtype) * B_shared[tn, tk].astype(accum_dtype) C[bn * BLOCK_N + tn] = C_reg[0] return main @@ -54,12 +52,11 @@ def naive_splitk_gemv( dtype: str = "float16", accum_dtype: str = "float", ): - @T.prim_func def main( - A: T.Tensor((K,), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, BLOCK_K)) as bn: tn = T.get_thread_binding(0) @@ -95,9 +92,9 @@ def splitk_gemv( @T.prim_func def main( - A: T.Tensor((K,), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: tn = T.get_thread_binding(0) @@ -136,9 +133,9 @@ def splitk_gemv_vectorized( @T.prim_func def main( - A: T.Tensor((K,), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: tn = T.get_thread_binding(0) @@ -177,9 +174,9 @@ def splitk_gemv_vectorized_tvm( @T.prim_func def main( - A: T.Tensor((K,), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: tn = T.get_thread_binding(0) @@ -197,9 +194,9 @@ def splitk_gemv_vectorized_tvm( C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) C_reduced = T.alloc_local((1,), accum_dtype) with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), - "reduce_scope", - T.reinterpret(T.uint64(0), dtype="handle"), + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), ): T.evaluate( T.tvm_thread_allreduce( @@ -209,7 +206,8 @@ def splitk_gemv_vectorized_tvm( C_reduced[0], tk, dtype="handle", - )) + ) + ) C[bn * BLOCK_N + tn] = C_reduced[0] @@ -218,10 +216,8 @@ def splitk_gemv_vectorized_tvm( def get_block_template_configs(): iter_params = dict( - block_M=[2, 4, 8, 32, 64, 128], - block_N=[2, 4, 8, 32, 64, 128], - num_stages=[0, 1, 2, 3, 4], - threads=[32, 64, 128, 256]) + block_M=[2, 4, 8, 32, 64, 128], block_N=[2, 4, 8, 32, 64, 128], num_stages=[0, 1, 2, 3, 4], threads=[32, 64, 128, 256] + ) return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] @@ -237,18 +233,9 @@ def get_block_template_configs(): }, out_idx=[2], ) -def gemv_alloc_reducer(M, - N, - block_M=128, - block_N=128, - num_stages=2, - threads=256, - dtype: str = "float16", - accum_dtype: str = "float"): - +def gemv_alloc_reducer(M, N, block_M=128, block_N=128, num_stages=2, threads=256, dtype: str = "float16", accum_dtype: str = "float"): @T.prim_func - def main(a: T.Tensor((M, N), dtype), x: T.Tensor(N, dtype), o: T.Tensor(M, - dtype)): # type: ignore + def main(a: T.Tensor((M, N), dtype), x: T.Tensor(N, dtype), o: T.Tensor(M, dtype)): # type: ignore with T.Kernel(T.ceildiv(M, block_M), threads=threads) as i0_m: o_reducer = T.alloc_reducer(block_M, accum_dtype, replication="all") T.clear(o_reducer) @@ -295,9 +282,9 @@ def get_autotuned_kernel( @T.prim_func def main( - A: T.Tensor((K,), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: tn = T.get_thread_binding(0) @@ -315,9 +302,9 @@ def get_autotuned_kernel( C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) C_reduced = T.alloc_local((1,), accum_dtype) with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), - "reduce_scope", - T.reinterpret(T.uint64(0), dtype="handle"), + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), ): T.evaluate( T.tvm_thread_allreduce( @@ -327,7 +314,8 @@ def get_autotuned_kernel( C_reduced[0], tk, dtype="handle", - )) + ) + ) C[bn * BLOCK_N + tn] = C_reduced[0] @@ -355,8 +343,7 @@ def main(do_bench: bool = True): check_correctness_and_bench(splitk_gemv(N, K, 32, 32, 32), N, K, do_bench=do_bench) check_correctness_and_bench(splitk_gemv_vectorized(N, K, 2, 32), N, K, do_bench=do_bench) check_correctness_and_bench(splitk_gemv_vectorized_tvm(N, K, 2, 32), N, K, do_bench=do_bench) - check_correctness_and_bench( - gemv_alloc_reducer(N, K, block_M=128, block_N=128), N, K, do_bench=do_bench) + check_correctness_and_bench(gemv_alloc_reducer(N, K, block_M=128, block_N=128), N, K, do_bench=do_bench) print("Test passed!") diff --git a/examples/grouped_gemm/example_grouped_gemm_bwd.py b/examples/grouped_gemm/example_grouped_gemm_bwd.py index ac8da7e2c34eca885669417fad43da64289886e0..b1af5360cf8291ae2f51b4b53237ff2e7ee48168 100644 --- a/examples/grouped_gemm/example_grouped_gemm_bwd.py +++ b/examples/grouped_gemm/example_grouped_gemm_bwd.py @@ -5,21 +5,8 @@ import tilelang import tilelang.language as T -@tilelang.jit( - out_idx=[2], pass_configs={ - "tl.disable_tma_lower": True, - "tl.disable_warp_specialized": True - }) -def grouped_gemm_fwd(batch_sum, - batch_count, - K, - N, - block_M, - block_N, - block_K, - num_stages=2, - threads=128, - dtype="float16"): +@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +def grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype="float16"): """ args: a (torch.Tensor): Input tensor of shape (M, K). @@ -29,17 +16,14 @@ def grouped_gemm_fwd(batch_sum, @T.prim_func def kernel( - A: T.Tensor([batch_sum, K], dtype), # type: ignore - B: T.Tensor([batch_count, K, N], dtype), # type: ignore - C: T.Tensor([batch_sum, N], dtype), # type: ignore - batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore - batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore - batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore + A: T.Tensor([batch_sum, K], dtype), # type: ignore + B: T.Tensor([batch_count, K, N], dtype), # type: ignore + C: T.Tensor([batch_sum, N], dtype), # type: ignore + batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore + batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore + batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore ): - - with T.Kernel( - T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N), - threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N), threads=threads) as (bx, by): A_shared = T.alloc_shared([block_M, block_K], dtype) B_shared = T.alloc_shared([block_K, block_N], dtype) C_local = T.alloc_fragment([block_M, block_N], accum_dtype) @@ -49,23 +33,17 @@ def grouped_gemm_fwd(batch_sum, m_start_padded = bx * block_M for i in range(batch_count): - in_cur_batch_idx = (m_start_padded >= batch_padded_offsets[i]) + in_cur_batch_idx = m_start_padded >= batch_padded_offsets[i] cur_batch_idx[0] = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx[0]) cur_batch_size[0] = batch_sizes[cur_batch_idx[0]] - m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[ - cur_batch_idx[0]] - actual_rows = T.max( - 0, - T.min(block_M, - cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded)) + m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[cur_batch_idx[0]] + actual_rows = T.max(0, T.min(block_M, cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded)) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[m_start:m_start + block_M, k * block_K:(k + 1) * block_K], A_shared) - T.copy( - B[cur_batch_idx[0], k * block_K:(k + 1) * block_K, - by * block_N:(by + 1) * block_N], B_shared) + T.copy(A[m_start : m_start + block_M, k * block_K : (k + 1) * block_K], A_shared) + T.copy(B[cur_batch_idx[0], k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared) T.gemm(A_shared, B_shared, C_local) for i, j in T.Parallel(block_M, block_N): @@ -76,7 +54,6 @@ def grouped_gemm_fwd(batch_sum, class _GroupedGEMM(torch.autograd.Function): - @staticmethod def forward(ctx, a, b, batch_sizes): block_M = 64 @@ -99,15 +76,11 @@ class _GroupedGEMM(torch.autograd.Function): for i in range(batch_count - 1): batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes[i]) for i in range(batch_count - 1): - batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + - math.ceil((batch_sizes[i] + 1) / padding_M) * - padding_M) + batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes[i] + 1) / padding_M) * padding_M) batch_offsets = torch.tensor(batch_offsets_list, device=a.device, dtype=torch.int32) - batch_padded_offsets = torch.tensor( - batch_padded_offsets_list, device=a.device, dtype=torch.int32) + batch_padded_offsets = torch.tensor(batch_padded_offsets_list, device=a.device, dtype=torch.int32) - kernel = grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, - num_stages, threads) + kernel = grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages, threads) o = kernel(a, b, batch_sizes, batch_offsets, batch_padded_offsets) ctx.save_for_backward(a, b, batch_sizes, batch_offsets) @@ -135,8 +108,7 @@ class _GroupedGEMM(torch.autograd.Function): return x A, B, batch_sizes = [maybe_contiguous(x) for x in (A, B, batch_sizes)] - kernel = grouped_gemm_bwd(ctx.batch_sum, ctx.batch_count, M, N, block_M, block_N, block_K, - num_stages, threads) + kernel = grouped_gemm_bwd(ctx.batch_sum, ctx.batch_count, M, N, block_M, block_N, block_K, num_stages, threads) dB = kernel(A, grad_output, batch_sizes, batch_offsets) return None, dB, None @@ -172,9 +144,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): for i in range(batch_count - 1): batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i]) for i in range(batch_count - 1): - batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + - math.ceil((batch_sizes_list[i] + 1) / padding_M) * - padding_M) + batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes_list[i] + 1) / padding_M) * padding_M) A = torch.randn(batch_sum, K, device=device, dtype=dtype) B = torch.randn(batch_count, K, M, device=device, dtype=dtype) C = torch.empty(batch_sum, M, device=device, dtype=dtype) @@ -187,21 +157,8 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets -@tilelang.jit( - out_idx=[2], pass_configs={ - "tl.disable_tma_lower": True, - "tl.disable_warp_specialized": True - }) -def grouped_gemm_bwd(batch_sum, - batch_count, - M, - N, - block_M, - block_N, - block_K, - num_stages=2, - threads=128, - dtype="float16"): +@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +def grouped_gemm_bwd(batch_sum, batch_count, M, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype="float16"): """ args: a (torch.Tensor): Input tensor of shape (M, K). @@ -211,16 +168,13 @@ def grouped_gemm_bwd(batch_sum, @T.prim_func def kernel( - A: T.Tensor([batch_sum, M], dtype), # type: ignore - B: T.Tensor([batch_sum, N], dtype), # type: ignore - C: T.Tensor([batch_count, M, N], dtype), # type: ignore - batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore - batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore + A: T.Tensor([batch_sum, M], dtype), # type: ignore + B: T.Tensor([batch_sum, N], dtype), # type: ignore + C: T.Tensor([batch_count, M, N], dtype), # type: ignore + batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore + batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore ): - - with T.Kernel( - T.ceildiv(M, block_M), T.ceildiv(N, block_N), batch_count, - threads=threads) as (bx, by, bz): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), batch_count, threads=threads) as (bx, by, bz): A_shared = T.alloc_shared([block_K, block_M], dtype) B_shared = T.alloc_shared([block_K, block_N], dtype) C_local = T.alloc_fragment([block_M, block_N], accum_dtype) @@ -228,13 +182,9 @@ def grouped_gemm_bwd(batch_sum, T.clear(C_local) for k in T.Pipelined(T.ceildiv(batch_sizes[bz], block_K), num_stages=num_stages): for i, j in T.Parallel(block_K, block_M): - A_shared[i, j] = T.if_then_else( - i < batch_sizes[bz], A[batch_offsets[bz] + k * block_K + i, - bx * block_M + j], 0) + A_shared[i, j] = T.if_then_else(i < batch_sizes[bz], A[batch_offsets[bz] + k * block_K + i, bx * block_M + j], 0) for i, j in T.Parallel(block_K, block_N): - B_shared[i, j] = T.if_then_else( - i < batch_sizes[bz], B[batch_offsets[bz] + k * block_K + i, - by * block_N + j], 0) + B_shared[i, j] = T.if_then_else(i < batch_sizes[bz], B[batch_offsets[bz] + k * block_K + i, by * block_N + j], 0) T.gemm(A_shared, B_shared, C_local, transpose_A=True) T.copy(C_local, C[bz, bx * block_M, by * block_N]) @@ -242,23 +192,12 @@ def grouped_gemm_bwd(batch_sum, return kernel -def run_tilelang_grouped_gemm(batch_sizes_list, - K, - M, - block_M, - block_N, - block_K, - trans_b, - num_stages=2, - threads=128, - profile=False): - +def run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages=2, threads=128, profile=False): padding_M = block_M device = torch.device("cuda") dtype = torch.float16 - A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs( - batch_sizes_list, K, M, False, padding_M, device, dtype) + A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs(batch_sizes_list, K, M, False, padding_M, device, dtype) A.requires_grad_(False) B.requires_grad_(True) @@ -273,10 +212,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list, O.backward(dO, retain_graph=True) dB, B.grad = B.grad.clone(), None - if ( - torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) and \ - torch.allclose(dB, dB_ref, rtol=1e-2, atol=1e-2) - ): + if torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) and torch.allclose(dB, dB_ref, rtol=1e-2, atol=1e-2): print("✅ Tilelang and Torch match") else: print("❌ Tilelang and Torch mismatch") @@ -284,12 +220,11 @@ def run_tilelang_grouped_gemm(batch_sizes_list, if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - '--batch_sizes', type=str, default="64, 128", help='comma-separated batch sizes') - parser.add_argument('--K', type=int, default=8192, help='reduce dim') - parser.add_argument('--M', type=int, default=8192, help='output dim') - parser.add_argument('--trans_b', action="store_true", help="transpose B") - parser.add_argument('--profile', action="store_true", help="profile") + parser.add_argument("--batch_sizes", type=str, default="64, 128", help="comma-separated batch sizes") + parser.add_argument("--K", type=int, default=8192, help="reduce dim") + parser.add_argument("--M", type=int, default=8192, help="output dim") + parser.add_argument("--trans_b", action="store_true", help="transpose B") + parser.add_argument("--profile", action="store_true", help="profile") args = parser.parse_args() batch_sizes_list = [int(x) for x in args.batch_sizes.split(",")] @@ -301,14 +236,4 @@ if __name__ == "__main__": num_stages = 2 threads = 256 - run_tilelang_grouped_gemm( - batch_sizes_list, - K, - M, - block_M, - block_N, - block_K, - trans_b, - num_stages, - threads, - profile=args.profile) + run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages, threads, profile=args.profile) diff --git a/examples/grouped_gemm/example_grouped_gemm_fwd.py b/examples/grouped_gemm/example_grouped_gemm_fwd.py index 9b58e3a21c6df034b3d5f91e8f7ec66b5afa4548..8f7710512dbf7ec1c44bdbafc69467a4f3604d4e 100644 --- a/examples/grouped_gemm/example_grouped_gemm_fwd.py +++ b/examples/grouped_gemm/example_grouped_gemm_fwd.py @@ -18,8 +18,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False): torch.Tensor: Resulting tensor after grouped matrix multiplication. """ assert a.shape[0] == sum(batch_sizes), "Sum of batch_sizes must equal the first dimension of a" - assert b.shape[0] == len( - batch_sizes), "The first dimension of b must match the length of batch_sizes" + assert b.shape[0] == len(batch_sizes), "The first dimension of b must match the length of batch_sizes" # Initialize output tensor output = torch.empty((sum(batch_sizes), b.shape[2]), device=a.device, dtype=a.dtype) @@ -38,15 +37,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False): @tilelang.jit(out_idx=[2]) -def grouped_gemm(batch_sizes_list, - K, - N, - block_M, - block_N, - block_K, - num_stages=2, - threads=128, - dtype="float16"): +def grouped_gemm(batch_sizes_list, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype="float16"): """ args: a (torch.Tensor): Input tensor of shape (M, K). @@ -59,14 +50,13 @@ def grouped_gemm(batch_sizes_list, @T.prim_func def kernel( - A: T.Tensor([batch_sum, K], dtype), # type: ignore - B: T.Tensor([batch_count, K, N], dtype), # type: ignore - C: T.Tensor([batch_sum, N], dtype), # type: ignore - batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore - batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore - batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore + A: T.Tensor([batch_sum, K], dtype), # type: ignore + B: T.Tensor([batch_count, K, N], dtype), # type: ignore + C: T.Tensor([batch_sum, N], dtype), # type: ignore + batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore + batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore + batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore ): - with T.Kernel(total_m_blocks, T.ceildiv(N, block_N), threads=threads) as (bx, by): A_shared = T.alloc_shared([block_M, block_K], dtype) B_shared = T.alloc_shared([block_K, block_N], dtype) @@ -77,23 +67,17 @@ def grouped_gemm(batch_sizes_list, m_start_padded = bx * block_M for i in range(batch_count): - in_cur_batch_idx = (m_start_padded >= batch_padded_offsets[i]) + in_cur_batch_idx = m_start_padded >= batch_padded_offsets[i] cur_batch_idx[0] = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx[0]) cur_batch_size[0] = batch_sizes[cur_batch_idx[0]] - m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[ - cur_batch_idx[0]] - actual_rows = T.max( - 0, - T.min(block_M, - cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded)) + m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[cur_batch_idx[0]] + actual_rows = T.max(0, T.min(block_M, cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded)) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[m_start:m_start + block_M, k * block_K:(k + 1) * block_K], A_shared) - T.copy( - B[cur_batch_idx[0], k * block_K:(k + 1) * block_K, - by * block_N:(by + 1) * block_N], B_shared) + T.copy(A[m_start : m_start + block_M, k * block_K : (k + 1) * block_K], A_shared) + T.copy(B[cur_batch_idx[0], k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared) T.gemm(A_shared, B_shared, C_local) for i, j in T.Parallel(block_M, block_N): @@ -111,8 +95,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): for i in range(batch_count - 1): batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i]) for i in range(batch_count - 1): - batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + - math.ceil((batch_sizes_list[i]) / padding_M) * padding_M) + batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes_list[i]) / padding_M) * padding_M) A = torch.randn(batch_sum, K, device=device, dtype=dtype) B = torch.randn(batch_count, K, M, device=device, dtype=dtype) C = torch.empty(batch_sum, M, device=device, dtype=dtype) @@ -125,27 +108,16 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets -def run_tilelang_grouped_gemm(batch_sizes_list, - K, - M, - block_M, - block_N, - block_K, - trans_b, - num_stages=2, - threads=128, - profile=False): +def run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages=2, threads=128, profile=False): padding_M = block_M batch_sum = sum(batch_sizes_list) - kernel = grouped_gemm( - tuple(batch_sizes_list), K, M, block_M, block_N, block_K, num_stages, threads) + kernel = grouped_gemm(tuple(batch_sizes_list), K, M, block_M, block_N, block_K, num_stages, threads) # print(kernel.get_kernel_source()) device = torch.device("cuda") dtype = torch.float16 - A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs( - batch_sizes_list, K, M, trans_b, padding_M, device, dtype) + A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype) out = kernel(A, B, batch_sizes, batch_offsets, batch_padded_offsets) ref_output = torch_gmm(A, B, batch_sizes, batch_offsets, trans_b) # print(out) @@ -157,8 +129,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list, if profile: profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) - latency = profiler.do_bench( - warmup=500, input_tensors=[A, B, batch_sizes, batch_offsets, batch_padded_offsets]) + latency = profiler.do_bench(warmup=500, input_tensors=[A, B, batch_sizes, batch_offsets, batch_padded_offsets]) print(f"Latency: {latency} ms") print(f"TFlops: {batch_sum * K * M * 2 / latency * 1e-9} TFlops") @@ -173,12 +144,11 @@ def test_grouped_gemm(): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - '--batch_sizes', type=str, default="64, 128", help='comma-separated batch sizes') - parser.add_argument('--K', type=int, default=8192, help='reduce dim') - parser.add_argument('--M', type=int, default=8192, help='output dim') - parser.add_argument('--trans_b', action="store_true", help="transpose B") - parser.add_argument('--profile', action="store_true", help="profile") + parser.add_argument("--batch_sizes", type=str, default="64, 128", help="comma-separated batch sizes") + parser.add_argument("--K", type=int, default=8192, help="reduce dim") + parser.add_argument("--M", type=int, default=8192, help="output dim") + parser.add_argument("--trans_b", action="store_true", help="transpose B") + parser.add_argument("--profile", action="store_true", help="profile") args = parser.parse_args() batch_sizes_list = [int(x) for x in args.batch_sizes.split(",")] @@ -190,14 +160,4 @@ if __name__ == "__main__": num_stages = 2 threads = 256 - run_tilelang_grouped_gemm( - batch_sizes_list, - K, - M, - block_M, - block_N, - block_K, - trans_b, - num_stages, - threads, - profile=args.profile) + run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages, threads, profile=args.profile) diff --git a/examples/hadamard_transform/example_hadamard.py b/examples/hadamard_transform/example_hadamard.py index 531d4689183e308b2694fe0e8b0028202799b276..64eb9bbdb54484303dda2e191bad33c23c53c828 100644 --- a/examples/hadamard_transform/example_hadamard.py +++ b/examples/hadamard_transform/example_hadamard.py @@ -17,7 +17,7 @@ def is_pow_of_2(n): def hadamard(b, n, dtype): assert is_pow_of_2(n), "n must be a power of 2" assert 2 <= n <= 32768, "n must be in [2, 32768]" - elem_size = {'float32': 4, 'float16': 2, 'bfloat16': 2}[dtype] + elem_size = {"float32": 4, "float16": 2, "bfloat16": 2}[dtype] logN = int(math.log2(n)) threads = [0, 1, 1, 1, 2, 4, 8, 16, 32, 32, 128, 256, 256, 256, 256, 256][logN] @@ -40,23 +40,21 @@ def hadamard(b, n, dtype): # print(f'{exchange_round=}') @T.macro - def warp_shfl(local: T.Tensor((thread_elem,), dtype), buf: T.Tensor((thread_elem,), dtype), - round: int): + def warp_shfl(local: T.Tensor((thread_elem,), dtype), buf: T.Tensor((thread_elem,), dtype), round: int): tx = T.get_thread_binding(0) for i in T.serial(round): tx_stride = 1 << i another_tx = tx ^ tx_stride - sign = ( - tx >> i - ) & 1 # get i-th lowest bit of tx, which determines the operation type for shared[tx, :] + sign = (tx >> i) & 1 # get i-th lowest bit of tx, which determines the operation type for shared[tx, :] for j in T.Pipelined(thread_elem, num_stages=1): buf[j] = T.tvm_warp_shuffle( - 0xffffffff, # mask of all threads + 0xFFFFFFFF, # mask of all threads local[j], another_tx % warp_size, warp_size, - warp_size) + warp_size, + ) local[j] = T.if_then_else(sign == 0, local[j] + buf[j], buf[j] - local[j]) @T.prim_func @@ -78,10 +76,8 @@ def hadamard(b, n, dtype): for j in T.serial(chunknum): chunkbase = j * chunksize for k in T.serial(chunksize // 2): - local[chunkbase + - k] = local[chunkbase + k] + local[chunkbase + k + chunksize // 2] - local[chunkbase + k + chunksize // - 2] = local[chunkbase + k] - 2 * local[chunkbase + k + chunksize // 2] + local[chunkbase + k] = local[chunkbase + k] + local[chunkbase + k + chunksize // 2] + local[chunkbase + k + chunksize // 2] = local[chunkbase + k] - 2 * local[chunkbase + k + chunksize // 2] # 3. Hadamard inside warp, n<=512 # In warp level, we rely on warp shuffle to exchange data inside each warp, without using shared memory @@ -131,28 +127,27 @@ def ref_program(x: torch.Tensor): assert x.ndim == 2 dim = x.shape[-1] assert is_pow_of_2(dim) - return F.linear( - x, torch.tensor(scipy.linalg.hadamard(dim, dtype=float), dtype=x.dtype, device=x.device)) + return F.linear(x, torch.tensor(scipy.linalg.hadamard(dim, dtype=float), dtype=x.dtype, device=x.device)) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=64, help='Batch size') - parser.add_argument('--dim', type=int, default=32768, help='Dimension') + parser.add_argument("--batch", type=int, default=64, help="Batch size") + parser.add_argument("--dim", type=int, default=32768, help="Dimension") args = parser.parse_args() B, D = args.batch, args.dim - x = torch.randn((B, D), device='cuda') - kernel = hadamard(B, D, 'float32') + x = torch.randn((B, D), device="cuda") + kernel = hadamard(B, D, "float32") y = kernel(x) y_ref = ref_program(x) torch.testing.assert_close(y, y_ref, atol=1e-2, rtol=1e-2) - print('All tests passed.') + print("All tests passed.") profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) latency = profiler.do_bench(warmup=100) print("Tile-lang: {:.2f} ms".format(latency)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/lazy_jit/lazyjit.en.ipynb b/examples/lazy_jit/lazyjit.en.ipynb index acb318c1438a4266347698ca86961d20c9c3306c..196ddfc4668acd3498eedda4b303e0b349f11145 100644 --- a/examples/lazy_jit/lazyjit.en.ipynb +++ b/examples/lazy_jit/lazyjit.en.ipynb @@ -9,6 +9,7 @@ "source": [ "import sys\n", "from pathlib import Path\n", + "\n", "sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n", "import tilelang\n", "import torch\n", @@ -61,7 +62,7 @@ " out_dtype: T.dtype = T.float32,\n", " block_M: int = 128,\n", " block_N: int = 128,\n", - " block_K: int = 32\n", + " block_K: int = 32,\n", "):\n", " M, K = A.shape\n", " K, N = B.shape\n", @@ -94,8 +95,8 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", - "B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n", + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", "C = gemm(A, B)\n", "\n", "# check output is correct\n", @@ -118,8 +119,8 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", - "B = torch.randn(512, 1024, dtype=torch.float16, device='cuda')\n", + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 1024, dtype=torch.float16, device=\"cuda\")\n", "C = gemm(A, B, block_M=64, block_N=64)" ] }, @@ -218,8 +219,8 @@ "source": [ "@tilelang.lazy_jit\n", "def gemm_dyn_K(\n", - " A: T.Tensor[[int, T.dyn['K']], T.float16], # noqa: F821\n", - " B: T.Tensor[[T.dyn['K'], int], T.float16], # noqa: F821\n", + " A: T.Tensor[[int, T.dyn[\"K\"]], T.float16], # noqa: F821\n", + " B: T.Tensor[[T.dyn[\"K\"], int], T.float16], # noqa: F821\n", "):\n", " M, K = A.shape\n", " K, N = B.shape\n", @@ -265,8 +266,8 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", - "B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n", + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", "C = gemm_dyn_K(A, B)\n", "C_ref = (A @ B).float()\n", "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" @@ -295,18 +296,17 @@ "source": [ "from typing import Any\n", "\n", + "\n", "@tilelang.lazy_jit\n", - "def as_contingious(\n", - " A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]\n", - "):\n", + "def as_contingious(A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]):\n", " M, N = A.shape\n", " B = T.empty((M, N), A.dtype)\n", " block_M = 128\n", " block_N = 128\n", " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", " T.copy(\n", - " A[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N],\n", - " B[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N]\n", + " A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", + " B[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", " )\n", " return B" ] @@ -318,7 +318,7 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, 1024, device='cuda')\n", + "A = torch.randn(1024, 1024, device=\"cuda\")\n", "B = as_contingious(A[::2, ::2])\n", "B_ref = A[::2, ::2].contiguous()\n", "torch.testing.assert_close(B, B_ref)" @@ -370,8 +370,8 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", - "B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n", + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", "C = gemm_ptr(A, B, 1024, 256, 512)\n", "C_ref = (A @ B).float()\n", "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" @@ -416,8 +416,8 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", - "B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n", + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", "C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n", "C_ref = (A @ B).float()\n", "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" @@ -496,18 +496,20 @@ "source": [ "from itertools import product\n", "\n", + "\n", "def get_configs():\n", " return [\n", " {\n", - " 'A': T.Tensor((1024, 1024), T.float32),\n", - " 'B': T.Tensor((1024, 1024), T.float32),\n", - " 'block_M': block_M,\n", - " 'block_N': block_N,\n", - " 'block_K': block_K,\n", + " \"A\": T.Tensor((1024, 1024), T.float32),\n", + " \"B\": T.Tensor((1024, 1024), T.float32),\n", + " \"block_M\": block_M,\n", + " \"block_N\": block_N,\n", + " \"block_K\": block_K,\n", " }\n", " for block_M, block_N, block_K in product([32, 64], repeat=3)\n", " ]\n", "\n", + "\n", "gemm.par_compile(get_configs())" ] }, @@ -579,7 +581,8 @@ "source": [ "@T.macro\n", "def macro_with_ref(x: T.Ref):\n", - " x = 1 # noqa: F841\n", + " x = 1 # noqa: F841\n", + "\n", "\n", "@T.prim_func\n", "def foo(x: T.Tensor((2,))):\n", @@ -591,6 +594,7 @@ " idx = T.alloc_var(T.int32, 0)\n", " macro_with_ref(x[idx])\n", "\n", + "\n", "foo" ] }, @@ -616,7 +620,7 @@ " A: T.Tensor[[T.dyn], Any],\n", " fn,\n", "):\n", - " N, = A.shape\n", + " (N,) = A.shape\n", " B = T.empty((N,), dtype=A.dtype)\n", " block_N = 128\n", " with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n", @@ -624,6 +628,8 @@ " idx = bx * block_N + i\n", " B[idx] = fn(A[idx])\n", " return B\n", + "\n", + "\n", "@T.macro\n", "def add_one(x):\n", " return x + 1" @@ -636,7 +642,7 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, device='cuda')\n", + "A = torch.randn(1024, device=\"cuda\")\n", "B = element_wise(A, add_one)\n", "B_ref = A + 1\n", "torch.testing.assert_close(B, B_ref)" @@ -670,10 +676,11 @@ " var = var * 3 + 1\n", " n31(x * 3 + 1, var)\n", "\n", + "\n", "@tilelang.lazy_jit\n", "def foo(A: T.Tensor[[1], T.int32], n: int):\n", " with T.Kernel(1) as _:\n", - " n31(n, A[0])\n" + " n31(n, A[0])" ] }, { @@ -694,7 +701,7 @@ } ], "source": [ - "A = torch.tensor([100], dtype=torch.int32, device='cuda')\n", + "A = torch.tensor([100], dtype=torch.int32, device=\"cuda\")\n", "foo(A, 5)\n", "A" ] @@ -745,12 +752,15 @@ "def sincos(x):\n", " return T.sin(x), T.cos(x)\n", "\n", + "\n", "@T.prim_func\n", "def foo():\n", " with T.Kernel(32) as x:\n", " s, c = sincos(x)\n", - " a = s + c # noqa: F841\n", - " b = s - c # noqa: F841\n", + " a = s + c # noqa: F841\n", + " b = s - c # noqa: F841\n", + "\n", + "\n", "foo" ] } diff --git a/examples/lazy_jit/lazyjit.zh.ipynb b/examples/lazy_jit/lazyjit.zh.ipynb index fb9b71b72885daaf26831325f110aea8c47fc1fc..d6db4c76e01b842ad562f19537ca6d3f9b19bdb3 100644 --- a/examples/lazy_jit/lazyjit.zh.ipynb +++ b/examples/lazy_jit/lazyjit.zh.ipynb @@ -9,6 +9,7 @@ "source": [ "import sys\n", "from pathlib import Path\n", + "\n", "sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n", "import tilelang\n", "import torch\n", @@ -61,7 +62,7 @@ " out_dtype: T.dtype = T.float32,\n", " block_M: int = 128,\n", " block_N: int = 128,\n", - " block_K: int = 32\n", + " block_K: int = 32,\n", "):\n", " M, K = A.shape\n", " K, N = B.shape\n", @@ -94,8 +95,8 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", - "B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n", + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", "C = gemm(A, B)\n", "\n", "# check output is correct\n", @@ -118,8 +119,8 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", - "B = torch.randn(512, 1024, dtype=torch.float16, device='cuda')\n", + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 1024, dtype=torch.float16, device=\"cuda\")\n", "C = gemm(A, B, block_M=64, block_N=64)" ] }, @@ -218,8 +219,8 @@ "source": [ "@tilelang.lazy_jit\n", "def gemm_dyn_K(\n", - " A: T.Tensor[[int, T.dyn['K']], T.float16], # noqa: F821\n", - " B: T.Tensor[[T.dyn['K'], int], T.float16], # noqa: F821\n", + " A: T.Tensor[[int, T.dyn[\"K\"]], T.float16], # noqa: F821\n", + " B: T.Tensor[[T.dyn[\"K\"], int], T.float16], # noqa: F821\n", "):\n", " M, K = A.shape\n", " K, N = B.shape\n", @@ -265,8 +266,8 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", - "B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n", + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", "C = gemm_dyn_K(A, B)\n", "C_ref = (A @ B).float()\n", "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" @@ -295,18 +296,17 @@ "source": [ "from typing import Any\n", "\n", + "\n", "@tilelang.lazy_jit\n", - "def as_contingious(\n", - " A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]\n", - "):\n", + "def as_contingious(A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]):\n", " M, N = A.shape\n", " B = T.empty((M, N), A.dtype)\n", " block_M = 128\n", " block_N = 128\n", " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", " T.copy(\n", - " A[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N],\n", - " B[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N]\n", + " A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", + " B[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", " )\n", " return B" ] @@ -318,7 +318,7 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, 1024, device='cuda')\n", + "A = torch.randn(1024, 1024, device=\"cuda\")\n", "B = as_contingious(A[::2, ::2])\n", "B_ref = A[::2, ::2].contiguous()\n", "torch.testing.assert_close(B, B_ref)" @@ -370,8 +370,8 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", - "B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n", + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", "C = gemm_ptr(A, B, 1024, 256, 512)\n", "C_ref = (A @ B).float()\n", "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" @@ -416,8 +416,8 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n", - "B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n", + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", "C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n", "C_ref = (A @ B).float()\n", "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" @@ -496,18 +496,20 @@ "source": [ "from itertools import product\n", "\n", + "\n", "def get_configs():\n", " return [\n", " {\n", - " 'A': T.Tensor((1024, 1024), T.float32),\n", - " 'B': T.Tensor((1024, 1024), T.float32),\n", - " 'block_M': block_M,\n", - " 'block_N': block_N,\n", - " 'block_K': block_K,\n", + " \"A\": T.Tensor((1024, 1024), T.float32),\n", + " \"B\": T.Tensor((1024, 1024), T.float32),\n", + " \"block_M\": block_M,\n", + " \"block_N\": block_N,\n", + " \"block_K\": block_K,\n", " }\n", " for block_M, block_N, block_K in product([32, 64], repeat=3)\n", " ]\n", "\n", + "\n", "gemm.par_compile(get_configs())" ] }, @@ -579,7 +581,8 @@ "source": [ "@T.macro\n", "def macro_with_ref(x: T.Ref):\n", - " x = 1 # noqa: F841\n", + " x = 1 # noqa: F841\n", + "\n", "\n", "@T.prim_func\n", "def foo(x: T.Tensor((2,))):\n", @@ -591,6 +594,7 @@ " idx = T.alloc_var(T.int32, 0)\n", " macro_with_ref(x[idx])\n", "\n", + "\n", "foo" ] }, @@ -616,7 +620,7 @@ " A: T.Tensor[[T.dyn], Any],\n", " fn,\n", "):\n", - " N, = A.shape\n", + " (N,) = A.shape\n", " B = T.empty((N,), dtype=A.dtype)\n", " block_N = 128\n", " with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n", @@ -624,6 +628,8 @@ " idx = bx * block_N + i\n", " B[idx] = fn(A[idx])\n", " return B\n", + "\n", + "\n", "@T.macro\n", "def add_one(x):\n", " return x + 1" @@ -636,7 +642,7 @@ "metadata": {}, "outputs": [], "source": [ - "A = torch.randn(1024, device='cuda')\n", + "A = torch.randn(1024, device=\"cuda\")\n", "B = element_wise(A, add_one)\n", "B_ref = A + 1\n", "torch.testing.assert_close(B, B_ref)" @@ -670,10 +676,11 @@ " var = var * 3 + 1\n", " n31(x * 3 + 1, var)\n", "\n", + "\n", "@tilelang.lazy_jit\n", "def foo(A: T.Tensor[[1], T.int32], n: int):\n", " with T.Kernel(1) as _:\n", - " n31(n, A[0])\n" + " n31(n, A[0])" ] }, { @@ -694,7 +701,7 @@ } ], "source": [ - "A = torch.tensor([100], dtype=torch.int32, device='cuda')\n", + "A = torch.tensor([100], dtype=torch.int32, device=\"cuda\")\n", "foo(A, 5)\n", "A" ] @@ -745,12 +752,15 @@ "def sincos(x):\n", " return T.sin(x), T.cos(x)\n", "\n", + "\n", "@T.prim_func\n", "def foo():\n", " with T.Kernel(32) as x:\n", " s, c = sincos(x)\n", - " a = s + c # noqa: F841\n", - " b = s - c # noqa: F841\n", + " a = s + c # noqa: F841\n", + " b = s - c # noqa: F841\n", + "\n", + "\n", "foo" ] } diff --git a/examples/linear_attention/example_linear_attn_bwd.py b/examples/linear_attention/example_linear_attn_bwd.py index 568bcc55f0cb9cad82e3ddf88f93784294185d8e..7cbfc465aebc3d736e277a4b3446c56bc04ed9be 100644 --- a/examples/linear_attention/example_linear_attn_bwd.py +++ b/examples/linear_attention/example_linear_attn_bwd.py @@ -13,20 +13,20 @@ from typing import Optional, Tuple pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + } +) def tl_fused_chunk_bwd_kernel( B, S, H, DK, DV, - dtype: str = 'float16', + dtype: str = "float16", scale: float = None, ) -> torch.Tensor: - if scale is None: scale = DK**-0.5 - accum_dtype = 'float' + accum_dtype = "float" chunk_size = 64 BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA @@ -37,13 +37,13 @@ def tl_fused_chunk_bwd_kernel( @T.prim_func def fused_chunk_linear_attn_bwd( - Q: T.Tensor([B, S, H, DK], dtype), # type: ignore - K: T.Tensor([B, S, H, DK], dtype), # type: ignore - V: T.Tensor([B, S, H, DV], dtype), # type: ignore - dO: T.Tensor([B, S, H, DV], dtype), # type: ignore - dQ: T.Tensor([B, S, H, DK], accum_dtype), # type: ignore - dK: T.Tensor([B, S, H, DK], accum_dtype), # type: ignore - dV: T.Tensor([B, S, H, DV], accum_dtype), # type: ignore + Q: T.Tensor([B, S, H, DK], dtype), # type: ignore + K: T.Tensor([B, S, H, DK], dtype), # type: ignore + V: T.Tensor([B, S, H, DV], dtype), # type: ignore + dO: T.Tensor([B, S, H, DV], dtype), # type: ignore + dQ: T.Tensor([B, S, H, DK], accum_dtype), # type: ignore + dK: T.Tensor([B, S, H, DK], accum_dtype), # type: ignore + dV: T.Tensor([B, S, H, DV], accum_dtype), # type: ignore ): with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): i_b = i_bh // H @@ -66,11 +66,13 @@ def tl_fused_chunk_bwd_kernel( dh = T.alloc_fragment([BK, BV], accum_dtype) dh_shared = T.alloc_shared([BK, BV], dtype) - T.annotate_layout({ - dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), - dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), - dv_shared: tilelang.layout.make_swizzled_layout(dv_shared) - }) + T.annotate_layout( + { + dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + } + ) T.use_swizzle(10) T.clear(h) @@ -78,10 +80,9 @@ def tl_fused_chunk_bwd_kernel( # Calculate dQ for i in T.Pipelined(0, NT): - T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k) - T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v) - T.copy(dO[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], - do) + T.copy(K[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k) + T.copy(V[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v) + T.copy(dO[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], do) T.gemm(do, v, ds, transpose_B=True, clear_accum=True) for row, col in T.Parallel(chunk_size, chunk_size): @@ -94,29 +95,19 @@ def tl_fused_chunk_bwd_kernel( for row, col in T.Parallel(chunk_size, BK): dq[row, col] *= scale T.copy(dq, dq_shared) - T.atomic_add( - dQ[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], - dq_shared) + T.atomic_add(dQ[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], dq_shared) # Calculate dK, dV (reversely) for i in T.Pipelined(1, NT + 1): start = NT - i for row, col in T.Parallel(chunk_size, BK): q[row, col] = Q[i_b, start * chunk_size + row, i_h, i_k * BK + col] * scale - T.copy( - K[i_b, start * chunk_size:(start + 1) * chunk_size, i_h, - i_k * BK:(i_k + 1) * BK], k) - T.copy( - V[i_b, start * chunk_size:(start + 1) * chunk_size, i_h, - i_v * BV:(i_v + 1) * BV], v) - T.copy( - dO[i_b, start * chunk_size:(start + 1) * chunk_size, i_h, - i_v * BV:(i_v + 1) * BV], do) + T.copy(K[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k) + T.copy(V[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v) + T.copy(dO[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], do) # Calculate dk - T.gemm( - v, do, ds, transpose_B=True, clear_accum=True - ) # ds here actually means `s`, but we simply reuse the buffer `ds` + T.gemm(v, do, ds, transpose_B=True, clear_accum=True) # ds here actually means `s`, but we simply reuse the buffer `ds` for row, col in T.Parallel(chunk_size, chunk_size): ds_shared[row, col] = T.if_then_else(row <= col, ds[row, col], 0) T.gemm(ds_shared, q, dk, clear_accum=True) @@ -134,13 +125,9 @@ def tl_fused_chunk_bwd_kernel( T.gemm(q, do, dh, transpose_A=True) T.copy(dk, dk_shared) - T.atomic_add( - dK[i_b, start * chunk_size:(start + 1) * chunk_size, i_h, - i_k * BK:(i_k + 1) * BK], dk_shared) + T.atomic_add(dK[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], dk_shared) T.copy(dv, dv_shared) - T.atomic_add( - dV[i_b, start * chunk_size:(start + 1) * chunk_size, i_h, - i_v * BV:(i_v + 1) * BV], dv_shared) + T.atomic_add(dV[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], dv_shared) return fused_chunk_linear_attn_bwd @@ -155,34 +142,31 @@ def tl_fused_chunk_bwd(Q, K, V, dO): return dQ.to(torch.float16), dK.to(torch.float16), dV.to(torch.float16) -def ref_program(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]: +def ref_program(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]: q, k, v = q.float(), k.float(), v.float() if scale is None: - scale = q.shape[-1]**-0.5 + scale = q.shape[-1] ** -0.5 chunk_size = 64 - q = rearrange(q, 'b (n c) h d -> b h n c d', c=chunk_size) * scale - k = rearrange(k, 'b (n c) h d -> b h n c d', c=chunk_size) - v = rearrange(v, 'b (n c) h d -> b h n c d', c=chunk_size) + q = rearrange(q, "b (n c) h d -> b h n c d", c=chunk_size) * scale + k = rearrange(k, "b (n c) h d -> b h n c d", c=chunk_size) + v = rearrange(v, "b (n c) h d -> b h n c d", c=chunk_size) kv = k.transpose(-1, -2) @ v kv = kv.cumsum(2) h = kv[:, :, -1, :, :] kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) inter = q @ kv - intra = ((q @ k.transpose(-1, -2)).masked_fill_( - torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), - 0)) @ v + intra = ( + (q @ k.transpose(-1, -2)).masked_fill_(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0) + ) @ v o = inter + intra - return rearrange(o, 'b h n c d -> b (n c) h d'), h + return rearrange(o, "b h n c d -> b (n c) h d"), h def main(B=1, S=1024, H=16, D=128): - q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) - k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) - v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) - do = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) + q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True) + k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True) + v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True) + do = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) # qk norm is necessary for linear attn q = l2norm_fwd(q)[0].requires_grad_(True) @@ -193,30 +177,27 @@ def main(B=1, S=1024, H=16, D=128): o_ref, _ = ref_program(q, k, v) o_ref.backward(do, retain_graph=True) - assert torch.allclose( - dq, q.grad, atol=1e-2, rtol=1e-2), f'dq max err: {(dq - q.grad).abs().max()}' - assert torch.allclose( - dk, k.grad, atol=1e-2, rtol=1e-2), f'dk max err: {(dk - k.grad).abs().max()}' - assert torch.allclose( - dv, v.grad, atol=1e-2, rtol=1e-2), f'dv max err: {(dv - v.grad).abs().max()}' - print('Passed all tests!✅') + assert torch.allclose(dq, q.grad, atol=1e-2, rtol=1e-2), f"dq max err: {(dq - q.grad).abs().max()}" + assert torch.allclose(dk, k.grad, atol=1e-2, rtol=1e-2), f"dk max err: {(dk - k.grad).abs().max()}" + assert torch.allclose(dv, v.grad, atol=1e-2, rtol=1e-2), f"dv max err: {(dv - v.grad).abs().max()}" + print("Passed all tests!✅") # Benchmark q.grad = k.grad = v.grad = None o_ref, _ = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) - t1 = do_bench(lambda: o_ref.backward(do, retain_graph=True), backend='cupti') - t2 = do_bench(lambda: tl_fused_chunk_bwd(q, k, v, do), backend='cupti') - print(f'Triton latency: {t1:.3f} ms') - print(f'TileLang latency: {t2:.3f} ms') - print(f'Speedup: {t1/t2:.3f}x') + t1 = do_bench(lambda: o_ref.backward(do, retain_graph=True), backend="cupti") + t2 = do_bench(lambda: tl_fused_chunk_bwd(q, k, v, do), backend="cupti") + print(f"Triton latency: {t1:.3f} ms") + print(f"TileLang latency: {t2:.3f} ms") + print(f"Speedup: {t1 / t2:.3f}x") -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--B', type=int, default=8, help='Batch size') - parser.add_argument('--S', type=int, default=1024, help='Seq len') - parser.add_argument('--H', type=int, default=32, help='Num heads') - parser.add_argument('--D', type=int, default=128, help='Head dim') + parser.add_argument("--B", type=int, default=8, help="Batch size") + parser.add_argument("--S", type=int, default=1024, help="Seq len") + parser.add_argument("--H", type=int, default=32, help="Num heads") + parser.add_argument("--D", type=int, default=128, help="Head dim") args = parser.parse_args() main(args.B, args.S, args.H, args.D) diff --git a/examples/linear_attention/example_linear_attn_fwd.py b/examples/linear_attention/example_linear_attn_fwd.py index 03900a7e649b0f2113a103147abc91aef3b1c6b1..3d28f92b0524aaa7565d9e6a47d92ef55912605e 100644 --- a/examples/linear_attention/example_linear_attn_fwd.py +++ b/examples/linear_attention/example_linear_attn_fwd.py @@ -14,20 +14,20 @@ from typing import Optional, Tuple pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, +) def tl_fused_chunk_fwd_kernel( B, S, H, DK, DV, - dtype: str = 'float16', + dtype: str = "float16", scale: float = None, ) -> torch.Tensor: - if scale is None: scale = DK**-0.5 - accum_dtype = 'float' + accum_dtype = "float" chunk_size = 64 BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA @@ -38,11 +38,12 @@ def tl_fused_chunk_fwd_kernel( @T.prim_func def fused_chunk_linear_attn_fwd( - Q: T.Tensor([B, S, H, DK], dtype), # type: ignore - K: T.Tensor([B, S, H, DK], dtype), # type: ignore - V: T.Tensor([B, S, H, DV], dtype), # type: ignore - O: T.Tensor([B, S, H, DV], accum_dtype), # type: ignore - final_state: T.Tensor([B, H, DK, DV], accum_dtype)): # type: ignore + Q: T.Tensor([B, S, H, DK], dtype), # type: ignore + K: T.Tensor([B, S, H, DK], dtype), # type: ignore + V: T.Tensor([B, S, H, DV], dtype), # type: ignore + O: T.Tensor([B, S, H, DV], accum_dtype), # type: ignore + final_state: T.Tensor([B, H, DK, DV], accum_dtype), + ): # type: ignore with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): i_b = i_bh // H i_h = i_bh % H @@ -65,8 +66,8 @@ def tl_fused_chunk_fwd_kernel( for i in T.Pipelined(0, NT): for row, col in T.Parallel(chunk_size, BK): q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale - T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k) - T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v) + T.copy(K[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k) + T.copy(V[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v) T.gemm(q, k, s, clear_accum=True, transpose_B=True) for row, col in T.Parallel(chunk_size, chunk_size): @@ -77,12 +78,10 @@ def tl_fused_chunk_fwd_kernel( T.gemm(k, v, h, transpose_A=True) T.gemm(q, h_shared, o) T.copy(o, o_shared) - T.atomic_add( - O[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], - o_shared) + T.atomic_add(O[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], o_shared) # Output final state - T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k + 1) * BK, i_v * BV:(i_v + 1) * BV]) + T.copy(h, final_state[i_b, i_h, i_k * BK : (i_k + 1) * BK, i_v * BV : (i_v + 1) * BV]) return fused_chunk_linear_attn_fwd @@ -91,38 +90,35 @@ def tl_fused_chunk_fwd(q, k, v): B, S, H, D = q.shape kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D) print(kernel.get_kernel_source()) - o = torch.zeros((B, S, H, D), device='cuda', dtype=torch.float32) + o = torch.zeros((B, S, H, D), device="cuda", dtype=torch.float32) h = kernel(q, k, v, o) return o, h -def ref_program(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]: +def ref_program(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]: q, k, v = q.float(), k.float(), v.float() if scale is None: - scale = q.shape[-1]**-0.5 + scale = q.shape[-1] ** -0.5 chunk_size = 64 - q = rearrange(q, 'b (n c) h d -> b h n c d', c=chunk_size) * scale - k = rearrange(k, 'b (n c) h d -> b h n c d', c=chunk_size) - v = rearrange(v, 'b (n c) h d -> b h n c d', c=chunk_size) + q = rearrange(q, "b (n c) h d -> b h n c d", c=chunk_size) * scale + k = rearrange(k, "b (n c) h d -> b h n c d", c=chunk_size) + v = rearrange(v, "b (n c) h d -> b h n c d", c=chunk_size) kv = k.transpose(-1, -2) @ v kv = kv.cumsum(2) h = kv[:, :, -1, :, :] kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) inter = q @ kv - intra = ((q @ k.transpose(-1, -2)).masked_fill_( - torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), - 0)) @ v + intra = ( + (q @ k.transpose(-1, -2)).masked_fill_(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0) + ) @ v o = inter + intra - return rearrange(o, 'b h n c d -> b (n c) h d'), h + return rearrange(o, "b h n c d -> b (n c) h d"), h def main(B=1, S=512, H=16, D=128): - q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) - k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) - v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) + q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) # qk norm is necessary for linear attn q, _ = l2norm_fwd(q) @@ -131,25 +127,23 @@ def main(B=1, S=512, H=16, D=128): o, h = tl_fused_chunk_fwd(q, k, v) o_ref, h_ref = ref_program(q, k, v) - assert torch.allclose(o, o_ref, atol=1e-2, rtol=1e-2), f'o max err: {(o - o_ref).abs().max()}' - assert torch.allclose(h, h_ref, atol=1e-2, rtol=1e-2), f'h max err: {(h - h_ref).abs().max()}' - print('Passed all tests!✅') + assert torch.allclose(o, o_ref, atol=1e-2, rtol=1e-2), f"o max err: {(o - o_ref).abs().max()}" + assert torch.allclose(h, h_ref, atol=1e-2, rtol=1e-2), f"h max err: {(h - h_ref).abs().max()}" + print("Passed all tests!✅") - t1 = do_bench( - lambda: fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False), - backend='cupti') - t2 = do_bench(lambda: tl_fused_chunk_fwd(q, k, v), backend='cupti') - print(f'Triton latency: {t1:.3f} ms') - print(f'TileLang latency: {t2:.3f} ms') - print(f'Speedup: {t1/t2:.3f}x') + t1 = do_bench(lambda: fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False), backend="cupti") + t2 = do_bench(lambda: tl_fused_chunk_fwd(q, k, v), backend="cupti") + print(f"Triton latency: {t1:.3f} ms") + print(f"TileLang latency: {t2:.3f} ms") + print(f"Speedup: {t1 / t2:.3f}x") -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--B', type=int, default=8, help='Batch size') - parser.add_argument('--S', type=int, default=1024, help='Seq len') - parser.add_argument('--H', type=int, default=32, help='Num heads') - parser.add_argument('--D', type=int, default=128, help='Head dim') + parser.add_argument("--B", type=int, default=8, help="Batch size") + parser.add_argument("--S", type=int, default=1024, help="Seq len") + parser.add_argument("--H", type=int, default=32, help="Num heads") + parser.add_argument("--D", type=int, default=128, help="Head dim") args = parser.parse_args() main(args.B, args.S, args.H, args.D) diff --git a/examples/linear_attention/example_mamba_chunk_scan.py b/examples/linear_attention/example_mamba_chunk_scan.py index add49052db785a707ac4fd0ff27b2853b517eeda..53b6cf9fb89f7b58f5ffb568f3461c9465141f23 100644 --- a/examples/linear_attention/example_mamba_chunk_scan.py +++ b/examples/linear_attention/example_mamba_chunk_scan.py @@ -9,6 +9,7 @@ import itertools def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D): from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd + out, _ = _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D) return out @@ -43,14 +44,15 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] decay = torch.exp(dt_segment_sum) scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s") - causal_mask = torch.tril( - torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) + causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) scores_decay = scores_decay.masked_fill(~causal_mask, 0) - out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype), - rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)) + out = torch.einsum( + "bchls,bhcs,bcshp->bclhp", scores_decay.to(x.dtype), dt.to(x.dtype), rearrange(x, "b (c s) h p -> b c s h p", c=nchunks) + ) state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) - out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange( - C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out + out_prev = ( + torch.einsum("bclhn,bchpn->bclhp", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out + ) out = out + out_prev out = rearrange(out, "b c l h p -> b (c l) h p") if D is not None: @@ -61,12 +63,7 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): def get_configs(): - iter_params = dict( - block_M=[64, 128, 256], - block_N=[32, 64], - block_K=[64, 128, 256], - block_Dstate=[128], - num_stages=[1, 2, 3, 4, 5]) + iter_params = dict(block_M=[64, 128, 256], block_N=[32, 64], block_K=[64, 128, 256], block_Dstate=[128], num_stages=[1, 2, 3, 4, 5]) return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] @@ -77,19 +74,21 @@ def get_configs(): tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }, ) -def chunk_scan_fwd(batch, - seqlen, - chunk_size, - ngroups, - nheads, - headdim, - dstate, - block_M=64, - block_N=64, - block_K=64, - block_Dstate=128, - num_stages=2, - threads=128): +def chunk_scan_fwd( + batch, + seqlen, + chunk_size, + ngroups, + nheads, + headdim, + dstate, + block_M=64, + block_N=64, + block_K=64, + block_Dstate=128, + num_stages=2, + threads=128, +): dtype = "float16" accum_dtype = "float" nchunks = T.ceildiv(seqlen, chunk_size) @@ -97,20 +96,20 @@ def chunk_scan_fwd(batch, @T.prim_func def main( - cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore - x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore - dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore - dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore - C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore - prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore - D: T.Tensor((nheads), dtype), # type: ignore - Output: T.Tensor((batch, seqlen, nheads, headdim), dtype) # type: ignore + cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore + x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore + dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore + dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore + C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore + prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore + D: T.Tensor((nheads), dtype), # type: ignore + Output: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore ): - with T.Kernel( - nheads, - T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), - batch * nchunks, - threads=threads) as (bz, bx, by): + with T.Kernel(nheads, T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), batch * nchunks, threads=threads) as ( + bz, + bx, + by, + ): acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) acc_o_shared = T.alloc_shared((block_M, block_N), dtype) cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn") @@ -136,27 +135,32 @@ def chunk_scan_fwd(batch, m_idx = bx // T.ceildiv(headdim, block_N) n_idx = bx % T.ceildiv(headdim, block_N) - T.annotate_layout({ - acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared), - cb_shared: tilelang.layout.make_swizzled_layout(cb_shared), - x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared) - }) + T.annotate_layout( + { + acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared), + cb_shared: tilelang.layout.make_swizzled_layout(cb_shared), + x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared), + } + ) T.no_set_max_nreg() - T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M], - dA_cs_m_shared) + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M : (m_idx + 1) * block_M], dA_cs_m_shared) T.copy(dA_cs_m_shared, dA_cs_m_local) T.clear(acc_o) for i in T.Parallel(block_M): scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p) T.copy( - C[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + - (m_idx + 1) * block_M, bz // (nheads // ngroups), 0:block_Dstate], C_shared) - T.copy( - prev_states[batch_idx, chunk_idx, bz, n_idx * block_N:(n_idx + 1) * block_N, - 0:block_Dstate], prev_state_shared) + C[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz // (nheads // ngroups), + 0:block_Dstate, + ], + C_shared, + ) + T.copy(prev_states[batch_idx, chunk_idx, bz, n_idx * block_N : (n_idx + 1) * block_N, 0:block_Dstate], prev_state_shared) T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True) for i, j in T.Parallel(block_M, block_N): acc_o[i, j] *= scale_m_local[i] @@ -165,34 +169,47 @@ def chunk_scan_fwd(batch, for k in T.Pipelined(loop_range, num_stages=num_stages): T.copy( - cb[batch_idx, chunk_idx, bz // (nheads // ngroups), - m_idx * block_M:(m_idx + 1) * block_M, k * block_K:(k + 1) * block_K], - cb_shared) + cb[ + batch_idx, + chunk_idx, + bz // (nheads // ngroups), + m_idx * block_M : (m_idx + 1) * block_M, + k * block_K : (k + 1) * block_K, + ], + cb_shared, + ) T.copy(cb_shared, cb_local) - T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], - dA_cs_k_shared) + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cs_k_shared) T.copy(dA_cs_k_shared, dA_cs_k_local) for i, j in T.Parallel(block_M, block_K): - cb_local[i, - j] = cb_local[i, - j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p) - T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared) + cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p) + T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared) T.copy(dt_shared, dt_local) for i, j in T.Parallel(block_M, block_K): cb_local[i, j] *= dt_local[j] for i, j in T.Parallel(block_M, block_K): - cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, - cb_local[i, j], 0) + cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, cb_local[i, j], 0) T.copy( - x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + - (k + 1) * block_K, bz, n_idx * block_N:(n_idx + 1) * block_N], x_shared) + x[ + batch_idx, + chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + x_shared, + ) T.gemm(cb_local, x_shared, acc_o) D_local[0] = D[bz] T.copy( - x[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + - (m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N], - x_residual_shared) + x[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + x_residual_shared, + ) T.copy(x_residual_shared, x_residual_local) for i, j in T.Parallel(block_M, block_N): acc_o[i, j] += x_residual_local[i, j] * D_local[0] @@ -200,27 +217,40 @@ def chunk_scan_fwd(batch, T.copy(acc_o, acc_o_shared) T.copy( acc_o_shared, - Output[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + - (m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N]) + Output[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + ) return main if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=80, help='heads') - parser.add_argument('--groups', type=int, default=1, help='groups') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--chunk_size', type=int, default=256, help='chunk size') - parser.add_argument('--dim', type=int, default=64, help='dim') - parser.add_argument('--dstate', type=int, default=128, help='dstate') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=80, help="heads") + parser.add_argument("--groups", type=int, default=1, help="groups") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--chunk_size", type=int, default=256, help="chunk size") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--dstate", type=int, default=128, help="dstate") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() - batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate + batch, heads, groups, seq_len, chunk_size, dim, dstate = ( + args.batch, + args.heads, + args.groups, + args.seq_len, + args.chunk_size, + args.dim, + args.dstate, + ) total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate - if (not args.tune): + if not args.tune: kernel = chunk_scan_fwd( batch, seq_len, @@ -234,7 +264,8 @@ if __name__ == "__main__": block_K=64, block_Dstate=128, num_stages=2, - threads=128) + threads=128, + ) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) print("All checks pass.") diff --git a/examples/linear_attention/example_mamba_chunk_state.py b/examples/linear_attention/example_mamba_chunk_state.py index ad3df0df81643e113a4dd09f0bea274c7c0016da..6aefde7bb839c62f438aead9fa1eaf269128035a 100644 --- a/examples/linear_attention/example_mamba_chunk_state.py +++ b/examples/linear_attention/example_mamba_chunk_state.py @@ -10,6 +10,7 @@ import itertools def chunk_state_triton(B, x, dt, dA_cumsum): from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd + return _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=False) @@ -41,46 +42,33 @@ def ref_program(B, x, dt, dA_cumsum): x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size) B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size) decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum)) - return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype), - dt.to(x.dtype), x) + return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype), dt.to(x.dtype), x) def get_configs(): - iter_params = dict( - block_M=[64, 128], block_N=[32, 64, 128], block_K=[32, 64], num_stages=[1, 2, 3, 4, 5]) + iter_params = dict(block_M=[64, 128], block_N=[32, 64, 128], block_K=[32, 64], num_stages=[1, 2, 3, 4, 5]) return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit(out_idx=[4]) -def chunk_state_fwd(batch, - seqlen, - chunk_size, - ngroups, - nheads, - headdim, - dstate, - block_M=64, - block_N=64, - block_K=64, - num_stages=2, - threads=128): +def chunk_state_fwd( + batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, block_M=64, block_N=64, block_K=64, num_stages=2, threads=128 +): dtype = "float16" accum_dtype = "float" nchunks = T.ceildiv(seqlen, chunk_size) p = 1.44269504 @T.prim_func - def main(B: T.Tensor((batch, seqlen, ngroups, dstate), dtype), x: T.Tensor( - (batch, seqlen, nheads, headdim), dtype), dt: T.Tensor( - (batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Tensor( - (batch, nheads, nchunks, chunk_size), dtype), Output: T.Tensor( - (batch, nchunks, nheads, headdim, dstate), dtype)): - with T.Kernel( - nheads, - T.ceildiv(headdim, block_M) * T.ceildiv(dstate, block_N), - batch * nchunks, - threads=threads) as (bz, bx, by): + def main( + B: T.Tensor((batch, seqlen, ngroups, dstate), dtype), + x: T.Tensor((batch, seqlen, nheads, headdim), dtype), + dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), + dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), + Output: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), + ): + with T.Kernel(nheads, T.ceildiv(headdim, block_M) * T.ceildiv(dstate, block_N), batch * nchunks, threads=threads) as (bz, bx, by): x_shared = T.alloc_shared((block_K, block_M), dtype) x_local = T.alloc_fragment((block_K, block_M), dtype) xt_local = T.alloc_fragment((block_M, block_K), dtype) @@ -101,20 +89,24 @@ def chunk_state_fwd(batch, m_idx = bx // T.ceildiv(dstate, block_N) n_idx = bx % T.ceildiv(dstate, block_N) - T.annotate_layout({ - x_shared: tilelang.layout.make_swizzled_layout(x_shared), - acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared) - }) + T.annotate_layout( + {x_shared: tilelang.layout.make_swizzled_layout(x_shared), acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared)} + ) dA_cs_last[0] = dA_cumsum[batch_idx, bz, chunk_idx, chunk_size - 1] T.clear(acc_o) for k in T.Pipelined(loop_range, num_stages=num_stages): T.copy( - x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + - (k + 1) * block_K, bz, m_idx * block_M:(m_idx + 1) * block_M], x_shared) - T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], - dA_cumsum_shared) - T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared) + x[ + batch_idx, + chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, + bz, + m_idx * block_M : (m_idx + 1) * block_M, + ], + x_shared, + ) + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cumsum_shared) + T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared) T.copy(dA_cumsum_shared, dA_cumsum_local) T.copy(dt_shared, dt_local) for i in T.Parallel(block_K): @@ -123,47 +115,50 @@ def chunk_state_fwd(batch, for i, j in T.Parallel(block_M, block_K): xt_local[i, j] = x_local[j, i] * scale[j] T.copy( - B[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + - (k + 1) * block_K, bz // (nheads // ngroups), - n_idx * block_N:(n_idx + 1) * block_N], B_shared) + B[ + batch_idx, + chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, + bz // (nheads // ngroups), + n_idx * block_N : (n_idx + 1) * block_N, + ], + B_shared, + ) T.gemm(xt_local, B_shared, acc_o) T.copy(acc_o, acc_o_shared) T.copy( acc_o_shared, - Output[batch_idx, chunk_idx, bz, m_idx * block_M:(m_idx + 1) * block_M, - n_idx * block_N:(n_idx + 1) * block_N]) + Output[batch_idx, chunk_idx, bz, m_idx * block_M : (m_idx + 1) * block_M, n_idx * block_N : (n_idx + 1) * block_N], + ) return main if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=80, help='heads') - parser.add_argument('--groups', type=int, default=1, help='groups') - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--chunk_size', type=int, default=256, help='chunk size') - parser.add_argument('--dim', type=int, default=64, help='dim') - parser.add_argument('--dstate', type=int, default=128, help='dstate') - parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=80, help="heads") + parser.add_argument("--groups", type=int, default=1, help="groups") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--chunk_size", type=int, default=256, help="chunk size") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--dstate", type=int, default=128, help="dstate") + parser.add_argument("--tune", action="store_true", help="tune configs") args = parser.parse_args() - batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate + batch, heads, groups, seq_len, chunk_size, dim, dstate = ( + args.batch, + args.heads, + args.groups, + args.seq_len, + args.chunk_size, + args.dim, + args.dstate, + ) total_flops = 2 * batch * seq_len * heads * dim * dstate - if (not args.tune): + if not args.tune: kernel = chunk_state_fwd( - batch, - seq_len, - chunk_size, - groups, - heads, - dim, - dstate, - block_M=64, - block_N=128, - block_K=64, - num_stages=4, - threads=128) + batch, seq_len, chunk_size, groups, heads, dim, dstate, block_M=64, block_N=128, block_K=64, num_stages=4, threads=128 + ) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) print("All checks pass.") diff --git a/examples/linear_attention/example_retention_fwd.py b/examples/linear_attention/example_retention_fwd.py index 59445419a11a99fbb1a00fb6b03b0a1c94d4e9fd..ccb11fe1b2a462484dafbff864113d01322f5e6e 100644 --- a/examples/linear_attention/example_retention_fwd.py +++ b/examples/linear_attention/example_retention_fwd.py @@ -13,13 +13,12 @@ def chunk_retention_fwd_kernel( H, DK, DV, - dtype: str = 'float16', + dtype: str = "float16", scale: float = None, ) -> torch.Tensor: - if scale is None: scale = DK**-0.5 - accum_dtype = 'float' + accum_dtype = "float" chunk_size = 64 BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA @@ -30,16 +29,16 @@ def chunk_retention_fwd_kernel( @T.prim_func def chunk_retention_fwd( - Q: T.Tensor([B, S, H, DK], dtype), # type: ignore - K: T.Tensor([B, S, H, DK], dtype), # type: ignore - V: T.Tensor([B, S, H, DV], dtype), # type: ignore - O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore + Q: T.Tensor([B, S, H, DK], dtype), # type: ignore + K: T.Tensor([B, S, H, DK], dtype), # type: ignore + V: T.Tensor([B, S, H, DV], dtype), # type: ignore + O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore ): with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): i_b = i_bh // H i_h = i_bh % H - log_decay = T.alloc_var('float32') - log_decay = T.log2(1 - T.exp2(-5. - 1. * i_h)) # Head-specific log decay + log_decay = T.alloc_var("float32") + log_decay = T.log2(1 - T.exp2(-5.0 - 1.0 * i_h)) # Head-specific log decay q = T.alloc_shared([chunk_size, BK], dtype) k = T.alloc_shared([chunk_size, BK], dtype) @@ -56,14 +55,12 @@ def chunk_retention_fwd_kernel( for i in T.Pipelined(0, NT): for row, col in T.Parallel(chunk_size, BK): q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale - T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k) - T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v) + T.copy(K[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k) + T.copy(V[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v) T.gemm(q, k, s, clear_accum=True, transpose_B=True) for row, col in T.Parallel(chunk_size, chunk_size): - s_shared[row, - col] = T.if_then_else(row >= col, s[row, col] * T.exp2( - (row - col) * log_decay), 0) + s_shared[row, col] = T.if_then_else(row >= col, s[row, col] * T.exp2((row - col) * log_decay), 0) T.copy(h, h_shared) T.gemm(q, h_shared, o, clear_accum=True) @@ -75,9 +72,7 @@ def chunk_retention_fwd_kernel( v[row, col] = v[row, col] * T.exp2((chunk_size - row - 1) * log_decay) for row, col in T.Parallel(BK, BV): h[row, col] = T.exp2(chunk_size * log_decay) * h[row, col] - T.copy( - o, O[i_k, i_b, i * chunk_size:(i + 1) * chunk_size, i_h, - i_v * BV:(i_v + 1) * BV]) + T.copy(o, O[i_k, i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV]) T.gemm(k, v, h, transpose_A=True) return chunk_retention_fwd @@ -89,24 +84,24 @@ def postprocess(o): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--B', type=int, default=8, help='Batch size') - parser.add_argument('--S', type=int, default=4096, help='Seq len') - parser.add_argument('--H', type=int, default=32, help='Num heads') - parser.add_argument('--D', type=int, default=128, help='Head dim') + parser.add_argument("--B", type=int, default=8, help="Batch size") + parser.add_argument("--S", type=int, default=4096, help="Seq len") + parser.add_argument("--H", type=int, default=32, help="Num heads") + parser.add_argument("--D", type=int, default=128, help="Head dim") args = parser.parse_args() B, S, H, D = args.B, args.S, args.H, args.D total_flops = 2.0 * B * S * S * H * D # causal - q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) - k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) - v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) + q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) kernel = chunk_retention_fwd_kernel(B, S, H, D, D) t = do_bench(lambda: postprocess(kernel(q, k, v)), warmup=25, rep=100) - print(f'Tilelang latency: {t:.3f} ms') - print(f'Tilelang TFLOPs: {total_flops/t * 1e-9}') + print(f"Tilelang latency: {t:.3f} ms") + print(f"Tilelang TFLOPs: {total_flops / t * 1e-9}") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/minference/example_vertical_slash_sparse_attn.py b/examples/minference/example_vertical_slash_sparse_attn.py index 48df3e09169563a1ecfb9b581b1aac3227babadb..6600bb5ed66eb8f38c6e9c8ab383759d3783c110 100644 --- a/examples/minference/example_vertical_slash_sparse_attn.py +++ b/examples/minference/example_vertical_slash_sparse_attn.py @@ -15,12 +15,11 @@ from tilelang.profiler import do_bench @tilelang.jit(out_idx=[3]) def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_size): - block_M = 64 block_N = 64 num_stages = 2 threads = 128 - scale = (1.0 / dim)**0.5 * 1.44269504 + scale = (1.0 / dim) ** 0.5 * 1.44269504 shape = [batch, heads, seq_len, dim] seq_blocks = (seq_len + block_M - 1) // block_M @@ -30,15 +29,13 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz offset_shape = count_shape + [slash_size] index_shape = count_shape + [vertical_size] - vertical_size_round, slash_size_round = tilelang.next_power_of_2( - vertical_size), tilelang.next_power_of_2(slash_size) + vertical_size_round, slash_size_round = tilelang.next_power_of_2(vertical_size), tilelang.next_power_of_2(slash_size) dtype = "float16" accum_dtype = "float" int_dtype = "int32" def kernel_func(block_M, block_N, num_stages, threads): - @T.macro def Prefetch( K: T.Tensor(shape, dtype), @@ -53,32 +50,30 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz ): with T.attr("default", "async_scope", 1): for i, j in T.Parallel(block_N, dim): - K_shared[i, j] = T.if_then_else(k + i < column_count, - K[bz, by, column_index[k + i], j], 0) + K_shared[i, j] = T.if_then_else(k + i < column_count, K[bz, by, column_index[k + i], j], 0) with T.attr("default", "async_scope", 1): for i, j in T.Parallel(block_N, dim): - V_shared[i, j] = T.if_then_else(k + i < column_count, - V[bz, by, column_index[k + i], j], 0) + V_shared[i, j] = T.if_then_else(k + i < column_count, V[bz, by, column_index[k + i], j], 0) T.ptx_commit_group() @T.macro def Compute( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - k: T.int32, - column_count: T.int32, - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), - V_shared: T.SharedBuffer([block_N, dim], dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - count: T.int32, + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + k: T.int32, + column_count: T.int32, + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + count: T.int32, ): T.ptx_wait_group(count) for i, j in T.Parallel(block_M, block_N): @@ -108,17 +103,16 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz @T.prim_func def vs_sparse_flashattn_ws( - Q: T.Tensor(shape, dtype), - K: T.Tensor(shape, dtype), - V: T.Tensor(shape, dtype), - Output: T.Tensor(shape, dtype), - BlockCount: T.Tensor(count_shape, int_dtype), - BlockOffset: T.Tensor(offset_shape, int_dtype), - ColumnCount: T.Tensor(count_shape, int_dtype), - ColumnIndex: T.Tensor(index_shape, int_dtype), + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), + BlockCount: T.Tensor(count_shape, int_dtype), + BlockOffset: T.Tensor(offset_shape, int_dtype), + ColumnCount: T.Tensor(count_shape, int_dtype), + ColumnIndex: T.Tensor(index_shape, int_dtype), ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bc, by, bz): - bx = T.ceildiv(seq_len, block_M) - 1 - bc Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([2, block_N, dim], dtype) @@ -143,9 +137,11 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz T.create_list_of_mbarrier([128] * 9) - T.annotate_layout({ - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - }) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + } + ) block_count[0] = BlockCount[bz, by, bx] column_count[0] = ColumnCount[bz, by, bx] @@ -162,15 +158,15 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz if tid >= 128: T.annotate_producer_reg_dealloc() - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.mbarrier_arrive(mbarrier=8) for bi in T.serial(block_count[0]): k = block_offset[bi] T.mbarrier_wait_parity(mbarrier=bi % 2 + 4, parity=(((bi & 3) >> 1) ^ 1)) - T.copy(K[bz, by, k:k + block_N, :], K_shared[bi % 2, :, :]) + T.copy(K[bz, by, k : k + block_N, :], K_shared[bi % 2, :, :]) T.mbarrier_arrive(mbarrier=bi % 2) T.mbarrier_wait_parity(mbarrier=bi % 2 + 6, parity=(((bi & 3) >> 1) ^ 1)) - T.copy(V[bz, by, k:k + block_N, :], V_shared[bi % 2, :, :]) + T.copy(V[bz, by, k : k + block_N, :], V_shared[bi % 2, :, :]) T.mbarrier_arrive(mbarrier=bi % 2 + 2) else: T.annotate_consumer_reg_alloc() @@ -181,16 +177,10 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz for bi in T.serial(block_count[0]): k = block_offset[bi] for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k + j, 0, -T.infinity(acc_s.dtype)) T.mbarrier_wait_parity(mbarrier=bi % 2, parity=((bi & 3) >> 1)) - T.gemm( - Q_shared, - K_shared[bi % 2, :, :], - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared[bi % 2, :, :], acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.mbarrier_arrive(mbarrier=bi % 2 + 4) T.copy(scores_max, scores_max_prev) @@ -200,20 +190,15 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - - scores_max[i] * scale) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim): acc_o[i, j] = acc_o[i, j] * scores_scale[i] T.copy(acc_s, acc_s_cast) - T.mbarrier_wait_parity(mbarrier=bi % 2 + 2, parity=(((bi & 3) >> 1))) - T.gemm( - acc_s_cast, - V_shared[bi % 2, :, :], - acc_o, - policy=T.GemmWarpPolicy.FullRow) + T.mbarrier_wait_parity(mbarrier=bi % 2 + 2, parity=((bi & 3) >> 1)) + T.gemm(acc_s_cast, V_shared[bi % 2, :, :], acc_o, policy=T.GemmWarpPolicy.FullRow) T.mbarrier_arrive(mbarrier=bi % 2 + 6) @@ -223,38 +208,85 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] if column_count[0] != 0: - Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count[0], 0, bz, - by) + Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count[0], 0, bz, by) for bi in T.serial(T.ceildiv(column_count[0], block_N) - 1): k = bi * block_N if bi % 2 == 0: - Prefetch(K, V, K_shared_2, V_shared_2, column_index, - column_count[0], k + block_N, bz, by) - - Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, k, - column_count[0], Q_shared, K_shared_1, V_shared_1, - scores_scale, scores_sum, logsum, 1) + Prefetch(K, V, K_shared_2, V_shared_2, column_index, column_count[0], k + block_N, bz, by) + + Compute( + acc_s, + acc_s_cast, + acc_o, + scores_max, + scores_max_prev, + k, + column_count[0], + Q_shared, + K_shared_1, + V_shared_1, + scores_scale, + scores_sum, + logsum, + 1, + ) else: - Prefetch(K, V, K_shared_1, V_shared_1, column_index, - column_count[0], k + block_N, bz, by) - - Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, k, - column_count[0], Q_shared, K_shared_2, V_shared_2, - scores_scale, scores_sum, logsum, 1) + Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count[0], k + block_N, bz, by) + + Compute( + acc_s, + acc_s_cast, + acc_o, + scores_max, + scores_max_prev, + k, + column_count[0], + Q_shared, + K_shared_2, + V_shared_2, + scores_scale, + scores_sum, + logsum, + 1, + ) if T.ceildiv(column_count[0], block_N) % 2 == 0: - Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, - T.ceildiv(column_count[0], block_N) * block_N - block_N, - column_count[0], Q_shared, K_shared_2, V_shared_2, scores_scale, - scores_sum, logsum, 0) + Compute( + acc_s, + acc_s_cast, + acc_o, + scores_max, + scores_max_prev, + T.ceildiv(column_count[0], block_N) * block_N - block_N, + column_count[0], + Q_shared, + K_shared_2, + V_shared_2, + scores_scale, + scores_sum, + logsum, + 0, + ) else: - Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, - T.ceildiv(column_count[0], block_N) * block_N - block_N, - column_count[0], Q_shared, K_shared_1, V_shared_1, scores_scale, - scores_sum, logsum, 0) + Compute( + acc_s, + acc_s_cast, + acc_o, + scores_max, + scores_max_prev, + T.ceildiv(column_count[0], block_N) * block_N - block_N, + column_count[0], + Q_shared, + K_shared_1, + V_shared_1, + scores_scale, + scores_sum, + logsum, + 0, + ) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return vs_sparse_flashattn_ws @@ -470,11 +502,8 @@ def vertical_slash_sparse_attention( import os current_dir = os.path.dirname(os.path.abspath(__file__)) - sources = [ - os.path.join(current_dir, 'ops', 'kernels.cpp'), - os.path.join(current_dir, 'ops', 'vertical_slash_index.cu') - ] - ops = load(name='convert', sources=sources, verbose=False) + sources = [os.path.join(current_dir, "ops", "kernels.cpp"), os.path.join(current_dir, "ops", "vertical_slash_index.cu")] + ops = load(name="convert", sources=sources, verbose=False) convert_vertical_slash_indexes = ops.convert_vertical_slash_indexes batch_size, num_heads, context_size, head_dim = query.shape pad = (block_size_M - context_size) & (block_size_M - 1) @@ -485,15 +514,13 @@ def vertical_slash_sparse_attention( value = torch.nn.functional.pad(value, [0, 0, 0, pad, 0, 0, 0, 0]) if head_dim not in [16, 32, 64, 128, 256, 512]: - target_dim = 2**math.ceil(math.log2(head_dim)) - head_dim + target_dim = 2 ** math.ceil(math.log2(head_dim)) - head_dim query = torch.nn.functional.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0]) key = torch.nn.functional.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0]) value = torch.nn.functional.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0]) - v_idx = v_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort( - dim=-1, descending=False)[0] - s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort( - dim=-1, descending=True)[0] + v_idx = v_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0] + s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0] seqlens = torch.tensor([context_size] * query.shape[0], dtype=torch.int32, device=query.device) sm_scale = head_dim**-0.5 @@ -506,8 +533,7 @@ def vertical_slash_sparse_attention( block_size_N, ) - tl_kernel = _tl_vs_sparse_flashattn(batch_size, num_heads, context_size, head_dim, - v_idx.shape[2], s_idx.shape[2]) + tl_kernel = _tl_vs_sparse_flashattn(batch_size, num_heads, context_size, head_dim, v_idx.shape[2], s_idx.shape[2]) def run(is_triton: bool = True): if is_triton: @@ -525,8 +551,7 @@ def vertical_slash_sparse_attention( block_size_N, ) else: - out = tl_kernel(query, key, value, block_count, block_offset, column_count, - column_index) + out = tl_kernel(query, key, value, block_count, block_offset, column_count, column_index) return out[..., :context_size, :head_dim] return run @@ -536,8 +561,7 @@ def sum_all_diagonal_matrix(mat: torch.tensor): b, h, n, m = mat.shape zero_mat = torch.zeros((b, h, n, n)).to(mat.device) # Zero matrix used for padding mat_padded = torch.cat((zero_mat, mat, zero_mat), -1) # pads the matrix on left and right - mat_strided = mat_padded.as_strided( - (1, 1, n, n + m), (1, n * (2 * n + m), 2 * n + m + 1, 1)) # Change the strides + mat_strided = mat_padded.as_strided((1, 1, n, n + m), (1, n * (2 * n + m), 2 * n + m + 1, 1)) # Change the strides sum_diags = torch.sum(mat_strided, 2) # Sums the resulting matrix's columns return sum_diags[:, :, 1:] @@ -559,24 +583,23 @@ def main(argv=None): vertical_size, slash_size = args.vertical_size, args.slash_size torch.manual_seed(0) - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) q_len = SEQ_LEN vertical_size, slash_size = min(q_len, vertical_size), min(q_len, slash_size) last_q = 64 - qk = torch.einsum('bhmk, bhnk -> bhmn', q[:, :, -last_q:, :], k) + qk = torch.einsum("bhmk, bhnk -> bhmn", q[:, :, -last_q:, :], k) arange = torch.arange(last_q, device="cuda") - qk[:, :, :, -last_q:] = torch.where(arange[None, None, :, None] >= arange[None, None, None, :], - qk[:, :, :, -last_q:], -torch.inf) + qk[:, :, :, -last_q:] = torch.where(arange[None, None, :, None] >= arange[None, None, None, :], qk[:, :, :, -last_q:], -torch.inf) qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32) vertical = qk.sum(-2, keepdim=True) vertical[..., :30] = torch.inf vertical_topk = torch.topk(vertical, vertical_size, -1).indices - slash = sum_all_diagonal_matrix(qk)[..., :-last_q + 1] + slash = sum_all_diagonal_matrix(qk)[..., : -last_q + 1] slash[..., -30:] = torch.inf slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices diff --git a/examples/norm/rms_norm.py b/examples/norm/rms_norm.py index 40d367c2d6708ce7e93d2f0db5f66fb2a38e7819..a7a06b9c64622c0794fddc136db76cec9ee32407 100644 --- a/examples/norm/rms_norm.py +++ b/examples/norm/rms_norm.py @@ -45,7 +45,7 @@ def rms_norm(M, N, blk_m): A_local = T.alloc_fragment((blk_m, N), dtype) A_powsum = T.alloc_fragment((blk_m,), dtype) - T.copy(A[bx * blk_m:(bx + 1) * blk_m, :], A_shared) + T.copy(A[bx * blk_m : (bx + 1) * blk_m, :], A_shared) T.copy(A_shared, A_local) for i, j in T.Parallel(blk_m, N): A_pow_local[i, j] = A_local[i, j] * A_local[i, j] @@ -54,7 +54,7 @@ def rms_norm(M, N, blk_m): A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) for i, j in T.Parallel(blk_m, N): A_local[i, j] *= A_powsum[i] - T.copy(A_local, B[bx * blk_m:(bx + 1) * blk_m, :]) + T.copy(A_local, B[bx * blk_m : (bx + 1) * blk_m, :]) return main diff --git a/examples/norm/test_rms_norm.py b/examples/norm/test_rms_norm.py index a05f9b082a8111d7b77b9351456d5b8760323178..124a212f6ee64e0a809c785e1266f656d3f703ca 100644 --- a/examples/norm/test_rms_norm.py +++ b/examples/norm/test_rms_norm.py @@ -45,7 +45,7 @@ def rms_norm(M, N, blk_m): A_local = T.alloc_fragment((blk_m, N), dtype) A_powsum = T.alloc_fragment((blk_m,), dtype) - T.copy(A[bx * blk_m:(bx + 1) * blk_m, :], A_shared) + T.copy(A[bx * blk_m : (bx + 1) * blk_m, :], A_shared) T.copy(A_shared, A_local) for i, j in T.Parallel(blk_m, N): A_pow_local[i, j] = A_local[i, j] * A_local[i, j] @@ -54,7 +54,7 @@ def rms_norm(M, N, blk_m): A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) for i, j in T.Parallel(blk_m, N): A_local[i, j] *= A_powsum[i] - T.copy(A_local, B[bx * blk_m:(bx + 1) * blk_m, :]) + T.copy(A_local, B[bx * blk_m : (bx + 1) * blk_m, :]) return main diff --git a/examples/online_softmax/online_softmax.py b/examples/online_softmax/online_softmax.py index 432482d063f271fd70411ef7ab72452211598688..32f1c001f221e0b917e4506bcbb202455d3bfb2c 100644 --- a/examples/online_softmax/online_softmax.py +++ b/examples/online_softmax/online_softmax.py @@ -20,8 +20,8 @@ def softmax_kernel( @T.prim_func def main( - X: T.Tensor([M, N], dtype), - Y: T.Tensor([M, N], dtype), + X: T.Tensor([M, N], dtype), + Y: T.Tensor([M, N], dtype), ): with T.Kernel(M, threads=128) as (i_m): x = T.alloc_fragment([BN], dtype) @@ -33,7 +33,7 @@ def softmax_kernel( T.fill(lse, -T.infinity(accum_dtype)) for i_n in T.Pipelined(0, NN): - T.copy(X[i_m, i_n * BN:(i_n + 1) * BN], x) + T.copy(X[i_m, i_n * BN : (i_n + 1) * BN], x) T.reduce_max(x, max_x, dim=0, clear=True) @@ -45,12 +45,12 @@ def softmax_kernel( lse[0] = max_x[0] * scale + T.log2(T.exp2(lse[0] - max_x[0] * scale) + sum_exp_x[0]) for i_n in T.Pipelined(0, NN): - T.copy(X[i_m, i_n * BN:(i_n + 1) * BN], x) + T.copy(X[i_m, i_n * BN : (i_n + 1) * BN], x) for j in T.Parallel(BN): y[j] = T.exp2(x[j] * scale - lse[0]) - T.copy(y, Y[i_m, i_n * BN:(i_n + 1) * BN]) + T.copy(y, Y[i_m, i_n * BN : (i_n + 1) * BN]) return main @@ -69,4 +69,4 @@ t1 = do_bench(lambda: X.softmax(dim=1), warmup=25, rep=100) t2 = do_bench(lambda: kernel(X), warmup=25, rep=100) print(f"torch latency: {t1:.3f} ms") print(f"TileLang latency: {t2:.3f} ms") -print(f"Speedup: {t1/t2:.3f}x") +print(f"Speedup: {t1 / t2:.3f}x") diff --git a/examples/plot_layout/fragment_mfma_load_a.py b/examples/plot_layout/fragment_mfma_load_a.py index 2c3b282a6ae8ac5b937d6bfeaefc0d3911e043db..a7e8f8909a3b950c45dc36d53cf00c6d2346f84a 100644 --- a/examples/plot_layout/fragment_mfma_load_a.py +++ b/examples/plot_layout/fragment_mfma_load_a.py @@ -11,10 +11,9 @@ from tilelang.intrinsics.mfma_layout import ( ) -def make_mfma_load_base_layout(dtype: str = "float16", - matrix: Literal["A", "B"] = "A", - k_dim: int = 16, - transposed: bool = False) -> T.Fragment: +def make_mfma_load_base_layout( + dtype: str = "float16", matrix: Literal["A", "B"] = "A", k_dim: int = 16, transposed: bool = False +) -> T.Fragment: """ Create a layout function for storing MFMA results into a fragment buffer. This layout is used in conjunction with `inverse_mfma_store_layout` to @@ -72,12 +71,10 @@ def make_mfma_load_base_layout(dtype: str = "float16", # so the b matrix expected a transposed basic layout transform_func: Callable = None if matrix == "A": - transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( - j, i) + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) micro_size_s, micro_size_r = micro_size_x, micro_size_k elif matrix == "B": - transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b( - j, i) + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i) micro_size_s, micro_size_r = micro_size_k, micro_size_y else: raise ValueError(f"Unsupported matrix {matrix}") @@ -120,14 +117,11 @@ print(base_layout) plot_layout(base_layout, name="base_layout") # warp layout 32x32 -warp_layout = base_layout.repeat([warp_rows, warp_cols], - repeat_on_thread=False, - lower_dim_first=False) +warp_layout = base_layout.repeat([warp_rows, warp_cols], repeat_on_thread=False, lower_dim_first=False) print(warp_layout) plot_layout(warp_layout, name="warp_layout") # block layout 64x32 -block_layout = warp_layout.repeat([block_rows, 1], repeat_on_thread=True, - lower_dim_first=True).replicate(block_cols) +block_layout = warp_layout.repeat([block_rows, 1], repeat_on_thread=True, lower_dim_first=True).replicate(block_cols) print(block_layout) plot_layout(block_layout, name="block_layout") diff --git a/examples/plot_layout/fragment_mma_load_a.py b/examples/plot_layout/fragment_mma_load_a.py index 988899448365892002536e950103fc1504d73cdc..17d1c6d51126eba26fab121453041d125b18da7e 100644 --- a/examples/plot_layout/fragment_mma_load_a.py +++ b/examples/plot_layout/fragment_mma_load_a.py @@ -5,9 +5,7 @@ from tvm.tir import IndexMap from tilelang.intrinsics.utils import get_mma_micro_size -def make_mma_load_base_layout(dtype: str = "float16", - matrix: Literal["A", "B"] = "A", - transposed: bool = False) -> T.Fragment: +def make_mma_load_base_layout(dtype: str = "float16", matrix: Literal["A", "B"] = "A", transposed: bool = False) -> T.Fragment: """ Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with `inverse_mma_store_layout` to @@ -36,6 +34,7 @@ def make_mma_load_base_layout(dtype: str = "float16", shared_16x16_to_mma_32x8_layout_sr_b, shared_16x32_to_mma_32x16_layout_sr_b, ) + assert matrix in ["A", "B"], "matrix should be either A or B" dtype_bits = DataType(dtype).bits # s represents spatial axis @@ -67,12 +66,10 @@ def make_mma_load_base_layout(dtype: str = "float16", # so the b matrix expected a transposed basic layout transform_func: Callable = None if matrix == "A": - transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( - j, i) + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) micro_size_s, micro_size_r = micro_size_x, micro_size_k elif matrix == "B": - transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b( - j, i) + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i) micro_size_s, micro_size_r = micro_size_k, micro_size_y else: raise ValueError(f"Unsupported matrix {matrix}") diff --git a/examples/quickstart.py b/examples/quickstart.py index 39ad348b5d8fa72318850c96c5f6cc8edf5b2955..4b765ce172df1e8774892de0911fd66a36bf3947 100644 --- a/examples/quickstart.py +++ b/examples/quickstart.py @@ -7,12 +7,11 @@ import tilelang.language as T # if not specified, it will be inferred from the input tensors during compile time @tilelang.jit def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def matmul_relu_kernel( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): diff --git a/examples/seer_attention/block_sparse_attn_tilelang.py b/examples/seer_attention/block_sparse_attn_tilelang.py index 219d3ee35f731dd7c8deab72f898a8ade24c33b8..f5f7fe7ba21ddceb8de2d0c8fc45e7645ba151cf 100644 --- a/examples/seer_attention/block_sparse_attn_tilelang.py +++ b/examples/seer_attention/block_sparse_attn_tilelang.py @@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -30,15 +27,17 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F @tilelang.jit( - out_idx=[4], pass_configs={ + out_idx=[4], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, +) def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_causal): block_M = 64 block_N = 64 num_stages = 0 threads = 128 - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] block_mask_shape = [batch, heads, downsample_len, downsample_len] @@ -48,16 +47,15 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c block_mask_dtype = "int8" def kernel_func(block_M, block_N, num_stages, threads): - @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -83,19 +81,19 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -112,7 +110,7 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c logsum = T.alloc_fragment([block_M], accum_dtype) block_mask = T.alloc_local([downsample_len], block_mask_dtype) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -124,33 +122,25 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c for k in T.Pipelined(loop_range, num_stages=num_stages): if block_mask[k] != 0: - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) if is_causal: past_len = seq_kv - seq_q for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else( - bx * block_M + i + past_len >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i + past_len >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) - - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, - scores_sum, logsum) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main @@ -165,44 +155,40 @@ def test_topk_sparse_attention(): torch.manual_seed(0) # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.float16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.float16) x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) # Run tilelang kernel - kernel = blocksparse_flashattn( - BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + kernel = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) tilelang_output = kernel(q, k, v, block_mask.to(torch.int8)) # Compute reference # Expand block mask to full attention matrix - full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal # PyTorch reference implementation - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale - attn = attn.masked_fill(~full_mask, float('-inf')) + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) print("ref_output", ref_output) print("tilelang_output", tilelang_output) # Verify accuracy - assert torch.allclose(tilelang_output, ref_output, atol=1e-2, rtol=1e-2), \ - "TileLang output doesn't match reference" + assert torch.allclose(tilelang_output, ref_output, atol=1e-2, rtol=1e-2), "TileLang output doesn't match reference" print("Pass topk sparse attention test with qlen == klen") @@ -215,42 +201,40 @@ def test_topk_sparse_attention_qlen_lt_klen(): torch.manual_seed(0) # Create inputs. - q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.float16) + q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.float16) sm_scale = 1.0 / (D_HEAD**0.5) downsample_factor = BLOCK downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension - x_ds = torch.randn( - BATCH, N_HEADS, downsample_len, downsample_len, device='cuda', dtype=torch.float16) + x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.float16) # Force the first column to be high so that the first block is always selected. x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) - kernel = blocksparse_flashattn( - BATCH, N_HEADS, Q_LEN, K_LEN, D_HEAD, downsample_len, is_causal=True) + kernel = blocksparse_flashattn(BATCH, N_HEADS, Q_LEN, K_LEN, D_HEAD, downsample_len, is_causal=True) print(kernel.get_kernel_source()) tilelang_output = kernel(q, k, v, block_mask.to(torch.int8)) past_len = K_LEN - Q_LEN - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale - full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool() + full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool() full_mask_full = full_mask_full[..., :K_LEN, :K_LEN] effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) - causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN) + causal_mask = j_global <= i_global # shape: (Q_LEN, K_LEN) final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN) - attn = attn.masked_fill(~final_mask, float('-inf')) + attn = attn.masked_fill(~final_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) print("ref_output", ref_output) print("tilelang_output", tilelang_output) diff --git a/examples/seer_attention/block_sparse_attn_triton.py b/examples/seer_attention/block_sparse_attn_triton.py index ed33cc1e2a4055bad1954318c51a163b3609e921..b4cc3cd00c854cdde0757af38dcf6302976cd49f 100644 --- a/examples/seer_attention/block_sparse_attn_triton.py +++ b/examples/seer_attention/block_sparse_attn_triton.py @@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], - False, - dtype=torch.bool, - device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: dense_mask[:, :, -2:, :] = True @@ -54,7 +51,6 @@ def _fwd_kernel_inner( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): - mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) if mask_val == True: @@ -69,7 +65,7 @@ def _fwd_kernel_inner( qk *= sm_scale # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N - qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float('-inf')) + qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf")) m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_ij[:, None] @@ -149,7 +145,7 @@ def _fwd_kernel( v_ptrs = V + off_v mask_ptrs = block_mask_ptr + start_m * stride_bmm - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) @@ -185,24 +181,12 @@ def _fwd_kernel( acc = acc * l_recip acc = acc.to(Out.dtype.element_ty) - off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[ - None, :] * stride_od + off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) -def _forward(ctx, - q, - k, - v, - block_sparse_mask, - sm_scale, - BLOCK_M=64, - BLOCK_N=64, - num_warps=None, - num_stages=1, - out=None): - +def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None): assert q.shape[-1] == k.shape[-1] == v.shape[-1] assert k.shape[2] == v.shape[2] o = out if out is not None else torch.empty_like(q).contiguous() @@ -247,7 +231,6 @@ def _forward(ctx, class _sparse_attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, block_sparse_dense, sm_scale): # shape constraints @@ -271,9 +254,9 @@ def test_topk_sparse_attention(): torch.manual_seed(0) # Create inputs - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) @@ -281,9 +264,7 @@ def test_topk_sparse_attention(): downsample_len = math.ceil(SEQ_LEN / downsample_factor) print("downsample_len", downsample_len) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 print("x_ds.shape", x_ds.shape) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) @@ -295,22 +276,21 @@ def test_topk_sparse_attention(): # Compute reference # Expand block mask to full attention matrix - full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal # PyTorch reference implementation - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale - attn = attn.masked_fill(~full_mask, float('-inf')) + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) # print("ref_output", ref_output) # print("triton_output", triton_output) # Verify accuracy - assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ - "Triton output doesn't match reference" + assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference" print("Pass topk sparse attention test with qlen == klen") @@ -322,16 +302,15 @@ def test_topk_sparse_attention_qlt_kl(): torch.manual_seed(0) # Create inputs. - q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) + q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) # softmax scale sm_scale = 1.0 / (D_HEAD**0.5) downsample_factor = BLOCK downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension - x_ds = torch.randn( - BATCH, N_HEADS, downsample_len, downsample_len, device='cuda', dtype=torch.bfloat16) + x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.bfloat16) # Force the first column to be high so that the first block is always selected. x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) @@ -340,26 +319,25 @@ def test_topk_sparse_attention_qlt_kl(): past_len = K_LEN - Q_LEN - attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale - full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool() + full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool() full_mask_full = full_mask_full[..., :K_LEN, :K_LEN] effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) - causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN) + causal_mask = j_global <= i_global # shape: (Q_LEN, K_LEN) final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN) - attn = attn.masked_fill(~final_mask, float('-inf')) + attn = attn.masked_fill(~final_mask, float("-inf")) attn = F.softmax(attn, dim=-1) - ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) # Verify accuracy. - assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ - "Triton output doesn't match reference when qlen < klen" + assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference when qlen < klen" print("Pass topk sparse attention test with qlen < klen") diff --git a/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py b/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py index 8707c9430a086652640e4c4db2294dec05e9acb8..6c37dc09c866776f50e9d5c467ac957fe0c377c6 100644 --- a/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py +++ b/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py @@ -28,24 +28,22 @@ def matmul_sp( @T.prim_func def main( - A_sparse: T.Tensor(A_sparse_shape, in_dtype), - E: T.Tensor((M, K // 8), 'uint8'), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // 8), "uint8"), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) - E_shared = T.alloc_shared((block_M, block_K // 8), 'uint8') + E_shared = T.alloc_shared((block_M, block_K // 8), "uint8") C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - T.annotate_layout({ - E: - make_cutlass_metadata_layout( - E, mma_dtype="float16", arch="9.0", block_k=block_K), - E_shared: - make_cutlass_metadata_layout( - E_shared, mma_dtype="float16", arch="9.0", block_k=block_K), - }) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype="float16", arch="9.0", block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", arch="9.0", block_k=block_K), + } + ) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(E[by * block_M, k * block_K // 8], E_shared) @@ -57,7 +55,7 @@ def matmul_sp( return main -def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device='cpu'): +def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device="cpu"): if shape[-1] % 4 != 0: raise ValueError("Last dimension must be divisible by 4 for 2:4 sparsity.") @@ -102,9 +100,9 @@ def run_gemm_sp( num_threads, ) - A = generate_2_to_4_sparse_tensor((M, K), dtype=torch.float16, device='cuda') + A = generate_2_to_4_sparse_tensor((M, K), dtype=torch.float16, device="cuda") A_sparse, E = compress_sm90(A, block_k=block_K, transposed=False) - B = torch.randn((K, N), device='cuda', dtype=torch.float16) + B = torch.randn((K, N), device="cuda", dtype=torch.float16) C_sp = kernel(A_sparse, E, B).half() C = torch.matmul(A, B) diff --git a/examples/topk/example_topk.py b/examples/topk/example_topk.py index 0ca19fb18d7b14fc69eaed7cc6721f8c3925645e..c0cf09bc0fb5f153d42771c72760ab356e22fb1d 100644 --- a/examples/topk/example_topk.py +++ b/examples/topk/example_topk.py @@ -26,9 +26,9 @@ def tl_topk( @T.prim_func def topk_kernel( - logits: T.Tensor([M, N], dtype), - topk_gates: T.Tensor([M, topk], dtype), - topk_indices: T.Tensor([M, topk], "int32"), + logits: T.Tensor([M, N], dtype), + topk_gates: T.Tensor([M, topk], dtype), + topk_indices: T.Tensor([M, topk], "int32"), ): with T.Kernel(T.ceildiv(M, blk_m), threads=threads) as bx: logits_frag = T.alloc_fragment([blk_m, N], dtype=dtype) @@ -43,15 +43,12 @@ def tl_topk( T.reduce_max(logits_frag, max_val, dim=1, clear=True) for i, j in T.Parallel(blk_m, N): - expand_max_idx[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], j, - expand_max_idx[i, j]) + expand_max_idx[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], j, expand_max_idx[i, j]) T.reduce_max(expand_max_idx, max_idx, dim=1, clear=True) for i, j in T.Parallel(blk_m, N): - - logits_frag[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], -10000.0, - logits_frag[i, j]) + logits_frag[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], -10000.0, logits_frag[i, j]) for i in T.Parallel(blk_m): topk_gates[bx * blk_m + i, k] = max_val[i] @@ -61,7 +58,6 @@ def tl_topk( def ref_program(logits, top_k): - top_k_gates, top_k_indices = logits.topk(top_k, dim=1) return top_k_gates, top_k_indices.to(torch.int32) diff --git a/examples/visual_layout_inference/visual_layout_inference.py b/examples/visual_layout_inference/visual_layout_inference.py index 3677d475434e97d5dd8a4d3000932625c51ffa09..dbb39f7898c396ac79c9741dabfb065d5068294a 100644 --- a/examples/visual_layout_inference/visual_layout_inference.py +++ b/examples/visual_layout_inference/visual_layout_inference.py @@ -7,15 +7,15 @@ import tilelang.language as T out_idx=[-1], pass_configs={ tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE: True, - tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS: "svg" - }) + tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS: "svg", + }, +) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def gemm( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -49,12 +49,12 @@ def main(): print("All check passed.") # print the layout visualization result and save figures to ./tmp. - ''' + """ C_local inferenced layout: Shape: [32, 32] -> [8] Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2 Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2] - ''' + """ if __name__ == "__main__": diff --git a/examples/warp_specialize/example_warp_specialize_flashmla.py b/examples/warp_specialize/example_warp_specialize_flashmla.py index 4a8f41ee4fbf846f5e655529729fc01b3e1951f9..4f4417e75db088254b2ce3741d7884cb4c67fd4a 100644 --- a/examples/warp_specialize/example_warp_specialize_flashmla.py +++ b/examples/warp_specialize/example_warp_specialize_flashmla.py @@ -9,7 +9,7 @@ import argparse @tilelang.jit(out_idx=[6]) def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): - scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) dtype = "float16" accum_dtype = "float" kv_group_num = heads // kv_head_num @@ -19,11 +19,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ @T.macro def flash_attn( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid): # smem_sQ @@ -81,10 +81,12 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ cur_kv_head = hid // (kv_group_num // block_H) - T.annotate_layout({ - O_shared_l: tilelang.layout.make_swizzled_layout(O_shared_l), - O_shared_r: tilelang.layout.make_swizzled_layout(O_shared_r), - }) + T.annotate_layout( + { + O_shared_l: tilelang.layout.make_swizzled_layout(O_shared_l), + O_shared_r: tilelang.layout.make_swizzled_layout(O_shared_r), + } + ) # barriers_Q q_shared_ready_barrier = T.alloc_barrier(arrive_count=256) @@ -108,9 +110,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ tx = T.get_thread_binding() - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :h_dim], Q_shared_l) - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, h_dim:], Q_shared_r) - T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :h_dim], Q_shared_l) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, h_dim:], Q_shared_r) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.barrier_arrive(q_shared_ready_barrier) T.barrier_wait(q_shared_ready_barrier, 0) @@ -123,25 +125,18 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.fill(acc_o_l, 0) T.fill(logsum_0, 0) - T.copy(KV[bid, block_N:2 * block_N, cur_kv_head, :h_dim], KV_shared_1_l) + T.copy(KV[bid, block_N : 2 * block_N, cur_kv_head, :h_dim], KV_shared_1_l) T.barrier_arrive(kv_shared_1_l_is_ready) - T.copy(KV[bid, block_N:2 * block_N, cur_kv_head, h_dim:], KV_shared_1_r) + T.copy(KV[bid, block_N : 2 * block_N, cur_kv_head, h_dim:], KV_shared_1_r) T.barrier_arrive(kv_shared_1_r_is_ready) - T.copy(K_pe[bid, block_N:2 * block_N, cur_kv_head, :], K_pe_shared_1) + T.copy(K_pe[bid, block_N : 2 * block_N, cur_kv_head, :], K_pe_shared_1) T.barrier_arrive(kv_shared_1_pe_is_ready) for k in T.serial(loop_range): - T.barrier_wait(kv_shared_0_l_is_ready, k % 2) - T.gemm( - Q_shared_l, - KV_shared_0_l, - acc_s_0, - transpose_B=True, - clear_accum=True, - wg_wait=-1) + T.gemm(Q_shared_l, KV_shared_0_l, acc_s_0, transpose_B=True, clear_accum=True, wg_wait=-1) T.barrier_wait(kv_shared_0_r_is_ready, k % 2) T.gemm(Q_shared_r, KV_shared_0_r, acc_s_0, transpose_B=True, wg_wait=-1) @@ -161,8 +156,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ for i, j in T.Parallel(block_H, block_N): acc_s_0[i, j] = T.exp2(acc_s_0[i, j] * scale - scores_max[i] * scale) for i in T.Parallel(block_H): - scores_scale_0[i] = T.exp2(scores_max_prev_0[i] * scale - - scores_max[i] * scale) + scores_scale_0[i] = T.exp2(scores_max_prev_0[i] * scale - scores_max[i] * scale) T.reduce_sum(acc_s_0, scores_sum_0, dim=1) @@ -182,9 +176,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.barrier_wait(scale_1_ready_barrier, k % 2) if k < loop_range - 1: - T.copy( - KV[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N, - cur_kv_head, :h_dim], KV_shared_0_l) + T.copy(KV[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, :h_dim], KV_shared_0_l) T.barrier_arrive(kv_shared_0_l_is_ready) # Step 11. @@ -204,15 +196,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.gemm(SP1_shared, KV_shared_1_l, acc_o_l) if k < loop_range - 1: - - T.copy( - KV[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N, - cur_kv_head, :h_dim], KV_shared_1_l) + T.copy(KV[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, :h_dim], KV_shared_1_l) T.barrier_arrive(kv_shared_1_l_is_ready) - T.copy( - K_pe[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N, cur_kv_head, :], - K_pe_shared_1) + T.copy(K_pe[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, :], K_pe_shared_1) T.barrier_arrive(kv_shared_1_pe_is_ready) T.copy(logsum_0, logsum) @@ -221,8 +208,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ for i, j in T.Parallel(block_H, h_dim): acc_o_l[i, j] /= logsum[i] T.copy(acc_o_l, O_shared_l) - T.copy(O_shared_l, Output[bid, - hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :h_dim]) + T.copy(O_shared_l, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :h_dim]) else: T.copy(Q_pe_shared, Q_pe_local_1) @@ -237,16 +223,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.barrier_arrive(kv_shared_0_pe_is_ready) for k in T.serial(loop_range): - # Step 2. T.barrier_wait(kv_shared_1_l_is_ready, k % 2) - T.gemm( - Q_shared_l, - KV_shared_1_l, - acc_s_1, - transpose_B=True, - clear_accum=True, - wg_wait=-1) + T.gemm(Q_shared_l, KV_shared_1_l, acc_s_1, transpose_B=True, clear_accum=True, wg_wait=-1) T.barrier_wait(kv_shared_1_r_is_ready, k % 2) T.gemm(Q_shared_r, KV_shared_1_r, acc_s_1, transpose_B=True, wg_wait=-1) @@ -265,8 +244,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.copy(scores_max_1, scores_max) for i in T.Parallel(block_H): - scores_scale_1[i] = T.exp2(scores_max_prev_1[i] * scale - - scores_max[i] * scale) + scores_scale_1[i] = T.exp2(scores_max_prev_1[i] * scale - scores_max[i] * scale) # Step 8. for i, j in T.Parallel(block_H, block_N): @@ -279,8 +257,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ acc_o_r[i, j] = acc_o_r[i, j] * (scores_scale_0[i] * scores_scale_1[i]) for i in T.Parallel(block_H): - logsum_1[i] = logsum_1[i] * scores_scale_1[i] * scores_scale_0[ - i] + scores_sum_1[i] + logsum_1[i] = logsum_1[i] * scores_scale_1[i] * scores_scale_0[i] + scores_sum_1[i] T.barrier_arrive(scale_1_ready_barrier) @@ -291,9 +268,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.barrier_arrive(s_shared_ready_barrier) if k < loop_range - 1: - T.copy( - KV[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N, cur_kv_head, - h_dim:], KV_shared_1_r) + T.copy(KV[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, h_dim:], KV_shared_1_r) T.barrier_arrive(kv_shared_1_r_is_ready) T.barrier_wait(p0_1_1_ready_barrier, k % 2) @@ -301,15 +276,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.gemm(SP0_shared, KV_shared_0_r, acc_o_r) if k < loop_range - 1: - - T.copy( - KV[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N, cur_kv_head, - h_dim:], KV_shared_0_r) + T.copy(KV[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, h_dim:], KV_shared_0_r) T.barrier_arrive(kv_shared_0_r_is_ready) - T.copy( - K_pe[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N, cur_kv_head, :], - K_pe_shared_0) + T.copy(K_pe[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, :], K_pe_shared_0) T.barrier_arrive(kv_shared_0_pe_is_ready) T.barrier_wait(lse_0_ready_barrier, 0) @@ -319,18 +289,17 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ for i, j in T.Parallel(block_H, h_dim): acc_o_r[i, j] /= logsum[i] T.copy(acc_o_r, O_shared_r) - T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, - h_dim:]) + T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, h_dim:]) @T.prim_func def main_no_split( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), ): flash_attn(Q, Q_pe, KV, K_pe, Output) @@ -352,31 +321,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): dim = q.shape[-1] pe_dim = q_pe.shape[-1] num_head_groups = q.shape[1] // kv.shape[2] - scale = (dim + pe_dim)**0.5 - q = rearrange( - q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] - q_pe = rearrange( - q_pe, 'b (h g) d -> b g h d', - g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] - kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] - k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] query = torch.concat([q, q_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1) - scores = einsum( - query, key, - 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv] + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] - attention = F.softmax( - scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] - out = einsum(attention, kv, - 'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] - out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim] + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] return out @@ -399,12 +361,12 @@ def main(batch=1, heads=64, kv_heads=1, kv_ctx=1024, dim=512, pe_dim=64): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--heads', type=int, default=128, help='q heads number') - parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') - parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim main(batch, heads, kv_heads, kv_ctx, dim, pe_dim) diff --git a/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py b/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py index b738a4b9c6fb814f7f2a2a3b058f93b5f83eb93f..5d438b5dedbdfa7f5ec34b741c6f143b730f57ee 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py @@ -8,7 +8,6 @@ tilelang.disable_cache() # @tilelang.jit @tilelang.jit(out_idx=[2]) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - num_stages = 2 mbarrier_list = [128, 128] * num_stages @@ -32,19 +31,13 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo for ko in range(T.ceildiv(K, block_K)): with T.ws(1): - T.mbarrier_wait_parity( - mbarrier=ko % num_stages + num_stages, - parity=((ko // num_stages) % num_stages) ^ 1) - T.copy(A[by * block_M:(by + 1) * block_M, ko * block_K:(ko + 1) * block_K], - A_shared[ko % num_stages, :, :]) - T.copy(B[ko * block_K:(ko + 1) * block_K, bx * block_N:(bx + 1) * block_N], - B_shared[ko % num_stages, :, :]) + T.mbarrier_wait_parity(mbarrier=ko % num_stages + num_stages, parity=((ko // num_stages) % num_stages) ^ 1) + T.copy(A[by * block_M : (by + 1) * block_M, ko * block_K : (ko + 1) * block_K], A_shared[ko % num_stages, :, :]) + T.copy(B[ko * block_K : (ko + 1) * block_K, bx * block_N : (bx + 1) * block_N], B_shared[ko % num_stages, :, :]) T.mbarrier_arrive(mbarrier=ko % num_stages) with T.ws(0): - T.mbarrier_wait_parity( - mbarrier=ko % num_stages, parity=(ko // num_stages) % num_stages) - T.gemm(A_shared[ko % num_stages, :, :], B_shared[ko % num_stages, :, :], - C_local) + T.mbarrier_wait_parity(mbarrier=ko % num_stages, parity=(ko // num_stages) % num_stages) + T.gemm(A_shared[ko % num_stages, :, :], B_shared[ko % num_stages, :, :], C_local) T.mbarrier_arrive(mbarrier=ko % num_stages + num_stages) with T.ws(0): diff --git a/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py b/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py index 9ba9f6816048cbbafeb275842cb197828732d24f..03ddf81228338497025c964b99c7d2edbf5da835 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py @@ -5,20 +5,12 @@ import tilelang.language as T # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit @tilelang.jit(out_idx=[2]) -def matmul_warp_specialize_copy_0_gemm_1(M, - N, - K, - block_M, - block_N, - block_K, - dtype="float16", - accum_dtype="float"): - +def matmul_warp_specialize_copy_0_gemm_1(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): diff --git a/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py b/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py index faaf48c6485fce0c1e04cf4e38de1deb2a244877..63aed2bed8d76fc0c77d5d7d604d996ec4e83ee0 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py @@ -5,20 +5,12 @@ import tilelang.language as T # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit @tilelang.jit(out_idx=[2]) -def matmul_warp_specialize_copy_1_gemm_0(M, - N, - K, - block_M, - block_N, - block_K, - dtype="float16", - accum_dtype="float"): - +def matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): diff --git a/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py b/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py index c91274540f308d5b58badde74b14d0178d31df04..f24d76a22f4a4f5637d5752ba157edd1fabe6468 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py @@ -5,26 +5,20 @@ import tilelang.language as T # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit @tilelang.jit( - out_idx=[2], pass_configs={ + out_idx=[2], + pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - }) -def matmul_warp_specialize_copy_1_gemm_0(M, - N, - K, - block_M, - block_N, - block_K, - dtype="float16", - accum_dtype="float"): - + }, +) +def matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): warp_group_num = 2 threads = 128 * warp_group_num @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): diff --git a/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py b/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py index 3b1d867198b673916b2b67cd0e9a7508d6ad369c..f3f8a665becd4b715c25ac1869d9204f5745d218 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py @@ -6,7 +6,6 @@ import tilelang.language as T # @tilelang.jit @tilelang.jit(out_idx=[2]) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def main( A: T.Tensor[(M, K), dtype], diff --git a/format.sh b/format.sh index e820b5886eb1d7f8182b734a3d05f0bb24578081..3cc4390dbe2a3b33e928c1c52f546dbf2cdc21bf 100755 --- a/format.sh +++ b/format.sh @@ -9,7 +9,7 @@ # bash format.sh --all # # -# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase. +# Ruff (format) + Clang formatter (if installed). This script formats all changed files from the last mergebase. # You are encouraged to run this locally before pushing changes for review. # Cause the script to exit if a single command fails diff --git a/maint/gemm_v2/correctness_evaluation.py b/maint/gemm_v2/correctness_evaluation.py index 33a58129600848439c627552fcbc8212bbdbd30d..e7a822544e295b12ff726ada17adcb92f25239a9 100644 --- a/maint/gemm_v2/correctness_evaluation.py +++ b/maint/gemm_v2/correctness_evaluation.py @@ -28,9 +28,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") @@ -66,7 +66,8 @@ def _compile_and_check( tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, # tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False, - }) + }, + ) print(kernel.get_kernel_source()) @@ -151,9 +152,9 @@ def matmul_rs( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") @@ -238,9 +239,9 @@ def matmul_sr( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") @@ -326,9 +327,9 @@ def matmul_rr( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") @@ -394,37 +395,48 @@ M_VALUES = [64, 128, 256] N_VALUES = [16, 32, 64, 128, 256, 512] K_VALUES = [16, 32, 64, 128] K_VALUES_8Bit = [32, 64, 128] -FALSE_TRUE_CASES = ([ - pytest.param( - k, - "float16", - "float16", - "float16", - id=f"K{k}-float16-float16-float16", - ) for k in K_VALUES -] + [pytest.param( - k, - "int8", - "int32", - "int32", - id="K32-int8-int32-int32", -) for k in K_VALUES_8Bit] + [ - pytest.param( - k, - "float8_e5m2", - "float32", - "float32", - id="K32-float8_e5m2-float32-float32", - ) for k in K_VALUES_8Bit -] + [ - pytest.param( - k, - "float8_e4m3", - "float32", - "float32", - id="K32-float8_e4m3-float32-float32", - ) for k in K_VALUES_8Bit -]) +FALSE_TRUE_CASES = ( + [ + pytest.param( + k, + "float16", + "float16", + "float16", + id=f"K{k}-float16-float16-float16", + ) + for k in K_VALUES + ] + + [ + pytest.param( + k, + "int8", + "int32", + "int32", + id="K32-int8-int32-int32", + ) + for k in K_VALUES_8Bit + ] + + [ + pytest.param( + k, + "float8_e5m2", + "float32", + "float32", + id="K32-float8_e5m2-float32-float32", + ) + for k in K_VALUES_8Bit + ] + + [ + pytest.param( + k, + "float8_e4m3", + "float32", + "float32", + id="K32-float8_e4m3-float32-float32", + ) + for k in K_VALUES_8Bit + ] +) def _ensure_torch_dtypes(*dtype_names): diff --git a/maint/gemm_v2/correctness_evaluation_sm70.py b/maint/gemm_v2/correctness_evaluation_sm70.py index 128f4abce56aa4f9e76270702eec3c43ee181210..3b4503d4e08ee0c09c56840a31158e2d2192ac7b 100644 --- a/maint/gemm_v2/correctness_evaluation_sm70.py +++ b/maint/gemm_v2/correctness_evaluation_sm70.py @@ -28,9 +28,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") @@ -67,7 +67,8 @@ def _compile_and_check( tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, # tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False, - }) + }, + ) print(kernel.get_kernel_source()) @@ -150,9 +151,9 @@ def matmul_rs( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") @@ -213,14 +214,15 @@ def run_gemm_rs( M_VALUES = [64, 128] N_VALUES = [32, 64, 128] K_VALUES = [16, 32, 64] -FALSE_TRUE_CASES = ([ +FALSE_TRUE_CASES = [ pytest.param( k, "float16", "float16", "float16", id=f"K{k}-float16-float16-float16", - ) for k in K_VALUES + ) + for k in K_VALUES ] + [ pytest.param( k, @@ -228,8 +230,9 @@ FALSE_TRUE_CASES = ([ "float16", "float32", id=f"K{k}-float16-float16-float32", - ) for k in K_VALUES -]) + ) + for k in K_VALUES +] def _ensure_torch_dtypes(*dtype_names): diff --git a/maint/gemm_v2/correctness_evaluation_tcgen05.py b/maint/gemm_v2/correctness_evaluation_tcgen05.py index 1831ac8aa1fee2d304bf79c4ff0886f1610d5f33..4ce8691ec8a96f8f64b5421a6cea720f71a9d033 100644 --- a/maint/gemm_v2/correctness_evaluation_tcgen05.py +++ b/maint/gemm_v2/correctness_evaluation_tcgen05.py @@ -27,9 +27,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -42,15 +42,7 @@ def matmul( for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_shared) T.copy(B[bx * block_N, k * block_K], B_shared) - T.gemm( - A_shared, - B_shared, - C_tmem, - trans_A, - trans_B, - mbar=mbar, - wg_wait=-1, - clear_accum=k == 0) + T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k == 0) T.mbarrier_wait_parity(mbar, k % 2) T.copy(C_tmem, C_local) @@ -74,7 +66,8 @@ def _compile_and_check( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) print(kernel.get_kernel_source()) @@ -138,14 +131,15 @@ M_VALUES = [32, 64, 128, 256] N_VALUES = [64, 128, 256, 512] K_VALUES = [16, 32, 64, 128] K_VALUES_8Bit = [32, 64, 128] -FALSE_TRUE_CASES = ([ +FALSE_TRUE_CASES = [ pytest.param( k, "float16", "float32", "float32", id=f"K{k}-float16-float-float", - ) for k in K_VALUES + ) + for k in K_VALUES ] + [ pytest.param( k, @@ -153,8 +147,9 @@ FALSE_TRUE_CASES = ([ "float32", "float32", id="K32-float8_e5m2-float32-float32", - ) for k in K_VALUES_8Bit -]) + ) + for k in K_VALUES_8Bit +] TRANS_CASES = [ pytest.param(False, True, id="nt"), diff --git a/maint/gemm_v2/latency.py b/maint/gemm_v2/latency.py index 07a502017da13161c8dc98a2fe7973f0f979c1ef..4dcb7cf9aaea88a7879c2e0894dfd2b068799aa7 100644 --- a/maint/gemm_v2/latency.py +++ b/maint/gemm_v2/latency.py @@ -14,12 +14,11 @@ use_v2 = args.use_v2 # if not specified, it will be inferred from the input tensors during compile time @tilelang.jit def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def matmul_relu_kernel( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): diff --git a/maint/gemm_v2/latency_gemm.py b/maint/gemm_v2/latency_gemm.py index 13392dec7f2d1510f6696c027fbd9ab554d1a7e4..a66167d4b24576fa7a16a2162f4ec727263b344f 100644 --- a/maint/gemm_v2/latency_gemm.py +++ b/maint/gemm_v2/latency_gemm.py @@ -14,12 +14,11 @@ use_v2 = args.use_v2 # if not specified, it will be inferred from the input tensors during compile time @tilelang.jit def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def matmul_relu_kernel( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): diff --git a/maint/gemm_v2/latency_mha_fwd_bhsd.py b/maint/gemm_v2/latency_mha_fwd_bhsd.py index 4126bb9d3c84e2141309a8d0f3efea533a9515c9..3fd560012f59fd34c2b18f10e78614df5ea5540d 100644 --- a/maint/gemm_v2/latency_mha_fwd_bhsd.py +++ b/maint/gemm_v2/latency_mha_fwd_bhsd.py @@ -8,13 +8,13 @@ import argparse from functools import partial parser = argparse.ArgumentParser() -parser.add_argument('--batch', type=int, default=128, help='batch size') -parser.add_argument('--heads', type=int, default=16, help='heads') -parser.add_argument('--seq_q', type=int, default=1024, help='query sequence length') -parser.add_argument('--seq_kv', type=int, default=1024, help='key/value sequence length') -parser.add_argument('--dim', type=int, default=256, help='dim') -parser.add_argument('--is_causal', action='store_true', help='causal') -parser.add_argument('--tune', action='store_true', help='tune configs') +parser.add_argument("--batch", type=int, default=128, help="batch size") +parser.add_argument("--heads", type=int, default=16, help="heads") +parser.add_argument("--seq_q", type=int, default=1024, help="query sequence length") +parser.add_argument("--seq_kv", type=int, default=1024, help="key/value sequence length") +parser.add_argument("--dim", type=int, default=256, help="dim") +parser.add_argument("--is_causal", action="store_true", help="causal") +parser.add_argument("--tune", action="store_true", help="tune configs") parser.add_argument("--use_v2", action="store_true") args = parser.parse_args() @@ -29,20 +29,13 @@ def get_configs(): @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit( - out_idx=[3], pass_configs={ + out_idx=[3], + pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) -def flashattn(batch, - heads, - seq_q, - seq_kv, - dim, - is_causal, - block_M=64, - block_N=64, - num_stages=0, - threads=128): - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + }, +) +def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=0, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] dtype = "float16" @@ -62,7 +55,7 @@ def flashattn(batch, by: T.int32, bz: T.int32, ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): q_idx = bx * block_M + i + past_len @@ -85,7 +78,7 @@ def flashattn(batch, by: T.int32, bz: T.int32, ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) # T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) if use_v2: T.gemm_v2(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @@ -94,13 +87,13 @@ def flashattn(batch, @T.macro def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -125,18 +118,18 @@ def flashattn(batch, @T.macro def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -152,43 +145,42 @@ def flashattn(batch, scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = ( - T.min( - T.ceildiv(seq_kv, block_N), T.ceildiv( - (bx + 1) * block_M + - past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) + T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + if is_causal + else T.ceildiv(seq_kv, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, - logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main def ref_program(Q, K, V, is_causal): dim = Q.size(-1) - scores = torch.einsum('bhqd,bhkd->bhqk', Q, K) + scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) if is_causal: seq_q = Q.size(2) seq_kv = K.size(2) mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + scores = scores.masked_fill(mask == 0, float("-inf")) attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V) + output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V) return output @@ -206,18 +198,8 @@ def main( if is_causal: total_flops *= 0.5 - if (not tune): - kernel = flashattn( - batch, - heads, - seq_q, - seq_kv, - dim, - is_causal, - block_M=64, - block_N=64, - num_stages=0, - threads=128) + if not tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=0, threads=128) print(kernel.get_kernel_source()) ref_program_processed = partial(ref_program, is_causal=is_causal) diff --git a/maint/host_checks/01_num_args_mismatch.py b/maint/host_checks/01_num_args_mismatch.py index 8ba366463c25ff66a48a940b2cb16ba48770820c..9528652eea985b6c1a57c80897329cd3c1eafef4 100644 --- a/maint/host_checks/01_num_args_mismatch.py +++ b/maint/host_checks/01_num_args_mismatch.py @@ -3,6 +3,7 @@ Note: The adapter-level wrapper expects only inputs (A, B) because C is marked as output. Calling with the wrong number of inputs raises a ValueError before host entry. """ + import torch from common import build_matmul_kernel diff --git a/maint/host_checks/02_pointer_type_error.py b/maint/host_checks/02_pointer_type_error.py index fd3585405c38dffc67b9829ebe923d88f50b6002..188a4f8cc02254a47665f82b538a0210fa51d768 100644 --- a/maint/host_checks/02_pointer_type_error.py +++ b/maint/host_checks/02_pointer_type_error.py @@ -3,6 +3,7 @@ We pass an integer for A; wrapper forwards it to the host where a pointer is expected. Expected: error like "Expect buffer A_handle to be pointer or tensor" (exact name depends on kernel param). """ + import torch from common import build_matmul_kernel diff --git a/maint/host_checks/03_ndim_mismatch.py b/maint/host_checks/03_ndim_mismatch.py index 994ce23e842e8eac4057ec05aa6a2c14848760b0..76637e8deda8a52165209c026661e45a2cf6da75 100644 --- a/maint/host_checks/03_ndim_mismatch.py +++ b/maint/host_checks/03_ndim_mismatch.py @@ -1,5 +1,5 @@ -"""Reproduce: ndim (rank) mismatch for A. -""" +"""Reproduce: ndim (rank) mismatch for A.""" + import torch from common import build_matmul_kernel diff --git a/maint/host_checks/04_dtype_mismatch.py b/maint/host_checks/04_dtype_mismatch.py index 6e6a0503e9f9de6474b72c615fc1ae4141d3bf61..f3554c1d6ace76971b1b787d8febfb7ce286dc09 100644 --- a/maint/host_checks/04_dtype_mismatch.py +++ b/maint/host_checks/04_dtype_mismatch.py @@ -1,5 +1,5 @@ -"""Reproduce: dtype mismatch for A (float32 vs expected float16). -""" +"""Reproduce: dtype mismatch for A (float32 vs expected float16).""" + import torch from common import build_matmul_kernel diff --git a/maint/host_checks/05_shape_mismatch.py b/maint/host_checks/05_shape_mismatch.py index 8b41ae36a209982b260b9d86f4ba064a4c3ce707..a48248176501d66c2a46207da7357d22b8519f7c 100644 --- a/maint/host_checks/05_shape_mismatch.py +++ b/maint/host_checks/05_shape_mismatch.py @@ -1,5 +1,5 @@ -"""Reproduce: shape constant/symbol mismatch on A. -""" +"""Reproduce: shape constant/symbol mismatch on A.""" + import torch from common import build_matmul_kernel diff --git a/maint/host_checks/06_strides_mismatch.py b/maint/host_checks/06_strides_mismatch.py index 477d200bcbda324ec638c6712d277419764a70de..7e523cd64ee2f9e2c762f8813901db4b50938a74 100644 --- a/maint/host_checks/06_strides_mismatch.py +++ b/maint/host_checks/06_strides_mismatch.py @@ -1,5 +1,5 @@ -"""Reproduce: strides check failure (non-contiguous A via transpose). -""" +"""Reproduce: strides check failure (non-contiguous A via transpose).""" + import torch from common import build_matmul_kernel diff --git a/maint/host_checks/07_device_type_mismatch.py b/maint/host_checks/07_device_type_mismatch.py index 67cb7718c15c10380012cdef3c32312146b13c12..af8e5efd5dfcf9514de445aec8eb8bedd676f1f0 100644 --- a/maint/host_checks/07_device_type_mismatch.py +++ b/maint/host_checks/07_device_type_mismatch.py @@ -1,5 +1,5 @@ -"""Reproduce: device_type mismatch by passing CPU tensors to a CUDA kernel. -""" +"""Reproduce: device_type mismatch by passing CPU tensors to a CUDA kernel.""" + import torch from common import build_matmul_kernel diff --git a/maint/host_checks/08_device_id_mismatch.py b/maint/host_checks/08_device_id_mismatch.py index 6491096615f9b715d3db676dfd6a7f4aef457439..280aca1570d9200aaca9f21260c6270c760e8c5b 100644 --- a/maint/host_checks/08_device_id_mismatch.py +++ b/maint/host_checks/08_device_id_mismatch.py @@ -1,5 +1,5 @@ -"""Reproduce: device_id mismatch (requires >=2 CUDA devices). -""" +"""Reproduce: device_id mismatch (requires >=2 CUDA devices).""" + import torch from common import build_matmul_kernel diff --git a/maint/host_checks/09_null_data_pointer.py b/maint/host_checks/09_null_data_pointer.py index 00bac67dd089906fd45fbbcecb515dbcfec005d0..09f5de1aff04438b3f24113c06979c0bd1c61dd5 100644 --- a/maint/host_checks/09_null_data_pointer.py +++ b/maint/host_checks/09_null_data_pointer.py @@ -7,6 +7,7 @@ or a host-side non-NULL pointer check. Note: Constructing a true DLTensor with NULL data in PyTorch is not typical; this script demonstrates passing None, which still reproduces the intended class of failure. """ + import torch from common import build_matmul_kernel diff --git a/maint/host_checks/10_scalar_type_mismatch.py b/maint/host_checks/10_scalar_type_mismatch.py index f1fcba274a14c2d41cfe91f2e4d0307bec3ea6f4..4f2c90b8d1df8a51eb0bdffa94e2223ecbcf47fb 100644 --- a/maint/host_checks/10_scalar_type_mismatch.py +++ b/maint/host_checks/10_scalar_type_mismatch.py @@ -1,5 +1,5 @@ -"""Reproduce: scalar parameter type mismatch (int/bool). -""" +"""Reproduce: scalar parameter type mismatch (int/bool).""" + from common import build_scalar_check_kernel diff --git a/maint/host_checks/common.py b/maint/host_checks/common.py index cdafc8bd8215b783d90130d33510d4dd8fd6da39..649527d4a637fe9e05517dcbd2f15e5eacf1436e 100644 --- a/maint/host_checks/common.py +++ b/maint/host_checks/common.py @@ -3,20 +3,12 @@ import tilelang.language as T import torch -def make_matmul_prim(M, - N, - K, - block_M=128, - block_N=128, - block_K=32, - dtype="float16", - accum_dtype="float"): - +def make_matmul_prim(M, N, K, block_M=128, block_N=128, block_K=32, dtype="float16", accum_dtype="float"): @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -42,7 +34,6 @@ def build_matmul_kernel(M=1024, N=1024, K=1024, target="cuda"): def build_scalar_check_kernel(target="cuda"): - @T.prim_func def scalar_check(x: T.int32, flag: T.bool()): T.evaluate(0) diff --git a/maint/precision/compare_ops.py b/maint/precision/compare_ops.py index 7d0d67db734d033fb4919b505dff13728aa96826..985c3bd9652e4dc3866e88a5e1cb698000d232f0 100644 --- a/maint/precision/compare_ops.py +++ b/maint/precision/compare_ops.py @@ -37,7 +37,7 @@ OP_NAMES: Dict[int, str] = { 6: "sqrt", 7: "tanh", 8: "rsqrt", - 9: "inv_sqrt" + 9: "inv_sqrt", } # Block sizes for kernels @@ -49,8 +49,7 @@ TILELANG_THREADS = 128 def parse_arguments() -> argparse.Namespace: """Parse command line arguments.""" - parser = argparse.ArgumentParser( - description="Precision comparison tool for various CUDA implementations") + parser = argparse.ArgumentParser(description="Precision comparison tool for various CUDA implementations") parser.add_argument("--n", type=int, default=1000000, help="Number of elements to test") parser.add_argument("--low", type=float, default=-4.0, help="Lower bound for random values") parser.add_argument("--high", type=float, default=4.0, help="Upper bound for random values") @@ -67,7 +66,7 @@ def initialize_cuda() -> torch.nn.Module: return load( name="cuda_ops", sources=["cuda_ops.cu"], - extra_cuda_cflags=[] # No fast_math flags + extra_cuda_cflags=[], # No fast_math flags ) @@ -149,8 +148,7 @@ def triton_unary_kernel(x_ptr, out_ptr, n_elements, op_id: tl.constexpr, BLOCK_S @triton.jit -def triton_libdevice_unary_kernel(x_ptr, out_ptr, n_elements, op_id: tl.constexpr, - BLOCK_SIZE: tl.constexpr): +def triton_libdevice_unary_kernel(x_ptr, out_ptr, n_elements, op_id: tl.constexpr, BLOCK_SIZE: tl.constexpr): """LibDevice Triton kernel for unary operations.""" pid = tl.program_id(0) block_start = pid * BLOCK_SIZE @@ -188,13 +186,10 @@ def make_tilelang_unary_kernel(M: int, N: int, op_id: int, use_fastmath: bool = @T.prim_func def tilelang_unary_kernel( - A: T.Tensor((M, N), "float32"), - B: T.Tensor((M, N), "float32"), + A: T.Tensor((M, N), "float32"), + B: T.Tensor((M, N), "float32"), ): - with T.Kernel( - T.ceildiv(N, TILELANG_BLOCK_N), - T.ceildiv(M, TILELANG_BLOCK_M), - threads=TILELANG_THREADS) as (bx, by): + with T.Kernel(T.ceildiv(N, TILELANG_BLOCK_N), T.ceildiv(M, TILELANG_BLOCK_M), threads=TILELANG_THREADS) as (bx, by): for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N): row = by * TILELANG_BLOCK_M + i col = bx * TILELANG_BLOCK_N + j @@ -229,14 +224,11 @@ def make_tilelang_binary_kernel(M: int, N: int): @T.prim_func def tilelang_binary_kernel( - A: T.Tensor((M, N), "float32"), - B: T.Tensor((M, N), "float32"), - C: T.Tensor((M, N), "float32"), + A: T.Tensor((M, N), "float32"), + B: T.Tensor((M, N), "float32"), + C: T.Tensor((M, N), "float32"), ): - with T.Kernel( - T.ceildiv(N, TILELANG_BLOCK_N), - T.ceildiv(M, TILELANG_BLOCK_M), - threads=TILELANG_THREADS) as (bx, by): + with T.Kernel(T.ceildiv(N, TILELANG_BLOCK_N), T.ceildiv(M, TILELANG_BLOCK_M), threads=TILELANG_THREADS) as (bx, by): for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N): row = by * TILELANG_BLOCK_M + i col = bx * TILELANG_BLOCK_N + j @@ -247,10 +239,7 @@ def make_tilelang_binary_kernel(M: int, N: int): return tilelang_binary_kernel -def tilelang_op(x: torch.Tensor, - op_id: int, - y: Optional[torch.Tensor] = None, - use_fastmath: bool = False) -> torch.Tensor: +def tilelang_op(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None, use_fastmath: bool = False) -> torch.Tensor: """TileLang operation interface.""" assert x.is_cuda @@ -272,7 +261,8 @@ def tilelang_op(x: torch.Tensor, target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: use_fastmath, - }) + }, + ) out = kernel(x, y) else: # Unary operation kernel_func = make_tilelang_unary_kernel(M, N, op_id, use_fastmath) @@ -282,7 +272,8 @@ def tilelang_op(x: torch.Tensor, target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: use_fastmath, - }) + }, + ) out = kernel(x) # Restore original shape @@ -293,7 +284,7 @@ def triton_op(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None) -> """Standard Triton operation interface.""" assert x.is_cuda out = torch.empty_like(x) - grid = lambda meta: ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],) + grid = lambda meta: ((x.numel() + meta["BLOCK_SIZE"] - 1) // meta["BLOCK_SIZE"],) if op_id == 0: # Division - binary operation assert y is not None, "Division operation requires second operand" @@ -304,13 +295,11 @@ def triton_op(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None) -> return out -def triton_libdevice_op(x: torch.Tensor, - op_id: int, - y: Optional[torch.Tensor] = None) -> torch.Tensor: +def triton_libdevice_op(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None) -> torch.Tensor: """LibDevice Triton operation interface.""" assert x.is_cuda out = torch.empty_like(x) - grid = lambda meta: ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],) + grid = lambda meta: ((x.numel() + meta["BLOCK_SIZE"] - 1) // meta["BLOCK_SIZE"],) if op_id == 0: # Division - binary operation assert y is not None, "Division operation requires second operand" @@ -321,9 +310,7 @@ def triton_libdevice_op(x: torch.Tensor, return out -def get_pytorch_reference(x: torch.Tensor, - op_id: int, - y: Optional[torch.Tensor] = None) -> torch.Tensor: +def get_pytorch_reference(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None) -> torch.Tensor: """Get PyTorch reference implementation for the given operation.""" if op_id == 0: assert y is not None, "Division requires second operand" @@ -362,8 +349,10 @@ def summarize_error(tag: str, output: Optional[torch.Tensor], reference: torch.T abs_err = (output_double - reference_double).abs() rel_err = abs_err / (reference_double.abs().clamp_min(1e-30)) - print(f"{tag:<32} max abs: {abs_err.max():.3e}, mean abs: {abs_err.mean():.3e}, " - f"max rel: {rel_err.max():.3e}, mean rel: {rel_err.mean():.3e}") + print( + f"{tag:<32} max abs: {abs_err.max():.3e}, mean abs: {abs_err.mean():.3e}, " + f"max rel: {rel_err.max():.3e}, mean rel: {rel_err.mean():.3e}" + ) # Precision comparison function @@ -407,9 +396,7 @@ def compare(op_id: int, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> No results[name] = None # Print comparison header - print( - f"{'Implementation':<32} {'Max Abs Error':<19} {'Mean Abs Error':<20} {'Max Rel Error':<19} {'Mean Rel Error'}" - ) + print(f"{'Implementation':<32} {'Max Abs Error':<19} {'Mean Abs Error':<20} {'Max Rel Error':<19} {'Mean Rel Error'}") print("-" * 90) # Compare all implementations against double precision reference @@ -427,8 +414,7 @@ def compare(op_id: int, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> No summarize_error(tag, output, ref_double) -def generate_test_data(op_id: int, n: int, device: torch.device, low: float, - high: float) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: +def generate_test_data(op_id: int, n: int, device: torch.device, low: float, high: float) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Generate appropriate test data for each operation.""" if op_id == 0: # Division x = torch.empty(n, device=device).uniform_(low, high) @@ -450,9 +436,7 @@ def generate_test_data(op_id: int, n: int, device: torch.device, low: float, def main() -> None: """Main execution function.""" - print( - "Precision comparison between CUDA Precise/Fast, Triton, Triton LibDevice, PyTorch, and TileLang" - ) + print("Precision comparison between CUDA Precise/Fast, Triton, Triton LibDevice, PyTorch, and TileLang") print("=" * 90) for op_id in range(len(OP_NAMES)): diff --git a/maint/scripts/ci_performance.py b/maint/scripts/ci_performance.py index 998e7b650f955af72f35c055feec4e8a43b69f59..8a353c0a999fb8848159ed3b21ec298cb3a47185 100644 --- a/maint/scripts/ci_performance.py +++ b/maint/scripts/ci_performance.py @@ -10,39 +10,32 @@ env["TILELANG_CLEAR_CACHE"] = "1" def parse_output(output): data = {} - for line in output.split('\n'): + for line in output.split("\n"): line = line.strip() - if line.startswith('Latency:'): - match = re.search(r'Latency: ([\d.]+)', line) - data['latency'] = match.group(1) if match else 'N/A' - elif line.startswith('TFlops:'): - match = re.search(r'TFlops: ([\d.]+)', line) - data['best_tflops'] = match.group(1) if match else 'N/A' - elif line.startswith('Config:'): - data['config'] = line.split('Config: ')[-1] - elif line.startswith('Reference TFlops:'): - match = re.search(r'Reference TFlops: ([\d.]+)', line) - data['ref_tflops'] = match.group(1) if match else 'N/A' + if line.startswith("Latency:"): + match = re.search(r"Latency: ([\d.]+)", line) + data["latency"] = match.group(1) if match else "N/A" + elif line.startswith("TFlops:"): + match = re.search(r"TFlops: ([\d.]+)", line) + data["best_tflops"] = match.group(1) if match else "N/A" + elif line.startswith("Config:"): + data["config"] = line.split("Config: ")[-1] + elif line.startswith("Reference TFlops:"): + match = re.search(r"Reference TFlops: ([\d.]+)", line) + data["ref_tflops"] = match.group(1) if match else "N/A" return data -output_v1 = subprocess.run(['./tl/bin/python', './maint/scripts/performance.py'], - capture_output=True, - text=True, - env=env).stdout +output_v1 = subprocess.run(["./tl/bin/python", "./maint/scripts/performance.py"], capture_output=True, text=True, env=env).stdout data_v1 = parse_output(output_v1) -output_v2 = subprocess.run(['./tll/bin/python', './maint/scripts/performance.py'], - capture_output=True, - text=True, - env=env).stdout +output_v2 = subprocess.run(["./tll/bin/python", "./maint/scripts/performance.py"], capture_output=True, text=True, env=env).stdout data_v2 = parse_output(output_v2) -table = [[ - "original", data_v1['latency'], data_v1['best_tflops'], data_v1['ref_tflops'], data_v1['config'] -], [ - "current", data_v2['latency'], data_v2['best_tflops'], data_v2['ref_tflops'], data_v2['config'] -]] +table = [ + ["original", data_v1["latency"], data_v1["best_tflops"], data_v1["ref_tflops"], data_v1["config"]], + ["current", data_v2["latency"], data_v2["best_tflops"], data_v2["ref_tflops"], data_v2["config"]], +] headers = ["version", "Best Latency (s)", "Best TFlops", "Reference TFlops", "Best Config"] diff --git a/maint/scripts/performance.py b/maint/scripts/performance.py index 24c4a21e8d13e565e12963ced9c56387ed69ca58..849bcf362cf002036dd8dffbc34e287220c42415 100644 --- a/maint/scripts/performance.py +++ b/maint/scripts/performance.py @@ -8,19 +8,20 @@ def ref_program(A, B): def get_configs(): - configs = [{ - "block_M": 128, - "block_N": 128, - "block_K": 64, - "num_stages": 2, - "thread_num": 256, - "enable_rasteration": True, # keep param name for backward-compat - }] + configs = [ + { + "block_M": 128, + "block_N": 128, + "block_K": 64, + "num_stages": 2, + "thread_num": 256, + "enable_rasteration": True, # keep param name for backward-compat + } + ] return configs def run(M, N, K): - def kernel( block_M=None, block_N=None, @@ -34,12 +35,11 @@ def run(M, N, K): @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_N, block_K), dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -60,12 +60,16 @@ def run(M, N, K): return main - autotuner = AutoTuner.from_kernel( - kernel=kernel, configs=get_configs()).set_compile_args( + autotuner = ( + AutoTuner.from_kernel(kernel=kernel, configs=get_configs()) + .set_compile_args( out_idx=[-1], target="auto", - ).set_profile_args( - ref_prog=ref_program,) + ) + .set_profile_args( + ref_prog=ref_program, + ) + ) return autotuner.run(warmup=3, rep=20) diff --git a/pyproject.toml b/pyproject.toml index 22467134598aef3072f60344f105dba0b5d8b9a2..992eba55ce45664425daa1952651aacb86f7b7dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,10 +122,7 @@ tilelang = "tilelang" "tilelang/3rdparty/composable_kernel/include" = "3rdparty/composable_kernel/include" "tilelang/3rdparty/composable_kernel/library" = "3rdparty/composable_kernel/library" -[tool.yapf] -based_on_style = "yapf" -column_limit = 100 -indent_width = 4 + [tool.codespell] ignore-words = "docs/spelling_wordlist.txt" @@ -138,7 +135,7 @@ skip = [ [tool.ruff] target-version = "py39" -line-length = 100 +line-length = 140 output-format = "full" exclude = [ @@ -146,6 +143,14 @@ exclude = [ "examples/deepseek_v32/inference", ] +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" +docstring-code-format = false +docstring-code-line-length = "dynamic" + [tool.ruff.lint.per-file-ignores] # Do not upgrade type hint in testing and examples. # See https://github.com/tile-ai/tilelang/issues/1079 for more information. diff --git a/requirements-lint.txt b/requirements-lint.txt index e64eee16059b9b0e48f1237666c3cc6cab8c54f6..54f03638b0b058d4be2b43ab3f9d66774eb95500 100644 --- a/requirements-lint.txt +++ b/requirements-lint.txt @@ -4,4 +4,3 @@ clang-format==21.1.2 clang-tidy==21.1.1 codespell[toml]==2.4.1 ruff==0.14.3 -yapf==0.43.0 diff --git a/testing/conftest.py b/testing/conftest.py index 9f49d40a9b50e14e41915811589d0011d3c2c910..4010e0d83ae84c641151d6dd56dbf40ee42e301f 100644 --- a/testing/conftest.py +++ b/testing/conftest.py @@ -33,12 +33,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): "warnings", "error", } - if (sum( - len(terminalreporter.stats.get(k, [])) - for k in known_types.difference({"skipped", "deselected"})) == 0): + if sum(len(terminalreporter.stats.get(k, [])) for k in known_types.difference({"skipped", "deselected"})) == 0: terminalreporter.write_sep( "!", - (f"Error: No tests were collected. " - f"{dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"), + (f"Error: No tests were collected. {dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"), ) pytest.exit("No tests were collected.", returncode=5) diff --git a/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py b/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py index a01bd45963252f1e0beadab5ecdd6d3c3002bace..4007bebe313aa9254b0eeb3e55b4fcada5a546b8 100644 --- a/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py +++ b/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py @@ -4,7 +4,8 @@ from tilelang import tvm as tvm import tilelang.language as T from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout from tilelang.intrinsics.mfma_macro_generator import ( - MatrixCoreIntrinEmitter,) + MatrixCoreIntrinEmitter, +) from tilelang.transform import simplify_prim_func tilelang.testing.set_random_seed(0) @@ -22,7 +23,6 @@ def tl_matmul( b_transposed=True, k_pack=1, ): - micro_size_x = micro_size_y = micro_size_k = 16 if in_dtype in {"float8_e4m3fnuz", "int8"}: @@ -78,12 +78,11 @@ def tl_matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -91,10 +90,12 @@ def tl_matmul( B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -102,7 +103,6 @@ def tl_matmul( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=0): - # Load A into shared memory if a_transposed: T.copy(A[ko * block_K, by * block_M], A_shared) @@ -116,7 +116,6 @@ def tl_matmul( T.copy(B[ko * block_K, bx * block_N], B_shared) for ki in T.serial(0, (block_K // (k_pack * micro_size_k))): - # Load A into fragment mfma_emitter.ldmatrix_a( A_local, @@ -160,17 +159,8 @@ def tl_matmul( return main -def assert_tl_matmul_correctness(M, - N, - K, - in_dtype, - out_dtype, - accum_dtype="float32", - a_transposed=False, - b_transposed=True, - k_pack=1): - matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, - k_pack) +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="float32", a_transposed=False, b_transposed=True, k_pack=1): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack) print(matmul) kernel = tilelang.compile(matmul) src_code = kernel.get_kernel_source() @@ -201,16 +191,13 @@ def assert_tl_matmul_correctness(M, if a_transposed and b_transposed: # Get Reference Result - ref_c = torch.matmul(A.T.to(torch.float32), - B.T.to(torch.float32)).to(getattr(torch, out_dtype)) + ref_c = torch.matmul(A.T.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) elif a_transposed and not b_transposed: # Get Reference Result - ref_c = torch.matmul(A.Tto(torch.float32), - B.to(torch.float32)).to(getattr(torch, out_dtype)) + ref_c = torch.matmul(A.Tto(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype)) elif not a_transposed and b_transposed: # Get Reference Result - ref_c = torch.matmul(A.to(torch.float32), - B.T.to(torch.float32)).to(getattr(torch, out_dtype)) + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) else: # Get Reference Result ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype)) @@ -228,16 +215,13 @@ def test_assert_tl_matmul(): assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32") assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32") assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2) - assert_tl_matmul_correctness( - 128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32") - assert_tl_matmul_correctness( - 128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2) + assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32") + assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2) assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3fnuz", "float16") assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32") assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", k_pack=2) assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False) - assert_tl_matmul_correctness( - 128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False, k_pack=2) + assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False, k_pack=2) if __name__ == "__main__": diff --git a/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py b/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py index b215f0d45c879413330f0553af239813ce3808b0..393a77b78b63d20dc1b72ad8c471dd831823677f 100644 --- a/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py +++ b/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py @@ -23,7 +23,6 @@ def tl_matmul( b_preshuffle=False, b_g2l_load=False, ): - micro_size_x = micro_size_y = micro_size_k = 16 if in_dtype in {"float8_e4m3fnuz", "int8"}: @@ -53,18 +52,21 @@ def tl_matmul( A_shape = (K, M) if a_transposed else (M, K) if b_preshuffle: - B_shape = (N // micro_size_y, K // pack_size_k, micro_size_y, - pack_size_k) if b_transposed else (K // pack_size_k, N // micro_size_y, - pack_size_k, micro_size_y) + B_shape = ( + (N // micro_size_y, K // pack_size_k, micro_size_y, pack_size_k) + if b_transposed + else (K // pack_size_k, N // micro_size_y, pack_size_k, micro_size_y) + ) else: B_shape = (N, K) if b_transposed else (K, N) A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K) if b_preshuffle: - B_shared_shape = (block_N // micro_size_y, block_K // pack_size_k, micro_size_y, - pack_size_k) if b_transposed else (block_K // pack_size_k, - block_N // micro_size_y, pack_size_k, - micro_size_y) + B_shared_shape = ( + (block_N // micro_size_y, block_K // pack_size_k, micro_size_y, pack_size_k) + if b_transposed + else (block_K // pack_size_k, block_N // micro_size_y, pack_size_k, micro_size_y) + ) else: B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N) @@ -94,21 +96,22 @@ def tl_matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + } + ) num_ko = K // block_K num_ki = block_K // (k_pack * micro_size_k) @@ -119,7 +122,6 @@ def tl_matmul( T.clear(C_local) for ko in T.Pipelined(num_ko, num_stages=0): - # Load A into shared memory if a_transposed: T.copy(A[ko * block_K, by * block_M], A_shared) @@ -129,20 +131,13 @@ def tl_matmul( # Load B into shared memory if b_g2l_load is False: if b_transposed: - for j, k, jj, kk in T.Parallel(block_N // micro_size_y, - block_K // pack_size_k, micro_size_y, - pack_size_k): - B_shared[j, k, jj, kk] = B[bx * block_N // micro_size_y + j, - ko * block_K // pack_size_k + k, jj, kk] + for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // pack_size_k, micro_size_y, pack_size_k): + B_shared[j, k, jj, kk] = B[bx * block_N // micro_size_y + j, ko * block_K // pack_size_k + k, jj, kk] else: - for k, j, kk, jj in T.Parallel(block_K // pack_size_k, - block_N // micro_size_y, pack_size_k, - micro_size_y): - B_shared[k, j, kk, jj] = B[ko * block_K // pack_size_k + k, - bx * block_N // micro_size_y + j, kk, jj] + for k, j, kk, jj in T.Parallel(block_K // pack_size_k, block_N // micro_size_y, pack_size_k, micro_size_y): + B_shared[k, j, kk, jj] = B[ko * block_K // pack_size_k + k, bx * block_N // micro_size_y + j, kk, jj] for ki in T.serial(0, num_ki): - # Load A S2L mfma_emitter.ldmatrix_a( A_local, @@ -176,10 +171,10 @@ def tl_matmul( def shuffle_weight( - x: torch.Tensor, - layout=(16, 32), - k_pack=1, - is_transpose=False, + x: torch.Tensor, + layout=(16, 32), + k_pack=1, + is_transpose=False, ) -> torch.Tensor: IN, IK = layout BK = IK * k_pack @@ -194,19 +189,20 @@ def shuffle_weight( return x.contiguous() -def assert_tl_matmul_correctness(M, - N, - K, - in_dtype, - out_dtype, - accum_dtype="float32", - a_transposed=False, - b_transposed=True, - k_pack=1, - b_preshuffle=False, - b_g2l_load=False): - matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, - k_pack, b_preshuffle, b_g2l_load) +def assert_tl_matmul_correctness( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype="float32", + a_transposed=False, + b_transposed=True, + k_pack=1, + b_preshuffle=False, + b_g2l_load=False, +): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack, b_preshuffle, b_g2l_load) print(matmul) kernel = tilelang.compile(matmul) src_code = kernel.get_kernel_source() @@ -244,16 +240,13 @@ def assert_tl_matmul_correctness(M, if a_transposed and b_transposed: # Get Reference Result - ref_c = torch.matmul(A.T.to(torch.float32), - B.T.to(torch.float32)).to(getattr(torch, out_dtype)) + ref_c = torch.matmul(A.T.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) elif a_transposed and not b_transposed: # Get Reference Result - ref_c = torch.matmul(A.Tto(torch.float32), - B.to(torch.float32)).to(getattr(torch, out_dtype)) + ref_c = torch.matmul(A.Tto(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype)) elif not a_transposed and b_transposed: # Get Reference Result - ref_c = torch.matmul(A.to(torch.float32), - B.T.to(torch.float32)).to(getattr(torch, out_dtype)) + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) else: # Get Reference Result ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype)) @@ -266,40 +259,17 @@ def assert_tl_matmul_correctness(M, @tilelang.testing.requires_rocm def test_assert_tl_matmul(): - assert_tl_matmul_correctness( - 256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True) - assert_tl_matmul_correctness( - 256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True) - assert_tl_matmul_correctness( - 256, 256, 512, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True) - - assert_tl_matmul_correctness( - 256, 256, 512, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True) - assert_tl_matmul_correctness( - 256, - 256, - 512, - "int8", - "int32", - b_transposed=False, - accum_dtype="int32", - k_pack=2, - b_preshuffle=True) + assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True) + assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True) + assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True) + + assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True) + assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2, b_preshuffle=True) assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", b_preshuffle=True) - assert_tl_matmul_correctness( - 256, 256, 512, "float8_e4m3fnuz", "float32", b_transposed=False, b_preshuffle=True) - assert_tl_matmul_correctness( - 256, 256, 512, "float8_e4m3fnuz", "float32", k_pack=2, b_preshuffle=True) - assert_tl_matmul_correctness( - 256, - 256, - 512, - "float8_e4m3fnuz", - "float32", - k_pack=2, - b_transposed=False, - b_preshuffle=True) + assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", b_transposed=False, b_preshuffle=True) + assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", k_pack=2, b_preshuffle=True) + assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", k_pack=2, b_transposed=False, b_preshuffle=True) if __name__ == "__main__": diff --git a/testing/python/amd/test_tilelang_test_amd.py b/testing/python/amd/test_tilelang_test_amd.py index 456a3ae46a32e94a7ed1b6a4a1c3690da3831bbe..0666fd479c45de17fbc0540b890f426dcb5e5838 100644 --- a/testing/python/amd/test_tilelang_test_amd.py +++ b/testing/python/amd/test_tilelang_test_amd.py @@ -27,8 +27,7 @@ def matmul( vec_size = 4 * k_pack @T.prim_func - def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor( - (M, N), out_dtype)): + def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) @@ -111,8 +110,7 @@ def test_gemm_bf16f32f32_nt(): run_gemm(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32) run_gemm(1024, 1024, 1024, True, True, "bfloat16", "float32", "float32", 128, 128, 32) run_gemm(1024, 1024, 1024, True, False, "bfloat16", "float32", "float32", 128, 128, 32) - run_gemm( - 1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32, k_pack=2) + run_gemm(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32, k_pack=2) @tilelang.testing.requires_rocm @@ -121,8 +119,7 @@ def test_gemm_bf16bf16f32(): run_gemm(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32) run_gemm(1024, 1024, 1024, True, True, "bfloat16", "bfloat16", "float32", 128, 128, 32) run_gemm(1024, 1024, 1024, True, False, "bfloat16", "bfloat16", "float32", 128, 128, 32) - run_gemm( - 1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32, k_pack=2) + run_gemm(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32, k_pack=2) def matmul_rs( @@ -149,9 +146,9 @@ def matmul_rs( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) diff --git a/testing/python/analysis/test_tilelang_fragment_loop_checker.py b/testing/python/analysis/test_tilelang_fragment_loop_checker.py index df88573f8c88e41d8586d16773213e9847f6c85a..85aa51895808c93be149fb2e7c774aee3541eec0 100644 --- a/testing/python/analysis/test_tilelang_fragment_loop_checker.py +++ b/testing/python/analysis/test_tilelang_fragment_loop_checker.py @@ -5,14 +5,12 @@ import pytest @tilelang.jit -def simple_invalid_loop(dtype: str = "bfloat16", - accum_dtype: str = "float32", - num_threads: int = 128): +def simple_invalid_loop(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128): A = T.dynamic("A") @T.prim_func def main( - data: T.Tensor((128, A), dtype), # type: ignore + data: T.Tensor((128, A), dtype), # type: ignore ): with T.Kernel(128, threads=num_threads) as (tid,): data_frag = T.alloc_fragment([128], accum_dtype) @@ -28,14 +26,12 @@ def simple_invalid_loop(dtype: str = "bfloat16", @tilelang.jit -def nested_invalid_loop(dtype: str = "bfloat16", - accum_dtype: str = "float32", - num_threads: int = 128): +def nested_invalid_loop(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128): A = T.dynamic("A") @T.prim_func def main( - data: T.Tensor((128, A), dtype), # type: ignore + data: T.Tensor((128, A), dtype), # type: ignore ): with T.Kernel(128, threads=num_threads) as (tid,): data_frag = T.alloc_fragment([128], accum_dtype) @@ -52,14 +48,12 @@ def nested_invalid_loop(dtype: str = "bfloat16", @tilelang.jit -def invalid_loop_with_complex_dataflow(dtype: str = "bfloat16", - accum_dtype: str = "float32", - num_threads: int = 128): +def invalid_loop_with_complex_dataflow(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128): A = T.dynamic("A") @T.prim_func def main( - data: T.Tensor((128, A), dtype), # type: ignore + data: T.Tensor((128, A), dtype), # type: ignore ): with T.Kernel(128, threads=num_threads) as (tid,): data_frag = T.alloc_fragment([128], accum_dtype) @@ -75,14 +69,12 @@ def invalid_loop_with_complex_dataflow(dtype: str = "bfloat16", @tilelang.jit -def valid_loop_not_use_loop_var(dtype: str = "bfloat16", - accum_dtype: str = "float32", - num_threads: int = 128): +def valid_loop_not_use_loop_var(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128): A = T.dynamic("A") @T.prim_func def main( - data: T.Tensor((128, A), dtype), # type: ignore + data: T.Tensor((128, A), dtype), # type: ignore ): with T.Kernel(128, threads=num_threads) as (tid,): data_frag = T.alloc_fragment([128], accum_dtype) @@ -99,14 +91,12 @@ def valid_loop_not_use_loop_var(dtype: str = "bfloat16", @tilelang.jit -def valid_loop_not_frag(dtype: str = "bfloat16", - accum_dtype: str = "float32", - num_threads: int = 128): +def valid_loop_not_frag(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128): A = T.dynamic("A") @T.prim_func def main( - data: T.Tensor((128, A), dtype), # type: ignore + data: T.Tensor((128, A), dtype), # type: ignore ): with T.Kernel(128, threads=num_threads) as (tid,): data_shared = T.alloc_shared([128], accum_dtype) @@ -122,14 +112,12 @@ def valid_loop_not_frag(dtype: str = "bfloat16", @tilelang.jit -def valid_loop_serial(dtype: str = "bfloat16", - accum_dtype: str = "float32", - num_threads: int = 128): +def valid_loop_serial(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128): A = T.dynamic("A") @T.prim_func def main( - data: T.Tensor((128, A), dtype), # type: ignore + data: T.Tensor((128, A), dtype), # type: ignore ): with T.Kernel(128, threads=num_threads) as (tid,): data_shared = T.alloc_shared([128], accum_dtype) diff --git a/testing/python/analysis/test_tilelang_nested_loop_checker.py b/testing/python/analysis/test_tilelang_nested_loop_checker.py index d3c2ec20e6bd591faabb59e5d910d72cf5e6b7bc..e282c8e34983e369a34586ccb5c07fccb28ec8aa 100644 --- a/testing/python/analysis/test_tilelang_nested_loop_checker.py +++ b/testing/python/analysis/test_tilelang_nested_loop_checker.py @@ -30,11 +30,10 @@ Rule: @tilelang.jit(out_idx=[1]) def nested_continuous_parallels(length=256, block=16, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((length,), dtype), - B: T.Tensor((length,), dtype), + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), ): with T.Kernel(1, threads=length) as _: for i in T.Parallel(length // block): @@ -46,29 +45,26 @@ def nested_continuous_parallels(length=256, block=16, dtype="float32"): @tilelang.jit(out_idx=[1]) def nested_triple_continuous_parallels(length=256, block1=8, block2=2, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((length,), dtype), - B: T.Tensor((length,), dtype), + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), ): with T.Kernel(1, threads=length) as _: for i in T.Parallel(length // block1 // block2): for j in T.Parallel(block1): for k in T.Parallel(block2): - B[i * block1 * block2 + j * block2 + - k] = A[i * block1 * block2 + j * block2 + k] + 1.0 + B[i * block1 * block2 + j * block2 + k] = A[i * block1 * block2 + j * block2 + k] + 1.0 return main @tilelang.jit(out_idx=[1]) def nested_noncontinuous_parallels(length=256, block=16, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((length,), dtype), - B: T.Tensor((length,), dtype), + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), ): with T.Kernel(1, threads=length) as _: for i in T.Parallel(length // block): @@ -103,8 +99,9 @@ is OK. """ -def matmul_nested_pipelines(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, - out_dtype, accum_dtype, threads, order, stage, extra_pipeline_repeats): +def matmul_nested_pipelines( + M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, threads, order, stage, extra_pipeline_repeats +): A_shape = (K, M) if trans_A else (M, K) B_shape = (N, K) if trans_B else (K, N) A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) @@ -114,9 +111,9 @@ def matmul_nested_pipelines(M, N, K, block_M, block_N, block_K, trans_A, trans_B @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -180,7 +177,8 @@ def run_gemm_nested_pipelines( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = kernel.get_profiler() def ref_program(A, B): @@ -193,8 +191,8 @@ def run_gemm_nested_pipelines( if in_dtype == "float32": # Convert float32 to tfloat32 because tfloat32 mma cannot truncate # float32 automatically, -0x1000 meas - A = ((A.view(torch.int32) - 0x1000)).view(torch.float32) - B = ((B.view(torch.int32) - 0x1000)).view(torch.float32) + A = (A.view(torch.int32) - 0x1000).view(torch.float32) + B = (B.view(torch.int32) - 0x1000).view(torch.float32) C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(torch.__getattribute__(out_dtype)) return C @@ -218,11 +216,10 @@ is OK. @tilelang.jit(out_idx=[1]) def nested_continuous_serials(length=256, block=16, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((length,), dtype), - B: T.Tensor((length,), dtype), + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), ): with T.Kernel(1, threads=length) as _: for i in T.serial(length // block): @@ -234,11 +231,10 @@ def nested_continuous_serials(length=256, block=16, dtype="float32"): @tilelang.jit(out_idx=[1]) def nested_noncontinuous_serials(length=256, block=16, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((length,), dtype), - B: T.Tensor((length,), dtype), + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), ): with T.Kernel(1, threads=length) as _: for i in T.serial(length // block): @@ -277,11 +273,10 @@ Rule: @tilelang.jit(out_idx=[1]) def nested_continuous_sp(length=256, block=16, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((length,), dtype), - B: T.Tensor((length,), dtype), + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), ): with T.Kernel(1, threads=length) as _: for i in T.serial(length // block): @@ -293,11 +288,10 @@ def nested_continuous_sp(length=256, block=16, dtype="float32"): @tilelang.jit(out_idx=[1]) def nested_continuous_ps(length=256, block=16, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((length,), dtype), - B: T.Tensor((length,), dtype), + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), ): with T.Kernel(1, threads=length) as _: for i in T.Parallel(length // block): @@ -309,36 +303,32 @@ def nested_continuous_ps(length=256, block=16, dtype="float32"): @tilelang.jit(out_idx=[1]) def nested_continuous_psp(length=256, block1=8, block2=2, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((length,), dtype), - B: T.Tensor((length,), dtype), + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), ): with T.Kernel(1, threads=length) as _: for i in T.Parallel(length // block1 // block2): for j in T.serial(block1): for k in T.Parallel(block2): - B[i * block1 * block2 + j * block2 + - k] = A[i * block1 * block2 + j * block2 + k] + 1.0 + B[i * block1 * block2 + j * block2 + k] = A[i * block1 * block2 + j * block2 + k] + 1.0 return main @tilelang.jit(out_idx=[1]) def nested_continuous_sps(length=256, block1=8, block2=2, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((length,), dtype), - B: T.Tensor((length,), dtype), + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), ): with T.Kernel(1, threads=length) as _: for i in T.serial(length // block1 // block2): for j in T.Parallel(block1): for k in T.serial(block2): - B[i * block1 * block2 + j * block2 + - k] = A[i * block1 * block2 + j * block2 + k] + 1.0 + B[i * block1 * block2 + j * block2 + k] = A[i * block1 * block2 + j * block2 + k] + 1.0 return main @@ -399,9 +389,9 @@ def matmul_nested_pipa( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -444,9 +434,9 @@ def matmul_nested_papipa( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -505,7 +495,8 @@ def run_gemm_mixed_pp( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = kernel.get_profiler() def ref_program(A, B): @@ -514,8 +505,8 @@ def run_gemm_mixed_pp( if in_dtype == "float32": # Convert float32 to tfloat32 because tfloat32 mma cannot truncate # float32 automatically, -0x1000 meas - A = ((A.view(torch.int32) - 0x1000)).view(torch.float32) - B = ((B.view(torch.int32) - 0x1000)).view(torch.float32) + A = (A.view(torch.int32) - 0x1000).view(torch.float32) + B = (B.view(torch.int32) - 0x1000).view(torch.float32) C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(torch.__getattribute__(out_dtype)) return C @@ -543,7 +534,8 @@ def run_gemm_mixed_pp( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) def test_mixed_pp(): @@ -576,9 +568,9 @@ def matmul_with_parallel( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -637,7 +629,8 @@ def run_gemm_tiled_op_with_parallel( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = kernel.get_profiler() def ref_program(A, B): @@ -646,8 +639,8 @@ def run_gemm_tiled_op_with_parallel( if in_dtype == "float32": # Convert float32 to tfloat32 because tfloat32 mma cannot truncate # float32 automatically, -0x1000 meas - A = ((A.view(torch.int32) - 0x1000)).view(torch.float32) - B = ((B.view(torch.int32) - 0x1000)).view(torch.float32) + A = (A.view(torch.int32) - 0x1000).view(torch.float32) + B = (B.view(torch.int32) - 0x1000).view(torch.float32) C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(torch.__getattribute__(out_dtype)) return C @@ -675,16 +668,16 @@ def run_gemm_tiled_op_with_parallel( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) @tilelang.jit(out_idx=[1]) def tir_op_with_parallel(length=256, block=16, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((length,), dtype), - B: T.Tensor((length,), dtype), + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), ): with T.Kernel(1, threads=length) as _: for i in T.Parallel(length // block): @@ -696,11 +689,10 @@ def tir_op_with_parallel(length=256, block=16, dtype="float32"): @tilelang.jit(out_idx=[1]) def customize_op_with_parallel(length=256, block=16, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((length,), dtype), - B: T.Tensor((length,), dtype), + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), ): with T.Kernel(1, threads=length) as _: for i in T.Parallel(length // block): diff --git a/testing/python/autotune/test_tilelang_autotune.py b/testing/python/autotune/test_tilelang_autotune.py index 85e2e48077a06425c3a47cf6b4306a6afca56369..3e6a05a2476d19c1fb423de791c68f1fa074fd01 100644 --- a/testing/python/autotune/test_tilelang_autotune.py +++ b/testing/python/autotune/test_tilelang_autotune.py @@ -48,6 +48,7 @@ def get_configs(M, N, K, with_roller=False): from tilelang.carver.template import MatmulTemplate from tilelang.carver.arch import CUDA from tilelang.carver.roller.rasterization import NoRasterization + arch = CUDA("cuda") topk = 20 @@ -84,7 +85,6 @@ def get_configs(M, N, K, with_roller=False): for config in configs: print(config) else: - block_M = [64] block_N = [64] block_K = [32] @@ -100,7 +100,8 @@ def get_configs(M, N, K, with_roller=False): num_stages, thread_num, enable_rasterization, - )) + ) + ) configs = [ { @@ -110,7 +111,8 @@ def get_configs(M, N, K, with_roller=False): "num_stages": c[3], "thread_num": c[4], "enable_rasteration": c[5], # keep param name for backward-compat - } for c in _configs + } + for c in _configs ] return configs @@ -190,9 +192,9 @@ def matmul(M, N, K, with_roller): @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): """ The compiled TVM function for block-level matrix multiplication. @@ -206,9 +208,7 @@ def matmul(M, N, K, with_roller): """ # Bind x-dimension to block index in N, # y-dimension to block index in M. - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): # Allocate shared memory for A sub-block of shape (block_M, block_K) A_shared = T.alloc_shared((block_M, block_K), dtype) # Allocate shared memory for B sub-block of shape (block_N, block_K) @@ -247,12 +247,16 @@ def matmul(M, N, K, with_roller): return main - autotuner = AutoTuner.from_kernel( - kernel=kernel, configs=get_configs(M, N, K, with_roller)).set_compile_args( + autotuner = ( + AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N, K, with_roller)) + .set_compile_args( out_idx=[-1], target="auto", - ).set_profile_args( - ref_prog=ref_program,) + ) + .set_profile_args( + ref_prog=ref_program, + ) + ) return autotuner.run(warmup=3, rep=20) diff --git a/testing/python/autotune/test_tilelang_autotune_with_inputs.py b/testing/python/autotune/test_tilelang_autotune_with_inputs.py index 39efce6bf4af6d5696c81ba2e9af34c220fb38e0..8f9a6098ddb5af0f1cb9e7fa7b5427ed30b2968f 100644 --- a/testing/python/autotune/test_tilelang_autotune_with_inputs.py +++ b/testing/python/autotune/test_tilelang_autotune_with_inputs.py @@ -30,38 +30,23 @@ def ref_program(A, B): def get_configs(): - iter_params = dict( - block_M=[64], - block_N=[64], - block_K=[32], - num_stages=[0, 1], - thread_num=[128], - enable_rasterization=[False]) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] - - -@tilelang.autotune(configs=get_configs(),) -@tilelang.jit(out_idx=[-1]) -def matmul(M, - N, - K, - block_M=128, - block_N=128, - block_K=32, - num_stages=0, - thread_num=128, - enable_rasterization=False): + iter_params = dict(block_M=[64], block_N=[64], block_K=[32], num_stages=[0, 1], thread_num=[128], enable_rasterization=[False]) + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] + +@tilelang.autotune( + configs=get_configs(), +) +@tilelang.jit(out_idx=[-1]) +def matmul(M, N, K, block_M=128, block_N=128, block_K=32, num_stages=0, thread_num=128, enable_rasterization=False): dtype = "float16" accum_dtype = "float" @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): """ The compiled TVM function for block-level matrix multiplication. @@ -76,7 +61,6 @@ def matmul(M, # Bind x-dimension to block index in N, # y-dimension to block index in M. with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - # Allocate shared memory for A sub-block of shape (block_M, block_K) A_shared = T.alloc_shared((block_M, block_K), dtype) # Allocate shared memory for B sub-block of shape (block_N, block_K) diff --git a/testing/python/cache/test_tilelang_cache_matmul.py b/testing/python/cache/test_tilelang_cache_matmul.py index 6e966a88af8c42dcc5a19ebcff7020c336c5707d..f38ed487e18a3e8450ad4aa998651bc69ce4d2d0 100644 --- a/testing/python/cache/test_tilelang_cache_matmul.py +++ b/testing/python/cache/test_tilelang_cache_matmul.py @@ -28,9 +28,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -63,6 +63,7 @@ def run_cache_matmul(): Reference PyTorch matrix multiplication for comparison. """ import torch + C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(torch.half) # Assuming dtype="float16" in matmul return C diff --git a/testing/python/carver/test_tilelang_carver_cuda_driver_properties.py b/testing/python/carver/test_tilelang_carver_cuda_driver_properties.py index 46b17bf03a0114e762db34e863d77cf9904461a9..67d20b89790afc81e2b2b5a71795a408e700cdaf 100644 --- a/testing/python/carver/test_tilelang_carver_cuda_driver_properties.py +++ b/testing/python/carver/test_tilelang_carver_cuda_driver_properties.py @@ -29,9 +29,7 @@ class _cudaDeviceAttrNames: def test_driver_get_device_properties(): prop = get_cuda_device_properties() assert prop is not None, "Failed to get CUDA device properties" - assert isinstance( - prop, - torch.cuda._CudaDeviceProperties), ("Returned object is not of type _CudaDeviceProperties") + assert isinstance(prop, torch.cuda._CudaDeviceProperties), "Returned object is not of type _CudaDeviceProperties" def test_device_get_device_name(): @@ -48,8 +46,7 @@ def test_device_get_shared_memory_per_block(): def test_device_get_persisting_l2_cache_size(): tl_cache_size = get_persisting_l2_cache_max_size() - driver_cache_size = get_device_attribute( - _cudaDeviceAttrNames.cudaDevAttrMaxPersistingL2CacheSize) + driver_cache_size = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMaxPersistingL2CacheSize) assert tl_cache_size == driver_cache_size, "Persisting L2 cache size values do not match" @@ -61,17 +58,14 @@ def test_device_get_num_sms(): def test_device_get_registers_per_block(): tl_regs_per_block = get_registers_per_block() - driver_regs_per_block = get_device_attribute( - _cudaDeviceAttrNames.cudaDevAttrMaxRegistersPerBlock) + driver_regs_per_block = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMaxRegistersPerBlock) assert tl_regs_per_block == driver_regs_per_block, "Registers per block values do not match" def test_device_get_max_dynamic_shared_size_bytes(): tl_dynamic_smem = get_max_dynamic_shared_size_bytes() - driver_dynamic_smem = get_device_attribute( - _cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerMultiprocessor) - assert tl_dynamic_smem == driver_dynamic_smem, ( - "Max dynamic shared size bytes values do not match") + driver_dynamic_smem = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerMultiprocessor) + assert tl_dynamic_smem == driver_dynamic_smem, "Max dynamic shared size bytes values do not match" if __name__ == "__main__": diff --git a/testing/python/carver/test_tilelang_carver_generate_hints.py b/testing/python/carver/test_tilelang_carver_generate_hints.py index 43cdb27e3ff7158294fb1034cc476588b183f390..313dc857ffd0011441171dc3361712749ca19df6 100644 --- a/testing/python/carver/test_tilelang_carver_generate_hints.py +++ b/testing/python/carver/test_tilelang_carver_generate_hints.py @@ -9,16 +9,13 @@ def run_general_matmul_emit_configs(M, N, K, topk: int = 20): arch = auto_infer_current_arch() def gemm(M, N, K): - A = te.placeholder((M, K), name='A', dtype='float16') - B = te.placeholder((N, K), name='B', dtype='float16') + A = te.placeholder((M, K), name="A", dtype="float16") + B = te.placeholder((N, K), name="B", dtype="float16") # Describe the matrix multiplication in TE - k = te.reduce_axis((0, K), name='k') + k = te.reduce_axis((0, K), name="k") - C = te.compute( - (M, N), - lambda i, j: te.sum(A[i, k].astype('float16') * B[j, k].astype('float16'), axis=[k]), - name='C') + C = te.compute((M, N), lambda i, j: te.sum(A[i, k].astype("float16") * B[j, k].astype("float16"), axis=[k]), name="C") return A, B, C @@ -29,8 +26,7 @@ def run_general_matmul_emit_configs(M, N, K, topk: int = 20): tensorized_func, tags = carver.utils.get_tensorized_func_and_tags(func, arch.target) print(tags) - policy = carver.TensorCorePolicy.from_prim_func( - func=tensorized_func, arch=arch, tags=tags, name="matmul_0") + policy = carver.TensorCorePolicy.from_prim_func(func=tensorized_func, arch=arch, tags=tags, name="matmul_0") hints = policy.emit_config(topk=topk) @@ -59,16 +55,13 @@ def run_general_matmul_matmul_emit_configs(M, N, K, topk: int = 20): arch = auto_infer_current_arch() def gemm(M, N, K): - A = te.placeholder((M, K), name='A', dtype='float16') - B = te.placeholder((N, K), name='B', dtype='float16') + A = te.placeholder((M, K), name="A", dtype="float16") + B = te.placeholder((N, K), name="B", dtype="float16") # Describe the matrix multiplication in TE - k = te.reduce_axis((0, K), name='k') + k = te.reduce_axis((0, K), name="k") - C = te.compute( - (M, N), - lambda i, j: te.sum(A[i, k].astype('float16') * B[j, k].astype('float16'), axis=[k]), - name='C') + C = te.compute((M, N), lambda i, j: te.sum(A[i, k].astype("float16") * B[j, k].astype("float16"), axis=[k]), name="C") return A, B, C diff --git a/testing/python/carver/test_tilelang_carver_recommend_hints.py b/testing/python/carver/test_tilelang_carver_recommend_hints.py index fee46761f21e84407d28749dcbd71f974c280053..4973c24d9984e377354d39c157d5d5aef0f43990 100644 --- a/testing/python/carver/test_tilelang_carver_recommend_hints.py +++ b/testing/python/carver/test_tilelang_carver_recommend_hints.py @@ -4,10 +4,7 @@ from tilelang.carver.arch import auto_infer_current_arch from typing import List -def run_general_reduction_recommend_hints(structure: str = "SSR", - shape: List[int] = None, - dtype: str = "float16", - topk: int = 20): +def run_general_reduction_recommend_hints(structure: str = "SSR", shape: List[int] = None, dtype: str = "float16", topk: int = 20): arch = auto_infer_current_arch() carve_template = carver.GeneralReductionTemplate( structure=structure, @@ -28,9 +25,7 @@ def test_general_reduction_recommend_hints(): run_general_reduction_recommend_hints("SRS", [1024, 1024, 1024], "float16") -def run_elementwise_recommend_hints(shape: List[int] = None, - dtype: str = "float16", - topk: int = 20): +def run_elementwise_recommend_hints(shape: List[int] = None, dtype: str = "float16", topk: int = 20): arch = auto_infer_current_arch() carve_template = carver.ElementwiseTemplate( shape=shape, @@ -81,11 +76,9 @@ def test_matmul_recommend_hints(): run_matmul_recommend_hints(1024, 1024, 1024, "float16", "float32", "float16") -def run_gemv_recommend_hints(N: int = 1024, - K: int = 1024, - in_dtype: str = "float16", - out_dtype: str = "float16", - accum_dtype: str = "float16"): +def run_gemv_recommend_hints( + N: int = 1024, K: int = 1024, in_dtype: str = "float16", out_dtype: str = "float16", accum_dtype: str = "float16" +): arch = auto_infer_current_arch() carve_template = carver.GEMVTemplate( N=N, diff --git a/testing/python/components/test_storage_rewrite_detect_inplace.py b/testing/python/components/test_storage_rewrite_detect_inplace.py index 1d60708fe00e0e482c2dbce7e8643bd2063c49a1..bd0a64d3910067f9d178e994691bc813bb04b522 100644 --- a/testing/python/components/test_storage_rewrite_detect_inplace.py +++ b/testing/python/components/test_storage_rewrite_detect_inplace.py @@ -23,7 +23,8 @@ def _compile_kernel_without_inplace(): @tilelang.jit( pass_configs={ tilelang.PassConfigKey.TL_STORAGE_REWRITE_DETECT_INPLACE: True, - },) + }, +) def _compile_kernel_with_inplace(): num_tokens = T.symbolic("num_tokens") diff --git a/testing/python/components/test_tilelang_pass_config_disable_warp_specialized.py b/testing/python/components/test_tilelang_pass_config_disable_warp_specialized.py index 499f3346bcca1bc9d927a7b5c13efce83aa1f01e..323f764586c3c5bc483b2131a0f6a358144bb6d3 100644 --- a/testing/python/components/test_tilelang_pass_config_disable_warp_specialized.py +++ b/testing/python/components/test_tilelang_pass_config_disable_warp_specialized.py @@ -26,9 +26,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -88,7 +88,8 @@ def run_gemm( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: disable_warp_specialized, - }) + }, + ) profiler = kernel.get_profiler() def ref_program(A, B): diff --git a/testing/python/cpu/test_tilelang_cpu_gemm.py b/testing/python/cpu/test_tilelang_cpu_gemm.py index 0129b37314e26aa33891922ee8cc4f3b94331147..4a878f3284f59f56e5f09e55451f02be5702d884 100644 --- a/testing/python/cpu/test_tilelang_cpu_gemm.py +++ b/testing/python/cpu/test_tilelang_cpu_gemm.py @@ -10,9 +10,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo @T.prim_func def matmul( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), is_cpu=True) as (bx, by): A_local = T.alloc_local((block_M, block_K), dtype) @@ -31,7 +31,6 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo # ) for ko in T.Pipelined(K // block_K, num_stages=num_stages): - T.copy(A[by * block_M, ko * block_K], A_local) # Or Copy with Parallel @@ -62,14 +61,13 @@ def test_matmul_codegen(): def test_matmul_compile(): - def matmul_jit_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): # a simple kernel just for jit test @T.prim_func def matmul( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), is_cpu=True) as (bx, by): A_local = T.alloc_local((block_M, block_K), dtype) diff --git a/testing/python/debug/test_device_assert.py b/testing/python/debug/test_device_assert.py index 1602c30c75a03393b8eb6eae986fbc8f98e7d76f..210b8966d7b4acced3a1adc62725a3e579e14900 100644 --- a/testing/python/debug/test_device_assert.py +++ b/testing/python/debug/test_device_assert.py @@ -7,7 +7,6 @@ import tilelang.language as T # TODO(dyq) It intentionally triggers a device-side assert so we can't include this in CI # Please run manually when you want to verify that device_assert actually traps on GPU. def _manual_device_assert_triggered(): - @T.prim_func def program(): with T.Kernel(threads=128): @@ -20,7 +19,6 @@ def _manual_device_assert_triggered(): def test_device_assert_no_trigger(): - @T.prim_func def program(): with T.Kernel(threads=128): diff --git a/testing/python/debug/test_tilelang_debug_print.py b/testing/python/debug/test_tilelang_debug_print.py index a1aa42edcdddf1ad25e4bf8c94740e5080f0908f..e26296613112b17d6adc3d263a5aaad44385316d 100644 --- a/testing/python/debug/test_tilelang_debug_print.py +++ b/testing/python/debug/test_tilelang_debug_print.py @@ -6,7 +6,6 @@ import tilelang.language as T def debug_print_buffer(M=16, N=16, dtype="float16"): - @T.prim_func def program(Q: T.Tensor((M, N), dtype)): with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz): @@ -19,24 +18,24 @@ def debug_print_buffer(M=16, N=16, dtype="float16"): def test_debug_print_buffer(): - debug_print_buffer(dtype='bool') - debug_print_buffer(dtype='int8') - debug_print_buffer(dtype='int16') - debug_print_buffer(dtype='int32') - debug_print_buffer(dtype='int64') - debug_print_buffer(dtype='uint8') - debug_print_buffer(dtype='uint16') - debug_print_buffer(dtype='uint32') - debug_print_buffer(dtype='uint64') - debug_print_buffer(dtype='float16') - debug_print_buffer(dtype='float32') - debug_print_buffer(dtype='float64') - debug_print_buffer(dtype='bfloat16') - debug_print_buffer(dtype='float8_e4m3') - debug_print_buffer(dtype='float8_e4m3fn') - debug_print_buffer(dtype='float8_e4m3fnuz') - debug_print_buffer(dtype='float8_e5m2') - debug_print_buffer(dtype='float8_e5m2fnuz') + debug_print_buffer(dtype="bool") + debug_print_buffer(dtype="int8") + debug_print_buffer(dtype="int16") + debug_print_buffer(dtype="int32") + debug_print_buffer(dtype="int64") + debug_print_buffer(dtype="uint8") + debug_print_buffer(dtype="uint16") + debug_print_buffer(dtype="uint32") + debug_print_buffer(dtype="uint64") + debug_print_buffer(dtype="float16") + debug_print_buffer(dtype="float32") + debug_print_buffer(dtype="float64") + debug_print_buffer(dtype="bfloat16") + debug_print_buffer(dtype="float8_e4m3") + debug_print_buffer(dtype="float8_e4m3fn") + debug_print_buffer(dtype="float8_e4m3fnuz") + debug_print_buffer(dtype="float8_e5m2") + debug_print_buffer(dtype="float8_e5m2fnuz") def debug_print_buffer_conditional(M=16, N=16): diff --git a/testing/python/dynamic/test_tilelang_dynamic_symbolic.py b/testing/python/dynamic/test_tilelang_dynamic_symbolic.py index 4b9dff7114157a56326ed16e007889127e69910a..8e50a27592cba85146c073f16b6e5dc5fa6067cd 100644 --- a/testing/python/dynamic/test_tilelang_dynamic_symbolic.py +++ b/testing/python/dynamic/test_tilelang_dynamic_symbolic.py @@ -5,7 +5,7 @@ import tilelang.testing from tvm import DataType import tilelang.language as T from tilelang.intrinsics.utils import get_swizzle_layout -from tilelang.intrinsics.mma_macro_generator import (TensorCoreIntrinEmitter) +from tilelang.intrinsics.mma_macro_generator import TensorCoreIntrinEmitter tilelang.testing.set_random_seed(0) @@ -96,12 +96,11 @@ def tl_matmul_macro( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -109,10 +108,12 @@ def tl_matmul_macro( B_local = T.alloc_local((warp_cols * local_size), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -120,7 +121,6 @@ def tl_matmul_macro( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -130,7 +130,6 @@ def tl_matmul_macro( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, @@ -207,8 +206,7 @@ def tl_matmul_block( B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) @T.prim_func - def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor( - (M, N), out_dtype)): + def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) @@ -306,8 +304,7 @@ def tl_matmul_block_all_dynamic( B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) @T.prim_func - def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor( - (M, N), out_dtype)): + def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) @@ -417,7 +414,7 @@ def assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( ) pass_configs = { tilelang.PassConfigKey.TL_DISABLE_DYNAMIC_TAIL_SPLIT: dynamic_alignment != 0, - tilelang.PassConfigKey.TL_DYNAMIC_ALIGNMENT: dynamic_alignment + tilelang.PassConfigKey.TL_DYNAMIC_ALIGNMENT: dynamic_alignment, } if M % 64 == 0 or N % 64 == 0 or K % 64 != 0: # workaround for hopper tma lower pass @@ -462,55 +459,31 @@ def test_assert_tl_matmul_macro(): def test_assert_tl_matmul_block(): - assert_tl_matmul_block_correctness(128, 128, 128, False, False, "float16", "float16", "float16", - 64, 64, 32) - assert_tl_matmul_block_correctness(67, 128, 128, False, False, "float16", "float16", "float16", - 64, 64, 32) - assert_tl_matmul_block_correctness(36, 128, 128, False, False, "float16", "float16", "float16", - 64, 64, 32) + assert_tl_matmul_block_correctness(128, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) + assert_tl_matmul_block_correctness(67, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) + assert_tl_matmul_block_correctness(36, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) def test_assert_tl_matmul_block_all_dynamic(): - assert_tl_matmul_block_all_dynamic_correctness(128, 128, 128, False, False, "float16", - "float16", "float16", 64, 64, 32) - assert_tl_matmul_block_all_dynamic_correctness(67, 128, 128, False, False, "float16", "float16", - "float16", 64, 64, 32) - assert_tl_matmul_block_all_dynamic_correctness(36, 128, 128, False, False, "float16", "float16", - "float16", 64, 64, 32) + assert_tl_matmul_block_all_dynamic_correctness(128, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) + assert_tl_matmul_block_all_dynamic_correctness(67, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) + assert_tl_matmul_block_all_dynamic_correctness(36, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) def test_assert_tl_matmul_block_all_dynamic_with_pass_config(): assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( - 128, - 128, - 128, - False, - False, - "float16", - "float16", - "float16", - 64, - 64, - 32, - dynamic_alignment=8) + 128, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=8 + ) assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( - 64, - 128, - 128, - False, - False, - "float16", - "float16", - "float16", - 64, - 64, - 32, - dynamic_alignment=8) + 64, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=8 + ) assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( - 64, 128, 60, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=4) + 64, 128, 60, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=4 + ) # Tail split is enabled with dynamic alignment 0 assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( - 64, 128, 64, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=0) + 64, 128, 64, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=0 + ) if __name__ == "__main__": diff --git a/testing/python/dynamic/test_tilelang_dynamic_symbolic_bench.py b/testing/python/dynamic/test_tilelang_dynamic_symbolic_bench.py index b5ccbda924f8767969b4a82bdd2a37d11ace3f5b..1bee1356f3f8e402c68401d1a66a300eea809762 100644 --- a/testing/python/dynamic/test_tilelang_dynamic_symbolic_bench.py +++ b/testing/python/dynamic/test_tilelang_dynamic_symbolic_bench.py @@ -25,10 +25,8 @@ def tl_matmul_block_static( B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) @T.prim_func - def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor( - (M, N), out_dtype)): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -137,10 +135,8 @@ def tl_matmul_block_dynamic_m( B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) @T.prim_func - def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor( - (M, N), out_dtype)): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -247,10 +243,8 @@ def tl_matmul_block_dynamic_mn( B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) @T.prim_func - def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor( - (M, N), out_dtype)): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -357,10 +351,8 @@ def tl_matmul_block_dynamic_mnk( B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) @T.prim_func - def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor( - (M, N), out_dtype)): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -445,8 +437,7 @@ def assert_tl_matmul_block_dynamic_mnk( def run_assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K): - assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K, False, False, "float16", - "float16", "float32") + assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K, False, False, "float16", "float16", "float32") def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K): @@ -462,10 +453,8 @@ def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K): "float16", "float16", "float32", - pass_configs={ - "tl.disable_dynamic_tail_split": True, - "tl.dynamic_alignment": 8 - }) + pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8}, + ) assert_tl_matmul_block_dynamic_m( M, N, @@ -478,7 +467,8 @@ def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K): "float16", "float16", "float32", - pass_configs={"tl.disable_dynamic_tail_split": False}) + pass_configs={"tl.disable_dynamic_tail_split": False}, + ) def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K): @@ -494,10 +484,8 @@ def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K): "float16", "float16", "float32", - pass_configs={ - "tl.disable_dynamic_tail_split": True, - "tl.dynamic_alignment": 8 - }) + pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8}, + ) assert_tl_matmul_block_dynamic_mn( M, N, @@ -510,7 +498,8 @@ def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K): "float16", "float16", "float32", - pass_configs={"tl.disable_dynamic_tail_split": False}) + pass_configs={"tl.disable_dynamic_tail_split": False}, + ) def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K): @@ -526,10 +515,8 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K): "float16", "float16", "float32", - pass_configs={ - "tl.disable_dynamic_tail_split": True, - "tl.dynamic_alignment": 4 - }) + pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 4}, + ) assert_tl_matmul_block_dynamic_mnk( M, N, @@ -542,7 +529,8 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K): "float16", "float16", "float32", - pass_configs={"tl.disable_dynamic_tail_split": False}) + pass_configs={"tl.disable_dynamic_tail_split": False}, + ) def test_all(): diff --git a/testing/python/fastmath/test_mathops_fastmath.py b/testing/python/fastmath/test_mathops_fastmath.py index c3b5d1b5288ef09252426869a5816268e991052d..7809983e8eb36d4ba231d4eecba5485b12b6ce0d 100644 --- a/testing/python/fastmath/test_mathops_fastmath.py +++ b/testing/python/fastmath/test_mathops_fastmath.py @@ -7,16 +7,16 @@ import re def get_mathop_lines(source, mathop_name): """Extract lines containing the mathop from CUDA source for debugging""" - lines = source.split('\n') + lines = source.split("\n") relevant_lines = [] for i, line in enumerate(lines): - if mathop_name in line and ('(' in line): + if mathop_name in line and ("(" in line): # Include some context start = max(0, i - 1) end = min(len(lines), i + 2) relevant_lines.extend([f"{j}: {lines[j]}" for j in range(start, end)]) relevant_lines.append("---") - return '\n'.join(relevant_lines[-10:]) # Show last 10 lines to avoid too much output + return "\n".join(relevant_lines[-10:]) # Show last 10 lines to avoid too much output def check_fastmath_usage(source, mathop_name, expect_fastmath=False): @@ -27,9 +27,7 @@ def check_fastmath_usage(source, mathop_name, expect_fastmath=False): fastmath_matches = re.findall(fastmath_pattern, source) non_fastmath_matches = re.findall(non_fastmath_pattern, source) - print( - f"Found {len(fastmath_matches)} fastmath calls, {len(non_fastmath_matches)} non-fastmath calls" - ) + print(f"Found {len(fastmath_matches)} fastmath calls, {len(non_fastmath_matches)} non-fastmath calls") if len(fastmath_matches) > 0: print(f"Fastmath calls found: {fastmath_matches}") if len(non_fastmath_matches) > 0: @@ -51,13 +49,7 @@ def check_non_fastmath_usage(source, mathop_name): check_fastmath_usage(source, mathop_name, expect_fastmath=False) -def run_single_arg_mathop_test(mathop_name, - mathop_func, - M=128, - N=128, - block_M=32, - block_N=32, - dtype="float32"): +def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"): """ Test single-argument mathops. T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath) @@ -65,13 +57,12 @@ def run_single_arg_mathop_test(mathop_name, @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): - B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, - bx * block_N + j]) + B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j]) # Test with FAST_MATH disabled kernel_no_fastmath = tilelang.compile( @@ -80,7 +71,8 @@ def run_single_arg_mathop_test(mathop_name, target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - }) + }, + ) source_no_fastmath = kernel_no_fastmath.get_kernel_source() @@ -93,28 +85,22 @@ def run_single_arg_mathop_test(mathop_name, print(f"✓ {mathop_name} compilation and execution test passed") -def run_two_arg_mathop_test(mathop_name, - mathop_func, - M=128, - N=128, - block_M=32, - block_N=32, - dtype="float32"): +def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"): """ Test two-argument mathops to ensure they generate non-fastmath CUDA code. """ @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, - bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j], - B[by * block_M + i, bx * block_N + j]) + C[by * block_M + i, bx * block_N + j] = mathop_func( + A[by * block_M + i, bx * block_N + j], B[by * block_M + i, bx * block_N + j] + ) # Test with FAST_MATH disabled kernel_no_fastmath = tilelang.compile( @@ -123,7 +109,8 @@ def run_two_arg_mathop_test(mathop_name, target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - }) + }, + ) # Test with FAST_MATH enabled kernel_fastmath = tilelang.compile( @@ -132,7 +119,8 @@ def run_two_arg_mathop_test(mathop_name, target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, + ) source_no_fastmath = kernel_no_fastmath.get_kernel_source() source_fastmath = kernel_fastmath.get_kernel_source() @@ -171,8 +159,8 @@ def run_abs_test(): @T.prim_func def main( - A: T.Tensor((M, N), "float32"), - B: T.Tensor((M, N), "float32"), + A: T.Tensor((M, N), "float32"), + B: T.Tensor((M, N), "float32"), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): @@ -184,7 +172,8 @@ def run_abs_test(): target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - }) + }, + ) source = kernel.get_kernel_source() print("\n=== Testing abs (maps to fabs) ===") @@ -199,26 +188,19 @@ def run_abs_test(): print("✓ abs numerical test passed") -def run_fastmath_mathop_test(mathop_name, - mathop_func, - M=128, - N=128, - block_M=32, - block_N=32, - dtype="float32"): +def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"): """ Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix). """ @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): - B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, - bx * block_N + j]) + B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j]) # Test with FAST_MATH enabled kernel_fastmath = tilelang.compile( @@ -227,14 +209,15 @@ def run_fastmath_mathop_test(mathop_name, target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, + ) source_fastmath = kernel_fastmath.get_kernel_source() print(f"\n=== Testing {mathop_name} (fastmath version) ===") print("FAST_MATH=True:") # Strip the __ prefix for checking in the CUDA source - cuda_mathop_name = mathop_name.lstrip('_') + cuda_mathop_name = mathop_name.lstrip("_") check_fastmath_usage(source_fastmath, cuda_mathop_name, expect_fastmath=True) # Test numerical correctness diff --git a/testing/python/issue/test_tilelang_issue_1001.py b/testing/python/issue/test_tilelang_issue_1001.py index 77d8cc1f15c2055180f3d6d11e5f05bc986753d3..a4283daa54bda45805ecb04a09c569350a337669 100644 --- a/testing/python/issue/test_tilelang_issue_1001.py +++ b/testing/python/issue/test_tilelang_issue_1001.py @@ -8,14 +8,15 @@ from tilelang import language as T pass_configs={ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - },) + }, +) def _cumsum_view_infer_layout(hidden): - num_tokens = T.dynamic('num_tokens') + num_tokens = T.dynamic("num_tokens") @T.prim_func - def buggy_kernel(x: T.Tensor[(num_tokens, hidden), 'float']): + def buggy_kernel(x: T.Tensor[(num_tokens, hidden), "float"]): with T.Kernel(num_tokens, threads=128) as pid: - smem = T.alloc_shared((hidden,), dtype='float') + smem = T.alloc_shared((hidden,), dtype="float") T.copy(x[pid, :], smem) T.cumsum(T.view(smem, (1, hidden)), dim=1) @@ -24,10 +25,10 @@ def _cumsum_view_infer_layout(hidden): def test_cumsum_view_infer_layout(): hidden = 128 - x = torch.randn(1, hidden, device='cuda', dtype=torch.float) + x = torch.randn(1, hidden, device="cuda", dtype=torch.float) kernel = _cumsum_view_infer_layout(hidden) kernel(x) -if __name__ == '__main__': +if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/issue/test_tilelang_issue_1008.py b/testing/python/issue/test_tilelang_issue_1008.py index 395593d8cbea1efb5fa172ce457c4956df06bb73..2d86d16453b96ca9eee9c00f5975fa12e53fe376 100644 --- a/testing/python/issue/test_tilelang_issue_1008.py +++ b/testing/python/issue/test_tilelang_issue_1008.py @@ -8,12 +8,13 @@ from tilelang import language as T pass_configs={ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - },) + }, +) def _fill_with_static_region_kernel(): - num_tokens = T.symbolic('num_tokens') + num_tokens = T.symbolic("num_tokens") @T.prim_func - def buggy_kernel(x: T.Tensor[(num_tokens,), 'int64']): # noqa: F821 + def buggy_kernel(x: T.Tensor[(num_tokens,), "int64"]): # noqa: F821 with T.Kernel(num_tokens, threads=128) as _: T.fill(x[0:128], 0) @@ -24,14 +25,15 @@ def _fill_with_static_region_kernel(): pass_configs={ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - },) + }, +) def _fill_with_dynamic_region_kernel(): - num_tokens = T.symbolic('num_tokens') + num_tokens = T.symbolic("num_tokens") @T.prim_func - def buggy_kernel(x: T.Tensor[(num_tokens,), 'int64']): # noqa: F821 + def buggy_kernel(x: T.Tensor[(num_tokens,), "int64"]): # noqa: F821 with T.Kernel(num_tokens, threads=128) as _: - a, b = T.alloc_var('int'), T.alloc_var('int') + a, b = T.alloc_var("int"), T.alloc_var("int") T.fill(x[a:b], 0) return buggy_kernel @@ -39,15 +41,15 @@ def _fill_with_dynamic_region_kernel(): def test_fill_with_static_region_kernel(): kernel = _fill_with_static_region_kernel() - x = torch.zeros((256,), dtype=torch.int64, device='cuda') + x = torch.zeros((256,), dtype=torch.int64, device="cuda") kernel(x) def test_fill_with_dynamic_region_kernel(): kernel = _fill_with_dynamic_region_kernel() - x = torch.zeros((256,), dtype=torch.int64, device='cuda') + x = torch.zeros((256,), dtype=torch.int64, device="cuda") kernel(x) -if __name__ == '__main__': +if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/issue/test_tilelang_issue_1115.py b/testing/python/issue/test_tilelang_issue_1115.py index 1769862356540a1580b00054d7560015ed7e1d0c..ce21a3b05740ba13cf472a7134af43e072dedfe0 100644 --- a/testing/python/issue/test_tilelang_issue_1115.py +++ b/testing/python/issue/test_tilelang_issue_1115.py @@ -4,25 +4,23 @@ import tilelang.language as T def test_int64_address(): - @tilelang.jit def set_cache_kernel( S, D, - pos_ty='int64', + pos_ty="int64", dtype="float32", ): - @T.prim_func def main( - pos: T - .Tensor( + pos: T.Tensor( [ S, - ], pos_ty + ], + pos_ty, ), # type: ignore `TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int64 vs. int32` - value: T.Tensor([S, D], dtype), # type: ignore - cache: T.Tensor([S, D], dtype), # type: ignore + value: T.Tensor([S, D], dtype), # type: ignore + cache: T.Tensor([S, D], dtype), # type: ignore ): with T.Kernel(S, threads=128) as bx: slot = pos[bx] @@ -34,11 +32,11 @@ def test_int64_address(): D = 2 S = 10 cache = torch.rand((S, D), device="cuda", dtype=torch.float32) - value = torch.rand((S, D), device='cuda', dtype=torch.float32) - pos_int64 = torch.arange(S, device='cuda', dtype=torch.int64) - pos_int32 = torch.arange(S, device='cuda', dtype=torch.int32) - kernel_int64 = set_cache_kernel(S, D, 'int64') - kernel_int32 = set_cache_kernel(S, D, 'int32') + value = torch.rand((S, D), device="cuda", dtype=torch.float32) + pos_int64 = torch.arange(S, device="cuda", dtype=torch.int64) + pos_int32 = torch.arange(S, device="cuda", dtype=torch.int32) + kernel_int64 = set_cache_kernel(S, D, "int64") + kernel_int32 = set_cache_kernel(S, D, "int32") kernel_int64(pos_int64, value, cache) torch.testing.assert_close(cache, value) kernel_int32(pos_int32, value, cache) diff --git a/testing/python/issue/test_tilelang_issue_1198.py b/testing/python/issue/test_tilelang_issue_1198.py index eb9ed45964cb766bd36b1b0f919d8c1d33f4478c..08f36822b16632dfcdc6a10d4c9582463973d455 100644 --- a/testing/python/issue/test_tilelang_issue_1198.py +++ b/testing/python/issue/test_tilelang_issue_1198.py @@ -3,13 +3,17 @@ import tilelang.language as T def test_issue_1198(): - @T.prim_func - def foo(x: T.Buffer([ - 32, - ], "int32")): + def foo( + x: T.Buffer( + [ + 32, + ], + "int32", + ), + ): pass -if __name__ == '__main__': +if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/issue/test_tilelang_issue_814.py b/testing/python/issue/test_tilelang_issue_814.py index 1a9e63d2992b90ca263724d472190a20514116e0..a202bd96047bf80e28944d2b6e34a877885b48be 100644 --- a/testing/python/issue/test_tilelang_issue_814.py +++ b/testing/python/issue/test_tilelang_issue_814.py @@ -6,11 +6,10 @@ import torch @tilelang.jit def _tmp_var_kernel(N, block_N, dtype="float"): - @T.prim_func def kernel( - A: T.Tensor((N,), dtype), - B: T.Tensor((N,), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx: for i in T.Parallel(block_N): diff --git a/testing/python/issue/test_tilelang_issue_830.py b/testing/python/issue/test_tilelang_issue_830.py index 950b85835ecf4f25663e11648e7914b94cacec17..74ceed3d96ea5c3bf3ddbdb5afc96dc6d33c32d4 100644 --- a/testing/python/issue/test_tilelang_issue_830.py +++ b/testing/python/issue/test_tilelang_issue_830.py @@ -8,7 +8,6 @@ import tilelang.language as T @tilelang.jit def _empty_kernel(): - @T.prim_func def empty_kernel(): with T.Kernel(1, threads=32) as thread_idx: @@ -51,7 +50,6 @@ def test_empty_with_dead_code_kernel(): @tilelang.jit def _empty_kernel_with_binding_variants(use_tuple_binding: bool = False): - @T.prim_func def kernel_with_tuple_kernel_binding(): with T.Kernel(1, threads=32) as (pid,): diff --git a/testing/python/issue/test_tilelang_issue_96.py b/testing/python/issue/test_tilelang_issue_96.py index e42ebb59e6021bbc2822280cdd4eca9e8e862714..6ab7fe479d848255431b19c89a683f3831bb9087 100644 --- a/testing/python/issue/test_tilelang_issue_96.py +++ b/testing/python/issue/test_tilelang_issue_96.py @@ -5,18 +5,16 @@ import torch def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as ( - bx, - by, - ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as ( + bx, + by, + ): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_N, block_K), dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) diff --git a/testing/python/issue/test_tilelang_issue_merge_if.py b/testing/python/issue/test_tilelang_issue_merge_if.py index 1db7f337c64d1bd328c91dc8bc30a4fbdf2c9c89..fa9432fc819c692d480b15e6d24bbe7e647af20f 100644 --- a/testing/python/issue/test_tilelang_issue_merge_if.py +++ b/testing/python/issue/test_tilelang_issue_merge_if.py @@ -6,7 +6,6 @@ import tilelang.language as T def merge_if_test(): - @T.prim_func def main(): A = T.alloc_fragment((1,), "float16") diff --git a/testing/python/jit/test_tilelang_jit_callback.py b/testing/python/jit/test_tilelang_jit_callback.py index e987368df78d214513954d089d5a7c7a2a978ad8..7d76a64d1938cd9170334116132608e894aef768 100644 --- a/testing/python/jit/test_tilelang_jit_callback.py +++ b/testing/python/jit/test_tilelang_jit_callback.py @@ -29,9 +29,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -141,9 +141,9 @@ def matmu_jit_kernel( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -208,6 +208,7 @@ def run_gemm_jit_kernel( def ref_program(A, B): import torch + C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(torch.__getattribute__(out_dtype)) return C diff --git a/testing/python/jit/test_tilelang_jit_gemm.py b/testing/python/jit/test_tilelang_jit_gemm.py index 25c19a058bafc5a3360528d1e6fe09697848c359..153f06cb1b3df6a7f736e1f48974a4c0ef101f23 100644 --- a/testing/python/jit/test_tilelang_jit_gemm.py +++ b/testing/python/jit/test_tilelang_jit_gemm.py @@ -31,9 +31,9 @@ def matmul_kernel_jit( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -96,6 +96,7 @@ def run_gemm_kernel_jit( def ref_program(A, B): import torch + C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(torch.__getattribute__(out_dtype)) return C diff --git a/testing/python/jit/test_tilelang_jit_gemm_cython.py b/testing/python/jit/test_tilelang_jit_gemm_cython.py index 12524f12910ca18beee3f9c25442b27c8111936b..4ea4ba88dee651bbcfbbb9ba0d6c927300b96664 100644 --- a/testing/python/jit/test_tilelang_jit_gemm_cython.py +++ b/testing/python/jit/test_tilelang_jit_gemm_cython.py @@ -28,9 +28,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -138,9 +138,9 @@ def matmu_jit_kernel( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -208,6 +208,7 @@ def run_gemm_jit_kernel( def ref_program(A, B): import torch + C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(out_dtype) return C @@ -235,19 +236,9 @@ def test_gemm_jit_kernel(): ) -def run_cython_kernel_do_bench(M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128): +def run_cython_kernel_do_bench( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): program = matmul( M, N, @@ -287,23 +278,12 @@ def run_cython_kernel_do_bench(M, def test_cython_kernel_do_bench(): - run_cython_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, - 256, 32, 2) - - -def run_cython_kernel_multi_stream(M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128): + run_cython_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + +def run_cython_kernel_multi_stream( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): program = matmul( M, N, @@ -342,23 +322,12 @@ def run_cython_kernel_multi_stream(M, def test_cython_kernel_multi_stream(): - run_cython_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", - 128, 256, 32, 2) - - -def run_cython_dynamic_shape(M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128): + run_cython_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + +def run_cython_dynamic_shape( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): program = matmul( M, N, @@ -398,36 +367,20 @@ def run_cython_dynamic_shape(M, matmul_kernel(tensor_a, tensor_b, tensor_c) tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype) - tilelang.testing.torch_assert_close( - tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) def test_cython_dynamic_shape(): - run_cython_dynamic_shape( - T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) - - run_cython_dynamic_shape( - T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, - 256, 32, 2) - - run_cython_dynamic_shape( - T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", - "float16", 128, 256, 32, 2) - - -def run_cython_dynamic_shape_with_out_idx(M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128): + run_cython_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + run_cython_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + run_cython_dynamic_shape(T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + +def run_cython_dynamic_shape_with_out_idx( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): program = matmul( M, N, @@ -467,13 +420,11 @@ def run_cython_dynamic_shape_with_out_idx(M, tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype) - tilelang.testing.torch_assert_close( - tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) def test_cython_dynamic_shape_with_out_idx(): - run_cython_dynamic_shape_with_out_idx( - T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_cython_dynamic_shape_with_out_idx(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) def matmul_int_variable( @@ -498,10 +449,10 @@ def matmul_int_variable( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - offset: T.int32, + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + offset: T.int32, ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -525,10 +476,10 @@ def matmul_int_variable( return main -def run_matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, - out_dtype, dtypeAccum, num_stages, threads): - program = matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, - out_dtype, dtypeAccum, num_stages, threads) +def run_matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, num_stages, threads): + program = matmul_int_variable( + M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, num_stages, threads + ) matmul_kernel = tilelang.compile(program, execution_backend="cython", out_idx=2) in_dtype = map_torch_type(in_dtype) @@ -544,8 +495,7 @@ def run_matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B def test_matmul_int_variable(): - run_matmul_int_variable(1024, 1024, 1024, 128, 128, 32, False, False, "float16", "float16", - "float32", 0, 128) + run_matmul_int_variable(1024, 1024, 1024, 128, 128, 32, False, False, "float16", "float16", "float32", 0, 128) def matmul_float_variable( @@ -570,10 +520,10 @@ def matmul_float_variable( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - offset: T.float32, + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + offset: T.float32, ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -597,10 +547,10 @@ def matmul_float_variable( return main -def run_matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, - out_dtype, dtypeAccum, num_stages, threads): - program = matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, - out_dtype, dtypeAccum, num_stages, threads) +def run_matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, num_stages, threads): + program = matmul_float_variable( + M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, num_stages, threads + ) matmul_kernel = tilelang.compile(program, execution_backend="cython", out_idx=2) in_dtype = map_torch_type(in_dtype) @@ -616,8 +566,7 @@ def run_matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans def test_matmul_float_variable(): - run_matmul_float_variable(1024, 1024, 1024, 128, 128, 32, False, False, "float16", "float16", - "float32", 0, 128) + run_matmul_float_variable(1024, 1024, 1024, 128, 128, 32, False, False, "float16", "float16", "float32", 0, 128) if __name__ == "__main__": diff --git a/testing/python/jit/test_tilelang_jit_nullptr.py b/testing/python/jit/test_tilelang_jit_nullptr.py index cce1fce8f8662c4452cbe27a5d5e3361026e611e..8965e2ad3c325c40cad4395b1231df63fe2ecf1f 100644 --- a/testing/python/jit/test_tilelang_jit_nullptr.py +++ b/testing/python/jit/test_tilelang_jit_nullptr.py @@ -7,22 +7,13 @@ from tilelang.utils import map_torch_type @tl.jit -def tensor_null_test(M, - N, - K, - block_M, - block_N, - block_K, - dtype="float16", - accum_dtype="float", - with_bias=False): - +def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float", with_bias=False): @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), accum_dtype), - Bias: T.Tensor((N), accum_dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), accum_dtype), + Bias: T.Tensor((N), accum_dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): @@ -48,12 +39,10 @@ def tensor_null_test(M, def run_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype)) b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype)) c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype)) - kernel = tensor_null_test( - M, N, K, block_M, block_N, block_K, dtype, accum_dtype, with_bias=False) + kernel = tensor_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype, with_bias=False) kernel(a, b, c, None) diff --git a/testing/python/jit/test_tilelang_jit_nvrtc.py b/testing/python/jit/test_tilelang_jit_nvrtc.py index c70768611eb918d45fbc3d014cd783e3ecf9fcba..2b15027724f5c8d46f327c4a8355437b8b2b4a4f 100644 --- a/testing/python/jit/test_tilelang_jit_nvrtc.py +++ b/testing/python/jit/test_tilelang_jit_nvrtc.py @@ -28,9 +28,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -136,9 +136,9 @@ def matmu_jit_kernel( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -206,6 +206,7 @@ def run_gemm_jit_kernel( def ref_program(A, B): import torch + C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(out_dtype) return C @@ -233,19 +234,9 @@ def test_gemm_jit_kernel(): ) -def run_nvrtc_kernel_do_bench(M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128): +def run_nvrtc_kernel_do_bench( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): program = matmul( M, N, @@ -278,23 +269,12 @@ def run_nvrtc_kernel_do_bench(M, def test_nvrtc_kernel_do_bench(): - run_nvrtc_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, - 256, 32, 2) - - -def run_nvrtc_kernel_multi_stream(M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128): + run_nvrtc_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + +def run_nvrtc_kernel_multi_stream( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): program = matmul( M, N, @@ -331,23 +311,12 @@ def run_nvrtc_kernel_multi_stream(M, def test_nvrtc_kernel_multi_stream(): - run_nvrtc_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", - 128, 256, 32, 2) - - -def run_nvrtc_dynamic_shape(M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128): + run_nvrtc_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + +def run_nvrtc_dynamic_shape( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): program = matmul( M, N, @@ -387,21 +356,15 @@ def run_nvrtc_dynamic_shape(M, matmul_kernel(tensor_a, tensor_b, tensor_c) tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype) - tilelang.testing.torch_assert_close( - tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) def test_nvrtc_dynamic_shape(): - run_nvrtc_dynamic_shape( - T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_nvrtc_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) - run_nvrtc_dynamic_shape( - T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, - 256, 32, 2) + run_nvrtc_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) - run_nvrtc_dynamic_shape( - T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", - "float16", 128, 256, 32, 2) + run_nvrtc_dynamic_shape(T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", "float16", 128, 256, 32, 2) def check_hopper(): @@ -412,35 +375,18 @@ def check_hopper(): return compute_capability == (9, 0) -def convolution_im2col(N, - C, - H, - W, - F, - K, - S, - D, - P, - block_M, - block_N, - block_K, - num_stages, - threads, - dtype="float16", - accum_dtype="float"): +def convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"): KH, KW = K, K OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 @T.prim_func def main( - data: T.Tensor((N, H, W, C), dtype), - kernel: T.Tensor((KH, KW, C, F), dtype), - out: T.Tensor((N, OH, OW, F), dtype), + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), ): - with T.Kernel( - T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), - threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by): data_shared = T.alloc_shared((block_M, block_K), dtype) kernel_shared = T.alloc_shared((block_K, block_N), dtype) out_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -449,11 +395,13 @@ def convolution_im2col(N, kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) - T.annotate_layout({ - out_shared: tilelang.layout.make_swizzled_layout(out_shared), - data_shared: tilelang.layout.make_swizzled_layout(data_shared), - kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), - }) + T.annotate_layout( + { + out_shared: tilelang.layout.make_swizzled_layout(out_shared), + data_shared: tilelang.layout.make_swizzled_layout(data_shared), + kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), + } + ) T.clear(out_local) for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): @@ -467,23 +415,9 @@ def convolution_im2col(N, return main -def run_nvrtc_im2col_tma_desc(N, - C, - H, - W, - F, - K, - S, - D, - P, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=256): +def run_nvrtc_im2col_tma_desc(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages=3, num_threads=256): """Test im2col TMA descriptor functionality in NVRTC backend.""" - program = convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, - num_threads) + program = convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, num_threads) conv_kernel = tilelang.compile(program, out_idx=-1, execution_backend="nvrtc") @@ -501,32 +435,20 @@ def run_nvrtc_im2col_tma_desc(N, return C ref_c = ref_program(a, b) - tilelang.testing.torch_assert_close( - out_c, ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + tilelang.testing.torch_assert_close(out_c, ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) def test_nvrtc_im2col_tma_desc(): """Test im2col TMA descriptor with NVRTC backend.""" if not check_hopper(): import pytest + pytest.skip("Test requires Hopper GPU (compute capability 9.0)") # Small test case for im2col TMA descriptor run_nvrtc_im2col_tma_desc( - N=4, - C=64, - H=32, - W=32, - F=64, - K=3, - S=1, - D=1, - P=1, - block_M=64, - block_N=128, - block_K=32, - num_stages=3, - num_threads=256) + N=4, C=64, H=32, W=32, F=64, K=3, S=1, D=1, P=1, block_M=64, block_N=128, block_K=32, num_stages=3, num_threads=256 + ) def test_nvrtc_l2_persistent_map(): @@ -543,12 +465,11 @@ def test_nvrtc_l2_persistent_map(): block_size=256, dtype="float32", ): - @T.prim_func def kernel( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(M * N // block_size, threads=block_size) as bx: # Annotate L2 persistent cache for buffer B diff --git a/testing/python/jit/test_tilelang_jit_parcompile.py b/testing/python/jit/test_tilelang_jit_parcompile.py index e7bcec41250373380ec429d025819db434ebbf6b..0a6e9062ce5d2ed4e717a142f4f9dd7b50510372 100644 --- a/testing/python/jit/test_tilelang_jit_parcompile.py +++ b/testing/python/jit/test_tilelang_jit_parcompile.py @@ -16,9 +16,9 @@ def matmul_kernel_jit( block_K, trans_A=False, trans_B=True, - in_dtype='float16', - out_dtype='float32', - accum_dtype='float32', + in_dtype="float16", + out_dtype="float32", + accum_dtype="float32", num_stages=2, threads=128, ): @@ -31,9 +31,9 @@ def matmul_kernel_jit( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) diff --git a/testing/python/jit/test_tilelang_jit_tvm_ffi.py b/testing/python/jit/test_tilelang_jit_tvm_ffi.py index f7bde6afda7ea6ad41f444b6f91a11748ce92382..5daaf30830fa282f5845f5ef0372da5edfef1f18 100644 --- a/testing/python/jit/test_tilelang_jit_tvm_ffi.py +++ b/testing/python/jit/test_tilelang_jit_tvm_ffi.py @@ -28,9 +28,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -74,9 +74,9 @@ def matmu_jit_kernel( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -144,6 +144,7 @@ def run_gemm_jit_kernel( def ref_program(A, B): import torch + C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(out_dtype) return C @@ -171,19 +172,9 @@ def test_gemm_jit_kernel(): ) -def run_tvm_ffi_kernel_do_bench(M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128): +def run_tvm_ffi_kernel_do_bench( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): program = matmul( M, N, @@ -216,23 +207,12 @@ def run_tvm_ffi_kernel_do_bench(M, def test_tvm_ffi_kernel_do_bench(): - run_tvm_ffi_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, - 256, 32, 2) - - -def run_tvm_ffi_kernel_multi_stream(M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128): + run_tvm_ffi_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + +def run_tvm_ffi_kernel_multi_stream( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): program = matmul( M, N, @@ -269,23 +249,12 @@ def run_tvm_ffi_kernel_multi_stream(M, def test_tvm_ffi_kernel_multi_stream(): - run_tvm_ffi_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", - 128, 256, 32, 2) - - -def run_tvm_ffi_dynamic_shape(M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128): + run_tvm_ffi_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + +def run_tvm_ffi_dynamic_shape( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): program = matmul( M, N, @@ -325,21 +294,17 @@ def run_tvm_ffi_dynamic_shape(M, matmul_kernel(tensor_a, tensor_b, tensor_c) tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype) - tilelang.testing.torch_assert_close( - tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) def test_tvm_ffi_dynamic_shape(): - run_tvm_ffi_dynamic_shape( - T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_tvm_ffi_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) - run_tvm_ffi_dynamic_shape( - T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, - 256, 32, 2) + run_tvm_ffi_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) run_tvm_ffi_dynamic_shape( - T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", - "float16", 128, 256, 32, 2) + T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", "float16", 128, 256, 32, 2 + ) def check_hopper(): @@ -350,35 +315,18 @@ def check_hopper(): return compute_capability == (9, 0) -def convolution_im2col(N, - C, - H, - W, - F, - K, - S, - D, - P, - block_M, - block_N, - block_K, - num_stages, - threads, - dtype="float16", - accum_dtype="float"): +def convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"): KH, KW = K, K OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 @T.prim_func def main( - data: T.Tensor((N, H, W, C), dtype), - kernel: T.Tensor((KH, KW, C, F), dtype), - out: T.Tensor((N, OH, OW, F), dtype), + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), ): - with T.Kernel( - T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), - threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by): data_shared = T.alloc_shared((block_M, block_K), dtype) kernel_shared = T.alloc_shared((block_K, block_N), dtype) out_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -387,11 +335,13 @@ def convolution_im2col(N, kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) - T.annotate_layout({ - out_shared: tilelang.layout.make_swizzled_layout(out_shared), - data_shared: tilelang.layout.make_swizzled_layout(data_shared), - kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), - }) + T.annotate_layout( + { + out_shared: tilelang.layout.make_swizzled_layout(out_shared), + data_shared: tilelang.layout.make_swizzled_layout(data_shared), + kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), + } + ) T.clear(out_local) for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): @@ -405,23 +355,9 @@ def convolution_im2col(N, return main -def run_tvm_ffi_im2col_tma_desc(N, - C, - H, - W, - F, - K, - S, - D, - P, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=256): +def run_tvm_ffi_im2col_tma_desc(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages=3, num_threads=256): """Test im2col TMA descriptor functionality in tvm_ffi backend.""" - program = convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, - num_threads) + program = convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, num_threads) conv_kernel = tilelang.compile(program, out_idx=-1, execution_backend="tvm_ffi") @@ -439,32 +375,20 @@ def run_tvm_ffi_im2col_tma_desc(N, return C ref_c = ref_program(a, b) - tilelang.testing.torch_assert_close( - out_c, ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + tilelang.testing.torch_assert_close(out_c, ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) def test_tvm_ffi_im2col_tma_desc(): """Test im2col TMA descriptor with tvm_ffi backend.""" if not check_hopper(): import pytest + pytest.skip("Test requires Hopper GPU (compute capability 9.0)") # Small test case for im2col TMA descriptor run_tvm_ffi_im2col_tma_desc( - N=4, - C=64, - H=32, - W=32, - F=64, - K=3, - S=1, - D=1, - P=1, - block_M=64, - block_N=128, - block_K=32, - num_stages=3, - num_threads=256) + N=4, C=64, H=32, W=32, F=64, K=3, S=1, D=1, P=1, block_M=64, block_N=128, block_K=32, num_stages=3, num_threads=256 + ) def test_tvm_ffi_l2_persistent_map(): @@ -481,12 +405,11 @@ def test_tvm_ffi_l2_persistent_map(): block_size=256, dtype="float32", ): - @T.prim_func def kernel( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(M * N // block_size, threads=block_size) as bx: # Annotate L2 persistent cache for buffer B @@ -506,8 +429,12 @@ def test_tvm_ffi_l2_persistent_map(): kernel = elementwise_add_with_l2_cache(M, N) source = kernel.get_host_source() - assert "__tvm_cuda_stream_set_access_policy_window_packed" in source, "Expected __tvm_cuda_stream_set_access_policy_window_packed in the kernel source" - assert "__tvm_cuda_stream_reset_access_policy_window_packed" in source, "Expected __tvm_cuda_stream_reset_access_policy_window_packed in the kernel source" + assert "__tvm_cuda_stream_set_access_policy_window_packed" in source, ( + "Expected __tvm_cuda_stream_set_access_policy_window_packed in the kernel source" + ) + assert "__tvm_cuda_stream_reset_access_policy_window_packed" in source, ( + "Expected __tvm_cuda_stream_reset_access_policy_window_packed in the kernel source" + ) # Create test tensors a = torch.randn(M, N, dtype=torch.float32).cuda() diff --git a/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py b/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py index 13135d416f9ba462f6bf4773410b0e9ac47d63d6..e7d7021c55cf958271c79837093430a0a3c4bbf3 100644 --- a/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py +++ b/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py @@ -6,7 +6,8 @@ from tvm import DataType import tilelang.language as T from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from tilelang.transform import simplify_prim_func from tilelang.utils.tensor import map_torch_type @@ -111,12 +112,11 @@ def tl_matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -124,10 +124,12 @@ def tl_matmul( B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -135,7 +137,6 @@ def tl_matmul( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -145,7 +146,6 @@ def tl_matmul( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, diff --git a/testing/python/kernel/test_tilelang_kernel_element_wise_add.py b/testing/python/kernel/test_tilelang_kernel_element_wise_add.py index 3ec6ae03040da1403d6e005ec173a30b412c707c..52763c817f42a05c295787a5c431357893160e4c 100644 --- a/testing/python/kernel/test_tilelang_kernel_element_wise_add.py +++ b/testing/python/kernel/test_tilelang_kernel_element_wise_add.py @@ -16,15 +16,15 @@ def elementwise_add( @T.prim_func def main( - A: T.Tensor((M, N), in_dtype), - B: T.Tensor((M, N), in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor((M, N), in_dtype), + B: T.Tensor((M, N), in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): start_x = bx * block_N start_y = by * block_M - for (local_y, local_x) in T.Parallel(block_M, block_N): + for local_y, local_x in T.Parallel(block_M, block_N): y = start_y + local_y x = start_x + local_x diff --git a/testing/python/kernel/test_tilelang_kernel_fp8_gemm.py b/testing/python/kernel/test_tilelang_kernel_fp8_gemm.py index 19f327d66c3abb816e130e886dd379a5648d2513..63c82120214826c37fd88ee4a448f4e58482c732 100644 --- a/testing/python/kernel/test_tilelang_kernel_fp8_gemm.py +++ b/testing/python/kernel/test_tilelang_kernel_fp8_gemm.py @@ -12,12 +12,11 @@ def calc_diff(x, y): def matmul_nt(M, N, K, bM, bN, bK, in_dtype, out_dtype, accum_dtype): - @T.prim_func def main( - A: T.Tensor((M, K), in_dtype), - B: T.Tensor((N, K), in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor((M, K), in_dtype), + B: T.Tensor((N, K), in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, bN), T.ceildiv(M, bM), threads=128) as (bx, by): A_shared = T.alloc_shared((bM, bK), in_dtype) @@ -44,8 +43,7 @@ def assert_matmul_correctness(M, N, K, block_M, block_N, block_K, in_dtype, out_ C = kernel(A, B) - ref_c = torch.matmul(A.to(map_torch_type(accum_dtype)), - B.T.to(map_torch_type(accum_dtype))).to(map_torch_type(out_dtype)) + ref_c = torch.matmul(A.to(map_torch_type(accum_dtype)), B.T.to(map_torch_type(accum_dtype))).to(map_torch_type(out_dtype)) print(C) print(ref_c) diff = calc_diff(C, ref_c) diff --git a/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py b/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py index 46f4e123a4a7eb1be8bcd3119302e5e4f4e6d0ba..eec3a9caf80ec069cda0d52e3101f777d2ce8a40 100644 --- a/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py +++ b/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py @@ -6,7 +6,8 @@ from tvm import DataType import tilelang.language as T from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from tilelang.transform import simplify_prim_func from tilelang.utils.tensor import map_torch_type @@ -110,12 +111,11 @@ def tl_matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -123,10 +123,12 @@ def tl_matmul( B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -134,7 +136,6 @@ def tl_matmul( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -144,7 +145,6 @@ def tl_matmul( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, diff --git a/testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py b/testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py index afd01f337b6972727a9c2732ecc6111377985619..4a48b656f1463343ce7599931b05045de28cfc27 100644 --- a/testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py +++ b/testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py @@ -27,8 +27,8 @@ def gemv_simt( ): assert n_partition is not None, "n_partition must be provided" assert reduce_thread is not None, ( - "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV" - "sch_outer_reduction_with_config is not implemented") + "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMVsch_outer_reduction_with_config is not implemented" + ) assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently" @@ -50,16 +50,15 @@ def gemv_simt( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - Bias: T.Tensor(Bias_shape, out_dtype), - C: T.Tensor(C_shape, out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + Bias: T.Tensor(Bias_shape, out_dtype), + C: T.Tensor(C_shape, out_dtype), ): - with T.Kernel( - T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as ( - bx, - by, - ): + with T.Kernel(T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as ( + bx, + by, + ): A_local = T.alloc_local((micro_size_k,), in_dtype) B_local = T.alloc_local((micro_size_k,), in_dtype) accum_res = T.alloc_local((1,), accum_dtype) @@ -88,13 +87,12 @@ def gemv_simt( ) else: for ki in T.serial(micro_size_k): - accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype( - accum_dtype) + accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype(accum_dtype) with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), - "reduce_scope", - T.reinterpret(T.uint64(0), dtype="handle"), + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), ): T.evaluate( T.tvm_thread_allreduce( @@ -104,11 +102,11 @@ def gemv_simt( reduced_accum_res[0], kr, dtype="handle", - )) + ) + ) if kr == 0: if with_bias: - C[by, - bx * n_partition + ni] = reduced_accum_res[0] + Bias[bx * n_partition + ni] + C[by, bx * n_partition + ni] = reduced_accum_res[0] + Bias[bx * n_partition + ni] else: C[by, bx * n_partition + ni] = reduced_accum_res[0] diff --git a/testing/python/kernel/test_tilelang_kernel_gemm.py b/testing/python/kernel/test_tilelang_kernel_gemm.py index 5dcde1d5ec4f4c37414837f1fc12ba5634124997..6c01297a15153bc62ca3c849b588d9ddd4bd66b5 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemm.py +++ b/testing/python/kernel/test_tilelang_kernel_gemm.py @@ -26,9 +26,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -95,8 +95,8 @@ def run_gemm( if in_dtype == "float32": # Convert float32 to tfloat32 because tfloat32 mma cannot truncate # float32 automatically, -0x1000 meas - A = ((A.view(torch.int32) - 0x1000)).view(torch.float32) - B = ((B.view(torch.int32) - 0x1000)).view(torch.float32) + A = (A.view(torch.int32) - 0x1000).view(torch.float32) + B = (B.view(torch.int32) - 0x1000).view(torch.float32) C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(torch.__getattribute__(out_dtype)) return C @@ -321,9 +321,9 @@ def matmul_sr( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -441,9 +441,9 @@ def matmul_rs( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared") diff --git a/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py b/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py index 6e20754eb1c289994c69099d82070cb70fe9b052..3633d3ece3d4d9c239400728cb8501d68e5a2266 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py +++ b/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py @@ -6,7 +6,8 @@ from tvm import DataType import tilelang.language as T from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from tilelang.transform import simplify_prim_func from tilelang.utils.tensor import map_torch_type @@ -111,12 +112,11 @@ def tl_matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -124,10 +124,12 @@ def tl_matmul( B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -135,7 +137,6 @@ def tl_matmul( T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -145,7 +146,6 @@ def tl_matmul( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, diff --git a/testing/python/kernel/test_tilelang_kernel_gemm_simt.py b/testing/python/kernel/test_tilelang_kernel_gemm_simt.py index 548497c72989010a76f15012c29bf91bb5cfafc5..e4da44b2609887509e58647848989252eced16bb 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemm_simt.py +++ b/testing/python/kernel/test_tilelang_kernel_gemm_simt.py @@ -76,12 +76,11 @@ def tl_matmul_simt( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor(C_shape, out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor(C_shape, out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) @@ -97,7 +96,6 @@ def tl_matmul_simt( T.clear(C_local) for ko in T.serial(K // block_K): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -109,29 +107,24 @@ def tl_matmul_simt( for ki in T.serial((block_K // micro_size_k)): for i in T.serial(local_size_a): for mk in T.vectorized(micro_size_k): - A_local[i, mk] = A_shared[warp_m * local_size_a + i, - ki * micro_size_k + mk] + A_local[i, mk] = A_shared[warp_m * local_size_a + i, ki * micro_size_k + mk] for i in T.serial(local_size_b): for mk in T.vectorized(micro_size_k): - B_local[i, mk] = B_shared[warp_n * local_size_b + i, - ki * micro_size_k + mk] + B_local[i, mk] = B_shared[warp_n * local_size_b + i, ki * micro_size_k + mk] for i, j in T.grid(local_size_a, local_size_b): for mk in T.serial(micro_size_k // dp4a_size): if use_dp4a: - T.dp4a(A_local[i, mk * dp4a_size], B_local[j, mk * dp4a_size], - C_local[i * local_size_b + j]) + T.dp4a(A_local[i, mk * dp4a_size], B_local[j, mk * dp4a_size], C_local[i * local_size_b + j]) else: for dp4a_idx in T.serial(dp4a_size): - C_local[i * local_size_b + - j] += A_local[i, mk * dp4a_size + - dp4a_idx] * B_local[j, mk * dp4a_size + - dp4a_idx] + C_local[i * local_size_b + j] += ( + A_local[i, mk * dp4a_size + dp4a_idx] * B_local[j, mk * dp4a_size + dp4a_idx] + ) for i, j in T.grid(local_size_a, local_size_b): - C[by * block_M + warp_m * local_size_a + i, - bx * block_N + warp_n * local_size_b + j] = C_local[i * local_size_b + j] + C[by * block_M + warp_m * local_size_a + i, bx * block_N + warp_n * local_size_b + j] = C_local[i * local_size_b + j] return main diff --git a/testing/python/kernel/test_tilelang_kernel_gemm_with_stride.py b/testing/python/kernel/test_tilelang_kernel_gemm_with_stride.py index bbc2e79e2ef609cdd014e022957e808f2a25fb79..2def480db83a3c447a948449cac3b7728532ec7f 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemm_with_stride.py +++ b/testing/python/kernel/test_tilelang_kernel_gemm_with_stride.py @@ -5,12 +5,11 @@ import torch def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): @@ -59,7 +58,8 @@ def run_gemm_with_stride_ss(M: int, N: int, K: int, block_M: int, block_N: int, pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) # Create random input tensors on the GPU a = torch.randn(M, K, device="cuda", dtype=torch.float16) b = torch.randn(K, N, device="cuda", dtype=torch.float16) diff --git a/testing/python/kernel/test_tilelang_kernel_gemv_simt.py b/testing/python/kernel/test_tilelang_kernel_gemv_simt.py index 86d6acbda6a36c8ad73d1e90975a46f4d0626d57..5825f695ced820b4b343defc3647555f0d255f84 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemv_simt.py +++ b/testing/python/kernel/test_tilelang_kernel_gemv_simt.py @@ -27,8 +27,8 @@ def gemv_simt( ): assert n_partition is not None, "n_partition must be provided" assert reduce_thread is not None, ( - "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV" - "sch_outer_reduction_with_config is not implemented") + "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMVsch_outer_reduction_with_config is not implemented" + ) assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently" @@ -50,16 +50,15 @@ def gemv_simt( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - Bias: T.Tensor(Bias_shape, out_dtype), - C: T.Tensor(C_shape, out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + Bias: T.Tensor(Bias_shape, out_dtype), + C: T.Tensor(C_shape, out_dtype), ): - with T.Kernel( - T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as ( - bx, - by, - ): + with T.Kernel(T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as ( + bx, + by, + ): A_local = T.alloc_local((micro_size_k,), in_dtype) B_local = T.alloc_local((micro_size_k,), in_dtype) accum_res = T.alloc_local((1,), accum_dtype) @@ -88,13 +87,12 @@ def gemv_simt( ) else: for ki in T.serial(micro_size_k): - accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype( - accum_dtype) + accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype(accum_dtype) with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), - "reduce_scope", - T.reinterpret(T.uint64(0), dtype="handle"), + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), ): T.evaluate( T.tvm_thread_allreduce( @@ -104,11 +102,11 @@ def gemv_simt( reduced_accum_res[0], kr, dtype="handle", - )) + ) + ) if kr == 0: if with_bias: - C[by, - bx * n_partition + ni] = reduced_accum_res[0] + Bias[bx * n_partition + ni] + C[by, bx * n_partition + ni] = reduced_accum_res[0] + Bias[bx * n_partition + ni] else: C[by, bx * n_partition + ni] = reduced_accum_res[0] diff --git a/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py b/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py index 5cdd6710563906eed2da42e78d44d3f1e76659dc..affeb3ddf39f55d6cdaa67e490910071adb23e75 100644 --- a/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py +++ b/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py @@ -4,7 +4,8 @@ from tilelang import tvm as tvm import tilelang.testing import tilelang.language as T from tilelang.intrinsics import ( - make_mma_swizzle_layout as make_swizzle_layout,) + make_mma_swizzle_layout as make_swizzle_layout, +) from tilelang.intrinsics.mma_macro_generator import ( INT4TensorCoreIntrinEmitter, @@ -91,12 +92,11 @@ def tl_matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -104,10 +104,12 @@ def tl_matmul( B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -115,7 +117,6 @@ def tl_matmul( T.clear(C_local) for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -125,7 +126,6 @@ def tl_matmul( B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, @@ -168,7 +168,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): out_idx=[2], pass_configs={ tilelang.PassConfigKey.TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS: True, - }) + }, + ) print(kernel.get_kernel_source()) profiler = kernel.get_profiler() @@ -285,12 +286,11 @@ def tl_matmul_weight_only_transform( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) @@ -298,10 +298,12 @@ def tl_matmul_weight_only_transform( B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) # Improve L2 Cache T.use_swizzle(panel_size=10) @@ -309,19 +311,15 @@ def tl_matmul_weight_only_transform( T.clear(C_local) for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=stage): - # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] # Load B into shared memory - for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // micro_size_k, - micro_size_y, micro_size_k): - B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j, - ko * (block_K // micro_size_k) + k, jj, kk] + for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // micro_size_k, micro_size_y, micro_size_k): + B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j, ko * (block_K // micro_size_k) + k, jj, kk] for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, @@ -359,6 +357,7 @@ def tl_matmul_weight_only_transform( def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): import bitblas + matmul = tl_matmul_weight_only_transform(M, N, K, in_dtype, out_dtype, accum_dtype) kernel = tilelang.compile(matmul, out_idx=[2]) profiler = kernel.get_profiler() diff --git a/testing/python/language/test_tilelang_capture.py b/testing/python/language/test_tilelang_capture.py index 875fa681bc512db9b8193beba9992ae21f6484c4..47fec999a2531f0f8d82dd5d9145175827e47e6f 100644 --- a/testing/python/language/test_tilelang_capture.py +++ b/testing/python/language/test_tilelang_capture.py @@ -6,16 +6,17 @@ import gc def test_tilelang_capture(): - @tilelang.jit( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - },) + }, + ) def get_dummy_kernel(): - @T.prim_func - def dummy_kernel(a: T.Tensor[(1,), T.float32],): + def dummy_kernel( + a: T.Tensor[(1,), T.float32], + ): with T.Kernel(1) as _: a[0] = 1 @@ -36,5 +37,5 @@ def test_tilelang_capture(): # objgraph.show_backrefs([a_upgrade], max_depth=5) -if __name__ == '__main__': +if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_intimm.py b/testing/python/language/test_tilelang_intimm.py index 58fea31d971a85f6d9bcb90a7a94295e19769a20..46c2c79873ebd079aac21cf6f50ea1738f0105f5 100644 --- a/testing/python/language/test_tilelang_intimm.py +++ b/testing/python/language/test_tilelang_intimm.py @@ -4,25 +4,25 @@ import tilelang.language as T def test_tilelang_intimm(): - T.int32(0x7fffffff) - T.int32(-0x7fffffff - 1) - T.uint32(0xffffffff) - T.int64(0x7fffffffffffffff) - T.int64(-0x7fffffffffffffff - 1) - T.uint64(0xffffffffffffffff) + T.int32(0x7FFFFFFF) + T.int32(-0x7FFFFFFF - 1) + T.uint32(0xFFFFFFFF) + T.int64(0x7FFFFFFFFFFFFFFF) + T.int64(-0x7FFFFFFFFFFFFFFF - 1) + T.uint64(0xFFFFFFFFFFFFFFFF) a = T.int32() - a & 0x7fffffff + a & 0x7FFFFFFF a = T.uint32() - a & 0xffffffff + a & 0xFFFFFFFF a = T.int64() - a & 0x7fffffffffffffff + a & 0x7FFFFFFFFFFFFFFF a = T.uint64() - a & T.uint64(0xffffffffffffffff) + a & T.uint64(0xFFFFFFFFFFFFFFFF) -if __name__ == '__main__': +if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_language_alias.py b/testing/python/language/test_tilelang_language_alias.py index c99d36102f57294bad7357ea6845b1015dc97c6a..f55d9e85e594d9f346afcb205f063597233bc559 100644 --- a/testing/python/language/test_tilelang_language_alias.py +++ b/testing/python/language/test_tilelang_language_alias.py @@ -5,12 +5,11 @@ import tilelang.language as T # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): diff --git a/testing/python/language/test_tilelang_language_all_of.py b/testing/python/language/test_tilelang_language_all_of.py index 73233ec871230ed5e5c43389c67b9b8ea08d29b7..48412127b22e24e4c60fe6c87475f1bd873012c7 100644 --- a/testing/python/language/test_tilelang_language_all_of.py +++ b/testing/python/language/test_tilelang_language_all_of.py @@ -13,11 +13,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K): accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device) for k in range(K // block_K): if torch.all(BlockMask[i, j, k]): - accu += A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to( - torch.float32) @ B[k * block_K:(k + 1) * block_K, - j * block_N:(j + 1) * block_N].to(torch.float32) - ref_c[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) * block_N] = ( - accu.to(torch.float16)) + accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[ + k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N + ].to(torch.float32) + ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16) return ref_c @@ -35,15 +34,14 @@ def blocksparse_matmul_global( dtype="float16", accum_dtype="float", ): - block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - BlockMask: T.Tensor(block_mask_shape, "bool"), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -80,15 +78,14 @@ def blocksparse_matmul_shared( dtype="float16", accum_dtype="float", ): - block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - BlockMask: T.Tensor(block_mask_shape, "bool"), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -130,15 +127,14 @@ def blocksparse_matmul_local( dtype="float16", accum_dtype="float", ): - block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - BlockMask: T.Tensor(block_mask_shape, "bool"), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -237,7 +233,8 @@ def run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, conditi pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) # Create block mask with desired sparsity mask_shape = (M // block_M, N // block_N, K // block_K) block_mask = torch.rand(mask_shape).cuda() > sparsity @@ -284,7 +281,8 @@ def run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, conditio pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) # Create block mask with desired sparsity mask_shape = (M // block_M, N // block_N, K // block_K) block_mask = torch.rand(mask_shape).cuda() > sparsity diff --git a/testing/python/language/test_tilelang_language_alloc.py b/testing/python/language/test_tilelang_language_alloc.py index 149a1c285baaea0e9b382a6f415c5f1e43bd475d..6695e9348f5c1815a7db68a73e24c0412c07aaae 100644 --- a/testing/python/language/test_tilelang_language_alloc.py +++ b/testing/python/language/test_tilelang_language_alloc.py @@ -10,8 +10,8 @@ def alloc_var( @T.prim_func def main( - A: T.Tensor((N,), dtype), - B: T.Tensor((N,), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: A_shared = T.alloc_shared([block_N], dtype) @@ -50,8 +50,8 @@ def alloc_var_add( @T.prim_func def main( - A: T.Tensor((N,), dtype), - B: T.Tensor((N,), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: A_shared = T.alloc_shared([block_N], dtype) @@ -91,8 +91,8 @@ def alloc_var_with_initializer( @T.prim_func def main( - A: T.Tensor((N,), dtype), - B: T.Tensor((N,), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: tmp = T.alloc_var(dtype, init_value) @@ -129,8 +129,8 @@ def alloc_multi_vars_with_initializer( @T.prim_func def main( - A: T.Tensor((N,), dtype), - B: T.Tensor((N,), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: tmp0 = T.alloc_var(dtype, 1) diff --git a/testing/python/language/test_tilelang_language_annot.py b/testing/python/language/test_tilelang_language_annot.py index 7425bf5c075ba175822e2b3744617380043467ea..5c9aeeac6b137189655ef21c6740572077b339ba 100644 --- a/testing/python/language/test_tilelang_language_annot.py +++ b/testing/python/language/test_tilelang_language_annot.py @@ -5,13 +5,14 @@ import torch def test_tensor_annot_mul(): - @tilelang.jit def example_tensor_annot(): - n = T.symbolic('n') + n = T.symbolic("n") @T.prim_func - def kernel(A: T.Tensor((n * 4,), T.int32),): + def kernel( + A: T.Tensor((n * 4,), T.int32), + ): with T.Kernel(1) as _: for i in range(n * 4): A[i] = 0 @@ -19,20 +20,21 @@ def test_tensor_annot_mul(): return kernel ker = example_tensor_annot() - A = torch.arange(16, dtype=torch.int32, device='cuda') + A = torch.arange(16, dtype=torch.int32, device="cuda") ker(A) - expected = torch.zeros(16, dtype=torch.int32, device='cuda') + expected = torch.zeros(16, dtype=torch.int32, device="cuda") assert torch.equal(A, expected) def test_tensor_annot_add(): - @tilelang.jit def example_tensor_annot(): - n = T.symbolic('n') + n = T.symbolic("n") @T.prim_func - def kernel(A: T.Tensor((n + 1,), T.int32),): + def kernel( + A: T.Tensor((n + 1,), T.int32), + ): with T.Kernel(1) as _: for i in range(n + 1): A[i] = 0 @@ -40,20 +42,21 @@ def test_tensor_annot_add(): return kernel ker = example_tensor_annot() - A = torch.arange(16, dtype=torch.int32, device='cuda') + A = torch.arange(16, dtype=torch.int32, device="cuda") ker(A) - expected = torch.zeros(16, dtype=torch.int32, device='cuda') + expected = torch.zeros(16, dtype=torch.int32, device="cuda") assert torch.equal(A, expected) def test_tensor_annot_mul_add(): - @tilelang.jit def example_tensor_annot(): - n = T.symbolic('n') + n = T.symbolic("n") @T.prim_func - def kernel(A: T.Tensor((n * 3 + 1,), T.int32),): + def kernel( + A: T.Tensor((n * 3 + 1,), T.int32), + ): with T.Kernel(1) as _: for i in range(n * 3 + 1): A[i] = 0 @@ -61,11 +64,11 @@ def test_tensor_annot_mul_add(): return kernel ker = example_tensor_annot() - A = torch.arange(16, dtype=torch.int32, device='cuda') + A = torch.arange(16, dtype=torch.int32, device="cuda") ker(A) - expected = torch.zeros(16, dtype=torch.int32, device='cuda') + expected = torch.zeros(16, dtype=torch.int32, device="cuda") assert torch.equal(A, expected) -if __name__ == '__main__': +if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_language_annotate_safe_value.py b/testing/python/language/test_tilelang_language_annotate_safe_value.py index 3d616ac1e0bc37c2bb0a75bb17229b52f6a2c98f..442172b6f0e101755500b53788768a9f67f65d04 100644 --- a/testing/python/language/test_tilelang_language_annotate_safe_value.py +++ b/testing/python/language/test_tilelang_language_annotate_safe_value.py @@ -7,11 +7,10 @@ import torch # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit def tilelang_copy(M, N, block_M, block_N, dtype="float16", pad_value=0): - @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): @@ -30,13 +29,8 @@ def tilelang_copy(M, N, block_M, block_N, dtype="float16", pad_value=0): def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16", pad_value=0): program = tilelang_copy(M, N, block_M, block_N, dtype, pad_value=pad_value) kernel = tilelang.compile( - program, - out_idx=[1], - target="cuda", - pass_configs={ - "tl.disable_warp_specialized": True, - "tl.disable_tma_lower": True - }) + program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True} + ) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) ref_b = torch.zeros_like(a) diff --git a/testing/python/language/test_tilelang_language_any_of.py b/testing/python/language/test_tilelang_language_any_of.py index 354d32cd07cc709b68889437c0a350edf2421ee6..37605e5a03899bddff372a9957531532729dc1ca 100644 --- a/testing/python/language/test_tilelang_language_any_of.py +++ b/testing/python/language/test_tilelang_language_any_of.py @@ -13,11 +13,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K): accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device) for k in range(K // block_K): if torch.any(BlockMask[i, j, k]): - accu += A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to( - torch.float32) @ B[k * block_K:(k + 1) * block_K, - j * block_N:(j + 1) * block_N].to(torch.float32) - ref_c[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) * block_N] = ( - accu.to(torch.float16)) + accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[ + k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N + ].to(torch.float32) + ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16) return ref_c @@ -35,15 +34,14 @@ def blocksparse_matmul_global( dtype="float16", accum_dtype="float", ): - block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - BlockMask: T.Tensor(block_mask_shape, "bool"), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -80,15 +78,14 @@ def blocksparse_matmul_shared( dtype="float16", accum_dtype="float", ): - block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - BlockMask: T.Tensor(block_mask_shape, "bool"), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -130,15 +127,14 @@ def blocksparse_matmul_local( dtype="float16", accum_dtype="float", ): - block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - BlockMask: T.Tensor(block_mask_shape, "bool"), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -237,7 +233,8 @@ def run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, conditi pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) # Create block mask with desired sparsity mask_shape = (M // block_M, N // block_N, K // block_K) block_mask = torch.rand(mask_shape).cuda() > sparsity @@ -284,7 +281,8 @@ def run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, conditio pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) # Create block mask with desired sparsity mask_shape = (M // block_M, N // block_N, K // block_K) block_mask = torch.rand(mask_shape).cuda() > sparsity diff --git a/testing/python/language/test_tilelang_language_assume.py b/testing/python/language/test_tilelang_language_assume.py index 9c75a5ac7fccb1883610a7be92890a127bbe8d5b..32e6b1c317fb8dca33c02136d870af42370ab4ad 100644 --- a/testing/python/language/test_tilelang_language_assume.py +++ b/testing/python/language/test_tilelang_language_assume.py @@ -4,10 +4,9 @@ import tilelang.testing def test_assume_remove_boundary_check(): - @tilelang.jit def kernel_with_assume(): - N = T.dynamic('N') + N = T.dynamic("N") @T.prim_func def main(A: T.Tensor((N,), "float32"), l: T.int32, r: T.int32): @@ -21,20 +20,19 @@ def test_assume_remove_boundary_check(): jit_kernel = kernel_with_assume() source = jit_kernel.get_kernel_source() - assert ("if (" not in source) + assert "if (" not in source def test_assume_enable_vectorization(): - @tilelang.jit def kernel_vectorize(M): - N = T.dynamic('N') + N = T.dynamic("N") vectorize_size = 4 @T.prim_func def main( - A: T.Tensor((M, N), "float32"), - B: T.Tensor((M, N), "float32"), + A: T.Tensor((M, N), "float32"), + B: T.Tensor((M, N), "float32"), ): with T.Kernel(1, threads=32) as _: tid = T.get_thread_binding() @@ -55,16 +53,15 @@ def test_assume_enable_vectorization(): def test_assume_complex_indexing(): - @tilelang.jit def kernel_complex(): - M = T.dynamic('M') - N = T.dynamic('N') + M = T.dynamic("M") + N = T.dynamic("N") @T.prim_func def main( - A: T.Tensor((M, N), "float32"), - B: T.Tensor((M, N), "float32"), + A: T.Tensor((M, N), "float32"), + B: T.Tensor((M, N), "float32"), ): with T.Kernel(1, threads=32) as _: tid = T.get_thread_binding() @@ -82,8 +79,8 @@ def test_assume_complex_indexing(): jit_kernel = kernel_complex() source = jit_kernel.get_kernel_source() - assert ("if (" not in source) + assert "if (" not in source -if __name__ == '__main__': +if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_language_atomic_add.py b/testing/python/language/test_tilelang_language_atomic_add.py index b157966a43911e40c596bf1a34744bbbf36c43c7..eaf5ae1ed046ab46183ae54420b2891dedd21799 100644 --- a/testing/python/language/test_tilelang_language_atomic_add.py +++ b/testing/python/language/test_tilelang_language_atomic_add.py @@ -4,14 +4,12 @@ import tilelang.language as T @tilelang.jit def atomic_add_program(K, M, N, block_M, block_N, dtype="float"): - @T.prim_func def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): A_shared = T.alloc_shared((block_M, block_N), dtype) - T.copy(A[bz, bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], - A_shared) + T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) for i, j in T.Parallel(block_M, block_N): T.atomic_add(B[bx * block_M + i, by * block_N + j], A_shared[i, j]) @@ -39,14 +37,12 @@ def run_atomic_add(K, M, N, block_M, block_N, dtype="float32"): @tilelang.jit def tile_atomic_add_program(K, M, N, block_M, block_N, dtype="float"): - @T.prim_func def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): A_shared = T.alloc_shared((block_M, block_N), dtype) - T.copy(A[bz, bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], - A_shared) + T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) T.atomic_add(B[bx * block_M, by * block_N], A_shared) @@ -76,14 +72,12 @@ def run_tile_atomic_add(K, M, N, block_M, block_N, dtype="float32"): @tilelang.jit def atomic_max_program(K, M, N, block_M, block_N, dtype="float"): - @T.prim_func def atomic_max(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): A_shared = T.alloc_shared((block_M, block_N), dtype) - T.copy(A[bz, bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], - A_shared) + T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) for i, j in T.Parallel(block_M, block_N): T.atomic_max(B[bx * block_M + i, by * block_N + j], A_shared[i, j]) @@ -111,14 +105,12 @@ def run_atomic_max(K, M, N, block_M, block_N, dtype="float32"): @tilelang.jit def atomic_min_program(K, M, N, block_M, block_N, dtype="float"): - @T.prim_func def atomic_min(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): A_shared = T.alloc_shared((block_M, block_N), dtype) - T.copy(A[bz, bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], - A_shared) + T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) for i, j in T.Parallel(block_M, block_N): T.atomic_min(B[bx * block_M + i, by * block_N + j], A_shared[i, j]) @@ -137,7 +129,7 @@ def run_atomic_min(K, M, N, block_M, block_N, dtype="float32"): B[i, j] = min(B[i, j], A[k, i, j]) A = torch.randn(K, M, N, dtype=getattr(torch, dtype)).cuda() - B = torch.full((M, N), float('inf'), dtype=getattr(torch, dtype)).cuda() + B = torch.full((M, N), float("inf"), dtype=getattr(torch, dtype)).cuda() ref_B = B.clone() ref_program(A, ref_B) kernel(A, B) @@ -146,7 +138,6 @@ def run_atomic_min(K, M, N, block_M, block_N, dtype="float32"): @tilelang.jit def atomic_load_store_program(M, N, block_M, block_N, dtype="float"): - @T.prim_func def atomic_load_store(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): @@ -172,18 +163,15 @@ def run_atomic_load_store(M, N, block_M, block_N, dtype="float32"): @tilelang.jit def atomic_memory_order_program(K, M, N, block_M, block_N, dtype="float"): - @T.prim_func def atomic_with_memory_order(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): A_shared = T.alloc_shared((block_M, block_N), dtype) - T.copy(A[bz, bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], - A_shared) + T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) for i, j in T.Parallel(block_M, block_N): - T.atomic_add( - B[bx * block_M + i, by * block_N + j], A_shared[i, j], memory_order="relaxed") + T.atomic_add(B[bx * block_M + i, by * block_N + j], A_shared[i, j], memory_order="relaxed") return atomic_with_memory_order @@ -208,7 +196,6 @@ def run_atomic_memory_order(K, M, N, block_M, block_N, dtype="float32"): @tilelang.jit def atomic_addx2_program(M, N, block_M, block_N): - @T.prim_func def atomic_addx2(A: T.Tensor((M, N), "float16"), B: T.Tensor((M, N), "float16")): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): @@ -262,10 +249,10 @@ def test_atomic_addx2(): @tilelang.jit def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype="float"): - @T.prim_func - def atomic_different_orders(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), C: T.Tensor( - (M, N), dtype), D: T.Tensor((M, N), dtype)): + def atomic_different_orders( + A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype), D: T.Tensor((M, N), dtype) + ): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): for i, j in T.Parallel(block_M, block_N): idx_i = bx * block_M + i @@ -286,18 +273,17 @@ def run_atomic_different_memory_orders(M, N, block_M, block_N, dtype="float32"): A = torch.randn(M, N, dtype=getattr(torch, dtype)).cuda() B = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda() C = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda() - D = torch.full((M, N), float('inf'), dtype=getattr(torch, dtype)).cuda() + D = torch.full((M, N), float("inf"), dtype=getattr(torch, dtype)).cuda() kernel(A, B, C, D) torch.testing.assert_close(B, A, atol=1e-3, rtol=1e-3) torch.testing.assert_close(C, torch.maximum(torch.zeros_like(A), A)) - torch.testing.assert_close(D, torch.minimum(torch.full_like(A, float('inf')), A)) + torch.testing.assert_close(D, torch.minimum(torch.full_like(A, float("inf")), A)) @tilelang.jit def atomic_addx4_program(M, N, block_M, block_N): - @T.prim_func def atomic_addx4(A: T.Tensor((M, N), "float32"), B: T.Tensor((M, N), "float32")): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): @@ -330,17 +316,14 @@ def run_atomic_addx4(M, N, block_M, block_N): @tilelang.jit def atomic_return_prev_program(M, N, block_M, block_N, dtype="float"): - @T.prim_func - def atomic_with_return_prev(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), - old_vals: T.Tensor((M, N), dtype)): + def atomic_with_return_prev(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), old_vals: T.Tensor((M, N), dtype)): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): for i, j in T.Parallel(block_M, block_N): idx_i = bx * block_M + i idx_j = by * block_N + j if idx_i < M and idx_j < N: - old_vals[idx_i, idx_j] = T.atomic_add( - B[idx_i, idx_j], A[idx_i, idx_j], return_prev=True) + old_vals[idx_i, idx_j] = T.atomic_add(B[idx_i, idx_j], A[idx_i, idx_j], return_prev=True) return atomic_with_return_prev diff --git a/testing/python/language/test_tilelang_language_ceildiv.py b/testing/python/language/test_tilelang_language_ceildiv.py index 35201a074397cede52b8b01b53949a73ca3f4989..66215abc55b70131306c2ad0cec261ce6d55ce08 100644 --- a/testing/python/language/test_tilelang_language_ceildiv.py +++ b/testing/python/language/test_tilelang_language_ceildiv.py @@ -5,7 +5,6 @@ import torch @tilelang.jit(out_idx=[-1]) def _ceildiv_kernel(a: int, b: int): - @T.prim_func def ceildiv_kernel(A: T.Tensor((1,), "int32")): with T.Kernel(1, threads=1) as _: @@ -30,7 +29,6 @@ def test_ceildiv(): @tilelang.jit def _ceildiv_kernel_dyn(b: int): - @T.prim_func def ceildiv_kernel(A: T.Tensor((1,), "int32"), a: T.int32): with T.Kernel(1, threads=1) as _: diff --git a/testing/python/language/test_tilelang_language_chain_equal.py b/testing/python/language/test_tilelang_language_chain_equal.py index 696a9c70b046a11af08be8ebe1d5572be784fe7b..0a9623fa9da8b6c5d1eaf4f0e4f8b9424b5f0a54 100644 --- a/testing/python/language/test_tilelang_language_chain_equal.py +++ b/testing/python/language/test_tilelang_language_chain_equal.py @@ -8,14 +8,14 @@ import torch pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - },) + }, +) def chain_equal(N, block_size, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((N,), dtype), - B: T.Tensor((N,), dtype), - C: T.Tensor((N,), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), + C: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as bx: for lane in T.Parallel(block_size): diff --git a/testing/python/language/test_tilelang_language_clamp.py b/testing/python/language/test_tilelang_language_clamp.py index 4a2f177918cea6a34ac1b17a34b133abe1ada0f2..06e558fda610fc57d25ae6746dbe4250abf191c7 100644 --- a/testing/python/language/test_tilelang_language_clamp.py +++ b/testing/python/language/test_tilelang_language_clamp.py @@ -13,8 +13,8 @@ def clamp_within_bounds( @T.prim_func def main( - A: T.Tensor((N,), dtype), - B: T.Tensor((N,), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: A_shared = T.alloc_shared([block_N], dtype) @@ -56,8 +56,8 @@ def clamp_value_range( @T.prim_func def main( - A: T.Tensor((1, N), dtype), - B: T.Tensor((1, N), dtype), + A: T.Tensor((1, N), dtype), + B: T.Tensor((1, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: # A_shared = T.alloc_shared([1, block_N], dtype=dtype) diff --git a/testing/python/language/test_tilelang_language_clear.py b/testing/python/language/test_tilelang_language_clear.py index be3d808f4d254c0fdc6aeb0fd62fb24e66ed1360..19ae0bbd5340e0147edbb0326fcbfe1cabcead3b 100644 --- a/testing/python/language/test_tilelang_language_clear.py +++ b/testing/python/language/test_tilelang_language_clear.py @@ -5,12 +5,11 @@ import tilelang.language as T # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): @@ -42,10 +41,10 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo def run_matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): program = matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype) - kernel = tilelang.compile( - program, out_idx=[2], target="cuda", pass_configs={"tl.disable_tma_lower": True}) + kernel = tilelang.compile(program, out_idx=[2], target="cuda", pass_configs={"tl.disable_tma_lower": True}) import torch from tilelang.utils import map_torch_type + a = torch.randn((M, K), dtype=map_torch_type(dtype)).cuda() b = torch.randn((N, K), dtype=map_torch_type(dtype)).cuda() c = kernel(a, b) diff --git a/testing/python/language/test_tilelang_language_composable_index.py b/testing/python/language/test_tilelang_language_composable_index.py index ac2254f3034f51993c9374803db3dc04a7e7a806..8a586956bbd3e62cd555f3b5a5b56659fb1cb267 100644 --- a/testing/python/language/test_tilelang_language_composable_index.py +++ b/testing/python/language/test_tilelang_language_composable_index.py @@ -7,11 +7,10 @@ import torch # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit def tilelang_composable_copy(M, N, block_M, block_N, dtype="float16"): - @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M * N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M * N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): @@ -35,7 +34,8 @@ def run_tilelang_composable_copy(M=1024, N=1024, block_M=128, block_N=128, dtype pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) torch.testing.assert_close(b.flatten(), a.flatten(), rtol=1e-2, atol=1e-2) diff --git a/testing/python/language/test_tilelang_language_copy.py b/testing/python/language/test_tilelang_language_copy.py index 4a2ddee8e4b327d96274fd8eb99c9c7d52ab22e5..367f8ed1dd8e784eb5205cd427d0affb6b288de2 100644 --- a/testing/python/language/test_tilelang_language_copy.py +++ b/testing/python/language/test_tilelang_language_copy.py @@ -7,11 +7,10 @@ import tilelang.testing # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit def tilelang_copy(M, N, block_M, block_N, dtype="float16"): - @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): @@ -27,10 +26,8 @@ def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16") program, out_idx=[1], target="cuda", - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True - }) + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}, + ) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) @@ -43,11 +40,10 @@ def test_tilelang_copy(): def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype="float16"): - @T.prim_func def main( - A: T.StridedTensor((M, N), (NN, 1), dtype), - B: T.Tensor((M, N), dtype), + A: T.StridedTensor((M, N), (NN, 1), dtype), + B: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): @@ -57,12 +53,7 @@ def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype="float16"): return main -def run_tilelang_copy_with_stride(M=1024, - N=1024, - NN=2048, - block_M=128, - block_N=128, - dtype="float16"): +def run_tilelang_copy_with_stride(M=1024, N=1024, NN=2048, block_M=128, block_N=128, dtype="float16"): if isinstance(NN, int): assert NN > N, "NN must be greater than N" program = tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype) @@ -73,7 +64,8 @@ def run_tilelang_copy_with_stride(M=1024, pass_configs={ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - }) + }, + ) if isinstance(NN, T.Var): NN = N * 2 a = torch.randn(M, NN, device="cuda", dtype=getattr(torch, dtype)) @@ -87,11 +79,10 @@ def test_tilelang_copy_with_stride(): def tilelang_copy_bufferload(num_tokens, dtype="float16"): - @T.prim_func def main( - indices: T.Tensor((num_tokens,), "int32"), - x: T.Tensor((num_tokens,), dtype), + indices: T.Tensor((num_tokens,), "int32"), + x: T.Tensor((num_tokens,), dtype), ): with T.Kernel(num_tokens, threads=32) as pid: idx = T.alloc_local([1], "int32") @@ -107,10 +98,8 @@ def run_tilelang_copy_bufferload(num_tokens=128, dtype="float16"): tilelang.compile( program, out_idx=[1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True - }) + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}, + ) def test_tilelang_copy_bufferload(): @@ -118,11 +107,10 @@ def test_tilelang_copy_bufferload(): def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype="float16"): - @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): @@ -132,20 +120,14 @@ def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype="float return main -def run_tilelang_copy_buffer_load_with_parallel(M=1024, - N=1024, - block_M=128, - block_N=128, - dtype="float16"): +def run_tilelang_copy_buffer_load_with_parallel(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): program = tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype) kernel = tilelang.compile( program, out_idx=[1], target="cuda", - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True - }) + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}, + ) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) diff --git a/testing/python/language/test_tilelang_language_cumsum.py b/testing/python/language/test_tilelang_language_cumsum.py index 0046405351df44fbd606a7872e877439439eca0e..76982a4e8069873166b56832de7fb7209322c42f 100644 --- a/testing/python/language/test_tilelang_language_cumsum.py +++ b/testing/python/language/test_tilelang_language_cumsum.py @@ -9,8 +9,8 @@ def cumsum_smem_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="float3 @T.prim_func def cumsum( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): @@ -28,8 +28,8 @@ def cumsum_fragment_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="fl @T.prim_func def cumsum( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): @@ -57,13 +57,16 @@ def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32", sc ref_b = torch.empty_like(A) for i in range(M // block_M): for j in range(N // block_N): - ref_b[i * block_M:(i + 1) * block_M, - j * block_N:(j + 1) * block_N] = A[i * block_M:(i + 1) * block_M, j * - block_N:(j + 1) * block_N].cumsum(dim=dim) + ref_b[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = A[ + i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N + ].cumsum(dim=dim) if reverse: - ref_b[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) * - block_N] = A[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) * - block_N].flip(dims=[dim]).cumsum(dim=dim).flip(dims=[dim]) + ref_b[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = ( + A[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] + .flip(dims=[dim]) + .cumsum(dim=dim) + .flip(dims=[dim]) + ) return ref_b tilelang_res = jit_kernel(A) @@ -76,8 +79,8 @@ def cumsum_smem_test_1d(N, block_N, reverse=False, dtype="float32"): @T.prim_func def cumsum( - A: T.Tensor((N,), dtype), - B: T.Tensor((N,), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: A_shared = T.alloc_shared((block_N,), dtype) @@ -94,8 +97,8 @@ def cumsum_fragment_test_1d(N, block_N, reverse=False, dtype="float32"): @T.prim_func def cumsum( - A: T.Tensor((N,), dtype), - B: T.Tensor((N,), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: A_shared = T.alloc_shared((block_N,), dtype) diff --git a/testing/python/language/test_tilelang_language_frontend_v2.py b/testing/python/language/test_tilelang_language_frontend_v2.py index 78d38f3a2d9df3b3a871f0c0994b852f62ba5c72..b0191b4d3f1b983570fe4a0d672df33aac60b67f 100644 --- a/testing/python/language/test_tilelang_language_frontend_v2.py +++ b/testing/python/language/test_tilelang_language_frontend_v2.py @@ -8,7 +8,6 @@ from tvm.tir.expr import IntImm, Var def test_argument(): - @T.prim_func def test_argument( t_1: T.bool, @@ -41,6 +40,7 @@ def test_argument(): def test_expr(): from tilelang.language.v2.dtypes import _all_dtypes + errors = [] for name in _all_dtypes: dtype = getattr(T, name) @@ -116,33 +116,32 @@ def test_expr(): def test_dtype_str_repr(): - @T.prim_func def test_str_repr(): - buf_1 = T.alloc_buffer((1,), dtype=T.bool, scope='shared') # noqa F841 - buf_2 = T.alloc_buffer((1,), dtype=T.short, scope='shared') # noqa F841 - buf_3 = T.alloc_buffer((1,), dtype=T.int, scope='shared') # noqa F841 - buf_4 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841 - buf_5 = T.alloc_buffer((1,), dtype=T.half, scope='shared') # noqa F841 - buf_6 = T.alloc_buffer((1,), dtype=T.float, scope='shared') # noqa F841 - buf_7 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841 - buf_8 = T.alloc_buffer((1,), dtype=T.int8, scope='shared') # noqa F841 - buf_9 = T.alloc_buffer((1,), dtype=T.int16, scope='shared') # noqa F841 - buf_10 = T.alloc_buffer((1,), dtype=T.int32, scope='shared') # noqa F841 - buf_11 = T.alloc_buffer((1,), dtype=T.int64, scope='shared') # noqa F841 - buf_12 = T.alloc_buffer((1,), dtype=T.uint8, scope='shared') # noqa F841 - buf_13 = T.alloc_buffer((1,), dtype=T.uint16, scope='shared') # noqa F841 - buf_14 = T.alloc_buffer((1,), dtype=T.uint32, scope='shared') # noqa F841 - buf_15 = T.alloc_buffer((1,), dtype=T.uint64, scope='shared') # noqa F841 - buf_16 = T.alloc_buffer((1,), dtype=T.float8_e4m3fn, scope='shared') # noqa F841 - buf_17 = T.alloc_buffer((1,), dtype=T.float8_e4m3fnuz, scope='shared') # noqa F841 - buf_18 = T.alloc_buffer((1,), dtype=T.float8_e5m2, scope='shared') # noqa F841 - buf_19 = T.alloc_buffer((1,), dtype=T.float8_e5m2fnuz, scope='shared') # noqa F841 - buf_20 = T.alloc_buffer((1,), dtype=T.float8_e8m0fnu, scope='shared') # noqa F841 - buf_21 = T.alloc_buffer((1,), dtype=T.float16, scope='shared') # noqa F841 - buf_22 = T.alloc_buffer((1,), dtype=T.bfloat16, scope='shared') # noqa F841 - buf_23 = T.alloc_buffer((1,), dtype=T.float32, scope='shared') # noqa F841 - buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope='shared') # noqa F841 + buf_1 = T.alloc_buffer((1,), dtype=T.bool, scope="shared") # noqa F841 + buf_2 = T.alloc_buffer((1,), dtype=T.short, scope="shared") # noqa F841 + buf_3 = T.alloc_buffer((1,), dtype=T.int, scope="shared") # noqa F841 + buf_4 = T.alloc_buffer((1,), dtype=T.long, scope="shared") # noqa F841 + buf_5 = T.alloc_buffer((1,), dtype=T.half, scope="shared") # noqa F841 + buf_6 = T.alloc_buffer((1,), dtype=T.float, scope="shared") # noqa F841 + buf_7 = T.alloc_buffer((1,), dtype=T.long, scope="shared") # noqa F841 + buf_8 = T.alloc_buffer((1,), dtype=T.int8, scope="shared") # noqa F841 + buf_9 = T.alloc_buffer((1,), dtype=T.int16, scope="shared") # noqa F841 + buf_10 = T.alloc_buffer((1,), dtype=T.int32, scope="shared") # noqa F841 + buf_11 = T.alloc_buffer((1,), dtype=T.int64, scope="shared") # noqa F841 + buf_12 = T.alloc_buffer((1,), dtype=T.uint8, scope="shared") # noqa F841 + buf_13 = T.alloc_buffer((1,), dtype=T.uint16, scope="shared") # noqa F841 + buf_14 = T.alloc_buffer((1,), dtype=T.uint32, scope="shared") # noqa F841 + buf_15 = T.alloc_buffer((1,), dtype=T.uint64, scope="shared") # noqa F841 + buf_16 = T.alloc_buffer((1,), dtype=T.float8_e4m3fn, scope="shared") # noqa F841 + buf_17 = T.alloc_buffer((1,), dtype=T.float8_e4m3fnuz, scope="shared") # noqa F841 + buf_18 = T.alloc_buffer((1,), dtype=T.float8_e5m2, scope="shared") # noqa F841 + buf_19 = T.alloc_buffer((1,), dtype=T.float8_e5m2fnuz, scope="shared") # noqa F841 + buf_20 = T.alloc_buffer((1,), dtype=T.float8_e8m0fnu, scope="shared") # noqa F841 + buf_21 = T.alloc_buffer((1,), dtype=T.float16, scope="shared") # noqa F841 + buf_22 = T.alloc_buffer((1,), dtype=T.bfloat16, scope="shared") # noqa F841 + buf_23 = T.alloc_buffer((1,), dtype=T.float32, scope="shared") # noqa F841 + buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope="shared") # noqa F841 # not supported now @@ -205,7 +204,6 @@ def test_dtype_str_repr(): def test_var_assign(): - @tilelang.jit(out_idx=-1) @T.prim_func def test_var_assign(A: T.Tensor((2,), T.int32)): @@ -223,7 +221,6 @@ def test_var_assign(): def test_marco_return(): - @T.macro def macro_return_constant(): return 0 @@ -258,11 +255,10 @@ def test_marco_return(): def test_prim_func_generator(): - @T.prim_func(generator=True) def prim_func_gen( - A=T.Tensor((128,), T.float32), # noqa: B008 - B=T.Tensor((128,), T.float32), # noqa: B008 + A=T.Tensor((128,), T.float32), # noqa: B008 + B=T.Tensor((128,), T.float32), # noqa: B008 ): with T.Kernel(128) as (tx,): T.copy(A[tx], B[tx]) @@ -277,7 +273,6 @@ def test_prim_func_generator(): def test_serial_for_with_step(): - @tilelang.jit(out_idx=-1) @T.prim_func def test_stepped_serial(A: T.Tensor((10,), T.int32)): @@ -291,7 +286,7 @@ def test_serial_for_with_step(): ker = test_stepped_serial() res = ker() - ref = torch.tensor([1, 2, 1, 2, 1, 2, 1, 2, 1, 2], dtype=torch.int32, device='cuda') + ref = torch.tensor([1, 2, 1, 2, 1, 2, 1, 2, 1, 2], dtype=torch.int32, device="cuda") assert torch.all(res == ref), f"Expected {ref}, but got {res}" @tilelang.jit(out_idx=-1) @@ -304,17 +299,16 @@ def test_serial_for_with_step(): ker = test_serial_step_neg() res = ker() - ref = torch.tensor([10, 9, 8, 7, 6, 5, 4, 3, 2, 1], dtype=torch.int32, device='cuda') + ref = torch.tensor([10, 9, 8, 7, 6, 5, 4, 3, 2, 1], dtype=torch.int32, device="cuda") assert torch.all(res == ref), f"Expected {ref}, but got {res}" assert isinstance(T.serial(1, 10, 1), IRBuilderFrame) - assert isinstance(T.serial(1, 10, IntImm('int32', 1)), IRBuilderFrame) - assert not isinstance(T.serial(1, 10, Var('tmp', 'int32')), IRBuilderFrame) + assert isinstance(T.serial(1, 10, IntImm("int32", 1)), IRBuilderFrame) + assert not isinstance(T.serial(1, 10, Var("tmp", "int32")), IRBuilderFrame) assert not isinstance(T.serial(10, -1, -1), IRBuilderFrame) def test_swap_logic(): - @tilelang.jit @T.prim_func def swap_var(A: T.Tensor[(2,), T.float32]): @@ -344,7 +338,6 @@ def test_swap_logic(): def test_while_loop(): - @tilelang.jit(out_idx=-1) @T.prim_func def test_while_loop(A: T.Tensor((1,), T.int32)): @@ -374,7 +367,7 @@ def test_var_macro(): x = T.alloc_var(T.int32) macro_with_var(x) - assert 'x[0] = 1' in prim_call_macro.script() + assert "x[0] = 1" in prim_call_macro.script() finally: pass @@ -406,7 +399,7 @@ def test_var_macro(): x = T.alloc_var(T.int32) macro_with_var(x) - assert 'x[0] = 1' in prim_call_macro.script() + assert "x[0] = 1" in prim_call_macro.script() finally: pass @@ -428,10 +421,8 @@ def test_var_macro(): def test_frame_inside_macro(): - @tilelang.jit def get_sample_kernel(): - @T.macro def transform(x): return x + 1 @@ -442,7 +433,7 @@ def test_frame_inside_macro(): idx_out: T.Tensor[(32,), T.int32], ): with T.Kernel(num_blocks, threads=32) as block_idx: # noqa: F841 - fragment = T.alloc_fragment(32, 'int32') + fragment = T.alloc_fragment(32, "int32") T.copy(idx_out, fragment) for i in T.Parallel(32): @@ -467,10 +458,10 @@ def test_buffer_slice_step(): def test_boolop(): - a = Var('a', 'int32') - b = Var('b', 'int32') - c = Var('c', 'int32') - d = Var('d', 'int32') + a = Var("a", "int32") + b = Var("b", "int32") + c = Var("c", "int32") + d = Var("d", "int32") @T.macro def cond(): @@ -479,5 +470,5 @@ def test_boolop(): cond() -if __name__ == '__main__': +if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_language_get_warp_info.py b/testing/python/language/test_tilelang_language_get_warp_info.py index 68b65fcd4e5c8f92e8747917111ef968f5791748..edbc511d0d74dfb703c6d42e2b31d3ca133de867 100644 --- a/testing/python/language/test_tilelang_language_get_warp_info.py +++ b/testing/python/language/test_tilelang_language_get_warp_info.py @@ -23,7 +23,6 @@ def _resolve_warps_per_group(warps_per_group: Optional[int]) -> int: @tilelang.jit(out_idx=[-1]) def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None): - @T.prim_func def laneid_kernel(A: T.Tensor((num_threads,), "int32")): with T.Kernel(1, threads=num_threads) as _: @@ -35,7 +34,6 @@ def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None): @tilelang.jit(out_idx=[-1]) def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] = None): - @T.prim_func def warp_idx_sync_kernel(A: T.Tensor((num_threads,), "int32")): with T.Kernel(1, threads=num_threads) as _: @@ -47,7 +45,6 @@ def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] = @tilelang.jit(out_idx=[-1]) def _get_warp_idx_kernel(num_threads: int = 128, warp_size: Optional[int] = None): - @T.prim_func def warp_idx_kernel(A: T.Tensor((num_threads,), "int32")): with T.Kernel(1, threads=num_threads) as _: @@ -63,7 +60,6 @@ def _get_warp_group_idx_kernel( warp_size: Optional[int] = None, warps_per_group: Optional[int] = None, ): - @T.prim_func def warp_group_idx_kernel(A: T.Tensor((num_threads,), "int32")): with T.Kernel(1, threads=num_threads) as _: @@ -75,7 +71,6 @@ def _get_warp_group_idx_kernel( @tilelang.jit(out_idx=[-1]) def _shuffle_elect_kernel(num_threads: int = 128, thread_extent: int = 64): - @T.prim_func def shuffle_elect_kernel(A: T.Tensor((num_threads,), "int32")): with T.Kernel(1, threads=num_threads) as _: diff --git a/testing/python/language/test_tilelang_language_if_range.py b/testing/python/language/test_tilelang_language_if_range.py index b3550f589b37d827069b9145cbfcdbebe5219bb7..9c98456904fb79c9604d8706bdafaf034cd5a65d 100644 --- a/testing/python/language/test_tilelang_language_if_range.py +++ b/testing/python/language/test_tilelang_language_if_range.py @@ -4,13 +4,14 @@ import torch import tilelang.testing -@tilelang.jit(out_idx=[1],) +@tilelang.jit( + out_idx=[1], +) def tilelang_if_range(M, N, block_M, block_N, dtype="float16"): - @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): diff --git a/testing/python/language/test_tilelang_language_infinity.py b/testing/python/language/test_tilelang_language_infinity.py index 0779bff57abf09b0bc4732f76a7a29319e004df2..5d2518661e4d0b15f66c9c5c9caf32b6537a24c6 100644 --- a/testing/python/language/test_tilelang_language_infinity.py +++ b/testing/python/language/test_tilelang_language_infinity.py @@ -5,7 +5,6 @@ import tilelang.language as T @tilelang.jit(out_idx=-1) def get_inf_kernel(dtype: str): - @T.prim_func def main(A: T.Tensor((32,), dtype)): with T.Kernel(1, threads=32): @@ -18,7 +17,7 @@ def _test_infinity(dtype: str): kernel = get_inf_kernel(dtype) output = kernel() - assert torch.all(output == torch.inf), f'check failed for {dtype=}' + assert torch.all(output == torch.inf), f"check failed for {dtype=}" @tilelang.testing.requires_cuda diff --git a/testing/python/language/test_tilelang_language_intrinsics_codegen.py b/testing/python/language/test_tilelang_language_intrinsics_codegen.py index f817be26db5582c074d3c106152f8549651d43db..80318242cbb9b9d786d0edb2b563568b7fea790d 100644 --- a/testing/python/language/test_tilelang_language_intrinsics_codegen.py +++ b/testing/python/language/test_tilelang_language_intrinsics_codegen.py @@ -9,8 +9,8 @@ def test_language_ldg_codegen(): @T.prim_func def main( - x: T.Tensor((N,), "float32"), - y: T.Tensor((N,), "float32"), + x: T.Tensor((N,), "float32"), + y: T.Tensor((N,), "float32"), ): with T.Kernel(N, threads=32) as pid: # Explicitly request read-only cache load for x[pid] diff --git a/testing/python/language/test_tilelang_language_lazy_jit.py b/testing/python/language/test_tilelang_language_lazy_jit.py index d3b20c6b97256d914e0c244c3eaa619829807b76..31da09c548f2c586d766005561b48546b1d79604 100644 --- a/testing/python/language/test_tilelang_language_lazy_jit.py +++ b/testing/python/language/test_tilelang_language_lazy_jit.py @@ -8,7 +8,6 @@ import torch def _gemm_impl(): - @T.macro def gemm_impl( A: T.Tensor[[int, int], Any], @@ -37,7 +36,6 @@ def _gemm_impl(): def test_jit2_gemm_annot(): - @tilelang.lazy_jit def gemm( A: T.Tensor[[int, int], Any], @@ -54,24 +52,24 @@ def test_jit2_gemm_annot(): return C prod = product([T.float16, T.float32], [T.float32]) - gemm.par_compile([{ - 'A': T.Tensor((1024, 1024), dtype=in_dtype), - 'B': T.Tensor((1024, 1024), dtype=in_dtype), - 'out_dtype': out_dtype - } for in_dtype, out_dtype in prod]) + gemm.par_compile( + [ + {"A": T.Tensor((1024, 1024), dtype=in_dtype), "B": T.Tensor((1024, 1024), dtype=in_dtype), "out_dtype": out_dtype} + for in_dtype, out_dtype in prod + ] + ) for in_dtype, out_dtype in prod: in_dtype = in_dtype.torch() out_dtype = out_dtype.torch() - A = torch.randn(1024, 1024, dtype=in_dtype, device='cuda') - B = torch.randn(1024, 1024, dtype=in_dtype, device='cuda') + A = torch.randn(1024, 1024, dtype=in_dtype, device="cuda") + B = torch.randn(1024, 1024, dtype=in_dtype, device="cuda") C_ref = out_dtype(A @ B) C = gemm(A, B) torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2) def test_jit2_gemm_ptr(): - @tilelang.lazy_jit def gemm_ptr( A: T.ptr, @@ -92,23 +90,19 @@ def test_jit2_gemm_ptr(): _gemm_impl()(A, B, C, out_dtype, block_M, block_N, block_K) prod = product([T.float16, T.float32], [T.float32]) - gemm_ptr.par_compile([{ - 'A': T.ptr(), - 'B': T.ptr(), - 'C': T.ptr(), - 'M': 1024, - 'N': 1024, - 'K': 1024, - 'dtype': in_dtype, - 'out_dtype': out_dtype - } for in_dtype, out_dtype in prod]) + gemm_ptr.par_compile( + [ + {"A": T.ptr(), "B": T.ptr(), "C": T.ptr(), "M": 1024, "N": 1024, "K": 1024, "dtype": in_dtype, "out_dtype": out_dtype} + for in_dtype, out_dtype in prod + ] + ) for in_dtype, out_dtype in prod: in_dtype = in_dtype.torch() out_dtype = out_dtype.torch() - A = torch.randn(1024, 1024, dtype=in_dtype, device='cuda') - B = torch.randn(1024, 1024, dtype=in_dtype, device='cuda') + A = torch.randn(1024, 1024, dtype=in_dtype, device="cuda") + B = torch.randn(1024, 1024, dtype=in_dtype, device="cuda") C_ref = out_dtype(A @ B) - C = torch.empty(1024, 1024, dtype=out_dtype, device='cuda') + C = torch.empty(1024, 1024, dtype=out_dtype, device="cuda") gemm_ptr(A, B, C, 1024, 1024, 1024, in_dtype, out_dtype) torch.testing.assert_close(C, C_ref, atol=1e-2, rtol=1e-2) @@ -129,8 +123,7 @@ def test_jit2_annot(): AnnotTest( annot=T.Tensor[[int, int], T.float32], promote=False, - match_ok=[torch.randn(1, 1, dtype=torch.float32), - T.Tensor((1, 1), dtype=T.float32)], + match_ok=[torch.randn(1, 1, dtype=torch.float32), T.Tensor((1, 1), dtype=T.float32)], match_ng=[ torch.randn(1, 1, dtype=torch.float16), T.Tensor(1, dtype=T.float32), @@ -146,8 +139,8 @@ def test_jit2_annot(): T.Tensor((1,), dtype=T.float32), T.Tensor((1,), dtype=T.float16), ], - match_ng=[torch.randn((1, 1), dtype=torch.float32), - T.Tensor((1, 1), dtype=T.float16)]), + match_ng=[torch.randn((1, 1), dtype=torch.float32), T.Tensor((1, 1), dtype=T.float16)], + ), AnnotTest( annot=T.Tensor[[int, 1], Any], promote=False, @@ -157,8 +150,8 @@ def test_jit2_annot(): T.Tensor((12, 1), T.float32), T.Tensor((12, 1), T.float16), ], - match_ng=[torch.randn(12, 12, dtype=torch.float32), - T.Tensor((12, 12), T.float32)]), + match_ng=[torch.randn(12, 12, dtype=torch.float32), T.Tensor((12, 12), T.float32)], + ), AnnotTest( annot=T.Tensor[[T.dyn, 1], Any], promote=False, @@ -168,43 +161,39 @@ def test_jit2_annot(): T.Tensor((12, 1), T.float32), T.Tensor((12, 1), T.float16), ], - match_ng=[torch.randn(12, 12, dtype=torch.float32), - T.Tensor((12, 12), T.float32)]), + match_ng=[torch.randn(12, 12, dtype=torch.float32), T.Tensor((12, 12), T.float32)], + ), AnnotTest( annot=T.Tensor[[1024, 1024], T.float32], promote=True, ), - AnnotTest(annot=T.dyn[int, 'X'], promote=False, match_ok=[1, 2, 3, 4]), - AnnotTest(annot=T.dyn, promote=False, match_ok=[1, 2, 3, 4]) + AnnotTest(annot=T.dyn[int, "X"], promote=False, match_ok=[1, 2, 3, 4]), + AnnotTest(annot=T.dyn, promote=False, match_ok=[1, 2, 3, 4]), ] for test in tests: promote = test.annot.promote() promoted = promote is not None if promoted != test.promote: - raise AssertionError( - f'Promote mismatch for {test.annot}: expected {test.promote}, got {promoted}') - with Builder().prim_func('_test'): + raise AssertionError(f"Promote mismatch for {test.annot}: expected {test.promote}, got {promoted}") + with Builder().prim_func("_test"): for match_ok in test.match_ok: try: vt = ArgVarTable() - test.annot.create_prim_func_arg('arg', match_ok, vt) + test.annot.create_prim_func_arg("arg", match_ok, vt) except Exception as e: traceback.print_exc() - raise AssertionError( - f'Match failed for {test.annot} with value {match_ok}: {e}') from e + raise AssertionError(f"Match failed for {test.annot} with value {match_ok}: {e}") from e for match_ng in test.match_ng: try: vt = ArgVarTable() - test.annot.create_prim_func_arg('arg', match_ng, vt) - raise AssertionError( - f'Match unexpectedly succeeded for {test.annot} with value {match_ng}') + test.annot.create_prim_func_arg("arg", match_ng, vt) + raise AssertionError(f"Match unexpectedly succeeded for {test.annot} with value {match_ng}") except Exception: pass def test_jit2_many_annot(): - @T.macro def copy_impl(A, B): M, N = A.shape @@ -213,8 +202,7 @@ def test_jit2_many_annot(): assert N == N_, f"N mismatch {N} {N_}" # assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}" with T.Kernel(T.ceildiv(M, 128), T.ceildiv(N, 128), threads=128) as (bx, by): - T.copy(A[bx * 128:bx * 128 + 128, by * 128:by * 128 + 128], B[bx * 128:bx * 128 + 128, - by * 128:by * 128 + 128]) + T.copy(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128], B[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) @tilelang.lazy_jit def copy1( @@ -259,20 +247,19 @@ def test_jit2_many_annot(): copy_impl(A, B) for copy in [copy1, copy2, copy3, copy4]: - A = torch.randn(128, 128, device='cuda') - B = torch.empty(128, 128, device='cuda') + A = torch.randn(128, 128, device="cuda") + B = torch.empty(128, 128, device="cuda") copy(A, B) assert torch.equal(B, A) for copy in [copy5, copy6]: - A = torch.randn(128, 2, 128, 2, device='cuda') - B = torch.randn(128, 2, 128, 2, device='cuda') + A = torch.randn(128, 2, 128, 2, device="cuda") + B = torch.randn(128, 2, 128, 2, device="cuda") copy(A[:, 0, :, 0], B[:, 0, :, 0]) assert torch.equal(A[:, 0, :, 0], B[:, 0, :, 0]) def test_jit2_return(): - @T.macro def copy_impl(A): M, N = A.shape @@ -283,8 +270,7 @@ def test_jit2_return(): assert N == N_, f"N mismatch {N} {N_}" # assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}" with T.Kernel(T.ceildiv(M, 128), T.ceildiv(N, 128), threads=128) as (bx, by): - T.copy(A[bx * 128:bx * 128 + 128, by * 128:by * 128 + 128], B[bx * 128:bx * 128 + 128, - by * 128:by * 128 + 128]) + T.copy(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128], B[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) return B @tilelang.lazy_jit @@ -292,41 +278,52 @@ def test_jit2_return(): return copy_impl(A) @tilelang.lazy_jit - def copy1(A: T.Tensor[[int, int], T.float32],): + def copy1( + A: T.Tensor[[int, int], T.float32], + ): return copy_impl(A) @tilelang.lazy_jit - def copy2(A: T.Tensor[[128, 128], T.float32],): + def copy2( + A: T.Tensor[[128, 128], T.float32], + ): return copy_impl(A) @tilelang.lazy_jit - def copy3(A: T.Tensor[[int, 128], T.float32],): + def copy3( + A: T.Tensor[[int, 128], T.float32], + ): return copy_impl(A) @tilelang.lazy_jit - def copy4(A: T.Tensor[[T.dyn, int], T.float32],): + def copy4( + A: T.Tensor[[T.dyn, int], T.float32], + ): return copy_impl(A) @tilelang.lazy_jit - def copy5(A: T.StridedTensor[[int, int], [int, int], T.float32],): + def copy5( + A: T.StridedTensor[[int, int], [int, int], T.float32], + ): return copy_impl(A) @tilelang.lazy_jit - def copy6(A: T.StridedTensor[[T.dyn, int], [int, int], T.float32],): + def copy6( + A: T.StridedTensor[[T.dyn, int], [int, int], T.float32], + ): return copy_impl(A) for copy in [copy0, copy1, copy2, copy3, copy4]: - A = torch.randn(128, 128, device='cuda') + A = torch.randn(128, 128, device="cuda") B = copy(A) assert torch.equal(B, A) for copy in [copy5, copy6]: - A = torch.randn(128, 2, 128, 2, device='cuda') + A = torch.randn(128, 2, 128, 2, device="cuda") B = copy(A[:, 0, :, 0]) assert torch.equal(A[:, 0, :, 0], B) def test_jit2_deepseek_deepgemm(): - @tilelang.lazy_jit def deep_gemm( A: T.Tensor[[int, int], T.float8_e4m3], @@ -351,13 +348,9 @@ def test_jit2_deepseek_deepgemm(): N, K = B.shape C = T.empty(M, N, dtype=out_dtype) - assert out_dtype in [ - T.bfloat16, T.float32 - ], f"Expect out_dtype to be one of [T.float16, T.float32], got {out_dtype}" - assert scales_a.shape == [M, T.ceildiv(K, group_size) - ], f"Expect scales_a shape to be f{[M, T.ceildiv(K, group_size)]}" - assert scales_b.shape == [N, T.ceildiv(K, group_size) - ], f"Expect scales_b shape to be f{[N, T.ceildiv(K, group_size)]}" + assert out_dtype in [T.bfloat16, T.float32], f"Expect out_dtype to be one of [T.float16, T.float32], got {out_dtype}" + assert scales_a.shape == [M, T.ceildiv(K, group_size)], f"Expect scales_a shape to be f{[M, T.ceildiv(K, group_size)]}" + assert scales_b.shape == [N, T.ceildiv(K, group_size)], f"Expect scales_b shape to be f{[N, T.ceildiv(K, group_size)]}" with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), in_dtype) @@ -421,5 +414,5 @@ def test_jit2_deepseek_deepgemm(): # M, N, K = 1024, 1024, 8192 # A = torch.randn((M, K), dtype=torch.float8_e4m3fn, ) -if __name__ == '__main__': +if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_language_let.py b/testing/python/language/test_tilelang_language_let.py index a2af09c672fc32d177f9ca10120363c9d9b0da86..a2905952b403f6ba888e6320231f7fee24039fe0 100644 --- a/testing/python/language/test_tilelang_language_let.py +++ b/testing/python/language/test_tilelang_language_let.py @@ -4,7 +4,6 @@ from tilelang import language as T def test_let_vectorize_load(): - @T.prim_func def main(A_ptr: T.handle): A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16) diff --git a/testing/python/language/test_tilelang_language_mask_op.py b/testing/python/language/test_tilelang_language_mask_op.py index ad90785f4994fdced27c3b7945fcdf1aa86e88f7..37b5204510aedc1e22163b79e42ba0ad3e0ac9a6 100644 --- a/testing/python/language/test_tilelang_language_mask_op.py +++ b/testing/python/language/test_tilelang_language_mask_op.py @@ -6,11 +6,10 @@ import torch # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype="float16"): - @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): @@ -30,13 +29,8 @@ def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype="float16"): def run_tilelang_copy_mask_parallel(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): program = tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype) kernel = tilelang.compile( - program, - out_idx=[1], - target="cuda", - pass_configs={ - "tl.disable_warp_specialized": True, - "tl.disable_tma_lower": True - }) + program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True} + ) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) @@ -49,11 +43,10 @@ def test_tilelang_copy_mask_parallel(): # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype="float16"): - @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): @@ -72,13 +65,8 @@ def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype="float16"): def run_tilelang_copy_mask_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): program = tilelang_copy_mask_copy(M, N, block_M, block_N, dtype) kernel = tilelang.compile( - program, - out_idx=[1], - target="cuda", - pass_configs={ - "tl.disable_warp_specialized": True, - "tl.disable_tma_lower": True - }) + program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True} + ) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) @@ -91,11 +79,10 @@ def test_tilelang_copy_mask_copy(): # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype="float16"): - @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): @@ -112,20 +99,11 @@ def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype="float16"): return main -def run_tilelang_copy_mask_parallel_range(M=1024, - N=1024, - block_M=128, - block_N=128, - dtype="float16"): +def run_tilelang_copy_mask_parallel_range(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): program = tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype) kernel = tilelang.compile( - program, - out_idx=[1], - target="cuda", - pass_configs={ - "tl.disable_warp_specialized": True, - "tl.disable_tma_lower": True - }) + program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True} + ) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) @@ -138,11 +116,10 @@ def test_tilelang_copy_mask_parallel_range(): # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype="float16"): - @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): @@ -161,13 +138,8 @@ def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype="float16"): def run_tilelang_copy_mask_copy_range(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): program = tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype) kernel = tilelang.compile( - program, - out_idx=[1], - target="cuda", - pass_configs={ - "tl.disable_warp_specialized": True, - "tl.disable_tma_lower": True - }) + program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True} + ) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) diff --git a/testing/python/language/test_tilelang_language_negative_index.py b/testing/python/language/test_tilelang_language_negative_index.py index 4a0df878b9b828215b44a3fee75b1445b701c59b..c052ccb92c657f244e4a54f8b6b5a00ac6b8d0b1 100644 --- a/testing/python/language/test_tilelang_language_negative_index.py +++ b/testing/python/language/test_tilelang_language_negative_index.py @@ -31,8 +31,7 @@ def negative_index_loop_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,) @T.prim_func -def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), "float32"), - B: T.Buffer((16,), "float32")): +def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): T.func_attr({"tir.noalias": True}) for i in T.serial(16): B[i] = A[shift + i] diff --git a/testing/python/language/test_tilelang_language_parallel.py b/testing/python/language/test_tilelang_language_parallel.py index b51ca8b680ceb87a0d972cfc9e7c848ff04d2c2e..b0e85ff47c8e4568ce7d245fc6dd09d6c0188af7 100644 --- a/testing/python/language/test_tilelang_language_parallel.py +++ b/testing/python/language/test_tilelang_language_parallel.py @@ -9,11 +9,10 @@ tilelang.testing.set_random_seed() @tilelang.jit(out_idx=[1]) def parallel_elementwise_static(length=256, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((length,), dtype), - B: T.Tensor((length,), dtype), + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), ): with T.Kernel(1, threads=length) as _: for i in T.Parallel(length): @@ -24,12 +23,11 @@ def parallel_elementwise_static(length=256, dtype="float32"): @tilelang.jit(out_idx=[1]) def parallel_elementwise_dynamic(max_len=512, threads=256, dtype="float32"): - @T.prim_func def main( - A: T.Tensor((max_len,), dtype), - B: T.Tensor((max_len,), dtype), - valid_len: T.int32, + A: T.Tensor((max_len,), dtype), + B: T.Tensor((max_len,), dtype), + valid_len: T.int32, ): with T.Kernel(1, threads=threads) as _: for i in T.Parallel(max_len): diff --git a/testing/python/language/test_tilelang_language_pipeline.py b/testing/python/language/test_tilelang_language_pipeline.py index 212f281ea9bebd1d771133d432acaef84b5f1ef9..54e10550b6603cb8c98dc682815730dd7a5c5606 100644 --- a/testing/python/language/test_tilelang_language_pipeline.py +++ b/testing/python/language/test_tilelang_language_pipeline.py @@ -27,9 +27,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -90,7 +90,8 @@ def run_gemm( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = kernel.get_profiler() def ref_program(A, B): @@ -103,8 +104,8 @@ def run_gemm( if in_dtype == "float32": # Convert float32 to tfloat32 because tfloat32 mma cannot truncate # float32 automatically, -0x1000 meas - A = ((A.view(torch.int32) - 0x1000)).view(torch.float32) - B = ((B.view(torch.int32) - 0x1000)).view(torch.float32) + A = (A.view(torch.int32) - 0x1000).view(torch.float32) + B = (B.view(torch.int32) - 0x1000).view(torch.float32) C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(torch.__getattribute__(out_dtype)) return C @@ -124,27 +125,19 @@ def test_pipeline_order_stage(): pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) -def blocksparse_matmul(M, - N, - K, - block_M, - block_N, - block_K, - num_stages, - dtype="float16", - accum_dtype="float"): - + }, +) +def blocksparse_matmul(M, N, K, block_M, block_N, block_K, num_stages, dtype="float16", accum_dtype="float"): block_mask_shape = (M // block_M, N // block_N, K // block_K) import tilelang.language as T @T.prim_func def block_sparse_matmul( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - BlockMask: T.Tensor(block_mask_shape, "bool"), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -183,8 +176,7 @@ def run_blocksparse_matmul(num_stages): a = torch.randn(M, K).cuda().half() b = torch.randn(K, N).cuda().half() - kernel = blocksparse_matmul( - M, N, K, block_M=block_M, block_N=block_N, block_K=block_K, num_stages=num_stages) + kernel = blocksparse_matmul(M, N, K, block_M=block_M, block_N=block_N, block_K=block_K, num_stages=num_stages) print(kernel.get_kernel_source()) # Create block mask with desired sparsity mask_shape = (M // block_M, N // block_N, K // block_K) @@ -200,12 +192,10 @@ def run_blocksparse_matmul(num_stages): accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device) for k in range(K // block_K): if BlockMask[i, j, k]: - accu += ( - A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to( - torch.float32) @ B[k * block_K:(k + 1) * block_K, - j * block_N:(j + 1) * block_N].to(torch.float32)) - ref_c[i * block_M:(i + 1) * block_M, - j * block_N:(j + 1) * block_N] = accu.to(torch.float16) + accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[ + k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N + ].to(torch.float32) + ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16) return ref_c # Compute the reference result using the naive PyTorch implementation diff --git a/testing/python/language/test_tilelang_language_ptr.py b/testing/python/language/test_tilelang_language_ptr.py index e4659ecc580697b36f9c69b93fdd167a873a3b97..0e60ddd7290ea6e19ecda3726057463f0fb80e45 100644 --- a/testing/python/language/test_tilelang_language_ptr.py +++ b/testing/python/language/test_tilelang_language_ptr.py @@ -7,7 +7,6 @@ from tilelang.utils import map_torch_type def matmul_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def main( a_ptr: T.ptr, diff --git a/testing/python/language/test_tilelang_language_reduce.py b/testing/python/language/test_tilelang_language_reduce.py index cecfaa097c7aee740b395ff933b5f4e2ccd4ad2d..7ec500391462b09c4c5838346f606905628adc04 100644 --- a/testing/python/language/test_tilelang_language_reduce.py +++ b/testing/python/language/test_tilelang_language_reduce.py @@ -10,8 +10,8 @@ def _make_shared_reduce(M, N, dtype, reduce_cb): @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M,), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M,), dtype), ): with T.Kernel(1) as _: A_shared = T.alloc_shared((M, N), dtype) @@ -35,8 +35,8 @@ def reduce_max_test(M, N, dtype="float16"): @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M,), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M,), dtype), ): with T.Kernel(1) as _: A_local = T.alloc_fragment((M, N), dtype) @@ -54,8 +54,8 @@ def reduce_sum_test(M, N, dtype="float32"): @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M,), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M,), dtype), ): with T.Kernel(1) as _: A_local = T.alloc_fragment((M, N), dtype) @@ -145,8 +145,8 @@ def reduce_sum_test_clear(M, N, dtype="float32"): @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M,), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M,), dtype), ): with T.Kernel(1, threads=32) as _: A_local = T.alloc_fragment((M, N), dtype) @@ -186,8 +186,8 @@ def reduce_max_test_clear(M, N, dtype="float16"): @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M,), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M,), dtype), ): with T.Kernel(1, threads=32) as _: A_local = T.alloc_fragment((M, N), dtype) diff --git a/testing/python/language/test_tilelang_language_reshape.py b/testing/python/language/test_tilelang_language_reshape.py index 60588b4afe66a38a0cbff9bd2e82a0aaa0475bc6..3c343309a7af73470902dd1e71a5b545bb4eac7b 100644 --- a/testing/python/language/test_tilelang_language_reshape.py +++ b/testing/python/language/test_tilelang_language_reshape.py @@ -10,8 +10,8 @@ def reshape_test(N, M, dtype): @T.prim_func def main( - A: T.Tensor((N,), dtype), - B: T.Tensor((N // M, M), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N // M, M), dtype), ): with T.Kernel(1) as _: A_reshaped = T.reshape(A, [N // M, M]) @@ -30,7 +30,8 @@ def run_reshape(N, M, dtype): pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = jit_kernel.get_profiler() def ref_program(A): @@ -50,8 +51,8 @@ def reshape_test_smem_1d_2_2d(N, M, dtype): @T.prim_func def main( - A: T.Tensor((N,), dtype), - B: T.Tensor((N // M, M), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N // M, M), dtype), ): with T.Kernel(1) as _: A_shared = T.alloc_shared((N,), dtype) @@ -74,7 +75,8 @@ def run_reshape_smem_1d_2_2d(N, M, dtype): pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = jit_kernel.get_profiler() def ref_program(A): @@ -93,8 +95,8 @@ def reshape_test_smem_2d_2_1d(N, M, dtype): @T.prim_func def main( - A: T.Tensor((N // M, M), dtype), - B: T.Tensor((N,), dtype), + A: T.Tensor((N // M, M), dtype), + B: T.Tensor((N,), dtype), ): with T.Kernel(1) as _: A_shared = T.alloc_shared((N // M, M), dtype) @@ -117,7 +119,8 @@ def run_reshape_smem_2d_2_1d(N, M, dtype): pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = jit_kernel.get_profiler() def ref_program(A): @@ -136,8 +139,8 @@ def reshape_fragment_test(N, M, dtype): @T.prim_func def main( - A: T.Tensor((N // M, M), dtype), - B: T.Tensor((N,), dtype), + A: T.Tensor((N // M, M), dtype), + B: T.Tensor((N,), dtype), ): with T.Kernel(1, threads=32) as _: A_shared = T.alloc_shared((N // M, M), dtype, scope="shared") @@ -161,7 +164,8 @@ def run_reshape_fragment(N, M, dtype): pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = jit_kernel.get_profiler() def ref_program(A): @@ -181,15 +185,17 @@ def reshape_layout_transform_shared(N, M, dtype): @T.prim_func def main( - A: T.Tensor((N // M, M), dtype), - B: T.Tensor((N,), dtype), + A: T.Tensor((N // M, M), dtype), + B: T.Tensor((N,), dtype), ): with T.Kernel(1, threads=32) as _: A_shared = T.alloc_shared((N // M, M), dtype, scope="shared") - T.annotate_layout({ - A_shared: make_mma_swizzle_layout(A_shared), - }) + T.annotate_layout( + { + A_shared: make_mma_swizzle_layout(A_shared), + } + ) T.copy(A, A_shared) A_shared_reshape = T.reshape(A_shared, [N]) T.copy(A_shared_reshape, B) @@ -205,7 +211,8 @@ def run_reshape_layout_transform_shared(N, M, dtype): pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = jit_kernel.get_profiler() def ref_program(A): @@ -224,8 +231,8 @@ def reduce_after_reshape_test(N, M, dtype): @T.prim_func def main( - A: T.Tensor((N,), dtype), - B: T.Tensor((N // M,), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N // M,), dtype), ): with T.Kernel(1, threads=32) as _: A_shared = T.alloc_shared((N,), dtype, scope="shared") @@ -249,7 +256,8 @@ def run_reduce_after_reshape(N, M, dtype): pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = jit_kernel.get_profiler() def ref_program(A): @@ -268,8 +276,8 @@ def reshape_shape_mismatch_test(N, M, dtype): @T.prim_func def main( - A: T.Tensor((N,), dtype), - B: T.Tensor((N // M, M), dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor((N // M, M), dtype), ): with T.Kernel(1) as _: A_reshaped = T.reshape(A, [N // M, M + 1]) diff --git a/testing/python/language/test_tilelang_language_ternary.py b/testing/python/language/test_tilelang_language_ternary.py index 821231ab40021fbaa00f84c49ed0e39ca6fc7fd4..632dcf7b46ba78bec268dbe06e834f71085583c1 100644 --- a/testing/python/language/test_tilelang_language_ternary.py +++ b/testing/python/language/test_tilelang_language_ternary.py @@ -4,19 +4,19 @@ import torch import tilelang.testing -@tilelang.jit(out_idx=[1],) +@tilelang.jit( + out_idx=[1], +) def tilelang_ternary(M, N, block_M, block_N, dtype="float16"): - @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): - B[by * block_M + i, bx * block_N + j] = ( - A[by * block_M + i, bx * block_N + j] if (by * block_M + i) < (M // 2) else 0) + B[by * block_M + i, bx * block_N + j] = A[by * block_M + i, bx * block_N + j] if (by * block_M + i) < (M // 2) else 0 return main diff --git a/testing/python/language/test_tilelang_language_tma_1d.py b/testing/python/language/test_tilelang_language_tma_1d.py index efb665ba345b47b5af9c900c9335b6abbdfec751..90022b5ec56674dc0897b5b4f62fb2007a53eb9e 100644 --- a/testing/python/language/test_tilelang_language_tma_1d.py +++ b/testing/python/language/test_tilelang_language_tma_1d.py @@ -9,10 +9,8 @@ def ref_program(x, y): @tilelang.jit(out_idx=[-1]) def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): - @T.prim_func - def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor( - (M, N), out_dtype)): + def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N), out_dtype)): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared((block_M, block_N), in_dtype) B_shared = T.alloc_shared((block_M, block_N), in_dtype) @@ -21,7 +19,7 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): T.copy(A[by * block_M, bx * block_N], A_shared) T.copy(B[by * block_M, bx * block_N], B_shared) - for (local_y, local_x) in T.Parallel(block_M, block_N): + for local_y, local_x in T.Parallel(block_M, block_N): C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x] T.copy(C_local, C_shared) T.copy(C_shared, C[by * block_M, bx * block_N]) diff --git a/testing/python/language/test_tilelang_language_unroll.py b/testing/python/language/test_tilelang_language_unroll.py index 1796302e36fdff504c78f27b51bcf6835b897b98..416840a1399475b6fcdbcb5f64ab110db6f44572 100644 --- a/testing/python/language/test_tilelang_language_unroll.py +++ b/testing/python/language/test_tilelang_language_unroll.py @@ -4,7 +4,6 @@ from tilelang import language as T def test_unroll_with_step(): - @T.prim_func def main(A_ptr: T.handle): A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16) @@ -19,7 +18,6 @@ def test_unroll_with_step(): def test_unroll_with_unroll_factor(): - @T.prim_func def main(A_ptr: T.handle): A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16) diff --git a/testing/python/language/test_tilelang_language_var_init.py b/testing/python/language/test_tilelang_language_var_init.py index a5a7ddeda85b3f0c0602ec2a4d26931e2f28847c..d4f9062b8f58f61ba6d5bdca7513e4042abeb5a5 100644 --- a/testing/python/language/test_tilelang_language_var_init.py +++ b/testing/python/language/test_tilelang_language_var_init.py @@ -4,17 +4,15 @@ import tilelang.testing def test_var_assign() -> None: - @tilelang.jit(out_idx=-1) def jit_kernel(): - @T.prim_func - def test_var_assign(A: T.Tensor((2,), 'int32')): + def test_var_assign(A: T.Tensor((2,), "int32")): with T.Kernel(1) as _: - a = T.alloc_var('int32', init=1) - b = T.alloc_var('int32', init=a) # b gets value of a + a = T.alloc_var("int32", init=1) + b = T.alloc_var("int32", init=a) # b gets value of a a = 2 - d = T.alloc_var('int32', init=a) # c gets new value of a + d = T.alloc_var("int32", init=a) # c gets new value of a A[0] = b A[1] = d @@ -28,5 +26,5 @@ def test_var_assign() -> None: assert res[1] == 2 -if __name__ == '__main__': +if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_language_vectorize.py b/testing/python/language/test_tilelang_language_vectorize.py index bc2d3144673e6a1eeaeeedb0f804b6861d0ff84e..6867079c3c81d8674c3b0d43d5c76b46f612d5a9 100644 --- a/testing/python/language/test_tilelang_language_vectorize.py +++ b/testing/python/language/test_tilelang_language_vectorize.py @@ -5,11 +5,10 @@ import tilelang.language as T @tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True}) def vectorize_test(N, M, stride_A, stride_B): - @T.prim_func def main( - A: T.StridedTensor[(N, M), (1, stride_A), "float32"], # noqa: F821 - B: T.StridedTensor[(N, M), (1, stride_B), "float32"], # noqa: F821 + A: T.StridedTensor[(N, M), (1, stride_A), "float32"], # noqa: F821 + B: T.StridedTensor[(N, M), (1, stride_B), "float32"], # noqa: F821 ): with T.Kernel(M // 128, threads=128) as (bx): tx = T.get_thread_binding(0) @@ -39,9 +38,7 @@ def run_vectorize(N, M, stride_A, stride_B): code = jit_kernel.get_kernel_source() vectorize_size = 1 - while vectorize_size <= 2 and \ - stride_A % (vectorize_size * 2) == 0 and \ - stride_B % (vectorize_size * 2) == 0: + while vectorize_size <= 2 and stride_A % (vectorize_size * 2) == 0 and stride_B % (vectorize_size * 2) == 0: vectorize_size *= 2 if vectorize_size == 4: @@ -61,12 +58,11 @@ def test_vectorize(): @tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True}) def vectorize_test_invariant_index(N, M, K): - @T.prim_func def main( - A: T.Tensor[(N, M), "float32"], # noqa: F821 - B: T.Tensor[(N, M), "float32"], # noqa: F821 - C: T.Tensor[(N, M // K), "float32"], # noqa: F821 + A: T.Tensor[(N, M), "float32"], # noqa: F821 + B: T.Tensor[(N, M), "float32"], # noqa: F821 + C: T.Tensor[(N, M // K), "float32"], # noqa: F821 ): with T.Kernel(N // 128, threads=128) as (bx): tx = T.get_thread_binding(0) diff --git a/testing/python/language/test_tilelang_language_vectorized_cast.py b/testing/python/language/test_tilelang_language_vectorized_cast.py index afb8a05d3e280e351a919a80281cc4a8f2a631f4..adb59a6bd2563f8c67338c3742c083727f565b54 100644 --- a/testing/python/language/test_tilelang_language_vectorized_cast.py +++ b/testing/python/language/test_tilelang_language_vectorized_cast.py @@ -17,8 +17,8 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str): @T.prim_func def main( - A: T.Tensor[(M,), dtype_A], # noqa: F821 - B: T.Tensor[(M,), dtype_B], # noqa: F821 + A: T.Tensor[(M,), dtype_A], # noqa: F821 + B: T.Tensor[(M,), dtype_B], # noqa: F821 ): with T.Kernel(1, threads=128): T.copy(A, B) @@ -32,8 +32,8 @@ def parallel_vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str): @T.prim_func def main( - A: T.Tensor[(M,), dtype_A], # noqa: F821 - B: T.Tensor[(M,), dtype_B], # noqa: F821 + A: T.Tensor[(M,), dtype_A], # noqa: F821 + B: T.Tensor[(M,), dtype_B], # noqa: F821 ): with T.Kernel(1, threads=128): A_local = T.alloc_fragment((M,), dtype_A) @@ -73,8 +73,7 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str, code = kernel.get_kernel_source() code_parallel = kernel_parallel.get_kernel_source() - assert check_str in code and check_str in code_parallel, \ - f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!" + assert check_str in code and check_str in code_parallel, f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!" def test_vectorized_cast(): diff --git a/testing/python/language/test_tilelang_language_view.py b/testing/python/language/test_tilelang_language_view.py index a79d428bd4936b211d0cca1abe805be3bda019c8..ff050e312900857b89d7a566464be0a50e2cca94 100644 --- a/testing/python/language/test_tilelang_language_view.py +++ b/testing/python/language/test_tilelang_language_view.py @@ -10,6 +10,7 @@ def view_test(N, M, dtype, new_dtype=None): new_shape = [N // M, M] if new_dtype: from tvm import DataType + dtype_src = DataType(dtype) dtype_dst = DataType(new_dtype) src_bits = dtype_src.bits @@ -19,8 +20,8 @@ def view_test(N, M, dtype, new_dtype=None): @T.prim_func def main( - A: T.Tensor((N,), dtype), - B: T.Tensor(new_shape, new_dtype if new_dtype else dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor(new_shape, new_dtype if new_dtype else dtype), ): with T.Kernel(1) as _: A_viewed = T.view(A, new_shape, dtype=new_dtype) @@ -37,6 +38,7 @@ def run_view(N, M, dtype, new_dtype=None): def ref_program(A): if new_dtype: from tilelang.utils.tensor import map_torch_type + torch_dtype = map_torch_type(new_dtype) return A.view(N // M, M).view(dtype=torch_dtype) return A.view(N // M, M) @@ -45,7 +47,6 @@ def run_view(N, M, dtype, new_dtype=None): def test_reshape_view(): - # Test view with same dtype run_view(1024, 32, "float32") run_view(2048, 64, "float16") @@ -61,6 +62,7 @@ def view_shape_mismatch_test(N, M, dtype, new_dtype=None): new_shape = [N // M, M + 1] if new_dtype: from tvm import DataType + dtype_src = DataType(dtype) dtype_dst = DataType(new_dtype) src_bits = dtype_src.bits @@ -70,8 +72,8 @@ def view_shape_mismatch_test(N, M, dtype, new_dtype=None): @T.prim_func def main( - A: T.Tensor((N,), dtype), - B: T.Tensor(new_shape, new_dtype if new_dtype else dtype), + A: T.Tensor((N,), dtype), + B: T.Tensor(new_shape, new_dtype if new_dtype else dtype), ): with T.Kernel(1) as _: A_viewed = T.view(A, new_shape, dtype=new_dtype) diff --git a/testing/python/language/test_tilelang_language_warp_reduce.py b/testing/python/language/test_tilelang_language_warp_reduce.py index 681b23470872ce2f8d558bb746631918afb12dbd..0a0fb70bb22339e4eacb66aea99afee81f406167 100644 --- a/testing/python/language/test_tilelang_language_warp_reduce.py +++ b/testing/python/language/test_tilelang_language_warp_reduce.py @@ -7,7 +7,6 @@ import tilelang.language as T @tilelang.jit def get_kernel(reduce_op: str, dtype: str): - assert reduce_op in ["sum", "max", "min", "bitand", "bitor"] @T.prim_func @@ -33,16 +32,16 @@ def get_kernel(reduce_op: str, dtype: str): def test_warp_reduce_sum(): - a = torch.randn((32,), dtype=torch.float32, device='cuda') - kernel = get_kernel('sum', 'float32') + a = torch.randn((32,), dtype=torch.float32, device="cuda") + kernel = get_kernel("sum", "float32") ref = torch.full_like(a, a.sum()) kernel(a) torch.testing.assert_close(a, ref) def test_warp_reduce_max(): - a = torch.randn((32,), dtype=torch.float32, device='cuda') - kernel = get_kernel("max", 'float32') + a = torch.randn((32,), dtype=torch.float32, device="cuda") + kernel = get_kernel("max", "float32") print(kernel.get_kernel_source()) ref = torch.full_like(a, a.max()) kernel(a) @@ -50,16 +49,16 @@ def test_warp_reduce_max(): def test_warp_reduce_min(): - a = torch.randn((32,), dtype=torch.float32, device='cuda') - kernel = get_kernel("min", 'float32') + a = torch.randn((32,), dtype=torch.float32, device="cuda") + kernel = get_kernel("min", "float32") ref = torch.full_like(a, a.min()) kernel(a) torch.testing.assert_close(a, ref) def test_warp_reduce_bitand(): - a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device='cuda') - kernel = get_kernel("bitand", 'int32') + a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device="cuda") + kernel = get_kernel("bitand", "int32") ref_val = a[0] for i in range(1, a.shape[0]): ref_val = ref_val & a[i] @@ -69,8 +68,8 @@ def test_warp_reduce_bitand(): def test_warp_reduce_bitor(): - a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device='cuda') - kernel = get_kernel("bitor", 'int32') + a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device="cuda") + kernel = get_kernel("bitor", "int32") ref_val = a[0] for i in range(1, a.shape[0]): ref_val = ref_val | a[i] diff --git a/testing/python/layout/test_tilelang_layout_fused_replicate.py b/testing/python/layout/test_tilelang_layout_fused_replicate.py index d67a87bc3bf165a128891a880fea68d0ae06544c..6d3c26820795762fe31a13566c1d87f89be9b9f2 100644 --- a/testing/python/layout/test_tilelang_layout_fused_replicate.py +++ b/testing/python/layout/test_tilelang_layout_fused_replicate.py @@ -12,17 +12,16 @@ VEC_SIZE = 32 @tilelang.jit def fused_index_kernel(B: int, M: int, N: int, BLOCK_MN: int, BLOCK_K: int): - @T.prim_func def main( - a: T.Buffer((B, M, N), "bfloat16"), - a_out: T.Buffer((B, M, N), "float32"), + a: T.Buffer((B, M, N), "bfloat16"), + a_out: T.Buffer((B, M, N), "float32"), ): with T.Kernel( - T.ceildiv(M, BLOCK_MN), - T.ceildiv(N, BLOCK_K), - B, - threads=128, + T.ceildiv(M, BLOCK_MN), + T.ceildiv(N, BLOCK_K), + B, + threads=128, ) as (pid_m, pid_n, pid_b): a_fp32_local = T.alloc_fragment((BLOCK_MN * BLOCK_K // VEC_SIZE, VEC_SIZE), "float32") offs_m = pid_m * BLOCK_MN diff --git a/testing/python/math/test_math_bitwise_reduce.py b/testing/python/math/test_math_bitwise_reduce.py index 9c22946692a510d202051b85ef51d721f63d9f3a..8d7f5a1ac15966027b8c2fdba9a2dab6f74ae8b9 100644 --- a/testing/python/math/test_math_bitwise_reduce.py +++ b/testing/python/math/test_math_bitwise_reduce.py @@ -19,12 +19,11 @@ def bitwise_reduce( func, clear=True, ): - @T.prim_func def reduce_func( - A: T.Tensor((M, N), "int32"), - B: T.Tensor((M), "int32"), - Output: T.Tensor((M), "int32"), + A: T.Tensor((M, N), "int32"), + B: T.Tensor((M), "int32"), + Output: T.Tensor((M), "int32"), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_N), "int32") @@ -64,7 +63,7 @@ def run_single_bitwise_reduce( row_pattern = (i & 0xF) << (i % 4) # 4-bit patterns shifted by row # Column-based pattern: different bit positions set based on column - col_pattern = (1 << (j % 31)) # Single bit set at different positions + col_pattern = 1 << (j % 31) # Single bit set at different positions # Combine patterns with XOR to create diverse bit distributions # Add some deterministic "noise" based on position @@ -76,7 +75,7 @@ def run_single_bitwise_reduce( if i % 4 == 0: a[i, j] &= ~(0x1 << (i // 4)) elif i % 2 == 0: - a[i, j] |= (0x1 << (i // 2)) + a[i, j] |= 0x1 << (i // 2) if name == "reduce_bitand": expected = torch.full((M,), -1, device="cuda", dtype=torch.int32) diff --git a/testing/python/math/test_math_fast_math.py b/testing/python/math/test_math_fast_math.py index c3b5d1b5288ef09252426869a5816268e991052d..7809983e8eb36d4ba231d4eecba5485b12b6ce0d 100644 --- a/testing/python/math/test_math_fast_math.py +++ b/testing/python/math/test_math_fast_math.py @@ -7,16 +7,16 @@ import re def get_mathop_lines(source, mathop_name): """Extract lines containing the mathop from CUDA source for debugging""" - lines = source.split('\n') + lines = source.split("\n") relevant_lines = [] for i, line in enumerate(lines): - if mathop_name in line and ('(' in line): + if mathop_name in line and ("(" in line): # Include some context start = max(0, i - 1) end = min(len(lines), i + 2) relevant_lines.extend([f"{j}: {lines[j]}" for j in range(start, end)]) relevant_lines.append("---") - return '\n'.join(relevant_lines[-10:]) # Show last 10 lines to avoid too much output + return "\n".join(relevant_lines[-10:]) # Show last 10 lines to avoid too much output def check_fastmath_usage(source, mathop_name, expect_fastmath=False): @@ -27,9 +27,7 @@ def check_fastmath_usage(source, mathop_name, expect_fastmath=False): fastmath_matches = re.findall(fastmath_pattern, source) non_fastmath_matches = re.findall(non_fastmath_pattern, source) - print( - f"Found {len(fastmath_matches)} fastmath calls, {len(non_fastmath_matches)} non-fastmath calls" - ) + print(f"Found {len(fastmath_matches)} fastmath calls, {len(non_fastmath_matches)} non-fastmath calls") if len(fastmath_matches) > 0: print(f"Fastmath calls found: {fastmath_matches}") if len(non_fastmath_matches) > 0: @@ -51,13 +49,7 @@ def check_non_fastmath_usage(source, mathop_name): check_fastmath_usage(source, mathop_name, expect_fastmath=False) -def run_single_arg_mathop_test(mathop_name, - mathop_func, - M=128, - N=128, - block_M=32, - block_N=32, - dtype="float32"): +def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"): """ Test single-argument mathops. T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath) @@ -65,13 +57,12 @@ def run_single_arg_mathop_test(mathop_name, @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): - B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, - bx * block_N + j]) + B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j]) # Test with FAST_MATH disabled kernel_no_fastmath = tilelang.compile( @@ -80,7 +71,8 @@ def run_single_arg_mathop_test(mathop_name, target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - }) + }, + ) source_no_fastmath = kernel_no_fastmath.get_kernel_source() @@ -93,28 +85,22 @@ def run_single_arg_mathop_test(mathop_name, print(f"✓ {mathop_name} compilation and execution test passed") -def run_two_arg_mathop_test(mathop_name, - mathop_func, - M=128, - N=128, - block_M=32, - block_N=32, - dtype="float32"): +def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"): """ Test two-argument mathops to ensure they generate non-fastmath CUDA code. """ @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, - bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j], - B[by * block_M + i, bx * block_N + j]) + C[by * block_M + i, bx * block_N + j] = mathop_func( + A[by * block_M + i, bx * block_N + j], B[by * block_M + i, bx * block_N + j] + ) # Test with FAST_MATH disabled kernel_no_fastmath = tilelang.compile( @@ -123,7 +109,8 @@ def run_two_arg_mathop_test(mathop_name, target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - }) + }, + ) # Test with FAST_MATH enabled kernel_fastmath = tilelang.compile( @@ -132,7 +119,8 @@ def run_two_arg_mathop_test(mathop_name, target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, + ) source_no_fastmath = kernel_no_fastmath.get_kernel_source() source_fastmath = kernel_fastmath.get_kernel_source() @@ -171,8 +159,8 @@ def run_abs_test(): @T.prim_func def main( - A: T.Tensor((M, N), "float32"), - B: T.Tensor((M, N), "float32"), + A: T.Tensor((M, N), "float32"), + B: T.Tensor((M, N), "float32"), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): @@ -184,7 +172,8 @@ def run_abs_test(): target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - }) + }, + ) source = kernel.get_kernel_source() print("\n=== Testing abs (maps to fabs) ===") @@ -199,26 +188,19 @@ def run_abs_test(): print("✓ abs numerical test passed") -def run_fastmath_mathop_test(mathop_name, - mathop_func, - M=128, - N=128, - block_M=32, - block_N=32, - dtype="float32"): +def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"): """ Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix). """ @T.prim_func def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): - B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, - bx * block_N + j]) + B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j]) # Test with FAST_MATH enabled kernel_fastmath = tilelang.compile( @@ -227,14 +209,15 @@ def run_fastmath_mathop_test(mathop_name, target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }) + }, + ) source_fastmath = kernel_fastmath.get_kernel_source() print(f"\n=== Testing {mathop_name} (fastmath version) ===") print("FAST_MATH=True:") # Strip the __ prefix for checking in the CUDA source - cuda_mathop_name = mathop_name.lstrip('_') + cuda_mathop_name = mathop_name.lstrip("_") check_fastmath_usage(source_fastmath, cuda_mathop_name, expect_fastmath=True) # Test numerical correctness diff --git a/testing/python/math/test_math_ieee_math.py b/testing/python/math/test_math_ieee_math.py index 0b04e3bab1b8962b0eca2dbaa8033bc706c64ead..193092ec771a7ae294cf4c2c7293fba114fa1bef 100644 --- a/testing/python/math/test_math_ieee_math.py +++ b/testing/python/math/test_math_ieee_math.py @@ -5,14 +5,7 @@ import tilelang.testing import pytest -def run_ieee_math_test(mathop_name, - mathop_func, - rounding_mode="rn", - M=128, - N=128, - block_M=32, - block_N=32, - dtype="float32"): +def run_ieee_math_test(mathop_name, mathop_func, rounding_mode="rn", M=128, N=128, block_M=32, block_N=32, dtype="float32"): """ Test IEEE-compliant math operations with specified rounding modes. """ @@ -22,18 +15,19 @@ def run_ieee_math_test(mathop_name, @T.prim_func def main_func( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), - C: T.Tensor((M, N), dtype), - D: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), + D: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): - D[by * block_M + i, - bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j], - B[by * block_M + i, bx * block_N + j], - C[by * block_M + i, - bx * block_N + j], rounding_mode) + D[by * block_M + i, bx * block_N + j] = mathop_func( + A[by * block_M + i, bx * block_N + j], + B[by * block_M + i, bx * block_N + j], + C[by * block_M + i, bx * block_N + j], + rounding_mode, + ) out_idx = [3] num_inputs = 3 @@ -41,16 +35,15 @@ def run_ieee_math_test(mathop_name, @T.prim_func def main_func( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, - bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j], - B[by * block_M + i, - bx * block_N + j], rounding_mode) + C[by * block_M + i, bx * block_N + j] = mathop_func( + A[by * block_M + i, bx * block_N + j], B[by * block_M + i, bx * block_N + j], rounding_mode + ) out_idx = [2] num_inputs = 2 @@ -58,14 +51,12 @@ def run_ieee_math_test(mathop_name, @T.prim_func def main_func( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M, N), dtype), + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): - B[by * block_M + i, - bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j], - rounding_mode) + B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j], rounding_mode) out_idx = [1] num_inputs = 1 @@ -77,7 +68,8 @@ def run_ieee_math_test(mathop_name, target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - }) + }, + ) print(f"\n=== Testing {mathop_name} with rounding mode {rounding_mode} ===") print(f"✓ {mathop_name} compilation test passed") @@ -194,8 +186,8 @@ def test_ieee_frsqrt_rn_only(): @T.prim_func def main( - A: T.Tensor((128, 128), "float32"), - B: T.Tensor((128, 128), "float32"), + A: T.Tensor((128, 128), "float32"), + B: T.Tensor((128, 128), "float32"), ): with T.Kernel(T.ceildiv(128, 32), T.ceildiv(128, 32), threads=128) as (bx, by): for i, j in T.Parallel(32, 32): @@ -207,7 +199,8 @@ def test_ieee_frsqrt_rn_only(): target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, - }) + }, + ) print("\n=== Testing ieee_frsqrt (rn only) ===") print("✓ ieee_frsqrt compilation test passed") diff --git a/testing/python/metal/test_metal_codegen.py b/testing/python/metal/test_metal_codegen.py index 22f4beb89b8fe26a58d3f467504e6e84ae935196..ea088aea9c85468383637588e5dcd9b572ed525c 100644 --- a/testing/python/metal/test_metal_codegen.py +++ b/testing/python/metal/test_metal_codegen.py @@ -5,18 +5,17 @@ import tilelang.language as T import torch -@tilelang.jit(execution_backend='torch') +@tilelang.jit(execution_backend="torch") def matmul(M, N, K, block_M, block_N, block_K, dtype="float32", accum_dtype="float"): - @T.prim_func def gemm( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), dtype, scope='shared') - B_shared = T.alloc_shared((block_K, block_N), dtype, scope='shared') + A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared") + B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared") C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) @@ -48,13 +47,13 @@ def assert_gemm( torch_dtype = getattr(torch, dtype) a, b = None, None - if 'int' in dtype: - a = torch.randint(100, (M, K), dtype=torch_dtype, device='mps') - b = torch.randint(100, (K, N), dtype=torch_dtype, device='mps') + if "int" in dtype: + a = torch.randint(100, (M, K), dtype=torch_dtype, device="mps") + b = torch.randint(100, (K, N), dtype=torch_dtype, device="mps") else: - a = torch.randn(M, K, dtype=torch_dtype, device='mps') - b = torch.randn(K, N, dtype=torch_dtype, device='mps') - c = torch.zeros(M, N, dtype=torch_dtype, device='mps') + a = torch.randn(M, K, dtype=torch_dtype, device="mps") + b = torch.randn(K, N, dtype=torch_dtype, device="mps") + c = torch.zeros(M, N, dtype=torch_dtype, device="mps") jit_kernel(a, b, c) @@ -70,12 +69,12 @@ def test_gemm_float32(): @tilelang.testing.requires_metal def test_gemm_float16(): - assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype='float16', atol=1) + assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype="float16", atol=1) @tilelang.testing.requires_metal def test_gemm_int32(): - assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype='int32', atol=1) + assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype="int32", atol=1) if __name__ == "__main__": diff --git a/testing/python/primitives/test_tilelang_primitives_mma.py b/testing/python/primitives/test_tilelang_primitives_mma.py index fcda9878cd6a15ee046cef5db64d27299c416156..97ce323158c73a4bd8ec096ee1c9bd5ab19bf756 100644 --- a/testing/python/primitives/test_tilelang_primitives_mma.py +++ b/testing/python/primitives/test_tilelang_primitives_mma.py @@ -27,9 +27,9 @@ def matmul_ssr( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) @@ -88,7 +88,8 @@ def run_matmul_ssr( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = kernel.get_profiler() def ref_program(A, B): @@ -106,24 +107,9 @@ def run_matmul_ssr( def test_gemm_f16f16f16_nt_ssr(): - run_matmul_ssr( - 16, 16, 16, False, True, "float16", "float16", "float16", 16, 16, 16, 0, num_threads=32) - run_matmul_ssr( - 128, 128, 128, False, True, "float16", "float16", "float16", 32, 32, 32, 0, num_threads=64) - run_matmul_ssr( - 1024, - 1024, - 1024, - False, - True, - "float16", - "float16", - "float16", - 128, - 128, - 32, - 2, - num_threads=128) + run_matmul_ssr(16, 16, 16, False, True, "float16", "float16", "float16", 16, 16, 16, 0, num_threads=32) + run_matmul_ssr(128, 128, 128, False, True, "float16", "float16", "float16", 32, 32, 32, 0, num_threads=64) + run_matmul_ssr(1024, 1024, 1024, False, True, "float16", "float16", "float16", 128, 128, 32, 2, num_threads=128) def matmul_rsr( @@ -151,9 +137,9 @@ def matmul_rsr( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) @@ -214,7 +200,8 @@ def run_matmul_rsr( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = kernel.get_profiler() def ref_program(A, B): @@ -276,9 +263,9 @@ def matmul_rrr( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -342,7 +329,8 @@ def run_matmul_rrr( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = kernel.get_profiler() def ref_program(A, B): diff --git a/testing/python/profiler/test_tilelang_profiler.py b/testing/python/profiler/test_tilelang_profiler.py index ee46725b9b64568940dc43a386342241aafd1a27..8aa54708445ed722bdb4a5f2a03dac1554bea09c 100644 --- a/testing/python/profiler/test_tilelang_profiler.py +++ b/testing/python/profiler/test_tilelang_profiler.py @@ -4,12 +4,11 @@ import tilelang.language as T @tilelang.jit(out_idx=[-1]) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def gemm( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py index d984ad4bcdd910db47b708e8a71f2ef626d7e7b6..a13e4533e88349f208561a67968fce5800585209 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py @@ -27,9 +27,9 @@ def matmul( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -89,7 +89,8 @@ def run_gemm_ss( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) @@ -159,9 +160,9 @@ def matmul_rs( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -169,9 +170,11 @@ def matmul_rs( A_frag = T.alloc_fragment(A_frag_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) - T.annotate_layout({ - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - }) + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + } + ) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): if trans_A: T.copy(A[k * block_K, by * block_M], A_shared) @@ -225,7 +228,8 @@ def run_gemm_rs( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) def ref_program(A, B): @@ -294,9 +298,9 @@ def matmul_sr( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -304,9 +308,11 @@ def matmul_sr( B_frag = T.alloc_fragment(B_frag_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) - T.annotate_layout({ - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - }) + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + } + ) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): if trans_A: T.copy(A[k * block_K, by * block_M], A_shared) @@ -360,7 +366,8 @@ def run_gemm_sr( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) def ref_program(A, B): @@ -430,9 +437,9 @@ def matmul_rr( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -441,10 +448,12 @@ def matmul_rr( B_frag = T.alloc_fragment(B_frag_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) - T.annotate_layout({ - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - }) + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + } + ) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): if trans_A: T.copy(A[k * block_K, by * block_M], A_shared) @@ -499,7 +508,8 @@ def run_gemm_rr( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) def ref_program(A, B): diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py index cefe986a098a1372864a099238f54522c728e43d..4ced4f83773dfee2e0beef0cf85947cad6099389 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py @@ -20,27 +20,11 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype): low, high = (0, 4) if is_unsigned else (-2, 2) else: low, high = (0, 128) if is_unsigned else (-64, 64) - A = randint_semi_sparse( - M, - K, - low=low, - high=high, - dtype=map_torch_type(in_dtype), - device='cuda', - transposed=trans_A) - B = torch.randint( - size=(N, K) if trans_B else (K, N), - low=low, - high=high, - dtype=map_torch_type(in_dtype), - device='cuda') + A = randint_semi_sparse(M, K, low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda", transposed=trans_A) + B = torch.randint(size=(N, K) if trans_B else (K, N), low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda") else: - A = randn_semi_sparse( - M, K, dtype=torch.float32, device='cuda', - transposed=trans_A).to(map_torch_type(in_dtype)) - B = torch.randn( - (N, K) if trans_B else (K, N), device='cuda', - dtype=torch.float32).to(map_torch_type(in_dtype)) + A = randn_semi_sparse(M, K, dtype=torch.float32, device="cuda", transposed=trans_A).to(map_torch_type(in_dtype)) + B = torch.randn((N, K) if trans_B else (K, N), device="cuda", dtype=torch.float32).to(map_torch_type(in_dtype)) return A, B @@ -69,24 +53,22 @@ def matmul_sp_sm90( @T.prim_func def main( - A_sparse: T.Tensor(A_sparse_shape, in_dtype), - E: T.Tensor((M, K // E_factor), 'uint8'), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), "uint8"), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) - E_shared = T.alloc_shared((block_M, block_K // E_factor), 'uint8') + E_shared = T.alloc_shared((block_M, block_K // E_factor), "uint8") C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) - T.annotate_layout({ - E: - make_cutlass_metadata_layout( - E, mma_dtype=in_dtype, arch="9.0", block_k=block_K), - E_shared: - make_cutlass_metadata_layout( - E_shared, mma_dtype=in_dtype, arch="9.0", block_k=block_K), - }) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="9.0", block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="9.0", block_k=block_K), + } + ) T.disable_warp_group_reg_alloc() T.clear(C_frag) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): @@ -121,7 +103,7 @@ def matmul_sp_sm80( trans_B, ): is_8_bit = "8" in in_dtype - metadata_dtype = 'int32' if is_8_bit else 'int16' + metadata_dtype = "int32" if is_8_bit else "int16" E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype] A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) B_shape = (K, N) if not trans_B else (N, K) @@ -132,20 +114,22 @@ def matmul_sp_sm80( @T.prim_func def main( - A_sparse: T.Tensor(A_sparse_shape, in_dtype), - E: T.Tensor((M, K // E_factor), metadata_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) - T.annotate_layout({ - E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), - E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), - }) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), + } + ) T.clear(C_frag) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(E[by * block_M, k * block_K // E_factor], E_shared) @@ -216,7 +200,7 @@ def run_gemm_sp( C = _matmul(A, B) - if 'float8' in in_dtype: + if "float8" in in_dtype: diff = calc_diff(C_sp, C) assert diff < 1e-3, f"{diff=}" else: @@ -332,15 +316,11 @@ def test_gemm_sp_sm90(): run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 0, 128) run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 2, 128) - run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, - True) - run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, - False) - run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, - True) + run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True) + run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, False) + run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, True) - run_gemm_sp_sm90(512, 1024, 768, "float8_e4m3", "float16", "float16", 64, 64, 64, 2, 128, False, - True) + run_gemm_sp_sm90(512, 1024, 768, "float8_e4m3", "float16", "float16", 64, 64, 64, 2, 128, False, True) run_gemm_sp_sm90(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True) @@ -352,12 +332,9 @@ def test_gemm_sp_sm80(): run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32) run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128) - run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 32, 32, 64, 0, 32, False, - True) - run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False, - True) - run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, - True) + run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 32, 32, 64, 0, 32, False, True) + run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False, True) + run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True) run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 1, 128) run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128) diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py index a82c29f388ae0fb577272d7ff9f612920e5e5a67..276bce4d9730b3e2be6b22b4d3d55be075954942 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py @@ -34,20 +34,22 @@ def matmul( @T.prim_func def main( - A_sparse: T.Tensor(A_sparse_shape, in_dtype), - E: T.Tensor((M, K // E_factor), metadata_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) - T.annotate_layout({ - E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), - E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), - }) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), + } + ) T.clear(C_frag) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(E[by * block_M, k * block_K // E_factor], E_shared) @@ -80,7 +82,7 @@ def run_gemm_ss( num_stages=3, num_threads=128, ): - metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16' + metadata_dtype = "int32" if ("8" in in_dtype) else "int16" program = matmul( M, N, @@ -105,7 +107,8 @@ def run_gemm_ss( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") @@ -142,26 +145,11 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype): low, high = (0, 4) if is_unsigned else (-2, 2) else: low, high = (0, 128) if is_unsigned else (-64, 64) - A = randint_semi_sparse( - M, - K, - low=low, - high=high, - dtype=map_torch_type(in_dtype), - device='cuda', - transposed=trans_A) - B = torch.randint( - size=(N, K) if trans_B else (K, N), - low=low, - high=high, - dtype=map_torch_type(in_dtype), - device='cuda') + A = randint_semi_sparse(M, K, low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda", transposed=trans_A) + B = torch.randint(size=(N, K) if trans_B else (K, N), low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda") else: - A = randn_semi_sparse( - M, K, dtype=map_torch_type(in_dtype), device='cuda', transposed=trans_A) - B = torch.randn( - (N, K) if trans_B else (K, N), device='cuda', - dtype=torch.float32).to(map_torch_type(in_dtype)) + A = randn_semi_sparse(M, K, dtype=map_torch_type(in_dtype), device="cuda", transposed=trans_A) + B = torch.randn((N, K) if trans_B else (K, N), device="cuda", dtype=torch.float32).to(map_torch_type(in_dtype)) return A, B @@ -184,8 +172,7 @@ def test_gemm_ss(): run_gemm_ss(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2) # float8 tests - run_gemm_ss(128, 128, 128, False, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, - 2) + run_gemm_ss(128, 128, 128, False, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2) run_gemm_ss(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2) # tfloat32 test @@ -222,10 +209,10 @@ def matmul_rs( @T.prim_func def main( - A_sparse: T.Tensor(A_sparse_shape, in_dtype), - E: T.Tensor((M, K // E_factor), metadata_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -233,11 +220,13 @@ def matmul_rs( E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) A_frag = T.alloc_fragment(A_frag_shape, in_dtype) C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) - T.annotate_layout({ - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), - E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), - }) + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), + } + ) T.clear(C_frag) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(E[by * block_M, k * block_K // E_factor], E_shared) @@ -271,7 +260,7 @@ def run_gemm_rs( num_stages=3, num_threads=128, ): - metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16' + metadata_dtype = "int32" if ("8" in in_dtype) else "int16" program = matmul_rs( M, N, @@ -296,7 +285,8 @@ def run_gemm_rs( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") C_sp = kernel(A_sparse, E, B) @@ -376,10 +366,10 @@ def matmul_sr( @T.prim_func def main( - A_sparse: T.Tensor(A_sparse_shape, in_dtype), - E: T.Tensor((M, K // E_factor), metadata_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -387,11 +377,13 @@ def matmul_sr( E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) B_frag = T.alloc_fragment(B_frag_shape, in_dtype) C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) - T.annotate_layout({ - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), - E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), - }) + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), + } + ) T.clear(C_frag) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(E[by * block_M, k * block_K // E_factor], E_shared) @@ -425,7 +417,7 @@ def run_gemm_sr( num_stages=3, num_threads=128, ): - metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16' + metadata_dtype = "int32" if ("8" in in_dtype) else "int16" program = matmul_sr( M, N, @@ -450,7 +442,8 @@ def run_gemm_sr( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") C_sp = kernel(A_sparse, E, B) @@ -531,10 +524,10 @@ def matmul_rr( @T.prim_func def main( - A_sparse: T.Tensor(A_sparse_shape, in_dtype), - E: T.Tensor((M, K // E_factor), metadata_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) @@ -543,12 +536,14 @@ def matmul_rr( A_frag = T.alloc_fragment(A_frag_shape, in_dtype) B_frag = T.alloc_fragment(B_frag_shape, in_dtype) C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) - T.annotate_layout({ - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), - E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), - }) + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), + } + ) T.clear(C_frag) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(E[by * block_M, k * block_K // E_factor], E_shared) @@ -583,7 +578,7 @@ def run_gemm_rr( num_stages=3, num_threads=128, ): - metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16' + metadata_dtype = "int32" if ("8" in in_dtype) else "int16" program = matmul_rr( M, N, @@ -608,7 +603,8 @@ def run_gemm_rr( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + }, + ) A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") C_sp = kernel(A_sparse, E, B) diff --git a/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py b/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py index 7cb1b55174688b36ae06e5f97c91175ae22debc0..d3f45c5ebef3ae93972c155b4d5548715864b542 100644 --- a/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py +++ b/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py @@ -11,22 +11,14 @@ def _check(original, transformed): mod = tl.transform.Simplify()(mod) mod = tl.transform.LowerOpaqueBlock()(mod) mod = tl.transform.Simplify()(mod) - tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), - True) + tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), True) def test_trival_pipeline(): - @T.prim_func def before(A: T.Tensor((16, 1), "float32"), C: T.Tensor((16, 1), "float32")): for tx in T.thread_binding(0, 16, thread="threadIdx.x"): - for i in T.serial( - 0, - 1, - annotations={ - "software_pipeline_stage": [0, 1], - "software_pipeline_order": [0, 1] - }): + for i in T.serial(0, 1, annotations={"software_pipeline_stage": [0, 1], "software_pipeline_order": [0, 1]}): with T.block(): T.reads(A[tx, i]) T.writes(C[tx, i]) diff --git a/testing/python/transform/test_tilelang_transform_cluster_planning.py b/testing/python/transform/test_tilelang_transform_cluster_planning.py index 8029305aec94cd4f787940323b626fdd953193ba..2ec6321e8619a6edb54be35f2a60d40e2434a96e 100644 --- a/testing/python/transform/test_tilelang_transform_cluster_planning.py +++ b/testing/python/transform/test_tilelang_transform_cluster_planning.py @@ -21,10 +21,8 @@ def _check(original, transformed): def test_cluster_planning(): - @T.prim_func - def before(A: T.Tensor((1024, 32), "float16"), B: T.Tensor((32, 1024), "float16"), C: T.Tensor( - (1024, 1024), "float16")): + def before(A: T.Tensor((1024, 32), "float16"), B: T.Tensor((32, 1024), "float16"), C: T.Tensor((1024, 1024), "float16")): with T.Kernel(8, 8, threads=128) as (bx, by): A_shared = T.alloc_shared((128, 32), "float16") B_shared = T.alloc_shared((32, 128), "float16") @@ -41,8 +39,7 @@ def test_cluster_planning(): T.copy(C_local, C[by * 128, bx * 128]) @T.prim_func - def after(A: T.Tensor((1024, 32), "float16"), B: T.Tensor((32, 1024), "float16"), C: T.Tensor( - (1024, 1024), "float16")): + def after(A: T.Tensor((1024, 32), "float16"), B: T.Tensor((32, 1024), "float16"), C: T.Tensor((1024, 1024), "float16")): T.func_attr({"clusterIdx.y": T.int32(2)}) with T.Kernel(8, 8, threads=128) as (bx, by): A_shared = T.alloc_shared((128, 32), "float16") diff --git a/testing/python/transform/test_tilelang_transform_config_index_bitwidth.py b/testing/python/transform/test_tilelang_transform_config_index_bitwidth.py index 1ef1589a7f8df6055d46119346e6c161f036cdcd..339b283e0628afc0cad3191cf690d7cd0b02196e 100644 --- a/testing/python/transform/test_tilelang_transform_config_index_bitwidth.py +++ b/testing/python/transform/test_tilelang_transform_config_index_bitwidth.py @@ -9,7 +9,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) block_N = 64 num_stages = 0 threads = 128 - scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) batch = T.int32(batch) heads = T.int32(heads) @@ -24,7 +24,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) block_mask_dtype = "bool" def kernel_func(block_M, block_N, num_stages, threads): - @T.macro def MMA0( K: T.Tensor(shape, dtype), @@ -36,37 +35,36 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) by: T.int32, bz: T.int32, ): - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro def MMA1( - V: T.Tensor(shape, dtype), - V_shared: T.Tensor([block_M, dim], dtype), - acc_s_cast: T.Tensor([block_M, block_N], dtype), - acc_o: T.Tensor([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, + V: T.Tensor(shape, dtype), + V_shared: T.Tensor([block_M, dim], dtype), + acc_s_cast: T.Tensor([block_M, block_N], dtype), + acc_o: T.Tensor([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @T.macro def Softmax( - acc_s: T.Tensor([block_M, block_N], accum_dtype), - acc_s_cast: T.Tensor([block_M, block_N], dtype), - scores_max: T.Tensor([block_M], accum_dtype), - scores_max_prev: T.Tensor([block_M], accum_dtype), - scores_scale: T.Tensor([block_M], accum_dtype), - scores_sum: T.Tensor([block_M], accum_dtype), - logsum: T.Tensor([block_M], accum_dtype), + acc_s: T.Tensor([block_M, block_N], accum_dtype), + acc_s_cast: T.Tensor([block_M, block_N], dtype), + scores_max: T.Tensor([block_M], accum_dtype), + scores_max_prev: T.Tensor([block_M], accum_dtype), + scores_scale: T.Tensor([block_M], accum_dtype), + scores_sum: T.Tensor([block_M], accum_dtype), + logsum: T.Tensor([block_M], accum_dtype), ): T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -92,22 +90,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) @T.macro def Rescale( - acc_o: T.Tensor([block_M, dim], accum_dtype), - scores_scale: T.Tensor([block_M], accum_dtype), + acc_o: T.Tensor([block_M, dim], accum_dtype), + scores_scale: T.Tensor([block_M], accum_dtype), ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] @T.prim_func def main( - Q: T.Tensor(shape, dtype), - K: T.Tensor(shape, dtype), - V: T.Tensor(shape, dtype), - BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), - Output: T.Tensor(shape, dtype), + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), + Output: T.Tensor(shape, dtype), ): - with T.Kernel( - T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -122,7 +119,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) logsum = T.alloc_fragment([block_M], accum_dtype) block_mask = T.alloc_local([downsample_len], block_mask_dtype) - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -131,19 +128,18 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) block_mask[vj] = BlockSparseMask[bz, by, bx, vj] loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, - scores_sum, logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) return main diff --git a/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py b/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py index 2859821cafa694067c9b13ac92862a5df0c9996b..854a2617280c1f0e46b6346d45d17ef5fa2aa15b 100644 --- a/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py +++ b/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py @@ -22,7 +22,6 @@ def _check(original, transformed): def test_lower_fence_proxy(): - @T.prim_func def before(): with T.Kernel(8): @@ -30,12 +29,15 @@ def test_lower_fence_proxy(): B_shared = T.decl_buffer((1, 4, 512), "float16", scope="shared.dyn") C_local = T.decl_buffer((32,), scope="local") for i in T.unroll(16): - C_local[i * 2:i * 2 + 2] = T.Broadcast(T.float32(0), 2) - T.call_intrin("handle", tir.op.Op.get("tl.tl_gemm"), - "tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>", - T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3)) + C_local[i * 2 : i * 2 + 2] = T.Broadcast(T.float32(0), 2) + T.call_intrin( + "handle", + tir.op.Op.get("tl.tl_gemm"), + "tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>", + T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), + ) @T.prim_func def after(): @@ -44,19 +46,21 @@ def test_lower_fence_proxy(): B_shared = T.decl_buffer((1, 4, 512), "float16", scope="shared.dyn") C_local = T.decl_buffer((32,), scope="local") for i in T.unroll(16): - C_local[i * 2:i * 2 + 2] = T.Broadcast(T.float32(0), 2) + C_local[i * 2 : i * 2 + 2] = T.Broadcast(T.float32(0), 2) T.fence_proxy_async() - T.call_intrin("handle", tir.op.Op.get("tl.tl_gemm"), - "tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>", - T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3)) + T.call_intrin( + "handle", + tir.op.Op.get("tl.tl_gemm"), + "tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>", + T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), + ) _check(before, after) def test_async_to_generic_no_double_fence(): - @T.prim_func def before(): with T.Kernel(8): @@ -90,7 +94,6 @@ def test_async_to_generic_no_double_fence(): def test_proxy_hint_override(): - @T.prim_func def before(): with T.Kernel(8): @@ -123,7 +126,6 @@ def test_proxy_hint_override(): def test_tma_store_sync_injection(): - @T.prim_func def before(): with T.Kernel(8): @@ -154,7 +156,6 @@ def test_tma_store_sync_injection(): def test_wgmma_marked_async(): - @T.prim_func def before(): with T.Kernel(1): @@ -164,9 +165,24 @@ def test_wgmma_marked_async(): C_local = T.decl_buffer((32,), "float16", scope="local") A_shared[0] = T.float16(0) T.warpgroup_arrive() - T.ptx_wgmma_ss("float16", "m64n64k16", T.bool(True), T.bool(True), "fp16", "fp16", - "fp16", desc_a.data, T.int32(0), desc_b.data, T.int32(0), C_local.data, - T.int32(0), T.bool(True), 1, 1) + T.ptx_wgmma_ss( + "float16", + "m64n64k16", + T.bool(True), + T.bool(True), + "fp16", + "fp16", + "fp16", + desc_a.data, + T.int32(0), + desc_b.data, + T.int32(0), + C_local.data, + T.int32(0), + T.bool(True), + 1, + 1, + ) mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) mod = tvm.tir.transform.BindTarget(auto_target)(mod) diff --git a/testing/python/transform/test_tilelang_transform_inject_set_max_nreg.py b/testing/python/transform/test_tilelang_transform_inject_set_max_nreg.py index 95cbf2db5d4eebea0294c7851229547fe60b42ab..0cc79b92fd8f33b5050eba9f59902b5eee484d61 100644 --- a/testing/python/transform/test_tilelang_transform_inject_set_max_nreg.py +++ b/testing/python/transform/test_tilelang_transform_inject_set_max_nreg.py @@ -35,26 +35,25 @@ def test_inject_set_max_nreg(): T.mbarrier_wait_parity(T.get_mbarrier(k % 3 + 3), T.bitwise_xor(k // 3 % 2, 1)) if v - 128 == 0: T.tma_load( - T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, - 0, 2, 2, 0), T.get_mbarrier(k % 3), - T.tvm_access_ptr( - T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), - k * 32, by * 64) - T.evaluate( - tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3)])) + T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), + T.get_mbarrier(k % 3), + T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), + k * 32, + by * 64, + ) + T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3)])) else: # Consumer branch - should have set_max_nreg(240, 1) for k in range(16): T.mbarrier_wait_parity(T.get_mbarrier(k % 3), k // 3 % 2) T.call_extern( - "handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", - T.tvm_access_ptr( - T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr( - T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3)) - T.evaluate( - tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)])) + "handle", + "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", + T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), + ) + T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)])) # Apply the InjectSetMaxNReg pass func = before @@ -67,15 +66,18 @@ def test_inject_set_max_nreg(): set_max_nreg_calls = [] def collect_set_max_nreg(stmt): - if (isinstance(stmt, tvm.tir.Evaluate) and hasattr(stmt.value, 'op') and - hasattr(stmt.value.op, 'name') and stmt.value.op.name == "tl.set_max_nreg"): + if ( + isinstance(stmt, tvm.tir.Evaluate) + and hasattr(stmt.value, "op") + and hasattr(stmt.value.op, "name") + and stmt.value.op.name == "tl.set_max_nreg" + ): set_max_nreg_calls.append(stmt.value) tvm.tir.stmt_functor.post_order_visit(main_func.body, collect_set_max_nreg) # We should have at least 2 set_max_nreg calls (one for producer, one for consumer) - assert len(set_max_nreg_calls - ) >= 2, f"Expected at least 2 set_max_nreg calls, got {len(set_max_nreg_calls)}" + assert len(set_max_nreg_calls) >= 2, f"Expected at least 2 set_max_nreg calls, got {len(set_max_nreg_calls)}" print("InjectSetMaxNReg test passed!") @@ -116,16 +118,18 @@ def test_inject_set_max_nreg_no_set_max_nreg(): set_max_nreg_calls = [] def collect_set_max_nreg(stmt): - if (isinstance(stmt, tvm.tir.Evaluate) and hasattr(stmt.value, 'op') and - hasattr(stmt.value.op, 'name') and stmt.value.op.name == "tl.set_max_nreg"): + if ( + isinstance(stmt, tvm.tir.Evaluate) + and hasattr(stmt.value, "op") + and hasattr(stmt.value.op, "name") + and stmt.value.op.name == "tl.set_max_nreg" + ): set_max_nreg_calls.append(stmt.value) tvm.tir.stmt_functor.post_order_visit(main_func.body, collect_set_max_nreg) # Should have no set_max_nreg calls when no_set_max_nreg is present - assert len( - set_max_nreg_calls - ) == 0, f"Expected 0 set_max_nreg calls when no_set_max_nreg is present, got {len(set_max_nreg_calls)}" + assert len(set_max_nreg_calls) == 0, f"Expected 0 set_max_nreg calls when no_set_max_nreg is present, got {len(set_max_nreg_calls)}" print("InjectSetMaxNReg with no_set_max_nreg test passed!") diff --git a/testing/python/transform/test_tilelang_transform_layout_inference.py b/testing/python/transform/test_tilelang_transform_layout_inference.py index 66415aacb950db5059d15c48aead5788019ccfe5..270dd31ee772272b93ba2963cca5e9c0af3f0bf4 100644 --- a/testing/python/transform/test_tilelang_transform_layout_inference.py +++ b/testing/python/transform/test_tilelang_transform_layout_inference.py @@ -8,17 +8,21 @@ import pytest auto_target = tvm.target.Target(determine_target("auto")) -@pytest.mark.parametrize("block_M, block_N, block_K, threads, vec_load_b, dtype", [ - (64, 64, 32, 128, 8, "float16"), -]) +@pytest.mark.parametrize( + "block_M, block_N, block_K, threads, vec_load_b, dtype", + [ + (64, 64, 32, 128, 8, "float16"), + ], +) def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): N = tvm.te.var("n") K = tvm.te.var("k") def before(): - @T.prim_func - def main(B: T.Tensor((K, N), dtype),): + def main( + B: T.Tensor((K, N), dtype), + ): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): B_shared = T.alloc_shared((block_K, block_N), dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") @@ -26,58 +30,62 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): t = thread_bindings for i in T.unroll(0, block_N * block_K // (threads * vec_load_b)): for vec in T.Parallel(vec_load_b): - B_shared[i * (threads * vec_load_b // block_N) + t // - (block_N // vec_load_b), t % (block_N // vec_load_b) * - (block_N // vec_load_b) + vec] = T.if_then_else( - k * block_K + i * (threads * vec_load_b // block_N) + t // - (block_N // vec_load_b) < K and bx * block_N + t % - (block_N // vec_load_b) * (block_N // vec_load_b) < N, - B[k * block_K + i * (threads * vec_load_b // block_N) + - t // (block_N // vec_load_b), bx * block_N + t % - (block_N // vec_load_b) * (block_N // vec_load_b) + vec], - T.float16(0)) + B_shared[ + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ] = T.if_then_else( + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b) < K + and bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) < N, + B[ + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ], + T.float16(0), + ) - return tvm.IRModule({'main': main}) + return tvm.IRModule({"main": main}) def after(): - @T.prim_func - def main(B: T.Tensor((K, N), dtype),): + def main( + B: T.Tensor((K, N), dtype), + ): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): B_shared = T.alloc_shared((block_K, block_N), dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): t = thread_bindings for i in T.unroll(0, block_N * block_K // (threads * vec_load_b)): - if (k * block_K + i * (threads * vec_load_b // block_N) + t // - (block_N // vec_load_b)) * N % vec_load_b == 0: + if (k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b)) * N % vec_load_b == 0: for vec in T.vectorized(vec_load_b): - B_shared[i * (threads * vec_load_b // block_N) + t // - (block_N // vec_load_b), t % (block_N // vec_load_b) * - (block_N // vec_load_b) + vec] = T.if_then_else( - k * block_K + i * - (threads * vec_load_b // block_N) + t // - (block_N // vec_load_b) < K and bx * block_N + t % - (block_N // vec_load_b) * (block_N // vec_load_b) < N, - B[k * block_K + i * (threads * vec_load_b // block_N) + - t // (block_N // vec_load_b), - bx * block_N + t % (block_N // vec_load_b) * - (block_N // vec_load_b) + vec], T.float16(0)) + B_shared[ + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ] = T.if_then_else( + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b) < K + and bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) < N, + B[ + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ], + T.float16(0), + ) else: for vec in T.serial(vec_load_b): - B_shared[i * (threads * vec_load_b // block_N) + t // - (block_N // vec_load_b), t % (block_N // vec_load_b) * - (block_N // vec_load_b) + vec] = T.if_then_else( - k * block_K + i * - (threads * vec_load_b // block_N) + t // - (block_N // vec_load_b) < K and bx * block_N + t % - (block_N // vec_load_b) * (block_N // vec_load_b) < N, - B[k * block_K + i * (threads * vec_load_b // block_N) + - t // (block_N // vec_load_b), - bx * block_N + t % (block_N // vec_load_b) * - (block_N // vec_load_b) + vec], T.float16(0)) + B_shared[ + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ] = T.if_then_else( + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b) < K + and bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) < N, + B[ + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ], + T.float16(0), + ) - return tvm.IRModule({'main': main}) + return tvm.IRModule({"main": main}) with tvm.target.Target(auto_target): mod = tvm.tir.transform.BindTarget(auto_target)(before()) diff --git a/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py b/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py index 5202ab6475f7a384278d1249fb6b23553a480f65..35a85aaf0b41ad9f58b6402e2a9efdbc299a60a9 100644 --- a/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py +++ b/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py @@ -8,7 +8,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_off dtype = "float32" @T.prim_func - def main(A: T.Tensor((M, N), dtype=dtype),): + def main( + A: T.Tensor((M, N), dtype=dtype), + ): with T.Kernel(1, 1, threads=M) as (bx, by): A_shared = T.alloc_shared((M, N), dtype=dtype) tid = T.get_thread_binding() @@ -16,17 +18,18 @@ def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_off A_shared[tid, j] = A[tid + M_offset, j + N_offset] @T.prim_func - def expected(A: T.Tensor((M, N), dtype=dtype),): + def expected( + A: T.Tensor((M, N), dtype=dtype), + ): with T.Kernel(1, 1, threads=M) as (bx, by): A_shared = T.alloc_shared((M, N), dtype=dtype) tid = T.get_thread_binding() - T.reads(A[tid + M_offset, N_offset:N + N_offset]) + T.reads(A[tid + M_offset, N_offset : N + N_offset]) for j in T.serial(N): A_shared[tid, j] = T.if_then_else( - j + N_offset < N, - T.if_then_else(tid + M_offset < M, A[tid + M_offset, j + N_offset], - T.float32(0)), T.float32(0)) + j + N_offset < N, T.if_then_else(tid + M_offset < M, A[tid + M_offset, j + N_offset], T.float32(0)), T.float32(0) + ) return main, expected @@ -41,13 +44,13 @@ def assert_vectorize_access(M: int = 64, N: int = 64): def issue_1013_buggy_kernel(): # NOTE: This kernel is mainly to test some corner cases in boundary check - num_tokens = T.dynamic('num_tokens') + num_tokens = T.dynamic("num_tokens") num_threads = 128 @T.prim_func def main(x: T.Tensor((num_tokens,), dtype="int64")): with T.Kernel(1, threads=num_threads) as _: - count = T.alloc_var('int') + count = T.alloc_var("int") thread_idx = T.get_thread_binding() for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)): idx = thread_idx + i * num_threads @@ -59,24 +62,22 @@ def issue_1013_buggy_kernel(): @T.prim_func def expected(x: T.Tensor((num_tokens,), dtype="int64")): with T.Kernel(1, threads=num_threads) as _: - count = T.alloc_var('int') + count = T.alloc_var("int") thread_idx = T.get_thread_binding() for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)): idx = thread_idx + i * num_threads - count += T.Cast("int32", - T.if_then_else(idx < num_tokens, x[idx], T.int64(0)) == T.int64(2)) + count += T.Cast("int32", T.if_then_else(idx < num_tokens, x[idx], T.int64(0)) == T.int64(2)) return main, expected -def vectorize_access_with_atmoic_add_legalize(M: int = 64, - N: int = 64, - M_offset: int = 2, - N_offset: int = 2): +def vectorize_access_with_atmoic_add_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2): dtype = "float32" @T.prim_func - def main(A: T.Tensor((M, N), dtype=dtype),): + def main( + A: T.Tensor((M, N), dtype=dtype), + ): with T.Kernel(1, 1, threads=M) as (bx, by): A_shared = T.alloc_shared((M, N), dtype=dtype) tid = T.get_thread_binding() @@ -85,17 +86,18 @@ def vectorize_access_with_atmoic_add_legalize(M: int = 64, T.atomic_add(A[tid + M_offset, j + N_offset], 1) @T.prim_func - def expected(A: T.Tensor((M, N), dtype=dtype),): + def expected( + A: T.Tensor((M, N), dtype=dtype), + ): with T.Kernel(1, 1, threads=M) as (bx, by): A_shared = T.alloc_shared((M, N), dtype=dtype) tid = T.get_thread_binding() - T.reads(A[tid + M_offset, N_offset:N + N_offset]) + T.reads(A[tid + M_offset, N_offset : N + N_offset]) for j in T.serial(N): A_shared[tid, j] = T.if_then_else( - j + N_offset < N, - T.if_then_else(tid + M_offset < M, A[tid + M_offset, j + N_offset], - T.float32(0)), T.float32(0)) + j + N_offset < N, T.if_then_else(tid + M_offset < M, A[tid + M_offset, j + N_offset], T.float32(0)), T.float32(0) + ) # Nest if-then-else is expected, do not flatten it to pass structural equal check if j + N_offset < N: # noqa: SIM102 if tid + M_offset < M: @@ -115,17 +117,21 @@ def oob_store_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: in dtype = "float32" @T.prim_func - def main(A: T.Tensor((M, N), dtype=dtype),): + def main( + A: T.Tensor((M, N), dtype=dtype), + ): with T.Kernel(1, 1, threads=M) as (bx, by): tid = T.get_thread_binding() for j in T.serial(N): A[tid + M_offset, j + N_offset] = 1 @T.prim_func - def expected(A: T.Tensor((M, N), dtype=dtype),): + def expected( + A: T.Tensor((M, N), dtype=dtype), + ): with T.Kernel(1, 1, threads=M) as (bx, by): tid = T.get_thread_binding() - T.writes(A[tid + M_offset, N_offset:N + N_offset]) + T.writes(A[tid + M_offset, N_offset : N + N_offset]) for j in T.serial(N): if j + N_offset < N: # noqa: SIM102 if tid + M_offset < M: diff --git a/testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py b/testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py index c95af8777a0ef411b4631f8e44bc0ae50c6213ab..ec570d4181c46a524f35668a0c0ee6cba1f6fc5a 100644 --- a/testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py +++ b/testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py @@ -9,7 +9,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64): vec_len = 8 @T.prim_func - def main(A: T.Tensor((M, N, vec_len), dtype="float32"),): + def main( + A: T.Tensor((M, N, vec_len), dtype="float32"), + ): with T.Kernel(1, 1, threads=M) as (bx, by): A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype) tid = T.get_thread_binding() @@ -18,7 +20,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64): A_shared[tid, j, v] = A[tid, j, v] @T.prim_func - def expected(A: T.Tensor((M, N, vec_len), dtype="float32"),): + def expected( + A: T.Tensor((M, N, vec_len), dtype="float32"), + ): with T.Kernel(1, 1, threads=M) as (bx, by): A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype) tid = T.get_thread_binding() diff --git a/testing/python/transform/test_tilelang_transform_let_inline.py b/testing/python/transform/test_tilelang_transform_let_inline.py index aa2638af10e678e5ec1f1d8c0f52ce4fa76df4af..6603ecab3bead44e76a85c24bef27ab20bab9fc1 100644 --- a/testing/python/transform/test_tilelang_transform_let_inline.py +++ b/testing/python/transform/test_tilelang_transform_let_inline.py @@ -8,12 +8,10 @@ def _check(original, transformed): func = original mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tl.transform.LetInline()(mod) - tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), - True) + tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), True) def test_let_binding(): - @T.prim_func def before(A: T.Tensor((128, 128), "float32"), B: T.Tensor((128, 128), "float32")): for i in range(128): @@ -34,7 +32,6 @@ def test_let_binding(): def test_parallel_scope(): - @T.prim_func def before(A: T.Tensor((128,), "float32")): for i in T.Parallel(128): diff --git a/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py b/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py index ca5042e0f4f45f683fa349bfd0005821a84d0fdb..f411b3d5b5779ec5b069deb5c6d9da4d1ba9ee26 100644 --- a/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py +++ b/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py @@ -24,7 +24,6 @@ def _check(original, transformed): def test_lower_hopper_intrin_barrier(): - @T.prim_func def before(): with T.Kernel(8): @@ -37,18 +36,10 @@ def test_lower_hopper_intrin_barrier(): v_1 = T.launch_thread("threadIdx.x", 128) T.evaluate(tir.Call("handle", "tir.create_barriers", [4])) with T.If(v_1 == 0), T.Then(): - T.evaluate( - tir.Call("handle", "tir.ptx_init_barrier_thread_count", - [T.get_mbarrier(0), 128])) - T.evaluate( - tir.Call("handle", "tir.ptx_init_barrier_thread_count", - [T.get_mbarrier(1), 128])) - T.evaluate( - tir.Call("handle", "tir.ptx_init_barrier_thread_count", - [T.get_mbarrier(2), 128])) - T.evaluate( - tir.Call("handle", "tir.ptx_init_barrier_thread_count", - [T.get_mbarrier(3), 128])) + T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.get_mbarrier(0), 128])) + T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.get_mbarrier(1), 128])) + T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.get_mbarrier(2), 128])) + T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.get_mbarrier(3), 128])) T.evaluate(tir.Call("handle", "tir.tvm_storage_sync", ["shared"])) _check(before, after) diff --git a/testing/python/transform/test_tilelang_transform_lower_tile_op.py b/testing/python/transform/test_tilelang_transform_lower_tile_op.py index 07dbd53f1f2162100cfadcea64676f2b9600a5a2..ac58418597890ad631ea59b63c7b0038f1970563 100644 --- a/testing/python/transform/test_tilelang_transform_lower_tile_op.py +++ b/testing/python/transform/test_tilelang_transform_lower_tile_op.py @@ -8,63 +8,69 @@ import pytest auto_target = tvm.target.Target(determine_target("auto")) -@pytest.mark.parametrize("block_M, block_N, block_K, threads, vec_load_b, dtype", [ - (64, 64, 32, 128, 8, "float16"), -]) +@pytest.mark.parametrize( + "block_M, block_N, block_K, threads, vec_load_b, dtype", + [ + (64, 64, 32, 128, 8, "float16"), + ], +) def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): N = tvm.te.var("n") K = tvm.te.var("k") def before(): - @T.prim_func - def main(B: T.Tensor((K, N), dtype),): + def main( + B: T.Tensor((K, N), dtype), + ): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): B_shared = T.alloc_shared((block_K, block_N), dtype) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): T.copy(B[k * block_K, bx * block_N], B_shared) - return tvm.IRModule({'main': main}) + return tvm.IRModule({"main": main}) def after(): - @T.prim_func - def main(B: T.Tensor((K, N), dtype),): + def main( + B: T.Tensor((K, N), dtype), + ): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): B_shared = T.alloc_shared((block_K, block_N), dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): t = thread_bindings for i in T.unroll(0, block_N * block_K // (threads * vec_load_b)): - if (k * block_K + i * (threads * vec_load_b // block_N) + t // - (block_N // vec_load_b)) * N % vec_load_b == 0: + if (k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b)) * N % vec_load_b == 0: for vec in T.vectorized(vec_load_b): - B_shared[i * (threads * vec_load_b // block_N) + t // - (block_N // vec_load_b), t % (block_N // vec_load_b) * - (block_N // vec_load_b) + vec] = T.if_then_else( - k * block_K + i * - (threads * vec_load_b // block_N) + t // - (block_N // vec_load_b) < K and bx * block_N + t % - (block_N // vec_load_b) * (block_N // vec_load_b) < N, - B[k * block_K + i * (threads * vec_load_b // block_N) + - t // (block_N // vec_load_b), - bx * block_N + t % (block_N // vec_load_b) * - (block_N // vec_load_b) + vec], T.float16(0)) + B_shared[ + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ] = T.if_then_else( + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b) < K + and bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) < N, + B[ + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ], + T.float16(0), + ) else: for vec in T.serial(vec_load_b): - B_shared[i * (threads * vec_load_b // block_N) + t // - (block_N // vec_load_b), t % (block_N // vec_load_b) * - (block_N // vec_load_b) + vec] = T.if_then_else( - k * block_K + i * - (threads * vec_load_b // block_N) + t // - (block_N // vec_load_b) < K and bx * block_N + t % - (block_N // vec_load_b) * (block_N // vec_load_b) < N, - B[k * block_K + i * (threads * vec_load_b // block_N) + - t // (block_N // vec_load_b), - bx * block_N + t % (block_N // vec_load_b) * - (block_N // vec_load_b) + vec], T.float16(0)) + B_shared[ + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ] = T.if_then_else( + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b) < K + and bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) < N, + B[ + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b), + bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec, + ], + T.float16(0), + ) - return tvm.IRModule({'main': main}) + return tvm.IRModule({"main": main}) with tvm.transform.PassContext(): mod = tvm.tir.transform.BindTarget(auto_target)(before()) diff --git a/testing/python/transform/test_tilelang_transform_make_packed_api.py b/testing/python/transform/test_tilelang_transform_make_packed_api.py index ff448732601cb38c0c2c4339c6539a5116ed5c94..2508a9d12e4c844152322246fbd5f72c57acf4f2 100644 --- a/testing/python/transform/test_tilelang_transform_make_packed_api.py +++ b/testing/python/transform/test_tilelang_transform_make_packed_api.py @@ -80,7 +80,6 @@ def test_target_host_removed(): @I.ir_module class before: - @T.prim_func def main(A: T.Buffer(1, "float32")): T.func_attr({"global_symbol": "main", "target": T.target("cuda", host=host)}) @@ -102,7 +101,6 @@ def test_internal_subroutine_call(): @I.ir_module class before: - @T.prim_func def main(A: T.Buffer(1, "float32")): T.func_attr({"target": T.target("llvm", host="llvm")}) @@ -121,7 +119,8 @@ def test_internal_subroutine_call(): subroutine_call_op = compute_scope.body.value.op assert isinstance(subroutine_call_op, tvm.ir.GlobalVar), ( f"The main function's CallNode should use the subroutine's GLobalVar as the operation, " - f"but instead has an operation of type {subroutine_call_op}") + f"but instead has an operation of type {subroutine_call_op}" + ) def test_subroutine_call_to_externally_visible_subroutine(): @@ -135,7 +134,6 @@ def test_subroutine_call_to_externally_visible_subroutine(): @I.ir_module class before: - @T.prim_func def main(A: T.Buffer(1, "float32")): T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")}) @@ -154,11 +152,10 @@ def test_subroutine_call_to_externally_visible_subroutine(): assert subroutine_compute_scope is not None subroutine_call_op = main_compute_scope.body.value.op - assert ( - isinstance(subroutine_call_op, tvm.ir.Op) and - subroutine_call_op.name == "tir.tvm_call_cpacked" - ), (f"The main function's CallNode should be lowered to the builtin 'tir.tvm_call_cpacked', " - f"but instead has an operation of type {subroutine_call_op}") + assert isinstance(subroutine_call_op, tvm.ir.Op) and subroutine_call_op.name == "tir.tvm_call_cpacked", ( + f"The main function's CallNode should be lowered to the builtin 'tir.tvm_call_cpacked', " + f"but instead has an operation of type {subroutine_call_op}" + ) @tilelang.testing.requires_llvm @@ -167,10 +164,10 @@ def test_function_call_with_wrong_argument_count(): @T.prim_func def func( - A: T.Buffer([16, 16], "int32"), - B: T.Buffer([16, 16], "int32"), - C: T.Buffer([16, 16], "int32"), - D: T.Buffer([16, 16], "int32"), + A: T.Buffer([16, 16], "int32"), + B: T.Buffer([16, 16], "int32"), + C: T.Buffer([16, 16], "int32"), + D: T.Buffer([16, 16], "int32"), ): pass diff --git a/testing/python/transform/test_tilelang_transform_multi_version_buffer.py b/testing/python/transform/test_tilelang_transform_multi_version_buffer.py index ddb7f66627f6077daaf8674c8fbe6346b939d532..0d56ab1a8d4fef4596bd039fa4d989f7b7ca5ae8 100644 --- a/testing/python/transform/test_tilelang_transform_multi_version_buffer.py +++ b/testing/python/transform/test_tilelang_transform_multi_version_buffer.py @@ -31,7 +31,6 @@ block_K = 32 def test_multi_version_buffer(): - @T.prim_func def before(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): bx = T.launch_thread("blockIdx.x", 8) @@ -49,21 +48,27 @@ def test_multi_version_buffer(): for k in T.serial(16, annotations={"num_stages": T.int32(3)}): if v == 0: T.tma_load( - T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, - 2, 0), 0, + T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), + 0, T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 2), - k * 32, by * 64) + k * 32, + by * 64, + ) if v == 0: T.tma_load( - T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, - 2, 0), 0, + T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0), + 0, T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 2), - bx * 64, k * 32) + bx * 64, + k * 32, + ) T.call_extern( - "handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", + "handle", + "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3)) + T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), + ) @T.prim_func def after(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): @@ -82,31 +87,32 @@ def test_multi_version_buffer(): for k in T.serial(16, annotations={"num_stages": T.int32(3)}): if v == 0: T.tma_load( - T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, - 2, 0), 0, - T.tvm_access_ptr( - T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), - k * 32, by * 64) + T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), + 0, + T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), + k * 32, + by * 64, + ) if v == 0: T.tma_load( - T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, - 2, 0), 0, - T.tvm_access_ptr( - T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2), - bx * 64, k * 32) + T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0), + 0, + T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2), + bx * 64, + k * 32, + ) T.call_extern( - "handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", - T.tvm_access_ptr( - T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr( - T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3)) + "handle", + "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", + T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), + ) _check(before, after) def test_multi_version_buffer_with_let(): - @T.prim_func def before(scales: T.Tensor((4,), "float32")): with T.block("root"): diff --git a/testing/python/transform/test_tilelang_transform_pipeline_planning.py b/testing/python/transform/test_tilelang_transform_pipeline_planning.py index b7448a2045e516e68549cbd9ac33e1cd2884a651..f38d6079e5699d8e896e04c21a8e33f1056316e6 100644 --- a/testing/python/transform/test_tilelang_transform_pipeline_planning.py +++ b/testing/python/transform/test_tilelang_transform_pipeline_planning.py @@ -19,10 +19,8 @@ def _check(original, transformed): def test_simple_pipeline(): - @T.prim_func - def before(A: T.Tensor((1024, 32), "float32"), B: T.Tensor((32, 1024), "float32"), C: T.Tensor( - (1024, 1024), "float32")): + def before(A: T.Tensor((1024, 32), "float32"), B: T.Tensor((32, 1024), "float32"), C: T.Tensor((1024, 1024), "float32")): with T.Kernel(8, 8, threads=128) as (bx, by): A_shared = T.alloc_shared((128, 32), "float32") B_shared = T.alloc_shared((32, 128), "float32") @@ -39,8 +37,7 @@ def test_simple_pipeline(): T.copy(C_local, C[by * 128, bx * 128]) @T.prim_func - def after(A: T.Tensor((1024, 32), "float32"), B: T.Tensor((32, 1024), "float32"), C: T.Tensor( - (1024, 1024), "float32")): + def after(A: T.Tensor((1024, 32), "float32"), B: T.Tensor((32, 1024), "float32"), C: T.Tensor((1024, 1024), "float32")): with T.Kernel(8, 8, threads=128) as (bx, by): A_shared = T.alloc_shared((128, 32), "float32") B_shared = T.alloc_shared((32, 128), "float32") @@ -49,14 +46,13 @@ def test_simple_pipeline(): T.clear(C_local) for ko in T.serial( - 32, - annotations={ - "software_pipeline_async_stages": [T.int32(0)], - "software_pipeline_order": [T.int32(0), T.int32(1), - T.int32(2)], - "software_pipeline_stage": [T.int32(3), T.int32(3), - T.int32(3)] - }): + 32, + annotations={ + "software_pipeline_async_stages": [T.int32(0)], + "software_pipeline_order": [T.int32(0), T.int32(1), T.int32(2)], + "software_pipeline_stage": [T.int32(3), T.int32(3), T.int32(3)], + }, + ): T.copy(A[by * 128, ko * 32], A_shared) T.copy(B[ko * 32, bx * 128], B_shared) T.gemm(A_shared, B_shared, C_local) diff --git a/testing/python/transform/test_tilelang_transform_simplify.py b/testing/python/transform/test_tilelang_transform_simplify.py index e1f4f9469ac8c5cee11e1578456871c60d0ebb00..657a2e8a401c2cdd25fb30e519fce2af68a5155e 100644 --- a/testing/python/transform/test_tilelang_transform_simplify.py +++ b/testing/python/transform/test_tilelang_transform_simplify.py @@ -8,14 +8,13 @@ def modify( with_B: bool = False, with_bias: bool = False, ): - @T.prim_func def main( - A: T.Tensor((64, 64)), - B: T.Tensor((64, 64)), - C: T.Tensor((64, 64)), - D: T.Tensor((64, 64)), - bias: T.Tensor((64, 64)), + A: T.Tensor((64, 64)), + B: T.Tensor((64, 64)), + C: T.Tensor((64, 64)), + D: T.Tensor((64, 64)), + bias: T.Tensor((64, 64)), ): if with_B: if with_bias: @@ -42,7 +41,6 @@ def test_modify(with_B=False, with_bias=False): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def main( a: T.handle, @@ -76,6 +74,7 @@ def test_matmul(): kernel = tl.compile(mod["main"], out_idx=[2]) import torch + a = torch.randn(1024, 1024, dtype=torch.float16).cuda().half() b = torch.randn(1024, 1024, dtype=torch.float16).cuda().half() c = kernel(a, b) diff --git a/testing/python/transform/test_tilelang_transform_thread_sync.py b/testing/python/transform/test_tilelang_transform_thread_sync.py index c0b705567262e54779ea61fd43b8d9e3b9dcab44..046ed447a49d0250ff29faac7ecd13c99e6be0ce 100644 --- a/testing/python/transform/test_tilelang_transform_thread_sync.py +++ b/testing/python/transform/test_tilelang_transform_thread_sync.py @@ -11,11 +11,7 @@ def run_passes(func: tvm.tir.PrimFunc): cuda_target = tvm.target.Target("cuda", host="llvm") - mod = tvm.tir.transform.Apply(lambda f: f.with_attr({ - "global_symbol": "test", - "target": cuda_target - }))( - mod) + mod = tvm.tir.transform.Apply(lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target}))(mod) mod = tvm.tir.transform.AnnotateDeviceRegions()(mod) mod = tvm.tir.transform.SplitHostDevice()(mod) @@ -24,7 +20,6 @@ def run_passes(func: tvm.tir.PrimFunc): @tilelang.testing.requires_cuda def test_sync_if_with_same_index(): - @T.prim_func(check_well_formed=False) def func(p0_arg: T.Buffer((1, 2, 1, 1), "float32"), p1: T.Buffer(2, "float32")) -> None: threadIdx_x = T.env_thread("threadIdx.x") @@ -47,7 +42,6 @@ def test_sync_if_with_same_index(): @tilelang.testing.requires_cuda def test_sync_read_thread_id_independent_location(): - @T.prim_func def func(p0_arg: T.Buffer((1, 2, 1, 1), "float32"), p1: T.Buffer(2, "float32")) -> None: threadIdx_x = T.env_thread("threadIdx.x") @@ -71,7 +65,6 @@ def test_sync_read_thread_id_independent_location(): @tilelang.testing.requires_cuda def test_sync_shared(): - @T.prim_func(private=True) def func(A: T.Buffer((4, 4), "float32"), E: T.Buffer((4, 4), "float32")): blockIdx_x = T.launch_thread("blockIdx.x", 1) @@ -113,7 +106,6 @@ def test_sync_shared(): @tvm.testing.requires_cuda def test_sync_let_stmt(): - @T.prim_func(private=True) def func(A: T.Buffer((16 * 512), "float32")): blockIdx_x = T.launch_thread("blockIdx.x", 16) @@ -136,9 +128,9 @@ def test_sync_let_stmt(): in_thread_A_temp_1[0] = A_temp cross_thread_A_temp_1 = T.Buffer((1,), data=cross_thread_A_temp, scope="local") with T.attr( - T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), - "reduce_scope", - T.reinterpret("handle", T.uint64(0)), + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), ): T.tvm_thread_allreduce( T.uint32(1), @@ -190,16 +182,19 @@ def test_sync_let_stmt(): @tilelang.testing.requires_cuda def test_sync_shared_dyn_stmatrix_loop_hoist(): - @T.prim_func def func(): buf_dyn_shmem = T.alloc_buffer((98304,), "uint8", scope="shared.dyn") tx = T.launch_thread("threadIdx.x", 384) for i in T.unroll(8): off = ( - i // 4 * 8192 + tx // 32 * 1024 + tx % 16 * 64 + - (tx % 8 // 4 + i % 4 // 2) % 2 * 32 + (tx % 4 // 2 + i % 2) % 2 * 16 + - (tx % 32 // 16 + tx % 2) % 2 * 8) + i // 4 * 8192 + + tx // 32 * 1024 + + tx % 16 * 64 + + (tx % 8 // 4 + i % 4 // 2) % 2 * 32 + + (tx % 4 // 2 + i % 2) % 2 * 16 + + (tx % 32 // 16 + tx % 2) % 2 * 8 + ) T.evaluate( T.call_intrin( "handle", @@ -214,7 +209,8 @@ def test_sync_shared_dyn_stmatrix_loop_hoist(): 2, ), T.int32(2), - )) + ) + ) mod = tvm.IRModule({"main": func}) mod = tilelang.transform.ThreadSync("shared.dyn")(mod) diff --git a/testing/python/transform/test_tilelang_transform_warp_specialized.py b/testing/python/transform/test_tilelang_transform_warp_specialized.py index 063ae2940c6c8cfbda125e059afc0e3904da08d8..2e101bf82256e82c1c4e982ab5787ba8f4d2c69b 100644 --- a/testing/python/transform/test_tilelang_transform_warp_specialized.py +++ b/testing/python/transform/test_tilelang_transform_warp_specialized.py @@ -32,7 +32,6 @@ block_K = 32 def test_warp_specialized(): - @T.prim_func def before(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): bx = T.launch_thread("blockIdx.x", 8) @@ -47,25 +46,27 @@ def test_warp_specialized(): for k in T.serial(16, annotations={"num_stages": T.int32(3)}): if v == 0: T.tma_load( - T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, - 2, 0), 0, - T.tvm_access_ptr( - T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), - k * 32, by * 64) + T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), + 0, + T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), + k * 32, + by * 64, + ) if v == 0: T.tma_load( - T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, - 2, 0), 0, - T.tvm_access_ptr( - T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2), - bx * 64, k * 32) + T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0), + 0, + T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2), + bx * 64, + k * 32, + ) T.call_extern( - "handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", - T.tvm_access_ptr( - T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr( - T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3)) + "handle", + "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", + T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), + ) @T.prim_func def after(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): @@ -85,34 +86,35 @@ def test_warp_specialized(): T.mbarrier_expect_tx(T.get_mbarrier(k % 3), 4096) if v - 128 == 0: T.tma_load( - T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, - 2, 0), T.get_mbarrier(k % 3), - T.tvm_access_ptr( - T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), - k * 32, by * 64) + T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), + T.get_mbarrier(k % 3), + T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), + k * 32, + by * 64, + ) if v - 128 == 0: T.mbarrier_expect_tx(T.get_mbarrier(k % 3), 4096) if v - 128 == 0: T.tma_load( - T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, - 2, 0), T.get_mbarrier(k % 3), - T.tvm_access_ptr( - T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2), - bx * 64, k * 32) + T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0), + T.get_mbarrier(k % 3), + T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2), + bx * 64, + k * 32, + ) T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3)])) else: T.set_max_nreg(240, 1) for k in range(16): T.mbarrier_wait_parity(T.get_mbarrier(k % 3), k // 3 % 2) T.call_extern( - "handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", - T.tvm_access_ptr( - T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr( - T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3)) - T.evaluate( - tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)])) + "handle", + "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", + T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), + ) + T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)])) _check(before, after) diff --git a/testing/python/utils/test_compress_utils.py b/testing/python/utils/test_compress_utils.py index 1ec4cace8e242b34a7abf7b22c1a4fe1891c4fad..e8fc20539eb9bf89a90d7645d9d5d31f6e3f6428 100644 --- a/testing/python/utils/test_compress_utils.py +++ b/testing/python/utils/test_compress_utils.py @@ -6,7 +6,7 @@ from tilelang.utils.sparse import compress_sm90, randn_semi_sparse def _test_compress_sm90(M, K, block_k, dtype): - A = randn_semi_sparse(M, K, dtype=dtype, device='cuda') + A = randn_semi_sparse(M, K, dtype=dtype, device="cuda") A_sparse, E = compress_sm90(A, block_k, False) diff --git a/testing/python/webgpu/test_webgpu_codegen.py b/testing/python/webgpu/test_webgpu_codegen.py index 0fe4f196d60631837c429a8dcf0f3752a84e2a31..ed1752796055b219a9bfdaac734469bfe85c6e64 100644 --- a/testing/python/webgpu/test_webgpu_codegen.py +++ b/testing/python/webgpu/test_webgpu_codegen.py @@ -5,12 +5,11 @@ import tilelang.language as T def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): diff --git a/tilelang/__init__.py b/tilelang/__init__.py index 0d8c21bac99adcd3a3c0bf4f720eb34511edba0f..1f2a4f40483489bd78ba3b11d11c062119cdfcad 100644 --- a/tilelang/__init__.py +++ b/tilelang/__init__.py @@ -23,6 +23,7 @@ def _compute_version() -> str: if version_file.is_file(): try: from version_provider import dynamic_metadata # type: ignore + return dynamic_metadata("version") except Exception: # Fall back to the raw VERSION file if provider isn't available. @@ -33,6 +34,7 @@ def _compute_version() -> str: try: from importlib.metadata import version as _dist_version # py3.8+ + return _dist_version("tilelang") except Exception as exc: warnings.warn( diff --git a/tilelang/analysis/fragment_loop_checker.py b/tilelang/analysis/fragment_loop_checker.py index 3186b23e7949810955b7daab24174f79e16a428c..94900a5cc6c9cc167e539a47c6bf4d3e70b5db21 100644 --- a/tilelang/analysis/fragment_loop_checker.py +++ b/tilelang/analysis/fragment_loop_checker.py @@ -1,6 +1,6 @@ from __future__ import annotations from tvm import tir -from tvm.tir import (PyStmtExprVisitor, BufferStore, For, Var, PrimFunc, BufferLoad, IntImm) +from tvm.tir import PyStmtExprVisitor, BufferStore, For, Var, PrimFunc, BufferLoad, IntImm from tvm.tir.transform import prim_func_pass from tvm.tir.stmt_functor import post_order_visit @@ -22,14 +22,14 @@ class _LoopVarUseAnalyzer(PyStmtExprVisitor): def collect_local_buffer_accesses(statement) -> list[BufferLoad | BufferStore]: """ - Collect local buffer accesses in the loop body. + Collect local buffer accesses in the loop body. - Args: - statement: The TIR statement to analyze + Args: + statement: The TIR statement to analyze - Returns: - Tuple of buffer accesses in the loop body. - """ + Returns: + Tuple of buffer accesses in the loop body. + """ buffer_accesses = [] @@ -44,7 +44,6 @@ def collect_local_buffer_accesses(statement) -> list[BufferLoad | BufferStore]: @tir.functor.visitor class _FragmentLoopCheckVisitor(PyStmtExprVisitor): - def __init__(self) -> None: super().__init__() @@ -75,7 +74,8 @@ class _FragmentLoopCheckVisitor(PyStmtExprVisitor): raise ValueError( "[Tilelang Semantic Check] " f"Loop variable {loop.loop_var} in a T.Parallel loop with symbolic range (min={loop.min}, extent={loop.extent}) is used to index " - "a local/fragment buffer, which is not allowed in Tilelang.") + "a local/fragment buffer, which is not allowed in Tilelang." + ) return diff --git a/tilelang/analysis/layout_visual.py b/tilelang/analysis/layout_visual.py index 782b9126d17ef53f2e5ef25b8a48758d9e812e1f..141fb808c49f978fb2fe3089e5903f814bda30fc 100644 --- a/tilelang/analysis/layout_visual.py +++ b/tilelang/analysis/layout_visual.py @@ -23,10 +23,7 @@ def print_fragment_format(layout: T.Fragment) -> str: if isinstance(layout, T.Fragment): input_shape = layout.get_input_shape() output_shape = layout.get_output_shape() - lines = [ - f" Shape: {input_shape} -> {output_shape}", f" Thread: {layout.forward_thread}", - f" Index: {layout.forward_index}" - ] + lines = [f" Shape: {input_shape} -> {output_shape}", f" Thread: {layout.forward_thread}", f" Index: {layout.forward_index}"] print("\n".join(lines)) else: raise ValueError(f"Expected T.Fragment, but got {type(layout).__name__}") @@ -82,7 +79,6 @@ class _LayoutVisualVisitor(PyStmtExprVisitor): def LayoutVisual(formats: str = ""): - def pass_fn(func: tir.PrimFunc, mod, ctx): _LayoutVisualVisitor(formats=formats).visit_stmt(func.body) return func diff --git a/tilelang/analysis/nested_loop_checker.py b/tilelang/analysis/nested_loop_checker.py index eff0fc2dbb6b8c4c2771e282813ba1b35faeda3e..51da7f4c8ea31c8957e79c366ab7981bd67102af 100644 --- a/tilelang/analysis/nested_loop_checker.py +++ b/tilelang/analysis/nested_loop_checker.py @@ -11,10 +11,7 @@ from tvm.tir.transform import prim_func_pass def is_pipelined_for(op: For) -> bool: """Check if a for loop is pipelined.""" - anno_keys = [ - "num_stages", "tl_pipeline_order", "tl_pipeline_stage", "tl_pipeline_sync", - "tl_pipeline_group" - ] + anno_keys = ["num_stages", "tl_pipeline_order", "tl_pipeline_stage", "tl_pipeline_sync", "tl_pipeline_group"] return any(key in op.annotations for key in anno_keys) @@ -26,7 +23,6 @@ def is_tile_op(op: Call) -> bool: @tir.functor.visitor class _NestedLoopCheckVisitor(PyStmtExprVisitor): - def __init__(self) -> None: super().__init__() self.in_parallel_context = False @@ -42,27 +38,24 @@ class _NestedLoopCheckVisitor(PyStmtExprVisitor): # Otherwise if self.in_parallel_context: - raise ValueError("[Tilelang Semantic Check] " - "Nested parallel loops are not allowed. " - "Please check your loop structure.") + raise ValueError("[Tilelang Semantic Check] Nested parallel loops are not allowed. Please check your loop structure.") self.in_parallel_context = True super().visit_for_(op) self.in_parallel_context = False return elif is_pipelined_for(op): if self.in_parallel_context: - raise ValueError("[Tilelang Semantic Check] " - "Pipelined loop cannot be nested inside a parallel loop. " - "Please check your loop structure.") + raise ValueError( + "[Tilelang Semantic Check] Pipelined loop cannot be nested inside a parallel loop. Please check your loop structure." + ) super().visit_for_(op) def visit_call_(self, op: Call) -> None: if self.in_parallel_context and is_tile_op(op): - raise ValueError("[Tilelang Semantic Check] " - "Only elementwise operations are allowed inside a parallel loop. " \ - f"Got a tile-op \"{op.op}\"." - ) + raise ValueError( + f'[Tilelang Semantic Check] Only elementwise operations are allowed inside a parallel loop. Got a tile-op "{op.op}".' + ) def NestedLoopChecker(): diff --git a/tilelang/autotuner/capture.py b/tilelang/autotuner/capture.py index 27c24f14eecc5846af5889c2d451326552c007d8..428a6da9047bea8d517dbc16dd5e403bc5c29d48 100644 --- a/tilelang/autotuner/capture.py +++ b/tilelang/autotuner/capture.py @@ -85,8 +85,7 @@ def _get_current_stack() -> CaptureStack: class AutotuneInputsCapture: - - __slots__ = ("tensors") + __slots__ = "tensors" def __init__(self, tensors: list[Any]): self.tensors = tensors diff --git a/tilelang/autotuner/param.py b/tilelang/autotuner/param.py index 4c8d9a94d21f1ea09783404542442f94111dd9e2..69ad49c79d91127c7676e904fd256f2f8f56f659 100644 --- a/tilelang/autotuner/param.py +++ b/tilelang/autotuner/param.py @@ -1,5 +1,5 @@ -"""The auto-tune parameters. -""" +"""The auto-tune parameters.""" + from __future__ import annotations import tilelang @@ -50,7 +50,7 @@ class CompileArgs: out_idx: list[int] | int | None = None execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto" - target: Literal['auto', 'cuda', 'hip'] = 'auto' + target: Literal["auto", "cuda", "hip"] = "auto" target_host: str | Target = None verbose: bool = False pass_configs: dict[str, Any] | None = None @@ -62,24 +62,20 @@ class CompileArgs: target=self.target, target_host=self.target_host, verbose=self.verbose, - pass_configs=self.pass_configs) + pass_configs=self.pass_configs, + ) def __hash__(self): data = { - "execution_backend": - self.execution_backend, - "target": - str(self.target), - "target_host": - str(self.target_host) if self.target_host else None, - "verbose": - self.verbose, - "pass_configs": - json.dumps(self.pass_configs, sort_keys=True) if self.pass_configs else None, + "execution_backend": self.execution_backend, + "target": str(self.target), + "target_host": str(self.target_host) if self.target_host else None, + "verbose": self.verbose, + "pass_configs": json.dumps(self.pass_configs, sort_keys=True) if self.pass_configs else None, } - hash_obj = hashlib.sha256(json.dumps(data, sort_keys=True).encode('utf-8')) - return int.from_bytes(hash_obj.digest(), byteorder='big') + hash_obj = hashlib.sha256(json.dumps(data, sort_keys=True).encode("utf-8")) + return int.from_bytes(hash_obj.digest(), byteorder="big") @dataclass(frozen=True) @@ -104,6 +100,7 @@ class ProfileArgs: manual_check_prog: Callable = None cache_input_tensors: bool = True """ + warmup: int = 25 rep: int = 100 timeout: int = 30 @@ -127,8 +124,8 @@ class ProfileArgs: "atol": self.atol, "max_mismatched_ratio": self.max_mismatched_ratio, } - hash_obj = hashlib.sha256(json.dumps(data, sort_keys=True).encode('utf-8')) - return int.from_bytes(hash_obj.digest(), byteorder='big') + hash_obj = hashlib.sha256(json.dumps(data, sort_keys=True).encode("utf-8")) + return int.from_bytes(hash_obj.digest(), byteorder="big") @dataclass(frozen=True) @@ -143,6 +140,7 @@ class AutotuneResult: func: Optimized function. kernel: Compiled kernel function. """ + latency: float | None = None config: dict | None = None ref_latency: float | None = None @@ -199,8 +197,7 @@ class AutotuneResult: if verbose: logger.debug(f"Saving kernel source code to file: {device_kernel_path}") if kernel.kernel_source is not None: - self._safe_write_file(device_kernel_path, "w", - lambda f: f.write(kernel.kernel_source)) + self._safe_write_file(device_kernel_path, "w", lambda f: f.write(kernel.kernel_source)) except Exception as e: logger.error(f"Error saving kernel source code to disk: {e}") @@ -211,11 +208,9 @@ class AutotuneResult: logger.debug(f"Saving wrapped kernel source code to file: {host_kernel_path}") # Match kernel_cache behavior: use host source for tvm_ffi, otherwise wrapped kernel if kernel.execution_backend == "tvm_ffi": - self._safe_write_file(host_kernel_path, "w", - lambda f: f.write(kernel.adapter.get_host_source())) + self._safe_write_file(host_kernel_path, "w", lambda f: f.write(kernel.adapter.get_host_source())) else: - self._safe_write_file(host_kernel_path, "w", - lambda f: f.write(kernel.adapter.get_kernel_source())) + self._safe_write_file(host_kernel_path, "w", lambda f: f.write(kernel.adapter.get_kernel_source())) except Exception as e: logger.error(f"Error saving wrapped kernel source code to disk: {e}") @@ -237,12 +232,10 @@ class AutotuneResult: py_src_path = src_lib_path.replace(".cubin", ".py") if verbose: logger.debug(f"Saving kernel nvrtc python code to file: {kernel_py_path}") - self._safe_write_file(kernel_py_path, "wb", - lambda f: f.write(self._load_binary(py_src_path))) + self._safe_write_file(kernel_py_path, "wb", lambda f: f.write(self._load_binary(py_src_path))) if verbose: logger.debug(f"Saving kernel library to file: {kernel_lib_path}") - self._safe_write_file(kernel_lib_path, "wb", - lambda f: f.write(self._load_binary(src_lib_path))) + self._safe_write_file(kernel_lib_path, "wb", lambda f: f.write(self._load_binary(src_lib_path))) elif kernel.execution_backend == "tvm_ffi": executable = kernel.adapter.executable if verbose: @@ -252,8 +245,7 @@ class AutotuneResult: src_lib_path = kernel.adapter.libpath if verbose: logger.debug(f"Saving kernel library to file: {kernel_lib_path}") - self._safe_write_file(kernel_lib_path, "wb", - lambda f: f.write(self._load_binary(src_lib_path))) + self._safe_write_file(kernel_lib_path, "wb", lambda f: f.write(self._load_binary(src_lib_path))) except Exception as e: logger.error(f"Error saving kernel library to disk: {e}") @@ -370,14 +362,12 @@ class AutotuneResult: # save best config (atomic) if verbose: logger.debug(f"Saving best config to file: {path / BEST_CONFIG_PATH}") - self._safe_write_file( - str(path / BEST_CONFIG_PATH), "w", lambda f: json.dump(self.config, f)) + self._safe_write_file(str(path / BEST_CONFIG_PATH), "w", lambda f: json.dump(self.config, f)) # save function (atomic) if verbose: logger.debug(f"Saving function to file: {path / FUNCTION_PATH}") - self._safe_write_file( - str(path / FUNCTION_PATH), "wb", lambda f: cloudpickle.dump(self.func, f)) + self._safe_write_file(str(path / FUNCTION_PATH), "wb", lambda f: cloudpickle.dump(self.func, f)) # save ref latency (atomic) if verbose: @@ -385,10 +375,13 @@ class AutotuneResult: self._safe_write_file( str(path / LATENCY_PATH), "w", - lambda f: json.dump({ - "latency": self.latency, - "ref_latency": self.ref_latency, - }, f), + lambda f: json.dump( + { + "latency": self.latency, + "ref_latency": self.ref_latency, + }, + f, + ), ) # save kernel @@ -403,8 +396,8 @@ class AutotuneResult: # Normalize target and resolve execution backend for loading from tilelang.utils.target import determine_target as _determine_target from tilelang.jit.execution_backend import resolve_execution_backend - norm_target = Target(_determine_target(compile_args.target)) if isinstance( - compile_args.target, str) else compile_args.target + + norm_target = Target(_determine_target(compile_args.target)) if isinstance(compile_args.target, str) else compile_args.target requested_backend = compile_args.execution_backend resolved_backend = resolve_execution_backend(requested_backend, norm_target) # load best config diff --git a/tilelang/autotuner/tuner.py b/tilelang/autotuner/tuner.py index 9b2fca2c39b2a7027dfbef46331170b591713ceb..5bbdc48a4e63dd95e3caa8756508552c71cf0389 100644 --- a/tilelang/autotuner/tuner.py +++ b/tilelang/autotuner/tuner.py @@ -3,6 +3,7 @@ This module provides functionality for auto-tuning tilelang programs, including JIT compilation and performance optimization through configuration search. """ + from __future__ import annotations from dataclasses import dataclass @@ -14,7 +15,8 @@ from tvm.tir import PrimFunc, Var from tvm.target import Target import inspect from functools import partial -from typing import (Callable, Generic, Literal, Any, TypeVar) +from typing import Callable, Generic, Literal, Any, TypeVar + # Python 3.9 compatibility for ParamSpec try: from typing import ParamSpec @@ -74,8 +76,8 @@ def _init_logger_handlers(): global _logger_handlers_initialized if _logger_handlers_initialized: return - formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s') - file_handler = logging.FileHandler('autotuner.log', mode='w') + formatter = logging.Formatter("%(asctime)s %(levelname)s:%(message)s") + file_handler = logging.FileHandler("autotuner.log", mode="w") file_handler.setLevel(logging.DEBUG) file_handler.setFormatter(formatter) console_handler = logging.StreamHandler(sys.stdout) @@ -87,8 +89,7 @@ def _init_logger_handlers(): def get_available_cpu_count() -> int: - """Gets the number of CPU cores available to the current process. - """ + """Gets the number of CPU cores available to the current process.""" try: cpu_count = len(os.sched_getaffinity(0)) except AttributeError: @@ -107,6 +108,7 @@ class AutoTuner: fn: The function to be auto-tuned. configs: List of configurations to try during auto-tuning. """ + compile_args = CompileArgs() profile_args = ProfileArgs() @@ -137,14 +139,15 @@ class AutoTuner: """ return cls(kernel, configs) - def set_compile_args(self, - out_idx: list[int] | int | None = None, - target: Literal['auto', 'cuda', 'hip', 'metal'] = 'auto', - execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", - "torch"] = "auto", - target_host: str | Target = None, - verbose: bool = False, - pass_configs: dict[str, Any] | None = None): + def set_compile_args( + self, + out_idx: list[int] | int | None = None, + target: Literal["auto", "cuda", "hip", "metal"] = "auto", + execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto", + target_host: str | Target = None, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + ): """Set compilation arguments for the auto-tuner. Args: @@ -161,6 +164,7 @@ class AutoTuner: # Normalize target to a concrete TVM Target and resolve execution backend t = Target(determine_target(target)) from tilelang.jit.execution_backend import resolve_execution_backend + resolved_backend = resolve_execution_backend(execution_backend, t) self.compile_args = CompileArgs( @@ -169,23 +173,26 @@ class AutoTuner: execution_backend=resolved_backend, target_host=target_host, verbose=verbose, - pass_configs=pass_configs) + pass_configs=pass_configs, + ) return self - def set_profile_args(self, - warmup: int = 25, - rep: int = 100, - timeout: int = 30, - supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto, - ref_prog: Callable = None, - supply_prog: Callable = None, - rtol: float = 1e-2, - atol: float = 1e-2, - max_mismatched_ratio: float = 0.01, - skip_check: bool = False, - manual_check_prog: Callable = None, - cache_input_tensors: bool = False): + def set_profile_args( + self, + warmup: int = 25, + rep: int = 100, + timeout: int = 30, + supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto, + ref_prog: Callable = None, + supply_prog: Callable = None, + rtol: float = 1e-2, + atol: float = 1e-2, + max_mismatched_ratio: float = 0.01, + skip_check: bool = False, + manual_check_prog: Callable = None, + cache_input_tensors: bool = False, + ): """Set profiling arguments for the auto-tuner. Args: @@ -209,9 +216,7 @@ class AutoTuner: # the `supply_prog` will be ignored and the `get_autotune_inputs` will be used instead. if get_autotune_inputs() is not None: if supply_prog is not None: - logger.warning( - "`supply_prog` will be ignored as this program is under `with set_autotune_inputs` context." - ) + logger.warning("`supply_prog` will be ignored as this program is under `with set_autotune_inputs` context.") supply_prog = lambda _: get_autotune_inputs() # noqa: E731 self.profile_args = ProfileArgs( @@ -226,13 +231,13 @@ class AutoTuner: cache_input_tensors=cache_input_tensors, warmup=warmup, rep=rep, - timeout=timeout) + timeout=timeout, + ) # If a custom `supply_prog` is provided, the profiler's `supply_type` setting # becomes ineffective. The custom supply program will be used instead. if supply_prog is not None and supply_type != tilelang.TensorSupplyType.Auto: - logger.warning("Ignoring `supply_type` passed to `set_profile_args` because " - "`supply_prog` is not None.") + logger.warning("Ignoring `supply_type` passed to `set_profile_args` because `supply_prog` is not None.") return self @@ -241,10 +246,8 @@ class AutoTuner: self._kernel_parameters = k_parameters self._function_parameters = f_parameters - def generate_cache_key(self, parameters: dict[str, Any], - extra_parameters: dict[str, Any]) -> AutotuneResult | None: - """Generate a cache key for the auto-tuning process. - """ + def generate_cache_key(self, parameters: dict[str, Any], extra_parameters: dict[str, Any]) -> AutotuneResult | None: + """Generate a cache key for the auto-tuning process.""" def _normalize_param(value): if isinstance(value, Var): @@ -315,8 +318,9 @@ class AutoTuner: if var_name in parameters: continue # Cell content must be serializable - assert isinstance(cell.cell_contents, (int, float, str, bool, type(None))), \ + assert isinstance(cell.cell_contents, (int, float, str, bool, type(None))), ( f"Cell contents {cell.cell_contents} is not serializable: {type(cell.cell_contents)}" + ) extra_parameters[var_name] = cell.cell_contents if isinstance(self.configs, Callable): @@ -328,8 +332,10 @@ class AutoTuner: if env.is_cache_enabled() and not env.is_autotune_cache_disabled(): # First check in-memory cache if key in self._memory_cache: - logger.warning("Found kernel in memory cache. For better performance," \ - " consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel.") + logger.warning( + "Found kernel in memory cache. For better performance," + " consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel." + ) return self._memory_cache[key] # Then check disk cache @@ -369,7 +375,6 @@ class AutoTuner: # This encapsulates the logic of using either a custom supply program (`supply_prog`) # or the default profiler input generation (`profiler._get_inputs`). def get_input_tensors_supply(with_output: bool): - def func(): if supply_prog is not None: return supply_prog(profiler._get_params(with_output=with_output)) @@ -387,8 +392,7 @@ class AutoTuner: self.jit_input_tensors = jit_input_tensors_supply() else: # check if the cached tensors are compatible with the current configuration - assert len(params) == len( - self.jit_input_tensors), "len(params) != len(self.jit_input_tensors)" + assert len(params) == len(self.jit_input_tensors), "len(params) != len(self.jit_input_tensors)" for p, c in zip(params, self.jit_input_tensors): if not isinstance(c, torch.Tensor): # skip non-tensor inputs checking @@ -397,8 +401,8 @@ class AutoTuner: # Check tensor compatibility using generator expression def shape_equal(a, b): return all( - a_dim == b_dim or isinstance(a_dim, Var) or isinstance(b_dim, Var) - for a_dim, b_dim in zip(a.shape, b.shape)) + a_dim == b_dim or isinstance(a_dim, Var) or isinstance(b_dim, Var) for a_dim, b_dim in zip(a.shape, b.shape) + ) if p.dtype != c.dtype or not shape_equal(p, c): logger.warning( @@ -409,7 +413,8 @@ class AutoTuner: "To ensure fresh, compatible inputs are generated for every trial " "you can disable caching by setting:\n" " `cache_input_tensors=False`\n" - "within your `.set_compile_args(...)` call.\n") + "within your `.set_compile_args(...)` call.\n" + ) # otherwise, regenerate the input tensors for safety self.jit_input_tensors = jit_input_tensors_supply() break @@ -418,24 +423,16 @@ class AutoTuner: if (not skip_check) and (ref_prog is not None): if manual_check_prog is not None: - profiler.manual_assert_close( - ref_prog, - input_tensors=self.jit_input_tensors, - manual_check_prog=manual_check_prog) + profiler.manual_assert_close(ref_prog, input_tensors=self.jit_input_tensors, manual_check_prog=manual_check_prog) else: profiler.assert_allclose( - ref_prog, - input_tensors=self.jit_input_tensors, - rtol=rtol, - atol=atol, - max_mismatched_ratio=max_mismatched_ratio) - latency = profiler.do_bench( - warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors) + ref_prog, input_tensors=self.jit_input_tensors, rtol=rtol, atol=atol, max_mismatched_ratio=max_mismatched_ratio + ) + latency = profiler.do_bench(warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors) if self.ref_latency_cache is None and ref_prog is not None: self.ref_input_tensors = ref_input_tensors_supply() - self.ref_latency_cache = profiler.do_bench( - ref_prog, n_warmup=warmup, n_repeat=rep, input_tensors=self.ref_input_tensors) + self.ref_latency_cache = profiler.do_bench(ref_prog, n_warmup=warmup, n_repeat=rep, input_tensors=self.ref_input_tensors) return latency, self.ref_latency_cache @@ -469,17 +466,14 @@ class AutoTuner: # Check if all tunable arguments have been tuned by comparing config keys with key_kwargs_tuple if any(key in top_config for key, _ in key_kwargs_tuple) or any( - check_tunable_argument_value(key, self._function_parameters, key_args_tuple) - for key in tunable_arguments): + check_tunable_argument_value(key, self._function_parameters, key_args_tuple) for key in tunable_arguments + ): logger.warning( f"Tunable parameters {tunable_arguments} already provided during auto-tuning. Skipping compilation and using direct JIT" ) # compile the kernel with the provided parameters jit_kernel = self.jit_compile() - autotuner_result = AutotuneResult( - libcode=jit_kernel.get_kernel_source(), - func=jit_kernel.prim_func, - kernel=jit_kernel) + autotuner_result = AutotuneResult(libcode=jit_kernel.get_kernel_source(), func=jit_kernel.prim_func, kernel=jit_kernel) self._memory_cache[key] = autotuner_result return autotuner_result # get the cpu count @@ -489,9 +483,7 @@ class AutoTuner: max_cpu_count = int(env.TILELANG_AUTO_TUNING_MAX_CPU_COUNT) if cpu_counts > 0: num_workers = min(cpu_counts, available_cpu_count) - logger.info( - f"Auto-tuning with {cpu_counts} CPU counts, {available_cpu_count} CPUs available, {num_workers} CPUs will be used" - ) + logger.info(f"Auto-tuning with {cpu_counts} CPU counts, {available_cpu_count} CPUs available, {num_workers} CPUs will be used") else: num_workers = max(1, int(available_cpu_count * cpu_utilizations)) logger.info( @@ -509,7 +501,6 @@ class AutoTuner: future_to_index = {} def cuda_device_wrapper(func, device): - def inner(**config_arg): torch.cuda.set_device(device) return func(**config_arg) @@ -532,18 +523,14 @@ class AutoTuner: future_to_index[future] = i results_with_configs = [] - for future in tqdm( - concurrent.futures.as_completed(futures), - total=len(futures), - desc="Compiling configurations"): + for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Compiling configurations"): idx = future_to_index[future] config = config_args[idx] try: result = future.result() results_with_configs.append((result, config)) except Exception as e: - logger.debug( - f"Compilation failed for config {config} at index {idx} with error: {e}") + logger.debug(f"Compilation failed for config {config} at index {idx} with error: {e}") continue ref_latency = None @@ -556,14 +543,10 @@ class AutoTuner: # latency, ref_latency = target_fn(jit_kernel) latency, ref_latency = run_with_timeout(target_fn, timeout, jit_kernel) except TimeoutException: - logger.warning( - f"A timeout occurred while testing config {config}, checkout autotuner.log for more details" - ) + logger.warning(f"A timeout occurred while testing config {config}, checkout autotuner.log for more details") continue except Exception: - logger.warning( - f"An error occurred while testing config {config}, checkout autotuner.log for more details" - ) + logger.warning(f"An error occurred while testing config {config}, checkout autotuner.log for more details") logger.debug(f"Error: {traceback.format_exc()}") continue @@ -578,8 +561,7 @@ class AutoTuner: pool.shutdown() if best_kernel is None: - error_msg = ("Auto-tuning failed: No configuration successfully " - "compiled and passed benchmarking/validation.") + error_msg = "Auto-tuning failed: No configuration successfully compiled and passed benchmarking/validation." logger.error(error_msg) raise RuntimeError(error_msg) @@ -595,7 +577,8 @@ class AutoTuner: ref_latency=ref_latency, libcode=best_kernel.get_kernel_source(), func=best_kernel.prim_func, - kernel=best_kernel) + kernel=best_kernel, + ) if self.compile_args.execution_backend in ("torch"): logger.warning("DLPack backend does not support cache saving to disk.") @@ -617,8 +600,8 @@ class AutoTuner: return self.run() -_P = ParamSpec('_P') -_T = TypeVar('_T') +_P = ParamSpec("_P") +_T = TypeVar("_T") @dataclass @@ -643,8 +626,9 @@ class AutoTuneImpl(Generic[_P, _T]): self._tuner_cache = {} def get_tunner(self): - autotuner = AutoTuner( - self.jit_impl.func, configs=self.configs).set_profile_args( + autotuner = ( + AutoTuner(self.jit_impl.func, configs=self.configs) + .set_profile_args( supply_type=self.supply_type, ref_prog=self.ref_prog, supply_prog=self.supply_prog, @@ -654,7 +638,8 @@ class AutoTuneImpl(Generic[_P, _T]): skip_check=self.skip_check, manual_check_prog=self.manual_check_prog, cache_input_tensors=self.cache_input_tensors, - ).set_compile_args( + ) + .set_compile_args( out_idx=self.jit_impl.out_idx, execution_backend=self.jit_impl.execution_backend, target=self.jit_impl.target, @@ -662,6 +647,7 @@ class AutoTuneImpl(Generic[_P, _T]): verbose=self.jit_impl.verbose, pass_configs=self.jit_impl.pass_configs, ) + ) autotuner.run = partial(autotuner.run, self.warmup, self.rep, self.timeout) return autotuner @@ -753,16 +739,13 @@ def autotune( # This is the new public interface if callable(func): # Case 1: Used as @autotune (func_or_out_idx is the function, others are defaults) # This is a placeholder for a real auto tuner implementation - raise ValueError( - "Use tilelang.autotune to decorate func without arguments is not supported yet.") + raise ValueError("Use tilelang.autotune to decorate func without arguments is not supported yet.") elif isinstance(func, PrimFunc): raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.") else: def decorator(impl): - assert isinstance( - impl, JITImpl - ), "The @autotune decorator can only be applied to @tilelang.jit decorated instances." + assert isinstance(impl, JITImpl), "The @autotune decorator can only be applied to @tilelang.jit decorated instances." return AutoTuneImpl( jit_impl=impl, configs=configs, diff --git a/tilelang/cache/__init__.py b/tilelang/cache/__init__.py index 144c27299b020d037dde1a291c3fb29bd16b46da..18ac847bf4adbad368b8f853cb152287415cdb9f 100644 --- a/tilelang/cache/__init__.py +++ b/tilelang/cache/__init__.py @@ -1,4 +1,5 @@ """The cache utils with class and database persistence - Init file""" + from __future__ import annotations from typing import Literal @@ -18,8 +19,7 @@ def cached( *args, target: str | Target = "auto", target_host: str | Target = None, - execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] - | None = "auto", + execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] | None = "auto", verbose: bool | None = False, pass_configs: dict | None = None, compile_flags: list[str] | str | None = None, @@ -36,7 +36,8 @@ def cached( execution_backend=execution_backend, verbose=verbose, pass_configs=pass_configs, - compile_flags=compile_flags) + compile_flags=compile_flags, + ) def clear_cache(): @@ -47,9 +48,11 @@ def clear_cache(): RuntimeError: Always raised to warn users to clear the cache manually. """ cache_dir = env.TILELANG_CACHE_DIR - raise RuntimeError("tilelang.clear_cache() is disabled because deleting the cache directory " - "is dangerous. If you accept the risk, remove it manually with " - f"`rm -rf '{cache_dir}'`.") + raise RuntimeError( + "tilelang.clear_cache() is disabled because deleting the cache directory " + "is dangerous. If you accept the risk, remove it manually with " + f"`rm -rf '{cache_dir}'`." + ) if env.TILELANG_CLEAR_CACHE.lower() in ("1", "true", "yes", "on"): diff --git a/tilelang/cache/kernel_cache.py b/tilelang/cache/kernel_cache.py index 74ecb2788ce278f72cb6659a591564c1140a565a..4fbe2dce5519e493978180af62ea23f17ae9b4dc 100644 --- a/tilelang/cache/kernel_cache.py +++ b/tilelang/cache/kernel_cache.py @@ -1,4 +1,5 @@ """The cache utils with class and database persistence - KernelCache Class""" + from __future__ import annotations import json @@ -97,9 +98,7 @@ class KernelCache: "version": __version__, "func": sha256(func_binary).hexdigest(), # Use SHA256 to generate hash key "out_idx": (tuple(out_idx) if isinstance(out_idx, (list, tuple)) else [out_idx]), - "args_repr": tuple( - repr(arg) for arg in args - ), # Use repr to serialize arguments, may need more robust serialization + "args_repr": tuple(repr(arg) for arg in args), # Use repr to serialize arguments, may need more robust serialization "target": str(target), "target_host": str(target_host) if target_host else None, "execution_backend": execution_backend, @@ -118,8 +117,7 @@ class KernelCache: *args, target: str | Target = "auto", target_host: str | Target = None, - execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", - "torch"] = "auto", + execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto", verbose: bool = False, pass_configs: dict = None, compile_flags: list[str] | str | None = None, @@ -140,6 +138,7 @@ class KernelCache: # Normalize target and resolve execution backend before proceeding from tilelang.utils.target import determine_target as _determine_target from tilelang.jit.execution_backend import resolve_execution_backend, allowed_backends_for_target + norm_target = Target(_determine_target(target)) if isinstance(target, str) else target requested_backend = execution_backend execution_backend = resolve_execution_backend(requested_backend, norm_target) @@ -180,21 +179,21 @@ class KernelCache: with self._lock: # First check in-memory cache if key in self._memory_cache: - self.logger.warning("Found kernel in memory cache. For better performance," \ - " consider using `@tilelang.jit` instead of direct kernel caching.") + self.logger.warning( + "Found kernel in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching." + ) return self._memory_cache[key] if verbose: self.logger.debug(f"Checking disk cache for kernel {func.attrs['global_symbol']}") # Then check disk cache - kernel = self._load_kernel_from_disk(key, norm_target, target_host, out_idx, - execution_backend, pass_configs, compile_flags, - func, verbose) + kernel = self._load_kernel_from_disk( + key, norm_target, target_host, out_idx, execution_backend, pass_configs, compile_flags, func, verbose + ) if kernel is not None: if verbose: - self.logger.debug( - f"Found kernel in disk cache for {func.attrs['global_symbol']}") + self.logger.debug(f"Found kernel in disk cache for {func.attrs['global_symbol']}") # Populate memory cache with disk result self._memory_cache[key] = kernel return kernel @@ -262,11 +261,7 @@ class KernelCache: executable.export_library(temp_path) os.replace(temp_path, path) - def _save_kernel_to_disk(self, - key: str, - kernel: JITKernel, - func: Callable = None, - verbose: bool = False): + def _save_kernel_to_disk(self, key: str, kernel: JITKernel, func: Callable = None, verbose: bool = False): """ Persists a compiled kernel to disk cache. @@ -292,8 +287,7 @@ class KernelCache: if verbose: self.logger.debug(f"Saving kernel source code to file: {device_kernel_path}") if kernel.kernel_source is not None: - KernelCache._safe_write_file(device_kernel_path, "w", - lambda file: file.write(kernel.kernel_source)) + KernelCache._safe_write_file(device_kernel_path, "w", lambda file: file.write(kernel.kernel_source)) except Exception as e: self.logger.error(f"Error saving kernel source code to disk: {e}") @@ -303,13 +297,9 @@ class KernelCache: if verbose: self.logger.debug(f"Saving wrapped kernel source code to file: {host_kernel_path}") if self.execution_backend == "tvm_ffi": - KernelCache._safe_write_file( - host_kernel_path, "w", - lambda file: file.write(kernel.adapter.get_host_source())) + KernelCache._safe_write_file(host_kernel_path, "w", lambda file: file.write(kernel.adapter.get_host_source())) else: - KernelCache._safe_write_file( - host_kernel_path, "w", - lambda file: file.write(kernel.adapter.get_kernel_source())) + KernelCache._safe_write_file(host_kernel_path, "w", lambda file: file.write(kernel.adapter.get_kernel_source())) except Exception as e: self.logger.error(f"Error saving host kernel source code to disk: {e}") @@ -332,9 +322,7 @@ class KernelCache: src_lib_path = src_lib_path.replace(".cubin", ".py") if verbose: self.logger.debug(f"Saving kernel nvrtc python code to file: {kernel_py_path}") - KernelCache._safe_write_file( - kernel_py_path, "wb", - lambda file: file.write(KernelCache._load_binary(src_lib_path))) + KernelCache._safe_write_file(kernel_py_path, "wb", lambda file: file.write(KernelCache._load_binary(src_lib_path))) elif self.execution_backend == "tvm_ffi": executable = kernel.adapter.executable if verbose: @@ -344,9 +332,7 @@ class KernelCache: src_lib_path = kernel.adapter.libpath if verbose: self.logger.debug(f"Saving kernel library to file: {kernel_lib_path}") - KernelCache._safe_write_file( - kernel_lib_path, "wb", - lambda file: file.write(KernelCache._load_binary(src_lib_path))) + KernelCache._safe_write_file(kernel_lib_path, "wb", lambda file: file.write(KernelCache._load_binary(src_lib_path))) except Exception as e: self.logger.error(f"Error saving kernel library to disk: {e}") @@ -356,8 +342,7 @@ class KernelCache: params_path = os.path.join(cache_path, PARAMS_PATH) if verbose: self.logger.debug(f"Saving kernel parameters to disk: {params_path}") - KernelCache._safe_write_file(params_path, "wb", - lambda file: cloudpickle.dump(kernel.params, file)) + KernelCache._safe_write_file(params_path, "wb", lambda file: cloudpickle.dump(kernel.params, file)) except Exception as e: self.logger.error(f"Error saving kernel parameters to disk: {e}") @@ -417,8 +402,7 @@ class KernelCache: self.logger.error(f"Error loading kernel source code from disk: {e}") try: if verbose: - self.logger.debug( - f"Loading wrapped kernel source code from file: {host_kernel_path}") + self.logger.debug(f"Loading wrapped kernel source code from file: {host_kernel_path}") with open(host_kernel_path) as f: host_kernel_source = f.read() except Exception as e: diff --git a/tilelang/carver/__init__.py b/tilelang/carver/__init__.py index 4ffd43644da36eb9ece7502fa91a9d0a1a634cbb..f1dfc5b4750d23e679535b11c34b1bfe7ceb6576 100644 --- a/tilelang/carver/__init__.py +++ b/tilelang/carver/__init__.py @@ -1,4 +1,5 @@ """Base infra""" + from .analysis import ( BlockInfo, # noqa: F401 IterInfo, # noqa: F401 diff --git a/tilelang/carver/analysis.py b/tilelang/carver/analysis.py index 96606e790ef37291b84213ab36784a1b4bdb70d3..6ca9168185f2d3149e273cbe2e3517d385eda204 100644 --- a/tilelang/carver/analysis.py +++ b/tilelang/carver/analysis.py @@ -1,4 +1,5 @@ """Analysis on TIR blocks, loops and functions.""" + from __future__ import annotations from typing_extensions import Literal @@ -144,11 +145,13 @@ def normalize_prim_func(sch: tir.Schedule) -> list[BlockInfo] | None: var=iter.var, dom=iter.dom, loop_rv=loop, - ) for loop, iter in zip(loops, iters) + ) + for loop, iter in zip(loops, iters) ], block_rv=block, reduction_block=is_reduction, - )) + ) + ) return blocks @@ -188,8 +191,7 @@ def get_max_shared_memory_per_block(target: Target) -> int: _assert_gpu_target(target) max_shared_memory_per_block = target.attrs.get("max_shared_memory_per_block", None) if max_shared_memory_per_block is None: - raise ValueError( - f"Cannot find `max_shared_memory_per_block` in {target}, please specify it manually") + raise ValueError(f"Cannot find `max_shared_memory_per_block` in {target}, please specify it manually") return int(max_shared_memory_per_block) @@ -197,13 +199,11 @@ def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV: try: block = sch.mod[func_name].body.block except Exception: - raise ValueError(f"The function body is expected to be the root block, but got:\n" - f"{sch.mod[func_name].body}") from None + raise ValueError(f"The function body is expected to be the root block, but got:\n{sch.mod[func_name].body}") from None return sch.get_block(block.name_hint) -def collect_block_iter_vars_used_in_access_region(block: tir.Block, - region: list[ir.Range]) -> set[tir.Var]: +def collect_block_iter_vars_used_in_access_region(block: tir.Block, region: list[ir.Range]) -> set[tir.Var]: """Collect the block iter variables used in the access region of a buffer region.""" tir_vars = set() for expr in region: @@ -251,15 +251,13 @@ def is_broadcast_epilogue( for buffer_region in sch.get(epilogue).reads: if buffer_region.buffer not in write_buffers: continue - tir_vars = collect_block_iter_vars_used_in_access_region( - sch.get(epilogue), buffer_region.region) + tir_vars = collect_block_iter_vars_used_in_access_region(sch.get(epilogue), buffer_region.region) if len(tir_vars) < len(epilogue_iters): return True return False -def get_reduction_blocks(sch: tir.Schedule, - blocks: list[tir.schedule.BlockRV]) -> list[tir.schedule.BlockRV]: +def get_reduction_blocks(sch: tir.Schedule, blocks: list[tir.schedule.BlockRV]) -> list[tir.schedule.BlockRV]: # Get the main computation block def is_reduction(block: BlockRV) -> bool: block_stmt = sch.get(block) diff --git a/tilelang/carver/arch/__init__.py b/tilelang/carver/arch/__init__.py index c2bc9c75d54c670e96a2d8bd558a8af955eba5d5..b6cb9e72f7243c46b9d13636e91b37ab887a0625 100644 --- a/tilelang/carver/arch/__init__.py +++ b/tilelang/carver/arch/__init__.py @@ -39,18 +39,18 @@ def auto_infer_current_arch() -> TileDevice: __all__ = [ - 'is_cpu_arch', - 'is_cuda_arch', - 'is_volta_arch', - 'is_ampere_arch', - 'is_ada_arch', - 'is_hopper_arch', - 'is_tensorcore_supported_precision', - 'has_mma_support', - 'is_cdna_arch', - 'is_metal_arch', - 'CUDA', - 'CDNA', - 'METAL', - 'CPU', + "is_cpu_arch", + "is_cuda_arch", + "is_volta_arch", + "is_ampere_arch", + "is_ada_arch", + "is_hopper_arch", + "is_tensorcore_supported_precision", + "has_mma_support", + "is_cdna_arch", + "is_metal_arch", + "CUDA", + "CDNA", + "METAL", + "CPU", ] diff --git a/tilelang/carver/arch/arch_base.py b/tilelang/carver/arch/arch_base.py index 4c8825e8e7af5c76d919abe44feacd3d86d87273..c5e9dfa683c0bb357008b2748ae280b2c52ee04a 100644 --- a/tilelang/carver/arch/arch_base.py +++ b/tilelang/carver/arch/arch_base.py @@ -7,9 +7,7 @@ class TileDevice: self.reg_cap: int = 0 # Register capacity: The amount of register memory available self.smem_cap: int = 0 # Shared memory capacity: The amount of shared memory available self.compute_max_core: int = 0 # The maximum number of computing cores - self.warp_size: int = ( - 0 # The size of a warp, a group of threads that execute instructions in lockstep - ) + self.warp_size: int = 0 # The size of a warp, a group of threads that execute instructions in lockstep self.sm_partition: int = 0 # The number of streaming multiprocessor partitions self.transaction_size: list[int] = [ 0, @@ -21,9 +19,7 @@ class TileDevice: 0, ] # Bandwidth specifications, possibly including peak and sustained rates self.platform: str = "unknown" # The platform or manufacturer of the device - self.compute_capability: str = ( - "unknown" # The compute capability, indicating the feature set and performance level - ) + self.compute_capability: str = "unknown" # The compute capability, indicating the feature set and performance level self.l2_cache_size_bytes: int = 0 # the number of transaction size in bytes self.transaction_size: list[int] = [0, 0] # in bytes diff --git a/tilelang/carver/arch/cdna.py b/tilelang/carver/arch/cdna.py index ec5aa905fe27c6fad7cfbb618819436c6823edf4..5c2d4c4ed6722e2577a2755d2af73ff76eabbf41 100644 --- a/tilelang/carver/arch/cdna.py +++ b/tilelang/carver/arch/cdna.py @@ -9,7 +9,6 @@ def is_cdna_arch(arch: TileDevice) -> bool: class CDNA(TileDevice): - def __init__(self, target: Target | str): if isinstance(target, str): target = tvm.target.Target(target) @@ -33,6 +32,6 @@ class CDNA(TileDevice): __all__ = [ - 'is_cdna_arch', - 'CDNA', + "is_cdna_arch", + "CDNA", ] diff --git a/tilelang/carver/arch/cpu.py b/tilelang/carver/arch/cpu.py index f4643baa0791f7dc621d37863292459ec15cef43..fc18c6c8b35e703c1efe98b4ed1378868823d741 100644 --- a/tilelang/carver/arch/cpu.py +++ b/tilelang/carver/arch/cpu.py @@ -10,7 +10,6 @@ def is_cpu_arch(arch: TileDevice) -> bool: # For LLVM Backend, we do not provide the detailed information of the CPU # As the LLVM backend do not required tuning, just maintain the consistency class CPU(TileDevice): - def __init__(self, target: Target): self.target = target device = tvm.runtime.cpu(0) @@ -21,6 +20,6 @@ class CPU(TileDevice): __all__ = [ - 'is_cpu_arch', - 'CPU', + "is_cpu_arch", + "CPU", ] diff --git a/tilelang/carver/arch/cuda.py b/tilelang/carver/arch/cuda.py index 4c7f98dffdab95272325ccefdb9c6fb5dd1830da..2b79b2832b223469f5f3aa8b6fc94d1448acd2c3 100644 --- a/tilelang/carver/arch/cuda.py +++ b/tilelang/carver/arch/cuda.py @@ -78,7 +78,6 @@ hopper_tensorcore_supported = ada_tensorcore_supported # instead of assuming both a and b share the same dtype. # As the tensorcore may supports float8_e4m3 * float8_e5m2 def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: TileDevice) -> bool: - if is_volta_arch(arch): return (in_dtype, accum_dtype) in volta_tensorcore_supported elif is_ampere_arch(arch): @@ -92,7 +91,6 @@ def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: Til class TensorInstruction: - def __init__( self, name: str, @@ -104,7 +102,6 @@ class TensorInstruction: class CUDA(TileDevice): - def __init__(self, target: Target | str): if isinstance(target, str): target = tvm.target.Target(target) @@ -148,12 +145,12 @@ class CUDA(TileDevice): __all__ = [ - 'is_cuda_arch', - 'is_volta_arch', - 'is_ampere_arch', - 'is_ada_arch', - 'is_hopper_arch', - 'is_tensorcore_supported_precision', - 'has_mma_support', + "is_cuda_arch", + "is_volta_arch", + "is_ampere_arch", + "is_ada_arch", + "is_hopper_arch", + "is_tensorcore_supported_precision", + "has_mma_support", "CUDA", ] diff --git a/tilelang/carver/arch/driver/cuda_driver.py b/tilelang/carver/arch/driver/cuda_driver.py index c8cc1a38eea874d28db322b889b86a54af188c98..a631276635f6d53df368271a65ab6c84926c1f62 100644 --- a/tilelang/carver/arch/driver/cuda_driver.py +++ b/tilelang/carver/arch/driver/cuda_driver.py @@ -83,8 +83,7 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes") Get the maximum dynamic shared memory size in bytes, kilobytes, or megabytes. """ assert format in ["bytes", "kb", "mb"], "Invalid format. Must be one of: bytes, kb, mb" - shared_mem = get_device_attribute( - cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerMultiprocessor, device_id) + shared_mem = get_device_attribute(cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerMultiprocessor, device_id) if format == "bytes": return shared_mem elif format == "kb": diff --git a/tilelang/carver/arch/metal.py b/tilelang/carver/arch/metal.py index 9cd1c4d1e3f42329169bdfc7dcb5dfcb376c36d3..0b76849a7695175a113b14cde3db19605de37006 100644 --- a/tilelang/carver/arch/metal.py +++ b/tilelang/carver/arch/metal.py @@ -8,7 +8,6 @@ def is_metal_arch(arch: TileDevice) -> bool: class METAL(TileDevice): - def __init__(self, target: Target | str): if isinstance(target, str): target = Target(target) @@ -16,6 +15,6 @@ class METAL(TileDevice): __all__ = [ - 'is_metal_arch', - 'METAL', + "is_metal_arch", + "METAL", ] diff --git a/tilelang/carver/common_schedules.py b/tilelang/carver/common_schedules.py index 199f0158cf9c71bdb89f38d5b62fffacdc7b222f..4904b770dd6a5c83e219fea80061ed4f9fb5b46b 100644 --- a/tilelang/carver/common_schedules.py +++ b/tilelang/carver/common_schedules.py @@ -19,6 +19,7 @@ # Modifications Copyright (c) Microsoft. # The code below is mostly copied from apache/tvm common_schedules.py in dlight. """Common schedule strategies for TIR.""" + from typing import Callable from tvm import tir diff --git a/tilelang/carver/matmul_analysis.py b/tilelang/carver/matmul_analysis.py index 02a86cc78c375fcf39a3af12a2ad73451867c090..6d27de8253c12cdebc3ba7f50ec66cb702148b63 100644 --- a/tilelang/carver/matmul_analysis.py +++ b/tilelang/carver/matmul_analysis.py @@ -1,5 +1,6 @@ # pylint: disable=missing-docstring, invalid-name """A GEMM schedule rule for GPU operators.""" + from __future__ import annotations from dataclasses import dataclass from enum import Enum @@ -157,8 +158,7 @@ def find_last_producer_from_buffer(sch, main_block, buffer: tir.Buffer) -> Block return block -def find_arg_idx_from_buffer_chain(sch: tir.Schedule, main_block: tir.schedule.BlockRV, - buffer: tir.Buffer) -> int: +def find_arg_idx_from_buffer_chain(sch: tir.Schedule, main_block: tir.schedule.BlockRV, buffer: tir.Buffer) -> int: """traverse to find the arg index from the buffer""" producers = sch.get_producers(main_block) @@ -226,9 +226,7 @@ def make_iter_fusion_index_map( else: fused_iters[trait.kind] = v_i - final_indices: list[tir.PrimExpr] = [ - fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order - ] + final_indices: list[tir.PrimExpr] = [fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order] return tir.IndexMap(input_iters, final_indices, None) @@ -307,8 +305,7 @@ def detect_iter_traits(block: tir.Block) -> tuple[list[IterTrait]] | None: return A_traits, B_traits, C_traits, block_traits -def get_index_map(block: tir.Block, - layout: list[str] | None = None) -> tuple[tir.IndexMap, ...] | None: +def get_index_map(block: tir.Block, layout: list[str] | None = None) -> tuple[tir.IndexMap, ...] | None: """Get index maps for the block Parameters @@ -343,10 +340,7 @@ def get_index_map(block: tir.Block, return axes def is_common_reduce(var: Var) -> bool: - for iter_var in block.iter_vars: - if iter_var.var == var and iter_var.iter_type == IterVar.CommReduce: - return True - return False + return any(iter_var.var == var and iter_var.iter_type == IterVar.CommReduce for iter_var in block.iter_vars) def has_common_reduce(var: Var) -> bool: vars = collect_vars_from_expr(var) @@ -384,17 +378,17 @@ def get_index_map(block: tir.Block, if kind == "C": return [IterKind.kIter_S, primary_iter, secondary_iter] else: - return ([IterKind.kIter_S, spatial_iter, reduction_iter] if check_last_trait(region) - else [IterKind.kIter_S, reduction_iter, spatial_iter]) + return ( + [IterKind.kIter_S, spatial_iter, reduction_iter] + if check_last_trait(region) + else [IterKind.kIter_S, reduction_iter, spatial_iter] + ) else: raise ValueError(f"Unknown layout {layout}") - A_index_map = make_iter_fusion_index_map( - A_traits, infer_layout(layout[0], block.reads[0].region, kind="A")) - B_index_map = make_iter_fusion_index_map( - B_traits, infer_layout(layout[1], block.reads[1].region, kind="B")) - C_index_map = make_iter_fusion_index_map( - C_traits, infer_layout(layout[2], block.writes[0].region, kind="C")) + A_index_map = make_iter_fusion_index_map(A_traits, infer_layout(layout[0], block.reads[0].region, kind="A")) + B_index_map = make_iter_fusion_index_map(B_traits, infer_layout(layout[1], block.reads[1].region, kind="B")) + C_index_map = make_iter_fusion_index_map(C_traits, infer_layout(layout[2], block.writes[0].region, kind="C")) matmul_index_map = make_iter_fusion_index_map( block_traits, @@ -429,8 +423,7 @@ def get_dequantize_block(sch, blocks) -> BlockRV | None: has_uint_input = any("uint" in str(region.buffer.dtype) for region in block_stmt.reads) if not has_uint_input: return False - return not (len(block_stmt.writes) != 1 or - "float" not in str(block_stmt.writes[0].buffer.dtype)) + return not (len(block_stmt.writes) != 1 or "float" not in str(block_stmt.writes[0].buffer.dtype)) dequantize_blocks = [block for block in blocks if is_dequantize(block)] return dequantize_blocks[0] if len(dequantize_blocks) == 1 else None @@ -452,8 +445,7 @@ def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool: return None axes.extend(undefined_vars(r.min)) # remove trivial axis - trivial_vars = set( - iter_var.var for iter_var in block_stmt.iter_vars if _is_one(iter_var.dom.extent)) + trivial_vars = set(iter_var.var for iter_var in block_stmt.iter_vars if _is_one(iter_var.dom.extent)) axes = [axis for axis in axes if axis not in trivial_vars] # remove duplicate axis axes = [var for i, var in enumerate(axes) if i == 0 or var != axes[i - 1]] @@ -462,8 +454,7 @@ def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool: lhs_access_vars = get_access_vars(block_stmt.reads[0].region)[-2:] rhs_access_vars = get_access_vars(block_stmt.writes[0].region)[-2:] is_identity = list(lhs_access_vars) == list(rhs_access_vars) - is_transpose = list(lhs_access_vars) != list(rhs_access_vars) and set(lhs_access_vars) == set( - rhs_access_vars) + is_transpose = list(lhs_access_vars) != list(rhs_access_vars) and set(lhs_access_vars) == set(rhs_access_vars) return is_identity, is_transpose @@ -491,9 +482,7 @@ def inline_transpose_block(sch: tir.Schedule, blocks: list[tir.schedule.BlockRV] return result_blocks -def normalize_to_matmul(sch: tir.Schedule, - main_block: BlockRV, - layout: list[str] | None = None) -> tir.Schedule | None: +def normalize_to_matmul(sch: tir.Schedule, main_block: BlockRV, layout: list[str] | None = None) -> tir.Schedule | None: if layout is None: layout = ["n", "t", "n"] block_stmt = sch.get(main_block) @@ -526,7 +515,7 @@ def get_tensorized_func_and_tags( allow_gemv: bool = False, ) -> tuple[tir.PrimFunc, dict[str, list[int] | int]]: """ - transform function to matmul if necessary (e.g. transform conv2d with im2col) + transform function to matmul if necessary (e.g. transform conv2d with im2col) """ if layout is None: layout = ["a", "a", "a"] @@ -543,10 +532,7 @@ def get_tensorized_func_and_tags( conditions = [] conditions.append(len(block_stmt.reads) == 2) conditions.append(len(block_stmt.writes) == 1) - conditions.append( - len( - collect_block_iter_vars_used_in_access_region(block_stmt, - block_stmt.writes[0].region)) > 0) + conditions.append(len(collect_block_iter_vars_used_in_access_region(block_stmt, block_stmt.writes[0].region)) > 0) return all(conditions) # step2. transform function to tensorcore matmul (e.g. conv2d with im2col) @@ -592,10 +578,7 @@ def get_tensorized_func_and_tags( return axes def is_common_reduce(var: Var) -> bool: - for iter_var in block_stmt.iter_vars: - if iter_var.var == var and iter_var.iter_type == IterVar.CommReduce: - return True - return False + return any(iter_var.var == var and iter_var.iter_type == IterVar.CommReduce for iter_var in block_stmt.iter_vars) def has_common_reduce(var: Var) -> bool: vars = collect_vars_from_expr(var) @@ -626,7 +609,7 @@ def get_tensorized_func_and_tags( # When the func is a dequantize like ops, we should consider the M require_block_reduce = False # And we only support float16 for now - if (hasattr(func.attrs, "dequantize_info") and in_dtype in ["bfloat16", "float16"]): + if hasattr(func.attrs, "dequantize_info") and in_dtype in ["bfloat16", "float16"]: for arg in func.params: inp_shape = func.buffer_map[arg].shape M = inp_shape[0] @@ -645,9 +628,7 @@ def get_tensorized_func_and_tags( if target.kind.name == "cuda" and check_sm_version(target.arch) >= 70: in_dtype, out_dtype = get_in_out_dtypes(block_stmt) if not is_tensorcore_supported_precision(in_dtype, out_dtype, arch=get_arch(target)): - logger.debug( - f"The input and output dtype ({in_dtype}, {out_dtype})is not supported by tensorcore" - ) + logger.debug(f"The input and output dtype ({in_dtype}, {out_dtype})is not supported by tensorcore") return func, None # reindex and transform functions @@ -676,7 +657,7 @@ def get_tensorized_func_and_tags( else: raise ValueError(f"Unknown IterVar type {iter_type}") - if (isinstance(extent, tir.expr.IntImm) and extent.value < minimal_tensorize_threshold): + if isinstance(extent, tir.expr.IntImm) and extent.value < minimal_tensorize_threshold: return func, None tags = analysis_tensorcore_tags(sch, main_block, target) return sch.mod["main"], tags @@ -686,8 +667,10 @@ def get_tensorized_func_and_tags( def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", index_dtype="int32"): from bitblas.tl.mma_layout import ( # pylint: disable=import-outside-toplevel - ldmatrix_32x8_to_shared_16x16_layout, ldmatrix_trans_32x8_to_shared_16x16_layout, - ldmatrix_32x16_to_shared_16x32_layout_a, ldmatrix_32x16_to_shared_16x32_layout_b, + ldmatrix_32x8_to_shared_16x16_layout, + ldmatrix_trans_32x8_to_shared_16x16_layout, + ldmatrix_32x16_to_shared_16x32_layout_a, + ldmatrix_32x16_to_shared_16x32_layout_b, ) assert dtype in [ @@ -727,9 +710,7 @@ def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", inde return ldmatrix_layout(thread_id, local_id) if dtype in ["bfloat16", "float16"]: - ldmatrix_index_map = ( - ldmatrix_trans_permutation_16x16_32x8_16x16 - if trans else ldmatrix_permutation_16x16_32x8_16x16) + ldmatrix_index_map = ldmatrix_trans_permutation_16x16_32x8_16x16 if trans else ldmatrix_permutation_16x16_32x8_16x16 else: ldmatrix_index_map = ldmatrix_permutation_16x32_32x16_32x16 @@ -744,7 +725,6 @@ def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", inde # Ladder weight propagation, which can be used to avoid the ldmatrix # Instructions. def get_ladder_stage3_map(dtype="float16", index_dtype="int32"): - def shared_32x8_to_mma_32x8_layout(i, j): thread_id = (i % 8) * 4 + (j // 2) local_id = (i // 8) * 2 + (j % 2) @@ -837,8 +817,7 @@ def layout_propagate_chain( scaling_factor = 1 for i, j in zip(write.buffer.shape, read.buffer.shape): scaling_factor *= i // j - final_indices = list( - index_map.map_indices(tmp_index_map.map_indices(write_indices))) + final_indices = list(index_map.map_indices(tmp_index_map.map_indices(write_indices))) final_indices[-1] = final_indices[-1] // scaling_factor index_map = IndexMap( write_indices, diff --git a/tilelang/carver/roller/bestfit.py b/tilelang/carver/roller/bestfit.py index b66ceaae7eb0bd9fef216d7a5372fa77a941d45d..ec7817429d8dd2fa79db1a33988086758556f19f 100644 --- a/tilelang/carver/roller/bestfit.py +++ b/tilelang/carver/roller/bestfit.py @@ -2,7 +2,6 @@ class Block: - def __init__(self, start, end, is_free): self.start = start self.end = end @@ -21,7 +20,6 @@ class Block: class BestFit: - def __init__(self, align=32): self.limit = 0 self.list = [] @@ -31,16 +29,14 @@ class BestFit: size = (size + self.align - 1) // self.align * self.align found = None for block in self.list: - if block.is_free and block.size() >= size and (not found or - found.size() > block.size()): + if block.is_free and block.size() >= size and (not found or found.size() > block.size()): found = block if found: found.is_free = False remain = found.size() - size if remain != 0: found.end -= remain - self.list.insert( - self.list.index(found) + 1, Block(found.end, found.end + remain, True)) + self.list.insert(self.list.index(found) + 1, Block(found.end, found.end + remain, True)) return found elif len(self.list) > 0 and self.list[-1].is_free: add = size - self.list[-1].size() diff --git a/tilelang/carver/roller/hint.py b/tilelang/carver/roller/hint.py index 17c69daef8b3fa5f240a9b88a9ff7c3f0911c12a..8fd1fb40652f855fe22ea1271da6af550549200e 100644 --- a/tilelang/carver/roller/hint.py +++ b/tilelang/carver/roller/hint.py @@ -1,4 +1,5 @@ """Hint definition for schedule""" + from tvm import DataType from . import PrimFuncNode import numpy as np @@ -60,7 +61,7 @@ class Stride: strided_elem = original_shape else: assert self.ax < len(shape) - strided_elem = np.prod(shape[0:self.ax + 1]) * self.stride + strided_elem = np.prod(shape[0 : self.ax + 1]) * self.stride assert strided_elem >= original_shape return int(strided_elem) @@ -217,7 +218,7 @@ class Hint: return dic @classmethod - def from_dict(cls, dic: dict) -> 'Hint': + def from_dict(cls, dic: dict) -> "Hint": hint = cls() for k, v in dic.items(): setattr(hint, k, v) diff --git a/tilelang/carver/roller/node.py b/tilelang/carver/roller/node.py index f9e38b168a80b7cbbfdb099ee3e0d70a159589f7..3122c7b078ad982ffbc8bc99d62580ac6617474d 100644 --- a/tilelang/carver/roller/node.py +++ b/tilelang/carver/roller/node.py @@ -1,4 +1,5 @@ """PrimFunc Wrapper and Block information Analaysis""" + from __future__ import annotations import tvm @@ -31,7 +32,6 @@ def pre_order_traverse(block_analyzer, blocks, func): class BlockAnalyzer: - def __init__(self, sch) -> None: self.sch: tir.Schedule = sch self.block_infos: list[BlockInfo] = normalize_prim_func(self.sch) @@ -92,7 +92,6 @@ class Edge: class Node: - def __init__(self, tags: dict | None = None, name: str = "Node") -> None: self.name = name if tags is None: @@ -177,7 +176,6 @@ class Node: class PlaceHolderNode(Node): - def __init__(self, name=""): super().__init__(name="PlaceHolder_" + name) @@ -189,11 +187,7 @@ class PlaceHolderNode(Node): class PrimFuncNode(Node): - - def __init__(self, - prim_func: PrimFunc, - tags: dict | None = None, - name: str = "PrimFuncNode") -> None: + def __init__(self, prim_func: PrimFunc, tags: dict | None = None, name: str = "PrimFuncNode") -> None: super().__init__(tags, name=name) self.prim_func = self._specialize_func(prim_func) self.sch: tir.Schedule = tir.Schedule(self.prim_func) @@ -227,7 +221,7 @@ class PrimFuncNode(Node): for dst_id, n in enumerate(inputs): if isinstance(n, Node): n = (n, 0) - assert (len(n) == 2) + assert len(n) == 2 src_node, src_id = n[0], n[1] edge = Edge(src_node, self, src_id, dst_id) self._in_edges.append(edge) @@ -338,9 +332,8 @@ class PrimFuncNode(Node): if rstep is None: rstep = {} shape = { - self.block_analyzer.get_output_buffers(block)[0].name: [ - tvm.arith.ConstIntBound(0, val - 1) for val in tile - ] for block in self.schedule_stages + self.block_analyzer.get_output_buffers(block)[0].name: [tvm.arith.ConstIntBound(0, val - 1) for val in tile] + for block in self.schedule_stages } return self.ana.infer(shape, rstep, targets) @@ -356,10 +349,7 @@ class PrimFuncNode(Node): results.append(shapes[arg.name]) continue # should not exceed original shape - trimmed_shape = [ - self.extent_wrapper(i) - for i in list(map(min, zip(shapes[arg.name], self.input_buffers[i].shape))) - ] + trimmed_shape = [self.extent_wrapper(i) for i in list(map(min, zip(shapes[arg.name], self.input_buffers[i].shape)))] results.append(trimmed_shape) return results @@ -380,10 +370,8 @@ class PrimFuncNode(Node): propagate_shape = shapes[arg.name] buffer_shape = args[i].shape if len(buffer_shape) > len(propagate_shape): - buffer_shape = buffer_shape[-len(propagate_shape):] - trimmed_shape = [ - self.extent_wrapper(j) for j in list(map(min, zip(propagate_shape, buffer_shape))) - ] + buffer_shape = buffer_shape[-len(propagate_shape) :] + trimmed_shape = [self.extent_wrapper(j) for j in list(map(min, zip(propagate_shape, buffer_shape)))] results.append(trimmed_shape) return results @@ -412,10 +400,7 @@ class PrimFuncNode(Node): def get_reduce_inputs_dtype(self): if self.reduction_block is None: return {} - return { - b.name: tvm.DataType(b.dtype) - for b in self.block_analyzer.get_input_buffers(self.reduction_block) - } + return {b.name: tvm.DataType(b.dtype) for b in self.block_analyzer.get_input_buffers(self.reduction_block)} @functools.lru_cache def infer_tensorcore_axis(self) -> tuple[int]: @@ -425,8 +410,7 @@ class PrimFuncNode(Node): C_ax_m, C_ax_n = self.get_tag("tensorcore_config") wmma_m, wmma_n, wmma_k = [16, 16, 16] # just for testing, any number is ok - output_buffer_shape = ( - self.block_analyzer.sch.get(self.reduction_block).writes[0].buffer.shape) + output_buffer_shape = self.block_analyzer.sch.get(self.reduction_block).writes[0].buffer.shape valid_region = [] for region in output_buffer_shape: if region.value == 1: @@ -438,8 +422,7 @@ class PrimFuncNode(Node): def get_cl_shapes(c_ax_m, c_ax_n, num_nvalid_regions): spatial_dim = self.get_space_dim() - assert len(valid_region) == len( - spatial_dim), f" {valid_region} mismatch with {spatial_dim}" + assert len(valid_region) == len(spatial_dim), f" {valid_region} mismatch with {spatial_dim}" cl_shapes = [1] * len(spatial_dim) cl_shapes[c_ax_m - num_nvalid_regions] = wmma_m cl_shapes[c_ax_n - num_nvalid_regions] = wmma_n @@ -467,9 +450,11 @@ class PrimFuncNode(Node): shapes, _ = self.propagate(shape, rstep) def is_broadcast_pattern(buffer, output_buffer): - return (buffer in self.args and - len(shapes[output_buffer.name]) > len(shapes[buffer.name]) and - np.prod(shapes[output_buffer.name]) > np.prod(shapes[buffer.name])) + return ( + buffer in self.args + and len(shapes[output_buffer.name]) > len(shapes[buffer.name]) + and np.prod(shapes[output_buffer.name]) > np.prod(shapes[buffer.name]) + ) def is_after_reduce_stage(block): if not self.reduction_block: @@ -491,8 +476,8 @@ class PrimFuncNode(Node): output_buffer = self.block_analyzer.get_output_buffers(block)[0] for buffer in self.block_analyzer.get_input_buffers(block): cache = buffer.name not in cached_tensor and ( - is_broadcast_pattern(buffer, output_buffer) or - self.block_analyzer.get_block_info(block).is_reduction()) + is_broadcast_pattern(buffer, output_buffer) or self.block_analyzer.get_block_info(block).is_reduction() + ) if not cache: continue cached_tensor.append(buffer.name) @@ -500,8 +485,7 @@ class PrimFuncNode(Node): continue # cache after reduce op can often reuse buffer in reduce stage if buffer.name in stride_map: - num_elem = stride_map[buffer.name].compute_elements_from_shape( - shapes[buffer.name]) + num_elem = stride_map[buffer.name].compute_elements_from_shape(shapes[buffer.name]) else: num_elem = np.prod(shapes[buffer.name]) buffer_len = num_elem * int((tvm.DataType(buffer.dtype).bits + 7) // 8) @@ -514,7 +498,6 @@ class PrimFuncNode(Node): class OutputNode(Node): - def __init__(self, node, id=0): super().__init__(name="OutputNode") # connect node and output node @@ -549,15 +532,16 @@ def topo_order(list_of_nodes) -> list[Node]: input_ready_count[dst_node] = len(dst_node.inputs) list_of_nodes.append(dst_node) input_ready_count[dst_node] -= 1 - assert (input_ready_count[dst_node] >= 0) + assert input_ready_count[dst_node] >= 0 if input_ready_count[dst_node] == 0: ready.append(dst_node) - assert (len(list_of_nodes) == len(output_list)) + assert len(list_of_nodes) == len(output_list) return output_list def find_topo_sort_priority(output_node_list) -> list[Node]: import sys + sys.setrecursionlimit(10000) def topo_sort_get_layer(node, topo_layer): @@ -576,9 +560,7 @@ def find_topo_sort_priority(output_node_list) -> list[Node]: if node in visited: return visited.add(node) - ordered_input_nodes = sorted([edge.src_node for edge in node.inputs], - key=lambda n: topo_layer[n], - reverse=True) + ordered_input_nodes = sorted([edge.src_node for edge in node.inputs], key=lambda n: topo_layer[n], reverse=True) for n in ordered_input_nodes: topo_sort_dfs(n, visited, topo_order) topo_order.append(node) @@ -591,7 +573,6 @@ def find_topo_sort_priority(output_node_list) -> list[Node]: def find_topo_sort(output_node_list) -> list[Node]: - def topo_sort_dfs(node, visited, topo_order): if node in visited: return diff --git a/tilelang/carver/roller/policy/default.py b/tilelang/carver/roller/policy/default.py index 161df27a7547a505faa35201dccb509f09d8c457..d09216e1ceb4ce7eb02a4643d18fac4990a01a34 100644 --- a/tilelang/carver/roller/policy/default.py +++ b/tilelang/carver/roller/policy/default.py @@ -1,4 +1,5 @@ """Policy for cuda core schedule""" + from __future__ import annotations import functools import math @@ -36,20 +37,14 @@ class DefaultPolicy: self.rasterization = NoRasterization() @classmethod - def from_prim_func(cls, - func: tvm.tir.PrimFunc, - arch: TileDevice, - tags: dict | None = None, - name: str = "PrimFuncNode"): + def from_prim_func(cls, func: tvm.tir.PrimFunc, arch: TileDevice, tags: dict | None = None, name: str = "PrimFuncNode"): return cls(arch, tags)._init_with_prim_func(func, name) @classmethod def from_output_nodes(cls, nodes: list[OutputNode], arch: TileDevice, tags: dict | None = None): return cls(arch, tags)._init_with_output_nodes(nodes) - def _init_with_prim_func(self, - func: tvm.tir.PrimFunc, - name: str = "PrimFuncNode") -> DefaultPolicy: + def _init_with_prim_func(self, func: tvm.tir.PrimFunc, name: str = "PrimFuncNode") -> DefaultPolicy: if func is not None and isinstance(func, tvm.tir.PrimFunc): self.func = func self.prim_func_node = PrimFuncNode(self.func, tags=self.tags, name=name) @@ -60,9 +55,7 @@ class DefaultPolicy: return self def _init_with_output_nodes(self, output_nodes: list[OutputNode]): - self.ordered_nodes = list( - filter(lambda n: not n.is_placeholder() and not n.is_output(), - find_topo_sort(output_nodes))) + self.ordered_nodes = list(filter(lambda n: not n.is_placeholder() and not n.is_output(), find_topo_sort(output_nodes))) for node in self.ordered_nodes: node.update_tags(self.tags) @@ -102,13 +95,14 @@ class DefaultPolicy: def dfs_smem_tile(self, init_tile, rstep_map) -> Iterable[TileDict]: _steps = [get_all_factors(n) for n in self.output_nodes[0].get_space_dim()] - steps = [step[step.index(t):] for step, t in zip(_steps, init_tile)] + steps = [step[step.index(t) :] for step, t in zip(_steps, init_tile)] for i in range(len(steps)): added = list( filter( lambda s: s < steps[i][-1] and s > steps[i][0] and s not in steps[i], [2, 4, 8, 16, 32], - )) + ) + ) steps[i].extend(added) steps[i] = sorted(steps[i]) visited_tiles = {} @@ -190,10 +184,7 @@ class DefaultPolicy: """ tile_map = {} for node in self.output_nodes: - tile_map[node] = [ - tile[i] * node.get_space_dim()[i] // self.output_nodes[0].get_space_dim()[i] - for i in range(len(tile)) - ] + tile_map[node] = [tile[i] * node.get_space_dim()[i] // self.output_nodes[0].get_space_dim()[i] for i in range(len(tile))] return tile_map def compute_workload_per_item(self, output_tile) -> float: @@ -304,8 +295,7 @@ class DefaultPolicy: score = 0 shape = node.propagate_inputs(tile, rstep=rstep) for i, input_buffer in enumerate(node.input_buffers): - read_transaction_elements = self.arch.transaction_size[1] // ( - (node.get_buffer_dtype(input_buffer).bits + 7) // 8) + read_transaction_elements = self.arch.transaction_size[1] // ((node.get_buffer_dtype(input_buffer).bits + 7) // 8) score += sim( int(coalesced_factor(shape[i], input_buffer.shape)), read_transaction_elements, @@ -380,17 +370,13 @@ class DefaultPolicy: return None return max(candidates, key=lambda x: x[1])[0] - cur_rstep_id = { - k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis - } + cur_rstep_id = {k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis} new_rstep_map = rstep_map.copy() while True: new_rstep_id = _enlarge(cur_rstep_id) if new_rstep_id is None: break - new_rstep_map[node] = { - k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis - } + new_rstep_map[node] = {k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis} old_rstep_map = td.rstep_map td.rstep_map = new_rstep_map smem_usage, _ = self._compute_shared_memory_usage(td) @@ -434,15 +420,14 @@ class DefaultPolicy: if edge.src_node.is_placeholder(): nbytes = (edge.src_node.get_dtype().bits + 7) // 8 read_transaction_elements = self.arch.transaction_size[1] // nbytes - traffic += coalesced_tensor_shape(input_shapes[i], edge.src_node.get_shape(), - read_transaction_elements) * nbytes + traffic += coalesced_tensor_shape(input_shapes[i], edge.src_node.get_shape(), read_transaction_elements) * nbytes for edge in node.outputs: if edge.dst_node.is_output(): nbytes = (edge.src_node.get_dtype().bits + 7) // 8 write_transaction_elements = self.arch.transaction_size[0] // nbytes - traffic += coalesced_tensor_shape(output_shapes[edge.src_id], - node.get_shape(edge.src_id), - write_transaction_elements) * nbytes + traffic += ( + coalesced_tensor_shape(output_shapes[edge.src_id], node.get_shape(edge.src_id), write_transaction_elements) * nbytes + ) return traffic, op_tile_map @@ -487,10 +472,7 @@ class DefaultPolicy: cached_tensors_map = {} def can_free(node, out_id): - for edge in node.outputs: - if edge.src_id == out_id and edge.dst_node not in processed: - return False - return True + return all(not (edge.src_id == out_id and edge.dst_node not in processed) for edge in node.outputs) for node in self.ordered_nodes: node_internal_bytes, cached_tensors_map[node] = self.infer_node_smem_usage(td, node) @@ -528,9 +510,7 @@ class DefaultPolicy: Tuple[Dict, Dict] A tuple of dictionaries containing the output strides and tensor strides. """ - output_strides = { - int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers) - } + output_strides = {int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers)} tensor_strides = {} return output_strides, tensor_strides @@ -551,8 +531,7 @@ class DefaultPolicy: output_strides_map = {} tensor_strides_map = {} for node in self.ordered_nodes: - output_strides_map[node], tensor_strides_map[node] = self.compute_node_stride_map( - node, td) + output_strides_map[node], tensor_strides_map[node] = self.compute_node_stride_map(node, td) td.output_strides_map, td.tensor_strides_map = output_strides_map, tensor_strides_map def compute_tile_dict(self, output_tile: list[int], rstep_map) -> TileDict: @@ -582,9 +561,7 @@ class DefaultPolicy: output_shape = self.output_nodes[0].get_space_dim() td.grid_size = int(np.prod([(y + x - 1) // x for x, y in zip(output_tile, output_shape)])) # estimated reg usage - reg_usage = int(2 * max([ - np.prod(td.get_tile(node)) * node.get_dtype().bits / 32 for node in self.ordered_nodes - ])) + reg_usage = int(2 * max([np.prod(td.get_tile(node)) * node.get_dtype().bits / 32 for node in self.ordered_nodes])) if reg_usage > self.arch.reg_cap: td.valid = False return td @@ -609,13 +586,10 @@ class DefaultPolicy: for node in self.ordered_nodes: if np.prod(td.get_tile(node)) == 0: return False - node_grid_size = np.prod([ - (y + x - 1) // x for x, y in zip(td.get_tile(node), node.get_space_dim()) - ]) + node_grid_size = np.prod([(y + x - 1) // x for x, y in zip(td.get_tile(node), node.get_space_dim())]) if node_grid_size != td.grid_size: return False - if (hasattr(node, "reduce_op") and node.reduce_op is not None and - len(node.reduce_op.axis) == len(td.output_tile)): + if hasattr(node, "reduce_op") and node.reduce_op is not None and len(node.reduce_op.axis) == len(td.output_tile): for i, tile_extent in enumerate(td.output_tile): if node.reduce_op.axis[i].dom.extent % tile_extent: return False @@ -639,23 +613,22 @@ class DefaultPolicy: node_space_sizes = [int(np.prod(td.get_tile(node))) for node in self.ordered_nodes] max_block_size = functools.reduce(math.gcd, node_space_sizes) - if max_block_size < self.arch.warp_size * self.arch.sm_partition and max_block_size == min( - node_space_sizes): - node_reduce_sizes = [ - int(np.prod(list(td.get_rstep(node).values()))) for node in self.ordered_nodes - ] + if max_block_size < self.arch.warp_size * self.arch.sm_partition and max_block_size == min(node_space_sizes): + node_reduce_sizes = [int(np.prod(list(td.get_rstep(node).values()))) for node in self.ordered_nodes] total_sizes = [x * y for x, y in zip(node_space_sizes, node_reduce_sizes)] max_possible_size = functools.reduce(math.gcd, total_sizes) possible_block_sizes = list( filter( lambda x: x % max_block_size == 0 and x <= 1024, get_all_factors(max_possible_size), - )) + ) + ) possible_block_sizes = list( filter( # either be a factor of space or cover fully cover the space lambda x: all([x % s == 0 or s % x == 0 for s in node_space_sizes]), possible_block_sizes, - )) + ) + ) factor_ordered = sorted(possible_block_sizes, key=self.score_block_size) return factor_ordered else: @@ -821,8 +794,7 @@ class DefaultPolicy: vectorize_result = {} for tensor, shape in shapes.items(): for v in vectorize_sizes: - if (is_shape_aligned(shape, block_size * v) and is_cont(shape, v) and - is_type_allowed(dtypes[tensor], v)): + if is_shape_aligned(shape, block_size * v) and is_cont(shape, v) and is_type_allowed(dtypes[tensor], v): vectorize_result[tensor] = v break return vectorize_result diff --git a/tilelang/carver/roller/policy/tensorcore.py b/tilelang/carver/roller/policy/tensorcore.py index 15bad4122f049265594fb6799a9777bf497543e6..86c79ea732a374847fa5ca79e88c4508596a282e 100644 --- a/tilelang/carver/roller/policy/tensorcore.py +++ b/tilelang/carver/roller/policy/tensorcore.py @@ -1,4 +1,5 @@ """Policy for tensorcore schedule""" + from __future__ import annotations import tvm import numpy as np @@ -13,7 +14,6 @@ logger = logging.getLogger(__name__) class TensorCorePolicy(DefaultPolicy): - # this is the trick for wmma. # However, for int8 mma, the wmma_k should be 32. wmma_k: int = 16 @@ -70,9 +70,9 @@ class TensorCorePolicy(DefaultPolicy): A_high_ax = min(A_ax_m, A_ax_k) B_high_ax = min(B_ax_n, B_ax_k) C_high_ax = min(C_ax_m, C_ax_n) - A_stride = Stride(stride=np.prod(AS_shape[A_high_ax + 1:]) + offset, ax=A_high_ax) - B_stride = Stride(stride=np.prod(BS_shape[B_high_ax + 1:]) + offset, ax=B_high_ax) - C_stride = Stride(stride=np.prod(CS_shape[C_high_ax + 1:]) + offset, ax=C_high_ax) + A_stride = Stride(stride=np.prod(AS_shape[A_high_ax + 1 :]) + offset, ax=A_high_ax) + B_stride = Stride(stride=np.prod(BS_shape[B_high_ax + 1 :]) + offset, ax=B_high_ax) + C_stride = Stride(stride=np.prod(CS_shape[C_high_ax + 1 :]) + offset, ax=C_high_ax) return A_stride, B_stride, C_stride def infer_node_smem_usage(self, td: TileDict, node: PrimFuncNode): @@ -86,8 +86,7 @@ class TensorCorePolicy(DefaultPolicy): # get reduce input size target_transaction = self.arch.transaction_size[0] * 2 # 512 bytes // type bits - reduce_input_dtype = node.get_buffer_dtype( - node.block_analyzer.get_input_buffers(node.reduction_block)[0]) + reduce_input_dtype = node.get_buffer_dtype(node.block_analyzer.get_input_buffers(node.reduction_block)[0]) basic = (target_transaction * 8) // reduce_input_dtype.bits result = {} @@ -95,7 +94,7 @@ class TensorCorePolicy(DefaultPolicy): iter_name = iter_info.var.name iter_dom = iter_info.dom.extent if iter_dom % 16 > 0: - result[iter_name] = (16 if iter_dom < basic else basic) # for the case of padding + result[iter_name] = 16 if iter_dom < basic else basic # for the case of padding elif iter_dom % basic == 0: result[iter_name] = basic else: @@ -114,7 +113,6 @@ class TensorCorePolicy(DefaultPolicy): return False if _check_small_tile(td): - smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap) rstep_map = td.rstep_map.copy() @@ -127,13 +125,10 @@ class TensorCorePolicy(DefaultPolicy): return rstep def _shared_memory_usage(td: TileDict): - return node.footprint(td.output_tile, new_rstep_map, - td.tensor_strides_map[node]) + return node.footprint(td.output_tile, new_rstep_map, td.tensor_strides_map[node]) def _score(rstep_id): - rstep = { - k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis - } + rstep = {k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis} score = 0 shape = node.propagate_inputs_on_reduction(td.get_tile(node), rstep=rstep) input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block) @@ -153,18 +148,13 @@ class TensorCorePolicy(DefaultPolicy): return None return max(candidates, key=lambda x: x[1])[0] - cur_rstep_id = { - k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis - } + cur_rstep_id = {k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis} new_rstep_map = rstep_map.copy() while True: new_rstep_id = _enlarge(cur_rstep_id) if new_rstep_id is None: break - new_rstep_map = { - k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] - for k in node.raxis - } + new_rstep_map = {k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis} old_rstep_map = td.rstep_map td.rstep_map = new_rstep_map smem_usage, _ = _shared_memory_usage(td) @@ -173,9 +163,7 @@ class TensorCorePolicy(DefaultPolicy): break else: cur_rstep_id = new_rstep_id - rstep = { - k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis - } + rstep = {k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis} return rstep for node in self.ordered_nodes: @@ -206,11 +194,7 @@ class TensorCorePolicy(DefaultPolicy): return super().get_node_reduce_step_candidates(node) else: # must be a a multiple of wmma_k - return { - k.var.name: [ - x * self.wmma_k for x in get_all_factors(int(k.dom.extent) // self.wmma_k) - ] for k in node.raxis - } + return {k.var.name: [x * self.wmma_k for x in get_all_factors(int(k.dom.extent) // self.wmma_k)] for k in node.raxis} def check_tile_shape_isvalid(self, td: TileDict): for node in self.ordered_nodes: @@ -221,10 +205,7 @@ class TensorCorePolicy(DefaultPolicy): td.tile_map[node][ax_n], ) # check the tile size is valid - wmma_invalid = [ - block_m < wmma_m or block_n < wmma_n - for wmma_m, wmma_n in self.arch.get_avaliable_tensorintrin_shapes() - ] + wmma_invalid = [block_m < wmma_m or block_n < wmma_n for wmma_m, wmma_n in self.arch.get_avaliable_tensorintrin_shapes()] if all(wmma_invalid): return False if any([y % x for x, y in zip(td.tile_map[node], node.get_space_dim())]): @@ -242,13 +223,10 @@ class TensorCorePolicy(DefaultPolicy): return super().compute_node_stride_map(node, td) use_layout = self._can_implement_layout(node, td) - AS_stride, BS_stride, C_stride = self._compute_tc_strides(node, td.get_tile(node), - td.get_rstep(node)) + AS_stride, BS_stride, C_stride = self._compute_tc_strides(node, td.get_tile(node), td.get_rstep(node)) A_stride, B_stride, _ = self._compute_tc_strides(node, td.get_tile(node)) tensor_strides = {} - output_strides = { - int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers) - } + output_strides = {int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers)} tensor_strides = {} # when connected to shared input, should use full stride without rstep for i, (_, _) in enumerate(zip([AS_stride, BS_stride], [A_stride, B_stride])): @@ -347,8 +325,7 @@ class TensorCorePolicy(DefaultPolicy): overall_gmem_size_in_bytes: int = 0 for node in self.ordered_nodes: for buffer in node.input_buffers: - overall_gmem_size_in_bytes += ( - int(np.prod(buffer.shape)) * tvm.DataType(buffer.dtype).bits // 8) + overall_gmem_size_in_bytes += int(np.prod(buffer.shape)) * tvm.DataType(buffer.dtype).bits // 8 return overall_gmem_size_in_bytes < self.arch.l2_cache_size_bytes conditions.append(_check_memory_size()) diff --git a/tilelang/carver/roller/rasterization.py b/tilelang/carver/roller/rasterization.py index ebd1319af6488a052d9c3d454584ac24b57e55ae..ec565a1c7c29cbc1a7c11191fbb8d248a5368d32 100644 --- a/tilelang/carver/roller/rasterization.py +++ b/tilelang/carver/roller/rasterization.py @@ -2,7 +2,6 @@ class Rasterization: - panel_width_ = None def __init__(self) -> None: @@ -18,7 +17,6 @@ class Rasterization: class NoRasterization(Rasterization): - def __init__(self) -> None: super().__init__() diff --git a/tilelang/carver/roller/shape_inference/common.py b/tilelang/carver/roller/shape_inference/common.py index c52a170e0137a1027aff8c3808ddead85441666f..c29ae4129831ab6a4255dc054c5b6b36f9bac50c 100644 --- a/tilelang/carver/roller/shape_inference/common.py +++ b/tilelang/carver/roller/shape_inference/common.py @@ -4,9 +4,7 @@ from tvm import arith class Statement: - - def __init__(self, output: str, dependent_region: dict, var_map: OrderedDict, - range_map: OrderedDict): + def __init__(self, output: str, dependent_region: dict, var_map: OrderedDict, range_map: OrderedDict): self.output = output self.dependent_region = dependent_region self.var_map = var_map @@ -18,7 +16,6 @@ def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound): class InputShapeInference: - def __init__(self, deps: list[Statement]): self.deps = deps diff --git a/tilelang/carver/roller/shape_inference/tir.py b/tilelang/carver/roller/shape_inference/tir.py index 618cf9b304ca768b066de366440929793c8ec244..d7b11d6086720f6e24d3dbcca981b455593a6221 100644 --- a/tilelang/carver/roller/shape_inference/tir.py +++ b/tilelang/carver/roller/shape_inference/tir.py @@ -5,7 +5,6 @@ from tvm import arith, tir class Statement: - def __init__(self, block_analyzer, block: BlockRV): self.block_analyzer = block_analyzer self.block = block @@ -21,9 +20,7 @@ class Statement: if len(self.dependent_region[input_name]) != 1: return None indices = self.dependent_region[input_name][0] - iter_map_range = { - _iter.var: _iter.dom for _iter in self.block_analyzer.get_spatial_axis(self.block) - } + iter_map_range = {_iter.var: _iter.dom for _iter in self.block_analyzer.get_spatial_axis(self.block)} iter_map_result = arith.detect_iter_map( indices, iter_map_range, @@ -77,7 +74,6 @@ class TensorDepNode: class DependencyAnalysis: - def __init__(self, deps): self.deps = deps # issue: duplicate name when we have two same ops. @@ -112,8 +108,7 @@ class DependencyAnalysis: def traverse_dependencies(self, compute): if isinstance(compute, Statement): - node = self.get_or_create_node( - compute.block_analyzer.get_output_buffers(compute.block)[0].name) + node = self.get_or_create_node(compute.block_analyzer.get_output_buffers(compute.block)[0].name) # Loop through input tensors for input_buffer in compute.block_analyzer.get_input_buffers(compute.block): # Get the input node @@ -167,7 +162,6 @@ class DependencyAnalysis: class InputShapeInference: - def __init__(self, deps: list[Statement]): self.deps = deps self.target_mapping = {} @@ -183,16 +177,11 @@ class InputShapeInference: if targets in self.target_mapping: return self.target_mapping[targets] # should be buffer name instead of block name - name2dep = { - dep.block_analyzer.get_output_buffers(dep.block)[0].name: dep for dep in self.deps - } + name2dep = {dep.block_analyzer.get_output_buffers(dep.block)[0].name: dep for dep in self.deps} mapping = {} input_vars = [] for target in targets: - vars = [ - iter.var - for iter in name2dep[target].block_analyzer.get_spatial_axis(name2dep[target].block) - ] + vars = [iter.var for iter in name2dep[target].block_analyzer.get_spatial_axis(name2dep[target].block)] input_vars.append(vars) mapping[target] = [vars] ana = arith.Analyzer() @@ -221,13 +210,8 @@ class InputShapeInference: mapping[input_name] = [] for indices in indices_list: for region in regions: - vmap = { - k: (tir.Cast(k.dtype, v) if v.dtype != k.dtype else v) - for k, v in zip(ax_vars, indices) - } - region = [ - ana.simplify(tir.stmt_functor.substitute(ax, vmap)) for ax in region - ] + vmap = {k: (tir.Cast(k.dtype, v) if v.dtype != k.dtype else v) for k, v in zip(ax_vars, indices)} + region = [ana.simplify(tir.stmt_functor.substitute(ax, vmap)) for ax in region] if not region_exist_in_list(region, mapping[input_name]): mapping[input_name].append(region) buffers = [] @@ -241,10 +225,7 @@ class InputShapeInference: self.target_mapping[targets] = input_vars, mapping return input_vars, mapping - def infer(self, - shape: dict[str, list[arith.ConstIntBound]], - rstep: dict[str, int] = None, - targets=None): + def infer(self, shape: dict[str, list[arith.ConstIntBound]], rstep: dict[str, int] = None, targets=None): if rstep is None: rstep = {} compute_targets = tuple(shape.keys()) @@ -258,8 +239,7 @@ class InputShapeInference: for ax in self.reduce_axes: # assume the dom.min is always 0, maybe we can extend the IterInfo to include the min value. if ax.var.name in rstep: - bound = arith.ConstIntBound( - int(ax.dom.min), int(ax.dom.min + min(ax.dom.extent, rstep[ax.var.name]) - 1)) + bound = arith.ConstIntBound(int(ax.dom.min), int(ax.dom.min + min(ax.dom.extent, rstep[ax.var.name]) - 1)) else: bound = arith.ConstIntBound(int(ax.dom.min), int(ax.dom.min + ax.dom.extent - 1)) ana.update(ax.var, bound, True) @@ -312,14 +292,11 @@ class InputShapeInference: for name, regions in mapping.items(): region = regions[0] - result[name] = [ - ana.simplify(tir.stmt_functor.substitute(index, vmap)) for index in region - ] + result[name] = [ana.simplify(tir.stmt_functor.substitute(index, vmap)) for index in region] return result def region_exist_in_list(a, list) -> bool: - def expr_is_same(a, b) -> bool: if isinstance(a, tir.IntImm) and isinstance(b, tir.IntImm): return a.value == b.value diff --git a/tilelang/carver/template/base.py b/tilelang/carver/template/base.py index a119c16a71c8101504c3eb4e95aef3a89595e55a..4a699fbc7dc8ec1a649d3b37e9f5bad24586e006 100644 --- a/tilelang/carver/template/base.py +++ b/tilelang/carver/template/base.py @@ -2,7 +2,12 @@ from abc import ABC, abstractmethod # For defining abstract base classes from dataclasses import dataclass, field # For defining data classes from ..arch import ( # Import architecture-related utilities and classes - TileDevice, is_volta_arch, is_ampere_arch, is_cdna_arch, auto_infer_current_arch) + TileDevice, + is_volta_arch, + is_ampere_arch, + is_cdna_arch, + auto_infer_current_arch, +) from ..roller.hint import Hint # Import the Hint class from ..roller.node import OutputNode # Import the OutputNode class from tvm.tir import PrimFunc # Import PrimFunc for handling tensor IR functions @@ -41,7 +46,7 @@ class BaseTemplate(ABC): """ pass - def with_arch(self, arch: TileDevice) -> 'BaseTemplate': + def with_arch(self, arch: TileDevice) -> "BaseTemplate": """ Sets the architecture for this template and returns itself. @@ -109,7 +114,7 @@ class BaseTemplate(ABC): """ raise NotImplementedError("initialize_function is not implemented") - def set_function(self, func: PrimFunc) -> 'BaseTemplate': + def set_function(self, func: PrimFunc) -> "BaseTemplate": """ Sets the function for this template and returns itself. @@ -122,7 +127,7 @@ class BaseTemplate(ABC): self._func = func return self - def set_output_nodes(self, output_nodes: list[OutputNode]) -> 'BaseTemplate': + def set_output_nodes(self, output_nodes: list[OutputNode]) -> "BaseTemplate": """ Sets the output nodes for this template and returns itself. diff --git a/tilelang/carver/template/conv.py b/tilelang/carver/template/conv.py index 9ea89202d658ade3c1dfdc059e5f80f5c809eb0a..c339e589488b4db57637804647c188275f0bc77b 100644 --- a/tilelang/carver/template/conv.py +++ b/tilelang/carver/template/conv.py @@ -28,6 +28,7 @@ class ConvTemplate(BaseTemplate): accum_dtype (str): Data type used for accumulation. with_bias (bool): Whether to add a bias term. """ + # Operation-related configuration parameters N: int # The number of input samples processed simultaneously in a batch. C: int # The number of input feature maps. @@ -69,12 +70,18 @@ class ConvTemplate(BaseTemplate): AssertionError: If N, C, H, W, F, K, S, D, P are not positive integers. """ N, C, H, W, F, K, S, D, P = self.N, self.C, self.H, self.W, self.F, self.K, self.S, self.D, self.P - assert (isinstance(N, int) and isinstance(C, int) and isinstance(H, int) and - isinstance(W, int) and isinstance(F, int) and isinstance(K, int) and - isinstance(S, int) and isinstance(D, int) and - isinstance(P, int)), "Only Support Integer Params" - assert (N > 0 and C > 0 and H > 0 and W > 0 and F > 0 and K > 0 and S > 0 and D > 0 and - P > 0), "Params should be positive" + assert ( + isinstance(N, int) + and isinstance(C, int) + and isinstance(H, int) + and isinstance(W, int) + and isinstance(F, int) + and isinstance(K, int) + and isinstance(S, int) + and isinstance(D, int) + and isinstance(P, int) + ), "Only Support Integer Params" + assert N > 0 and C > 0 and H > 0 and W > 0 and F > 0 and K > 0 and S > 0 and D > 0 and P > 0, "Params should be positive" # Load configuration parameters in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype @@ -123,8 +130,10 @@ class ConvTemplate(BaseTemplate): te.if_then_else( te.all(h_in >= 0, h_in < H, w_in >= 0, w_in < W), A[n, h_in, w_in, c].astype(accum_dtype) * B[kh, kw, c, f].astype(accum_dtype), - tir.const(0, accum_dtype)), - axis=[kh, kw, c]) + tir.const(0, accum_dtype), + ), + axis=[kh, kw, c], + ) # Compute convolution result C = te.compute( diff --git a/tilelang/carver/template/flashattention.py b/tilelang/carver/template/flashattention.py index ae1a25402178e725f712ef091d2a1a01c5139ca4..933ab9585405ee7dd790b462775219fb0ccae8b1 100644 --- a/tilelang/carver/template/flashattention.py +++ b/tilelang/carver/template/flashattention.py @@ -9,7 +9,6 @@ from ..utils import get_roller_hints_from_output_nodes, get_tensorized_func_and_ @dataclass class FlashAttentionTemplate(BaseTemplate): - _output_nodes: list[OutputNode] = None # Operation-related configuration parameters @@ -91,10 +90,7 @@ class FlashAttentionTemplate(BaseTemplate): """ A_indices = [b, i, k] B_indices = [b, j, k] - return te.sum( - A[tuple(A_indices)].astype(accum_dtype) * - B[tuple(B_indices)].astype(accum_dtype), - axis=k) + return te.sum(A[tuple(A_indices)].astype(accum_dtype) * B[tuple(B_indices)].astype(accum_dtype), axis=k) # Compute matrix multiplication result C = te.compute( diff --git a/tilelang/carver/template/gemv.py b/tilelang/carver/template/gemv.py index cdcc78d089455d2ed701d48493565a8fc764cf18..e7962f6ad76205f19df8df1676705ed6e6037d25 100644 --- a/tilelang/carver/template/gemv.py +++ b/tilelang/carver/template/gemv.py @@ -50,9 +50,8 @@ class GEMVTemplate(BaseTemplate): N, K = self.N, self.K # Ensure M, N, K are valid positive integers - assert (isinstance(M, int) and isinstance(N, int) and - isinstance(K, int)), "Only Support Integer M, N, K" - assert (M > 0 and N > 0 and K > 0), "M, N, K should be positive" + assert isinstance(M, int) and isinstance(N, int) and isinstance(K, int), "Only Support Integer M, N, K" + assert M > 0 and N > 0 and K > 0, "M, N, K should be positive" # Load configuration parameters trans_B = self.trans_B @@ -86,9 +85,7 @@ class GEMVTemplate(BaseTemplate): """ A_indices = [i, k] B_indices = [k, j] if not trans_B else [j, k] - return te.sum( - A[tuple(A_indices)].astype(accum_dtype) * B[tuple(B_indices)].astype(accum_dtype), - axis=k) + return te.sum(A[tuple(A_indices)].astype(accum_dtype) * B[tuple(B_indices)].astype(accum_dtype), axis=k) # Compute matrix multiplication result C = te.compute( diff --git a/tilelang/carver/template/general_reduce.py b/tilelang/carver/template/general_reduce.py index a8da5fd6cebc02071cd9b98c3e10c7801fb7a259..b7a55157c2586afa6c6ddcaf51f80c6c985544f1 100644 --- a/tilelang/carver/template/general_reduce.py +++ b/tilelang/carver/template/general_reduce.py @@ -9,15 +9,13 @@ from ..utils import get_roller_hints_from_func @dataclass class GeneralReductionTemplate(BaseTemplate): - # OP Related Config structure: str | list[str] = None shape: list[int] = None dtype: str = "float16" def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]: - roller_hints = get_roller_hints_from_func( - self._func, arch=arch, topk=topk, allow_gemv=False) + roller_hints = get_roller_hints_from_func(self._func, arch=arch, topk=topk, allow_gemv=False) return roller_hints def initialize_function(self) -> None: @@ -38,9 +36,9 @@ class GeneralReductionTemplate(BaseTemplate): spatial_axes = [] reduce_axes = [] for i, axis_type in enumerate(self.structure): - if axis_type.upper() == 'S': + if axis_type.upper() == "S": spatial_axes.append((i, self.shape[i])) - elif axis_type.upper() == 'R': + elif axis_type.upper() == "R": reduce_axes.append((i, self.shape[i])) else: raise ValueError(f"Unrecognized axis type '{axis_type}', only 'S'/'R' allowed.") @@ -90,7 +88,7 @@ class GeneralReductionTemplate(BaseTemplate): # Walk through the structure in order for axis_type in self.structure: - if axis_type.upper() == 'S': + if axis_type.upper() == "S": # use the next spatial_indices item full_index.append(spatial_indices[spatial_iter]) spatial_iter += 1 diff --git a/tilelang/carver/template/matmul.py b/tilelang/carver/template/matmul.py index 653ddab3ea8bbdcfc5d081e13f71138710744a13..57c92beb75d870cf17c491427241b1b9926ffe3d 100644 --- a/tilelang/carver/template/matmul.py +++ b/tilelang/carver/template/matmul.py @@ -65,9 +65,8 @@ class MatmulTemplate(BaseTemplate): M, N, K = self.M, self.N, self.K # Ensure M, N, K are valid positive integers - assert (isinstance(M, int) and isinstance(N, int) and - isinstance(K, int)), "Only Support Integer M, N, K" - assert (M > 0 and N > 0 and K > 0), "M, N, K should be positive" + assert isinstance(M, int) and isinstance(N, int) and isinstance(K, int), "Only Support Integer M, N, K" + assert M > 0 and N > 0 and K > 0, "M, N, K should be positive" # Load configuration parameters trans_A, trans_B = self.trans_A, self.trans_B @@ -101,9 +100,7 @@ class MatmulTemplate(BaseTemplate): """ A_indices = [i, k] if not trans_A else [k, i] # Adjust indexing if A is transposed B_indices = [k, j] if not trans_B else [j, k] # Adjust indexing if B is transposed - return te.sum( - A[tuple(A_indices)].astype(accum_dtype) * B[tuple(B_indices)].astype(accum_dtype), - axis=k) + return te.sum(A[tuple(A_indices)].astype(accum_dtype) * B[tuple(B_indices)].astype(accum_dtype), axis=k) # Compute matrix multiplication result C = te.compute( diff --git a/tilelang/carver/utils.py b/tilelang/carver/utils.py index cedb7547a0cee3eb976ad23802627d5fe4f22e78..67db89e39178fb8e1f1848a2e2dd95928cb4d6be 100644 --- a/tilelang/carver/utils.py +++ b/tilelang/carver/utils.py @@ -26,11 +26,9 @@ def get_rasterization_code(pannel_width: int = 8) -> str: """ -def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule, - arch: TileDevice, - topk: int = 10, - tensorcore_only: bool = False, - allow_gemv: bool = False) -> list[Hint] | None: +def get_roller_hints_from_func( + func_or_module: tir.PrimFunc | IRModule, arch: TileDevice, topk: int = 10, tensorcore_only: bool = False, allow_gemv: bool = False +) -> list[Hint] | None: func = None if isinstance(func_or_module, tir.PrimFunc): func = func_or_module @@ -44,8 +42,7 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule, roller_hints = None if tensorcore_only: try: - tensorized_func, tags = get_tensorized_func_and_tags( - func, arch.target, allow_gemv=allow_gemv) + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target, allow_gemv=allow_gemv) except Exception as e_msg: logger.debug("Get tensorized func and tags failed: ", e_msg) tags = None @@ -58,8 +55,7 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule, policy = DefaultPolicy.from_prim_func(func=func, arch=arch) tensorized_func = None try: - tensorized_func, tags = get_tensorized_func_and_tags( - func, arch.target, allow_gemv=allow_gemv) + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target, allow_gemv=allow_gemv) except Exception as e_msg: logger.debug("Get tensorized func and tags failed: ", e_msg) tags = None @@ -69,10 +65,9 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule, return roller_hints -def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode], - arch: TileDevice, - topk: int = 10, - extra_tags: list[str] | None = None) -> list[Hint] | None: +def get_roller_hints_from_output_nodes( + output_nodes: list[OutputNode], arch: TileDevice, topk: int = 10, extra_tags: list[str] | None = None +) -> list[Hint] | None: assert isinstance(output_nodes, list), "The input should be a list of functions." lints = [] @@ -80,8 +75,7 @@ def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode], policy = TensorCorePolicy.from_output_nodes(output_nodes, arch=arch, tags=None) lints = policy.emit_config(topk) except Exception as e_msg: - logger.debug(f"Generate hints from output nodes failed: {e_msg}", - "fallback to default policy") + logger.debug(f"Generate hints from output nodes failed: {e_msg}", "fallback to default policy") if len(lints) == 0: policy = DefaultPolicy.from_output_nodes(output_nodes, arch=arch, tags=None) @@ -92,7 +86,6 @@ def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode], def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc: if not isinstance(ir_module, IRModule): raise ValueError("Not supported type: ", type(ir_module)) - assert len(ir_module.get_global_vars()) == 1, ( - "The optimized module should only have one global variable for default schedule.") + assert len(ir_module.get_global_vars()) == 1, "The optimized module should only have one global variable for default schedule." func = list(ir_module.functions.values())[0] return func diff --git a/tilelang/contrib/cc.py b/tilelang/contrib/cc.py index 87d943ab39b147a5aaa897cce9e5c1cda0dfd2cd..7dc459770b06452f453a76769586e63f3b6fd57e 100644 --- a/tilelang/contrib/cc.py +++ b/tilelang/contrib/cc.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Util to invoke C/C++ compilers in the system.""" + import functools import os import shutil @@ -30,8 +31,7 @@ from tvm.contrib import utils as _utils def _is_linux_like(): - return (sys.platform == "darwin" or sys.platform.startswith("linux") or - sys.platform.startswith("freebsd")) + return sys.platform == "darwin" or sys.platform.startswith("linux") or sys.platform.startswith("freebsd") def _is_windows_like(): @@ -90,7 +90,7 @@ def get_cplus_compiler(): def is_darwin(): - return platform.system() == 'Darwin' + return platform.system() == "Darwin" def create_shared(output, objects, options=None, cc=None, cwd=None, ccache_env=None): @@ -287,11 +287,7 @@ create_shared.output_format = "so" if sys.platform != "win32" else "dll" create_shared.get_target_triple = get_target_by_dump_machine(os.environ.get("CXX", get_cc())) -def cross_compiler(compile_func, - options=None, - output_format=None, - get_target_triple=None, - add_files=None): +def cross_compiler(compile_func, options=None, output_format=None, get_target_triple=None, add_files=None): """Create a cross compiler function by specializing compile_func with options. This function can be used to construct compile functions that @@ -363,13 +359,7 @@ def cross_compiler(compile_func, return _fcompile -def _linux_compile(output, - objects, - options, - compile_cmd, - cwd=None, - ccache_env=None, - compile_shared=False): +def _linux_compile(output, objects, options, compile_cmd, cwd=None, ccache_env=None, compile_shared=False): cmd = [compile_cmd] if compile_cmd != "nvcc": if compile_shared or output.endswith(".so") or output.endswith(".dylib"): @@ -430,15 +420,15 @@ def _windows_compile(output, objects, options, cwd=None, ccache_env=None): raise ValueError("ccache not found") try: - proc = subprocess.Popen( - cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=cwd, env=env) + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=cwd, env=env) (out, _) = proc.communicate() except FileNotFoundError: - raise RuntimeError("Can not find the LLVM clang for Windows clang.exe)." - "Make sure it's installed" - " and the installation directory is in the %PATH% environment " - "variable. Prebuilt binaries can be found at: https://llvm.org/") \ - from None + raise RuntimeError( + "Can not find the LLVM clang for Windows clang.exe)." + "Make sure it's installed" + " and the installation directory is in the %PATH% environment " + "variable. Prebuilt binaries can be found at: https://llvm.org/" + ) from None if proc.returncode != 0: msg = "Compilation error:\n" msg += " ".join(cmd) + "\n" diff --git a/tilelang/contrib/dlpack.py b/tilelang/contrib/dlpack.py index 6772fe11a07634dc6b8e3baba2d6107c25c6b6d5..d80f0fdbc39302ec687893fb6d92c25406a657c8 100644 --- a/tilelang/contrib/dlpack.py +++ b/tilelang/contrib/dlpack.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Wrapping functions to bridge frameworks with DLPack support to TVM""" + from tvm import runtime @@ -45,12 +46,8 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func): def adapt_tensor(arg): if isinstance(arg, tensor_type): - if arg.dtype in { - torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, - torch.float8_e5m2fnuz - }: - return runtime.from_dlpack(to_dlpack_func(arg.view(torch.int8)))._create_view( - arg.shape, dtype=float8_dtype_map[arg.dtype]) + if arg.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz}: + return runtime.from_dlpack(to_dlpack_func(arg.view(torch.int8)))._create_view(arg.shape, dtype=float8_dtype_map[arg.dtype]) return runtime.from_dlpack(to_dlpack_func(arg)) return arg diff --git a/tilelang/contrib/hipcc.py b/tilelang/contrib/hipcc.py index 4e3c9a5c3b16759a6df9d9e825456ad31e10e690..7b7f9f9479e54dcd03d0b95779ce65fef59e448f 100644 --- a/tilelang/contrib/hipcc.py +++ b/tilelang/contrib/hipcc.py @@ -16,12 +16,7 @@ from tvm.base import py_str from tvm.contrib.rocm import get_rocm_arch, find_rocm_path -def compile_hip(code, - target_format="hsaco", - arch=None, - options=None, - path_target=None, - verbose=False): +def compile_hip(code, target_format="hsaco", arch=None, options=None, path_target=None, verbose=False): """Compile HIP code with hipcc. Parameters @@ -61,7 +56,7 @@ def compile_hip(code, file_target = path_target if path_target else temp_target cmd = ["hipcc"] - cmd += ["-O3", '-c'] + cmd += ["-O3", "-c"] if isinstance(arch, str): cmd += [f"--offload-arch={arch}"] if target_format == "hsaco": diff --git a/tilelang/contrib/nvcc.py b/tilelang/contrib/nvcc.py index 0e6a19ba1014d00461da40d3c4c456840d998509..36df6c875e198e2770fc3a449f0bc652fc3bb75e 100644 --- a/tilelang/contrib/nvcc.py +++ b/tilelang/contrib/nvcc.py @@ -1,6 +1,7 @@ # pylint: disable=invalid-name # modified from apache tvm python/tvm/contrib/nvcc.py """Utility to invoke nvcc compiler in the system""" + from __future__ import annotations import os @@ -18,12 +19,7 @@ from tvm.base import py_str from tvm.contrib import utils -def compile_cuda(code, - target_format="ptx", - arch=None, - options=None, - path_target=None, - verbose=False): +def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target=None, verbose=False): """Compile cuda code with NVCC from env. Parameters @@ -67,7 +63,7 @@ def compile_cuda(code, temp_target = temp.relpath(f"{file_name}.{target_format}") pass_context = tvm.get_global_func("transform.GetCurrentPassContext")() - kernels_output_dir = (pass_context.config.get("cuda.kernels_output_dir", None)) + kernels_output_dir = pass_context.config.get("cuda.kernels_output_dir", None) if kernels_output_dir is not None: if not os.path.isdir(kernels_output_dir): os.makedirs(kernels_output_dir) @@ -114,10 +110,7 @@ def compile_cuda(code, print(py_str(out)) if proc.returncode != 0: - msg = f"{code}\n" \ - f"Compilation error:\n" \ - f"{py_str(out)}\n" \ - f"Command: {' '.join(cmd)}\n" + msg = f"{code}\nCompilation error:\n{py_str(out)}\nCommand: {' '.join(cmd)}\n" raise RuntimeError(msg) with open(file_target, "rb") as f: @@ -165,6 +158,7 @@ def default_compile_options(compile_flags: list[str] | None = None) -> list[str] # (e.g., multiple "-gencode" pairs or repeated "-Xcompiler" entries). if compile_flags: import shlex + for flag in compile_flags: # Split each string like a shell would, preserving quoted args tokens = shlex.split(flag) if isinstance(flag, str) else [str(flag)] @@ -172,9 +166,7 @@ def default_compile_options(compile_flags: list[str] | None = None) -> list[str] return options -def get_ptx_from_source(code: str, - compile_flags: list[str] | None = None, - verbose: bool = False) -> str: +def get_ptx_from_source(code: str, compile_flags: list[str] | None = None, verbose: bool = False) -> str: """ Compile CUDA C++ source to PTX using NVCC and return as text. @@ -212,9 +204,7 @@ def _find_tool(name: str) -> str | None: return None -def get_sass_from_source(code: str, - compile_flags: list[str] | None = None, - verbose: bool = False) -> str: +def get_sass_from_source(code: str, compile_flags: list[str] | None = None, verbose: bool = False) -> str: """ Compile CUDA C++ source to CUBIN and disassemble to SASS. @@ -246,9 +236,7 @@ def get_sass_from_source(code: str, cand_nvdisasm = _find_tool("nvdisasm") cand_cuobjdump = _find_tool("cuobjdump") if not cand_nvdisasm and not cand_cuobjdump: - raise RuntimeError( - "Cannot find 'nvdisasm' or 'cuobjdump'. Please ensure CUDA toolkit is installed and in PATH." - ) + raise RuntimeError("Cannot find 'nvdisasm' or 'cuobjdump'. Please ensure CUDA toolkit is installed and in PATH.") last_err: str | None = None try: # Attempt nvdisasm first @@ -268,8 +256,7 @@ def get_sass_from_source(code: str, return text last_err = f"{tool_name} rc={proc.returncode}, output:\n{text}" # If we reach here, all attempts failed - raise RuntimeError(f"SASS disassembly failed. Tried tools: " - f"{', '.join(name for name, _ in tools_to_try)}\n{last_err or ''}") + raise RuntimeError(f"SASS disassembly failed. Tried tools: {', '.join(name for name, _ in tools_to_try)}\n{last_err or ''}") finally: with contextlib.suppress(Exception): os.remove(cubin_path) @@ -438,8 +425,7 @@ def get_target_compute_version(target=None): if tvm.cuda(0).exist: return tvm.cuda(0).compute_version - raise ValueError("No CUDA architecture was specified or GPU detected." - "Try specifying it by adding '-arch=sm_xx' to your target.") + raise ValueError("No CUDA architecture was specified or GPU detected.Try specifying it by adding '-arch=sm_xx' to your target.") def parse_compute_version(compute_version) -> tuple[int, int]: @@ -524,7 +510,8 @@ def have_tensorcore(compute_version=None, target=None): warnings.warn( "Tensorcore will be disabled due to no CUDA architecture specified." "Try specifying it by adding '-arch=sm_xx' to your target.", - stacklevel=2) + stacklevel=2, + ) return False compute_version = target.attrs["arch"] # Compute version will be in the form "sm_{major}{minor}" diff --git a/tilelang/contrib/nvrtc.py b/tilelang/contrib/nvrtc.py index b691155497c6db8c7093168d037f71fd19705061..105c518198b66aa1816d67b006d66c4d11b8f3cc 100644 --- a/tilelang/contrib/nvrtc.py +++ b/tilelang/contrib/nvrtc.py @@ -11,11 +11,13 @@ def get_nvrtc_version() -> tuple[int, int]: return (major, minor) -def compile_cuda(code: str, - target_format: Literal["ptx", "cubin"] = "ptx", - arch: int | None = None, - options: str | list[str] | None = None, - verbose: bool = False) -> bytearray: +def compile_cuda( + code: str, + target_format: Literal["ptx", "cubin"] = "ptx", + arch: int | None = None, + options: str | list[str] | None = None, + verbose: bool = False, +) -> bytearray: """Compile cuda code with NVRTC. Parameters @@ -43,8 +45,7 @@ def compile_cuda(code: str, if arch is None: # If None, then it will use `tvm.target.Target.current().arch`. # Target arch could be a str like "80", "90", "90a", etc. - major, minor = parse_compute_version( - get_target_compute_version(Target.current(allow_none=True))) + major, minor = parse_compute_version(get_target_compute_version(Target.current(allow_none=True))) arch = major * 10 + minor prefix = "compute" if target_format == "ptx" else "sm" suffix = "a" if arch >= 90 else "" @@ -77,8 +78,7 @@ def compile_cuda(code: str, compile_result = nvrtc.nvrtcCompileProgram(program, len(options_bytes), options_bytes)[0] if compile_result != nvrtc.nvrtcResult.NVRTC_SUCCESS: - msg = f"{code}\n" \ - f"Compilation error:\n" + msg = f"{code}\nCompilation error:\n" if verbose: result, log_size = nvrtc.nvrtcGetProgramLogSize(program) assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get program log size: {result}" @@ -105,7 +105,6 @@ def compile_cuda(code: str, assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get PTX: {result}" # Destroy handler - assert nvrtc.nvrtcDestroyProgram( - program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to destroy program: {result}" + assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to destroy program: {result}" return result_bytes diff --git a/tilelang/contrib/rocm.py b/tilelang/contrib/rocm.py index 4a57c3c64401095cbd87e8b1485bc8b8216a47f1..f3b92e54d09ba27fb672bf6a15993b2d337c0074 100644 --- a/tilelang/contrib/rocm.py +++ b/tilelang/contrib/rocm.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Utility for ROCm backend""" + # ruff: noqa import re import subprocess @@ -255,9 +256,11 @@ def get_rocm_arch(rocm_path="/opt/rocm"): gpu_arch = match.group(1) return gpu_arch except subprocess.CalledProcessError: - print(f"Unable to execute rocminfo command, \ + print( + f"Unable to execute rocminfo command, \ please ensure ROCm is installed and you have an AMD GPU on your system.\ - using default {gpu_arch}.") + using default {gpu_arch}." + ) return gpu_arch diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py index 7abdfb92acd474ea67ace13754e92d682f2bf30e..9932d522e4798b93b45f5cef4eea7db70c715eb2 100644 --- a/tilelang/engine/lower.py +++ b/tilelang/engine/lower.py @@ -1,4 +1,5 @@ """The compiler for TL programs.""" + from __future__ import annotations import os @@ -28,14 +29,13 @@ def is_cpu_device_backend(target: Target): def has_device_kernel_launch(attrs) -> bool: """Check if the attributes indicate a device kernel launch.""" - return bool(attrs and "calling_conv" in attrs and - attrs["calling_conv"] == CallingConv.DEVICE_KERNEL_LAUNCH) + return bool(attrs and "calling_conv" in attrs and attrs["calling_conv"] == CallingConv.DEVICE_KERNEL_LAUNCH) def is_device_call_c_device(func: tir.PrimFunc): attrs = func.attrs calling_conv = attrs.get("calling_conv", CallingConv.DEFAULT) - is_cpacked = (calling_conv == CallingConv.C_PACKED_FUNC) + is_cpacked = calling_conv == CallingConv.C_PACKED_FUNC # Check if it's a C target if "target" in attrs and attrs["target"].kind.name == "c" and not is_cpacked: @@ -141,16 +141,16 @@ def extrac_params(func: tir.PrimFunc) -> list[KernelParam]: if var in func.buffer_map: tensor_types.append(KernelParam.from_buffer(func.buffer_map[var])) else: - if var.dtype == 'handle': + if var.dtype == "handle": raise ValueError( - f'Handle parameter {var} must be mapped to a buffer.\n' - f'Please use T.tensor({var.name}, shape=..., dtype=...) to map it to a buffer.') + f"Handle parameter {var} must be mapped to a buffer.\n" + f"Please use T.tensor({var.name}, shape=..., dtype=...) to map it to a buffer." + ) tensor_types.append(KernelParam.from_var(var)) return tensor_types def canon_target_host(target: str | Target, target_host: str | Target | None): - if not target_host: target_host = "llvm" if tvm.runtime.enabled("llvm") else "c" @@ -195,11 +195,9 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) -> device_mod = tilelang.transform.LowerIntrin()(device_mod) device_mod = tir.transform.Simplify()(device_mod) if target.kind.name == "cuda": - device_mod = tvm.ffi.get_global_func("target.build.tilelang_cuda_without_compile")( - device_mod, target) + device_mod = tvm.ffi.get_global_func("target.build.tilelang_cuda_without_compile")(device_mod, target) elif target.kind.name == "hip": - device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip_without_compile")( - device_mod, target) + device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip_without_compile")(device_mod, target) elif target.kind.name == "c": device_mod = tvm.ffi.get_global_func("target.build.tilelang_cpp")(device_mod, target) elif target.kind.name == "llvm": @@ -222,12 +220,12 @@ def lower( enable_host_codegen=False, enable_device_compile=False, ) -> CompiledArtifact: - ''' - enable_host_codegen: whether to enable host codegen, default is False, as we have our - own host codegen implementation in jit. - enable_device_compile: whether to enable device codegen, default is False, as we have our - own device codegen implementation in jit. - ''' + """ + enable_host_codegen: whether to enable host codegen, default is False, as we have our + own host codegen implementation in jit. + enable_device_compile: whether to enable device codegen, default is False, as we have our + own device codegen implementation in jit. + """ mod = func_or_mod params = None @@ -259,14 +257,11 @@ def lower( host_mod = tir.transform.Filter(_is_host_call)(mod) device_mod = tir.transform.Filter(_is_device_call)(mod) - codegen_mod = device_codegen( - device_mod, target) if enable_device_compile else device_codegen_without_compile( - device_mod, target) + codegen_mod = device_codegen(device_mod, target) if enable_device_compile else device_codegen_without_compile(device_mod, target) if enable_host_codegen: host_mod = host_codegen(host_mod, target_host) host_mod.import_module(codegen_mod) - return CompiledArtifact( - host_mod, device_mod, params, codegen_mod.inspect_source(), rt_mod=host_mod) + return CompiledArtifact(host_mod, device_mod, params, codegen_mod.inspect_source(), rt_mod=host_mod) return CompiledArtifact(host_mod, device_mod, params, codegen_mod.inspect_source()) diff --git a/tilelang/engine/param.py b/tilelang/engine/param.py index de3c979ea1a55934dda9ea9959b7f8ffc080f006..1abf66a5fcb676c1e64406d60bd8935e60adc174 100644 --- a/tilelang/engine/param.py +++ b/tilelang/engine/param.py @@ -1,4 +1,5 @@ """The profiler and convert to torch utils""" + from __future__ import annotations from dataclasses import dataclass @@ -14,6 +15,7 @@ class KernelParam: Represents parameters for a kernel operation, storing dtype and shape information. Used to describe tensor or scalar parameters in TVM/PyTorch interop. """ + dtype: torch.dtype # PyTorch data type of the parameter shape: list[int | Var] # List of dimensions, can be integers or TVM variables @@ -109,6 +111,7 @@ class CompiledArtifact: Represents a compiled kernel artifact containing both host and device code. Stores all necessary components for kernel execution in the TVM runtime. """ + host_mod: tvm.IRModule # Host-side TVM IR module for managing kernel execution device_mod: tvm.IRModule # Device-side TVM IR module containing the actual kernel code params: list[KernelParam] # List of parameters (tensors/scalars) used by the kernel diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index cd205a6db7da0bbb8ecfb8d6c779c5a2dc260abd..cef3d9a2e0455c98df2eb3e17b72a0e125aec021 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -6,8 +6,7 @@ from tilelang.transform import PassContext from tilelang.contrib.nvcc import have_tma, is_hopper -def allow_warp_specialized(pass_ctx: PassContext | None = None, - target: Target | None = None) -> bool: +def allow_warp_specialized(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool: # avoid circular import from tilelang.jit.adapter.utils import is_cuda_target @@ -19,8 +18,7 @@ def allow_warp_specialized(pass_ctx: PassContext | None = None, return not disable_warp_specialized -def allow_tma_and_warp_specialized(pass_ctx: PassContext | None = None, - target: Target | None = None) -> bool: +def allow_tma_and_warp_specialized(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool: if pass_ctx is None: pass_ctx = tilelang.transform.get_pass_context() if not have_tma(target): @@ -47,12 +45,10 @@ def allow_global_thread_synchronization(pass_ctx: PassContext | None = None) -> return enable_global_thread_sync -def should_enable_aggressive_merge(pass_ctx: PassContext | None = None, - target: Target | None = None) -> bool: +def should_enable_aggressive_merge(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool: if pass_ctx is None: pass_ctx = tilelang.transform.get_pass_context() - enable_aggressive_merge = bool( - pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE, False)) + enable_aggressive_merge = bool(pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE, False)) if allow_warp_specialized(pass_ctx=pass_ctx, target=target): # This is a workaround to avoid the bug in the MergeSharedMemoryAllocations pass # when warp specialization is enabled, as different warp threads may access different @@ -88,7 +84,7 @@ def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]: return ["txt", "png", "pdf", "svg"] if "," in formats_str: - formats_list = [f.strip() for f in formats_str.split(',')] + formats_list = [f.strip() for f in formats_str.split(",")] else: formats_list = [formats_str] @@ -257,9 +253,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: # MergeSharedMemoryAllocations must be applied after SplitHostDevice # because the merged allocation site is at the beginning of each device function enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target) - mod = tilelang.transform.MergeSharedMemoryAllocations( - enable_aggressive_merge=enable_aggressive_merge)( - mod) + mod = tilelang.transform.MergeSharedMemoryAllocations(enable_aggressive_merge=enable_aggressive_merge)(mod) mod = tilelang.transform.ThreadSync("shared")(mod) mod = tilelang.transform.ThreadSync("shared.dyn")(mod) # Inject PTX async copy must behind the thread sync pass diff --git a/tilelang/env.py b/tilelang/env.py index ce27aba9c117765ad387a3c5dfacbafb625252de..0583cd4cf6c7fd8cc4f8ea3640cefb3bb5bce302 100644 --- a/tilelang/env.py +++ b/tilelang/env.py @@ -10,36 +10,34 @@ from dataclasses import dataclass logger = logging.getLogger(__name__) # SETUP ENVIRONMENT VARIABLES -CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path") +CUTLASS_NOT_FOUND_MESSAGE = "CUTLASS is not installed or found in the expected path" ", which may lead to compilation bugs when utilize tilelang backend." -COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE = ( - "Composable Kernel is not installed or found in the expected path") +COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE = "Composable Kernel is not installed or found in the expected path" ", which may lead to compilation bugs when utilize tilelang backend." -TL_TEMPLATE_NOT_FOUND_MESSAGE = ("TileLang is not installed or found in the expected path") +TL_TEMPLATE_NOT_FOUND_MESSAGE = "TileLang is not installed or found in the expected path" ", which may lead to compilation bugs when utilize tilelang backend." -TVM_LIBRARY_NOT_FOUND_MESSAGE = ("TVM is not installed or found in the expected path") +TVM_LIBRARY_NOT_FOUND_MESSAGE = "TVM is not installed or found in the expected path" TL_ROOT = os.path.dirname(os.path.abspath(__file__)) # Only expose the internal lib directory to sys.path to avoid shadowing # common top-level module names (e.g., utils, analysis) from user projects. -TL_LIBS = [os.path.join(TL_ROOT, 'lib')] +TL_LIBS = [os.path.join(TL_ROOT, "lib")] TL_LIBS = [i for i in TL_LIBS if os.path.exists(i)] DEV = False -THIRD_PARTY_ROOT = os.path.join(TL_ROOT, '3rdparty') +THIRD_PARTY_ROOT = os.path.join(TL_ROOT, "3rdparty") if not os.path.exists(THIRD_PARTY_ROOT): DEV = True tl_dev_root = os.path.dirname(TL_ROOT) - dev_lib_root = os.path.join(tl_dev_root, 'build') + dev_lib_root = os.path.join(tl_dev_root, "build") # In dev builds, place artifacts under build/lib and point search path there # to avoid adding the entire build root to sys.path. - TL_LIBS = [os.path.join(dev_lib_root, 'lib'), os.path.join(dev_lib_root, 'tvm')] - THIRD_PARTY_ROOT = os.path.join(tl_dev_root, '3rdparty') - logger.warning(f'Loading tilelang libs from dev root: {dev_lib_root}') + TL_LIBS = [os.path.join(dev_lib_root, "lib"), os.path.join(dev_lib_root, "tvm")] + THIRD_PARTY_ROOT = os.path.join(tl_dev_root, "3rdparty") + logger.warning(f"Loading tilelang libs from dev root: {dev_lib_root}") -assert TL_LIBS and all( - os.path.exists(i) for i in TL_LIBS), f'tilelang lib root do not exists: {TL_LIBS}' +assert TL_LIBS and all(os.path.exists(i) for i in TL_LIBS), f"tilelang lib root do not exists: {TL_LIBS}" for lib in TL_LIBS: if lib not in sys.path: @@ -52,7 +50,7 @@ def _find_cuda_home() -> str: Adapted from https://github.com/pytorch/pytorch/blob/main/torch/utils/cpp_extension.py """ # Guess #1 - cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") if cuda_home is None: # Guess #2 nvcc_path = shutil.which("nvcc") @@ -70,15 +68,15 @@ def _find_cuda_home() -> str: else: # Guess #3 - if sys.platform == 'win32': - cuda_homes = glob.glob('C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*') - cuda_home = '' if len(cuda_homes) == 0 else cuda_homes[0] + if sys.platform == "win32": + cuda_homes = glob.glob("C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*") + cuda_home = "" if len(cuda_homes) == 0 else cuda_homes[0] else: # Linux/macOS - if os.path.exists('/usr/local/cuda'): - cuda_home = '/usr/local/cuda' - elif os.path.exists('/opt/nvidia/hpc_sdk/Linux_x86_64'): - cuda_home = '/opt/nvidia/hpc_sdk/Linux_x86_64' + if os.path.exists("/usr/local/cuda"): + cuda_home = "/usr/local/cuda" + elif os.path.exists("/opt/nvidia/hpc_sdk/Linux_x86_64"): + cuda_home = "/opt/nvidia/hpc_sdk/Linux_x86_64" # Validate found path if cuda_home is None or not os.path.exists(cuda_home): @@ -89,13 +87,13 @@ def _find_cuda_home() -> str: def _find_rocm_home() -> str: """Find the ROCM install path.""" - rocm_home = os.environ.get('ROCM_PATH') or os.environ.get('ROCM_HOME') + rocm_home = os.environ.get("ROCM_PATH") or os.environ.get("ROCM_HOME") if rocm_home is None: rocmcc_path = shutil.which("hipcc") if rocmcc_path is not None: rocm_home = os.path.dirname(os.path.dirname(rocmcc_path)) else: - rocm_home = '/opt/rocm' + rocm_home = "/opt/rocm" if not os.path.exists(rocm_home): rocm_home = None return rocm_home if rocm_home is not None else "" @@ -104,6 +102,7 @@ def _find_rocm_home() -> str: # Cache control class CacheState: """Class to manage global kernel caching state.""" + _enabled = True @classmethod @@ -230,13 +229,11 @@ class Environment: TILELANG_TMP_DIR = EnvVar("TILELANG_TMP_DIR", os.path.join(TILELANG_CACHE_DIR.get(), "tmp")) # Kernel Build options - TILELANG_PRINT_ON_COMPILATION = EnvVar("TILELANG_PRINT_ON_COMPILATION", - "1") # print kernel name on compile + TILELANG_PRINT_ON_COMPILATION = EnvVar("TILELANG_PRINT_ON_COMPILATION", "1") # print kernel name on compile TILELANG_DISABLE_CACHE = EnvVar( - "TILELANG_DISABLE_CACHE", - "0") # disable kernel cache, usually for unit testing / debugging, high priority - TILELANG_CLEAR_CACHE = EnvVar("TILELANG_CLEAR_CACHE", - "0") # DEPRECATED! clear cache automatically if set + "TILELANG_DISABLE_CACHE", "0" + ) # disable kernel cache, usually for unit testing / debugging, high priority + TILELANG_CLEAR_CACHE = EnvVar("TILELANG_CLEAR_CACHE", "0") # DEPRECATED! clear cache automatically if set # Kernel selection options # Default to GEMM v2; set to "1"/"true"/"yes"/"on" to force v1 @@ -244,12 +241,9 @@ class Environment: # Auto-tuning settings TILELANG_AUTO_TUNING_DISABLE_CACHE = EnvVar("TILELANG_AUTO_TUNING_DISABLE_CACHE", "0") - TILELANG_AUTO_TUNING_CPU_UTILITIES = EnvVar("TILELANG_AUTO_TUNING_CPU_UTILITIES", - "0.9") # percent of CPUs used - TILELANG_AUTO_TUNING_CPU_COUNTS = EnvVar("TILELANG_AUTO_TUNING_CPU_COUNTS", - "-1") # -1 means auto - TILELANG_AUTO_TUNING_MAX_CPU_COUNT = EnvVar("TILELANG_AUTO_TUNING_MAX_CPU_COUNT", - "-1") # -1 means no limit + TILELANG_AUTO_TUNING_CPU_UTILITIES = EnvVar("TILELANG_AUTO_TUNING_CPU_UTILITIES", "0.9") # percent of CPUs used + TILELANG_AUTO_TUNING_CPU_COUNTS = EnvVar("TILELANG_AUTO_TUNING_CPU_COUNTS", "-1") # -1 means auto + TILELANG_AUTO_TUNING_MAX_CPU_COUNT = EnvVar("TILELANG_AUTO_TUNING_MAX_CPU_COUNT", "-1") # -1 means no limit # TVM integration SKIP_LOADING_TILELANG_SO = EnvVar("SKIP_LOADING_TILELANG_SO", "0") @@ -323,18 +317,18 @@ def prepend_pythonpath(path): if env.TVM_IMPORT_PYTHON_PATH is not None: prepend_pythonpath(env.TVM_IMPORT_PYTHON_PATH) else: - tvm_path = os.path.join(THIRD_PARTY_ROOT, 'tvm', 'python') + tvm_path = os.path.join(THIRD_PARTY_ROOT, "tvm", "python") assert os.path.exists(tvm_path), tvm_path if tvm_path not in sys.path: prepend_pythonpath(tvm_path) env.TVM_IMPORT_PYTHON_PATH = tvm_path # By default, the built TVM-related libraries are stored in TL_LIBS. if os.environ.get("TVM_LIBRARY_PATH") is None: - os.environ['TVM_LIBRARY_PATH'] = env.TVM_LIBRARY_PATH = os.pathsep.join(TL_LIBS) + os.environ["TVM_LIBRARY_PATH"] = env.TVM_LIBRARY_PATH = os.pathsep.join(TL_LIBS) # Initialize CUTLASS paths if os.environ.get("TL_CUTLASS_PATH", None) is None: - cutlass_inc_path = os.path.join(THIRD_PARTY_ROOT, 'cutlass', 'include') + cutlass_inc_path = os.path.join(THIRD_PARTY_ROOT, "cutlass", "include") if os.path.exists(cutlass_inc_path): os.environ["TL_CUTLASS_PATH"] = env.CUTLASS_INCLUDE_DIR = cutlass_inc_path else: @@ -342,7 +336,7 @@ if os.environ.get("TL_CUTLASS_PATH", None) is None: # Initialize COMPOSABLE_KERNEL paths if os.environ.get("TL_COMPOSABLE_KERNEL_PATH", None) is None: - ck_inc_path = os.path.join(THIRD_PARTY_ROOT, 'composable_kernel', 'include') + ck_inc_path = os.path.join(THIRD_PARTY_ROOT, "composable_kernel", "include") if os.path.exists(ck_inc_path): os.environ["TL_COMPOSABLE_KERNEL_PATH"] = env.COMPOSABLE_KERNEL_INCLUDE_DIR = ck_inc_path else: diff --git a/tilelang/intrinsics/mfma_layout.py b/tilelang/intrinsics/mfma_layout.py index 183ba646f9421eb2318242c6989ad41132eeba0d..38959649467cdfd9decd2fd73c3b4c46e8868ea0 100644 --- a/tilelang/intrinsics/mfma_layout.py +++ b/tilelang/intrinsics/mfma_layout.py @@ -4,7 +4,7 @@ import tilelang.language as T def shared_16x4_to_local_64x1_layout_A(i, j): - thread_id = (j * 16 + i) + thread_id = j * 16 + i return thread_id, convert(0) @@ -15,7 +15,7 @@ def thread_id_shared_access_64x1_to_16x4_layout_A(thread_id, local_id): def shared_4x16_to_local_64x1_layout_B(i, j): - thread_id = (i * 16 + j) + thread_id = i * 16 + j return thread_id, convert(0) @@ -27,7 +27,7 @@ def thread_id_shared_access_64x1_to_4x16_layout_B(thread_id, local_id): def shared_16x16_to_local_64x4_layout_C(i, j): thread_id = j + (i // 4) * 16 - local = (i % 4) + local = i % 4 return thread_id, local @@ -45,7 +45,7 @@ def thread_id_shared_access_64x4_to_16x16_layout_A(thread_id, local_id): def shared_16x16_to_local_64x4_layout_A(i, j): thread_id = i + 16 * (j // 4) - local = (j % 4) + local = j % 4 return thread_id, local @@ -57,7 +57,7 @@ def thread_id_shared_access_64x4_to_16x16_layout_B(thread_id, local_id): def shared_16x16_to_local_64x4_layout_B(i, j): thread_id = j + (i // 4) * 16 - local = (i % 4) + local = i % 4 return thread_id, local @@ -87,7 +87,7 @@ def thread_id_shared_access_64x8_to_16x32_layout_A(thread_id, local_id): def shared_16x32_to_local_64x8_layout_A(i, j): thread_id = i + 16 * (j // 8) - local = (j % 8) + local = j % 8 return thread_id, local @@ -99,7 +99,7 @@ def thread_id_shared_access_64x8_to_16x32_layout_B(thread_id, local_id): def shared_16x32_to_local_64x8_layout_B(i, j): thread_id = j + (i // 8) * 16 - local = (i % 8) + local = i % 8 return thread_id, local @@ -111,7 +111,7 @@ def thread_id_shared_access_64x16_to_16x64_layout_A(thread_id, local_id): def shared_16x64_to_local_64x16_layout_A(i, j): thread_id = i + 16 * (j // 16) - local = (j % 16) + local = j % 16 return thread_id, local @@ -123,7 +123,7 @@ def thread_id_shared_access_64x16_to_16x64_layout_B(thread_id, local_id): def shared_16x64_to_local_64x16_layout_B(i, j): thread_id = i + 16 * (j // 16) - local = (j % 16) + local = j % 16 return thread_id, local diff --git a/tilelang/intrinsics/mfma_macro_generator.py b/tilelang/intrinsics/mfma_macro_generator.py index 618a99811623dbfeeeb16bb9970807a771c9bd58..1e97bd0f20e20cef80d97c3165e9856c424804ef 100644 --- a/tilelang/intrinsics/mfma_macro_generator.py +++ b/tilelang/intrinsics/mfma_macro_generator.py @@ -6,7 +6,7 @@ from tvm import tir from tvm.ir import Range from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion, BufferLoad from tvm.runtime import convert -from .utils import (mfma_store_index_map) +from .utils import mfma_store_index_map from typing import Literal, Callable from tilelang.utils import is_fragment @@ -101,7 +101,7 @@ class MatrixCoreIntrinEmitter: self.warp_rows = warp_row_tiles // self.micro_size_x self.warp_cols = warp_col_tiles // self.micro_size_y self.reduce_k = reduce_k - self.threads = (self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k) + self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k self.num_elems_per_byte = num_elems_per_byte self.thread_var = thread_var @@ -132,12 +132,7 @@ class MatrixCoreIntrinEmitter: def _initialize_mfma_prefix(self, k_dim=16): in_dtype, out_dtype = self.a_dtype, self.accum_dtype M_DIM, N_DIM = self.M_DIM, self.N_DIM - out_dtype_abbrv = { - "float16": "f16", - "float32": "f32", - "int8": "i8", - "int32": "i32" - }[out_dtype] + out_dtype_abbrv = {"float16": "f16", "float32": "f32", "int8": "i8", "int32": "i32"}[out_dtype] in_dtype_abbrv = { "bfloat16": "bf16", @@ -176,7 +171,6 @@ class MatrixCoreIntrinEmitter: self.b_preshuffle = b_preshuffle def get_ldmatrix_index_map(self, is_b=False): - k_dim = self.k_dim * self.k_pack transposed = self.a_transposed if not is_b else self.b_transposed if k_dim == 4: @@ -184,28 +178,42 @@ class MatrixCoreIntrinEmitter: reverse_index_map = thread_id_shared_access_64x1_to_16x4_layout_A if is_b: index_map = shared_16x4_to_local_64x1_layout_A if transposed else shared_4x16_to_local_64x1_layout_B - reverse_index_map = thread_id_shared_access_64x1_to_16x4_layout_A if transposed else thread_id_shared_access_64x1_to_4x16_layout_B + reverse_index_map = ( + thread_id_shared_access_64x1_to_16x4_layout_A if transposed else thread_id_shared_access_64x1_to_4x16_layout_B + ) elif k_dim == 16: index_map = shared_16x16_to_local_64x4_layout_B if transposed else shared_16x16_to_local_64x4_layout_A - reverse_index_map = thread_id_shared_access_64x4_to_16x16_layout_B if transposed else thread_id_shared_access_64x4_to_16x16_layout_A + reverse_index_map = ( + thread_id_shared_access_64x4_to_16x16_layout_B if transposed else thread_id_shared_access_64x4_to_16x16_layout_A + ) if is_b: index_map = shared_16x16_to_local_64x4_layout_A if transposed else shared_16x16_to_local_64x4_layout_B - reverse_index_map = thread_id_shared_access_64x4_to_16x16_layout_A if transposed else thread_id_shared_access_64x4_to_16x16_layout_B + reverse_index_map = ( + thread_id_shared_access_64x4_to_16x16_layout_A if transposed else thread_id_shared_access_64x4_to_16x16_layout_B + ) elif k_dim == 32: index_map = shared_16x32_to_local_64x8_layout_B if transposed else shared_16x32_to_local_64x8_layout_A - reverse_index_map = thread_id_shared_access_64x8_to_16x32_layout_B if transposed else thread_id_shared_access_64x8_to_16x32_layout_A + reverse_index_map = ( + thread_id_shared_access_64x8_to_16x32_layout_B if transposed else thread_id_shared_access_64x8_to_16x32_layout_A + ) if is_b: index_map = shared_16x32_to_local_64x8_layout_A if transposed else shared_16x32_to_local_64x8_layout_B - reverse_index_map = thread_id_shared_access_64x8_to_16x32_layout_A if transposed else thread_id_shared_access_64x8_to_16x32_layout_B + reverse_index_map = ( + thread_id_shared_access_64x8_to_16x32_layout_A if transposed else thread_id_shared_access_64x8_to_16x32_layout_B + ) elif k_dim == 64: index_map = shared_16x64_to_local_64x16_layout_B if transposed else shared_16x64_to_local_64x16_layout_A - reverse_index_map = thread_id_shared_access_64x16_to_16x64_layout_B if transposed else thread_id_shared_access_64x16_to_16x64_layout_A + reverse_index_map = ( + thread_id_shared_access_64x16_to_16x64_layout_B if transposed else thread_id_shared_access_64x16_to_16x64_layout_A + ) if is_b: index_map = shared_16x64_to_local_64x16_layout_A if transposed else shared_16x64_to_local_64x16_layout_B - reverse_index_map = thread_id_shared_access_64x16_to_16x64_layout_A if transposed else thread_id_shared_access_64x16_to_16x64_layout_B + reverse_index_map = ( + thread_id_shared_access_64x16_to_16x64_layout_A if transposed else thread_id_shared_access_64x16_to_16x64_layout_B + ) else: raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently") @@ -227,14 +235,12 @@ class MatrixCoreIntrinEmitter: else: return self.thread_var - def extract_thread_binding(self, - thread_id, - is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: - ''' - is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) - which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] - Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)] - ''' + def extract_thread_binding(self, thread_id, is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: + """ + is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) + which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] + Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)] + """ WARP_SIZE = self.WARP_SIZE block_row_warps = self.block_row_warps block_col_warps = self.block_col_warps @@ -244,16 +250,18 @@ class MatrixCoreIntrinEmitter: is_m_first = self.is_m_first if is_m_first: - lane_id, warp_n, warp_m = thread_id % WARP_SIZE, ( - thread_id // - WARP_SIZE) % block_col_warps, (thread_id // - (WARP_SIZE * block_col_warps)) % block_row_warps, + lane_id, warp_n, warp_m = ( + thread_id % WARP_SIZE, + (thread_id // WARP_SIZE) % block_col_warps, + (thread_id // (WARP_SIZE * block_col_warps)) % block_row_warps, + ) return lane_id, warp_n, warp_m else: - lane_id, warp_m, warp_n = thread_id % WARP_SIZE, ( - thread_id // - WARP_SIZE) % block_row_warps, (thread_id // - (WARP_SIZE * block_row_warps)) % block_col_warps, + lane_id, warp_m, warp_n = ( + thread_id % WARP_SIZE, + (thread_id // WARP_SIZE) % block_row_warps, + (thread_id // (WARP_SIZE * block_row_warps)) % block_col_warps, + ) return lane_id, warp_n, warp_m def ldmatrix_a(self, A_local_buf, A_shared_buf: Buffer | BufferRegion, ki, rk=0): @@ -287,18 +295,14 @@ class MatrixCoreIntrinEmitter: for i in T.serial(warp_rows): for local_id in T.vectorized(k_pack * local_size_a): row, col = T.meta_var(reverse_index_map(tx, local_id)) - l, r = (rk * chunk + ki * (k_pack * micro_size_k), - warp_m * warp_row_tiles + i * micro_size_x) - A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, - A_base1 + r + col] + l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_m * warp_row_tiles + i * micro_size_x) + A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col] else: for i in T.serial(warp_rows): for local_id in T.vectorized(k_pack * local_size_a): row, col = T.meta_var(reverse_index_map(tx, local_id)) - l, r = (warp_m * warp_row_tiles + i * micro_size_x, - rk * chunk + ki * (k_pack * micro_size_k)) - A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, - A_base1 + r + col] + l, r = (warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * (k_pack * micro_size_k)) + A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col] return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) @@ -337,8 +341,7 @@ class MatrixCoreIntrinEmitter: warp_n * warp_col_tiles + j * micro_size_y, rk * chunk + ki * (k_pack * micro_size_k), ) - B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, - B_base1 + r + col] + B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col] else: for j in T.serial(warp_cols): @@ -348,16 +351,11 @@ class MatrixCoreIntrinEmitter: rk * chunk + ki * (k_pack * micro_size_k), warp_n * warp_col_tiles + j * micro_size_y, ) - B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, - B_base1 + r + col] + B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col] return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) - def mfma(self, - A_local_buf: Buffer, - B_local_buf: Buffer, - C_local_buf: Buffer, - k_inner: PrimExpr | None = 0): + def mfma(self, A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr | None = 0): warp_rows = self.warp_rows warp_cols = self.warp_cols local_size_a = self.local_size_a @@ -421,14 +419,13 @@ class MatrixCoreIntrinEmitter: for local_id in T.vectorized(local_size_out): row, col = T.meta_var(mfma_store_index_map(tx, local_id)) if C_buf_dims == 2: - C_buf[(warp_m * warp_rows + i) * M_DIM + row, - (warp_n * warp_cols + j) * N_DIM + - col] = C_local_buf[i * (warp_cols * local_size_out) + - j * local_size_out + local_id] + C_buf[(warp_m * warp_rows + i) * M_DIM + row, (warp_n * warp_cols + j) * N_DIM + col] = C_local_buf[ + i * (warp_cols * local_size_out) + j * local_size_out + local_id + ] else: - C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, - col] = C_local_buf[i * warp_cols * local_size_out + - j * local_size_out + local_id] + C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, col] = C_local_buf[ + i * warp_cols * local_size_out + j * local_size_out + local_id + ] @T.macro def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding): @@ -436,18 +433,17 @@ class MatrixCoreIntrinEmitter: for i, j in T.grid(warp_rows, warp_cols): for local_id in T.vectorized(local_size_out): row, col = T.meta_var(mfma_store_index_map(tx, local_id)) - C_buf[(pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, - (pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM + - col] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + - local_id] - - return _warp_stmatrix_global(C_local_buf, C_buf, - thread_binding) if is_global else _warp_stmatrix_shared( - C_local_buf, C_buf, thread_binding) - - def make_mfma_load_layout(self, - local_buf: Buffer, - matrix: Literal["A", "B"] = "A") -> T.Fragment: + C_buf[ + (pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, (pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM + col + ] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + local_id] + + return ( + _warp_stmatrix_global(C_local_buf, C_buf, thread_binding) + if is_global + else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding) + ) + + def make_mfma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A") -> T.Fragment: """ Create a layout function for storing MFMA results into a fragment buffer. @@ -468,6 +464,7 @@ class MatrixCoreIntrinEmitter: If `local_buf` is not detected to be a fragment buffer. """ from tilelang.utils import is_fragment + assert matrix in ["A", "B"], "matrix should be either A or B" matrix_is_a: bool = matrix == "A" matrix_is_b: bool = matrix == "B" @@ -506,11 +503,9 @@ class MatrixCoreIntrinEmitter: transform_func: Callable = None if matrix_is_a: - transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( - j, i) + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) elif matrix_is_b: - transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b( - j, i) + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i) else: raise ValueError(f"Unsupported matrix {matrix}") @@ -543,8 +538,7 @@ class MatrixCoreIntrinEmitter: return local_id base_fragment = T.Fragment( - [micro_size_s, micro_size_r * - self.k_pack] if is_sr_axis_order else [micro_size_r * self.k_pack, micro_size_s], + [micro_size_s, micro_size_r * self.k_pack] if is_sr_axis_order else [micro_size_r * self.k_pack, micro_size_s], forward_thread_fn=forward_thread, forward_index_fn=forward_index, ) @@ -558,31 +552,19 @@ class MatrixCoreIntrinEmitter: replicate = block_col_warps if matrix_is_a else block_row_warps if is_sr_axis_order: - warp_fragment = base_fragment.repeat([warp_s, warp_r], - repeat_on_thread=False, - lower_dim_first=False) + warp_fragment = base_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False) if matrix_is_a: - block_fragment = warp_fragment.repeat([block_s, 1], - repeat_on_thread=True, - lower_dim_first=True).replicate(replicate) + block_fragment = warp_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) elif matrix_is_b: - block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], - repeat_on_thread=True, - lower_dim_first=True) + block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True) else: raise ValueError(f"Unsupported matrix type {matrix}") else: - warp_fragment = base_fragment.repeat([warp_r, warp_s], - repeat_on_thread=False, - lower_dim_first=True) + warp_fragment = base_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True) if matrix_is_a: - block_fragment = warp_fragment.repeat([1, block_s], - repeat_on_thread=True, - lower_dim_first=True).replicate(replicate) + block_fragment = warp_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) elif matrix_is_b: - block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], - repeat_on_thread=True, - lower_dim_first=True) + block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True) else: raise ValueError(f"Unsupported matrix type {matrix}") @@ -686,7 +668,6 @@ class MatrixCoreIntrinEmitter: class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): - def __init__( self, a_dtype: str = "float16", @@ -792,20 +773,20 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): rk * (chunk // micro_size_k) + ki, warp_m * warp_rows + i, ) - A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, - col] + A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, col] else: print(self.a_preshuffle) for i in T.serial(warp_rows): for local_id in T.vectorized(k_pack * local_size_a): row, col = T.meta_var(reverse_index_map(tx, local_id)) l, r = (warp_m * warp_rows + i, rk * (chunk // micro_size_k) + ki) - A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, - col] + A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, col] - return _warp_ldmatrix_a_global(A_local_buf, A_buf, ki, thread_binding, - rk) if is_global else _warp_ldmatrix_a_shared( - A_local_buf, A_buf, ki, thread_binding, rk) + return ( + _warp_ldmatrix_a_global(A_local_buf, A_buf, ki, thread_binding, rk) + if is_global + else _warp_ldmatrix_a_shared(A_local_buf, A_buf, ki, thread_binding, rk) + ) def ldmatrix_b(self, B_local_buf, B_buf, ki, rk=0, pid_m=None, pid_n=None): warp_cols = self.warp_cols @@ -867,8 +848,7 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): warp_n * warp_cols + j, rk * (chunk // micro_size_k) + ki, ) - B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, - col] + B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, col] else: for j in T.serial(warp_cols): for local_id in T.vectorized(k_pack * local_size_b): @@ -877,9 +857,10 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): rk * (chunk // micro_size_k) + ki, warp_n * warp_cols + j, ) - B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, - col] + B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, col] - return _warp_ldmatrix_b_global(B_local_buf, B_buf, ki, thread_binding, - rk) if is_global else _warp_ldmatrix_b_shared( - B_local_buf, B_buf, ki, thread_binding, rk) + return ( + _warp_ldmatrix_b_global(B_local_buf, B_buf, ki, thread_binding, rk) + if is_global + else _warp_ldmatrix_b_shared(B_local_buf, B_buf, ki, thread_binding, rk) + ) diff --git a/tilelang/intrinsics/mma_layout.py b/tilelang/intrinsics/mma_layout.py index f49b59569649b36d10b92d729fd7655e03d3cd09..2eb575f0ca1fc3e070e3e9402439dcfdd936131e 100644 --- a/tilelang/intrinsics/mma_layout.py +++ b/tilelang/intrinsics/mma_layout.py @@ -153,14 +153,14 @@ def mma_load_a_32x16_to_shared_16x32_layout(thread_id, local_id): def mma_load_a_32x8_to_shared_16x16_layout(thread_id, local_id): """ - groupID = %laneid >> 2 - threadID_in_group = %laneid % 4 + groupID = %laneid >> 2 + threadID_in_group = %laneid % 4 - row = groupID for ai where 0 <= i < 2 || 4 <= i < 6 - groupID + 8 Otherwise + row = groupID for ai where 0 <= i < 2 || 4 <= i < 6 + groupID + 8 Otherwise - col = (threadID_in_group * 2) + (i & 0x1) for ai where i < 4 - (threadID_in_group * 2) + (i & 0x1) + 8 for ai where i >= 4 + col = (threadID_in_group * 2) + (i & 0x1) for ai where i < 4 + (threadID_in_group * 2) + (i & 0x1) + 8 for ai where i >= 4 """ row = (thread_id // 4) + 8 * (local_id % 4 // 2) col = (thread_id % 4) * 2 + (local_id % 2) + 8 * (local_id // 4) @@ -175,13 +175,13 @@ def mma_load_b_32x16_to_shared_16x32_layout(thread_id, local_id): def mma_load_b_32x8_to_shared_16x16_layout(thread_id, local_id): """ - groupID = %laneid >> 2 - threadID_in_group = %laneid % 4 + groupID = %laneid >> 2 + threadID_in_group = %laneid % 4 - row = (threadID_in_group * 2) + (i & 0x1) for bi where i < 2 - (threadID_in_group * 2) + (i & 0x1) + 8 for bi where i >= 2 + row = (threadID_in_group * 2) + (i & 0x1) for bi where i < 2 + (threadID_in_group * 2) + (i & 0x1) + 8 for bi where i >= 2 - col = groupID + col = groupID """ col = (thread_id % 4) * 2 + ((local_id % 4) % 2) + ((local_id % 4) // 2) * 8 row = (thread_id // 4) + 8 * (local_id // 4) diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index 5811eb534d02af1272f731b84da0ef0ae8ce1b6a..28afdb2912341ccc3066cbd76351e644f8fd1241 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -191,6 +191,7 @@ class TensorCoreIntrinEmitter: def get_store_index_map(self, inverse: bool = False) -> IndexMap: from .utils import mma_store_index_map, mma_store_index_map_fp64 + warp_size, local_size_c = self.WARP_SIZE, self.local_size_out if DataType(self.accum_dtype).bits == 64: index_map = IndexMap.from_func(mma_store_index_map_fp64, index_dtype="int32") @@ -201,10 +202,7 @@ class TensorCoreIntrinEmitter: inverse_index_map = index_map.inverse([warp_size, local_size_c]) return inverse_index_map - def extract_thread_binding( - self, - thread_id: PrimExpr, - is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: + def extract_thread_binding(self, thread_id: PrimExpr, is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: """ is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] @@ -233,11 +231,7 @@ class TensorCoreIntrinEmitter: ) return lane_id, warp_n, warp_m - def ldmatrix_a(self, - A_local_buf: Buffer, - A_shared_buf: Buffer | BufferRegion, - ki: PrimExpr, - rk: PrimExpr | None = 0): + def ldmatrix_a(self, A_local_buf: Buffer, A_shared_buf: Buffer | BufferRegion, ki: PrimExpr, rk: PrimExpr | None = 0): # Fast path for fp64: no ldmatrix support, do direct per-lane loads if DataType(self.a_dtype).bits == 64: warp_row_tiles = self.warp_row_tiles @@ -324,9 +318,7 @@ class TensorCoreIntrinEmitter: for i in T.serial(warp_rows): # Assign A_shared_buf_elem wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k - A_shared_buf_elem = A_buf[A_base0 + wk, - A_base1 + wi] if a_transposed else A_buf[A_base0 + wi, - A_base1 + wk] + A_shared_buf_elem = A_buf[A_base0 + wk, A_base1 + wi] if a_transposed else A_buf[A_base0 + wi, A_base1 + wk] if ldmatrix_available: T.ptx_ldmatrix( @@ -343,20 +335,13 @@ class TensorCoreIntrinEmitter: for j in T.serial(local_size_a): mi, mk = mma_load_layout(tx, j) if a_transposed: - A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wk + mk, - A_base1 + wi + mi] + A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wk + mk, A_base1 + wi + mi] else: - A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wi + mi, - A_base1 + wk + mk] + A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wi + mi, A_base1 + wk + mk] return _warp_ldmatrix_a(A_local_buf, A_region, ki, thread_binding, rk) - def ldmatrix_b(self, - B_local_buf: Buffer, - B_shared_buf: Buffer | BufferRegion, - ki: PrimExpr, - rk: PrimExpr | None = 0): - + def ldmatrix_b(self, B_local_buf: Buffer, B_shared_buf: Buffer | BufferRegion, ki: PrimExpr, rk: PrimExpr | None = 0): # Fast path for fp64: no ldmatrix support, do direct per-lane loads if DataType(self.b_dtype).bits == 64: warp_col_tiles = self.warp_col_tiles @@ -411,7 +396,7 @@ class TensorCoreIntrinEmitter: B_base0 = B_region.region[-2].min B_base1 = B_region.region[-1].min B_stride_last = B_buf.shape[-1] - replicate_b = (self.n_dim == 16) + replicate_b = self.n_dim == 16 # ldmatrix cannot be used for int8 + trans case. ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed) @@ -448,9 +433,7 @@ class TensorCoreIntrinEmitter: ) if ldmatrix_available: - B_shared_buf_elem = B_buf[B_base0 + wi, - B_base1 + wk] if b_transposed else B_buf[B_base0 + wk, - B_base1 + wi] + B_shared_buf_elem = B_buf[B_base0 + wi, B_base1 + wk] if b_transposed else B_buf[B_base0 + wk, B_base1 + wi] T.ptx_ldmatrix( b_dtype, @@ -469,19 +452,13 @@ class TensorCoreIntrinEmitter: for j in T.serial(local_size_b): mi, mk = mma_load_layout(tx, j) if b_transposed: - B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi, - B_base1 + wk + mk] + B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi, B_base1 + wk + mk] else: - B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk, - B_base1 + wi + mi] + B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk, B_base1 + wi + mi] return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) - def mma(self, - A_local_buf: Buffer, - B_local_buf: Buffer, - C_local_buf: Buffer, - k_inner: PrimExpr | None = 0): + def mma(self, A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr | None = 0): warp_rows = self.warp_rows warp_cols = self.warp_cols local_size_a = self.local_size_a @@ -492,7 +469,7 @@ class TensorCoreIntrinEmitter: accum_dtype = self.accum_dtype accum_dtype_abbrv = self.accum_dtype_abbrv mma_prefix = self.mma_prefix - replicate_b = (self.n_dim == 16) + replicate_b = self.n_dim == 16 a_is_fragment = is_fragment(A_local_buf) b_is_fragment = is_fragment(B_local_buf) @@ -532,8 +509,7 @@ class TensorCoreIntrinEmitter: B_local_buf.data, b_local_stride + j * local_size_b + lift(local_size_b) // 2, C_local_buf.data, - i * warp_cols * local_size_out + j * local_size_out + - lift(local_size_out) // 2, + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, T.bool(False), # saturate ) @@ -568,14 +544,13 @@ class TensorCoreIntrinEmitter: local_id = local_id_o * 2 + local_id_i row, col = T.meta_var(mma_store_index_map(tx, local_id)) if C_buf_dims == 2: - C_buf[(warp_m * warp_rows + i) * M_DIM + row, - (warp_n * warp_cols + j) * n_dim + - col] = C_local_buf[i * (warp_cols * local_size_out) + - j * local_size_out + local_id] + C_buf[(warp_m * warp_rows + i) * M_DIM + row, (warp_n * warp_cols + j) * n_dim + col] = C_local_buf[ + i * (warp_cols * local_size_out) + j * local_size_out + local_id + ] else: - C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, - col] = C_local_buf[i * (warp_cols * local_size_out) + - j * local_size_out + local_id] + C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, col] = C_local_buf[ + i * (warp_cols * local_size_out) + j * local_size_out + local_id + ] @T.macro def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding): @@ -588,15 +563,15 @@ class TensorCoreIntrinEmitter: C_buf[ (pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, (pid_n * BLOCK_N + warp_n * warp_cols + j) * n_dim + col, - ] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + - local_id] + ] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + local_id] - return (_warp_stmatrix_global(C_local_buf, C_buf, thread_binding) - if is_global else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding)) + return ( + _warp_stmatrix_global(C_local_buf, C_buf, thread_binding) + if is_global + else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding) + ) - def make_mma_load_layout(self, - local_buf: Buffer, - matrix: Literal["A", "B"] = "A") -> T.Fragment: + def make_mma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A") -> T.Fragment: """ Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with `inverse_mma_store_layout` to @@ -619,6 +594,7 @@ class TensorCoreIntrinEmitter: If `local_buf` is not detected to be a fragment buffer. """ from tilelang.utils import is_fragment + assert matrix in ["A", "B"], "matrix should be either A or B" matrix_is_a: bool = matrix == "A" matrix_is_b: bool = matrix == "B" @@ -655,11 +631,9 @@ class TensorCoreIntrinEmitter: # so the b matrix expected a transposed basic layout transform_func: Callable = None if matrix_is_a: - transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( - j, i) + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) elif matrix_is_b: - transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b( - j, i) + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i) else: raise ValueError(f"Unsupported matrix {matrix}") @@ -706,31 +680,19 @@ class TensorCoreIntrinEmitter: replicate = block_col_warps if matrix_is_a else block_row_warps if is_sr_axis_order: - warp_fragment = base_fragment.repeat([warp_s, warp_r], - repeat_on_thread=False, - lower_dim_first=False) + warp_fragment = base_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False) if matrix_is_a: - block_fragment = warp_fragment.repeat([block_s, 1], - repeat_on_thread=True, - lower_dim_first=True).replicate(replicate) + block_fragment = warp_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) elif matrix_is_b: - block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], - repeat_on_thread=True, - lower_dim_first=True) + block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True) else: raise ValueError(f"Unsupported matrix type {matrix}") else: - warp_fragment = base_fragment.repeat([warp_r, warp_s], - repeat_on_thread=False, - lower_dim_first=True) + warp_fragment = base_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True) if matrix_is_a: - block_fragment = warp_fragment.repeat([1, block_s], - repeat_on_thread=True, - lower_dim_first=True).replicate(replicate) + block_fragment = warp_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) elif matrix_is_b: - block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], - repeat_on_thread=True, - lower_dim_first=True) + block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True) else: raise ValueError(f"Unsupported matrix type {matrix}") @@ -761,8 +723,7 @@ class TensorCoreIntrinEmitter: from tilelang.utils import is_fragment shape = local_buf.shape - assert is_fragment( - local_buf), f"local_buf {local_buf} must be a fragment, but got {local_buf.scope()}" + assert is_fragment(local_buf), f"local_buf {local_buf} must be a fragment, but got {local_buf.scope()}" inverse_mma_store_layout = self.get_store_index_map(inverse=True) micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y @@ -954,10 +915,12 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): ".b16", A_local_buf.data, i * local_size_a, - T.address_of(A_shared_buf[ - warp_m * warp_row_tiles + i * micro_size_x, - rk * chunk + ki * micro_size_k, - ]), + T.address_of( + A_shared_buf[ + warp_m * warp_row_tiles + i * micro_size_x, + rk * chunk + ki * micro_size_k, + ] + ), get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed), ) elif transform_kind_a == TransformKind.InterWarpTransform: @@ -1019,10 +982,8 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): warp_m * warp_rows + j, rk * (chunk // micro_size_k) + ki, ) - rii, rjj = (tx * local_size_a + - local_id) // micro_size_k, (tx * local_size_a + local_id) % ( - micro_size_k) - A_local_buf[j * local_size_a + local_id] = (A_shared_buf[ri, rj, rii, rjj]) + rii, rjj = (tx * local_size_a + local_id) // micro_size_k, (tx * local_size_a + local_id) % (micro_size_k) + A_local_buf[j * local_size_a + local_id] = A_shared_buf[ri, rj, rii, rjj] else: raise ValueError("Unsupported TransformKind for Input A") @@ -1131,12 +1092,11 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): warp_n * warp_cols + j, rk * (chunk // micro_size_k) + ki, ) - rii, rjj = (tx * local_size_dequantize + - local_id) // (micro_size_k // num_elems_per_byte), ( - tx * local_size_dequantize + local_id) % ( - micro_size_k // num_elems_per_byte) - B_local_buf[j * local_size_dequantize + local_id] = ( - B_shared_buf[ri, rj, rii, rjj]) + rii, rjj = ( + (tx * local_size_dequantize + local_id) // (micro_size_k // num_elems_per_byte), + (tx * local_size_dequantize + local_id) % (micro_size_k // num_elems_per_byte), + ) + B_local_buf[j * local_size_dequantize + local_id] = B_shared_buf[ri, rj, rii, rjj] else: raise ValueError("Unsupported TransformKind for Input B") @@ -1195,7 +1155,6 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): class INT4TensorCoreIntrinEmitter(TensorCoreIntrinEmitter): - def mma(self, A_local_buf, B_local_buf, C_local_buf): warp_rows = self.warp_rows warp_cols = self.warp_cols @@ -1298,9 +1257,7 @@ class INT4TensorCoreIntrinEmitter(TensorCoreIntrinEmitter): class INT4TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitterWithLadderTransform): - def mma(self, A_local_buf, B_local_buf, C_local_buf): - warp_rows = self.warp_rows warp_cols = self.warp_cols local_size_a = self.local_size_a diff --git a/tilelang/intrinsics/mma_sm70_layout.py b/tilelang/intrinsics/mma_sm70_layout.py index e7a57da767c658e6544febac28d2456f4d7e5386..8029234414710af6e923b9d68e0df9a4fc9f4801 100644 --- a/tilelang/intrinsics/mma_sm70_layout.py +++ b/tilelang/intrinsics/mma_sm70_layout.py @@ -17,10 +17,8 @@ def shared_16x4_to_mma_b_32x4_layout_trans(row, col, rep): def mma_32x8_to_shared_16x16_layout_fp32(thread_id, local_id): - row = (thread_id % 2) + ( - (local_id // 2 % 2) * 2) + 4 * (thread_id // 16) + (thread_id % 16 // 4) % 2 * 8 - col = (thread_id % 4 // 2) * 2 + (thread_id % 16 // 8) * 4 + (local_id % - 2) + (local_id // 4) * 8 + row = (thread_id % 2) + ((local_id // 2 % 2) * 2) + 4 * (thread_id // 16) + (thread_id % 16 // 4) % 2 * 8 + col = (thread_id % 4 // 2) * 2 + (thread_id % 16 // 8) * 4 + (local_id % 2) + (local_id // 4) * 8 return row, col @@ -31,7 +29,7 @@ def mma_32x8_to_shared_16x16_layout_fp16(thread_id, local_id): def mma_load_a_32x4_to_shared_16x4_layout(thread_id, local_id): - row = (thread_id % 4) + (4 * (((thread_id // 16 + thread_id % 16 // 4 * 2)) % 4)) + row = (thread_id % 4) + (4 * ((thread_id // 16 + thread_id % 16 // 4 * 2) % 4)) col = local_id return row, col diff --git a/tilelang/intrinsics/mma_sm70_macro_generator.py b/tilelang/intrinsics/mma_sm70_macro_generator.py index 782480816fce71478f220aa51a0e9f0a678bec9e..3186adb2afc58898b2dc6e56bd97cd18ce78748f 100644 --- a/tilelang/intrinsics/mma_sm70_macro_generator.py +++ b/tilelang/intrinsics/mma_sm70_macro_generator.py @@ -147,18 +147,15 @@ class TensorCoreIntrinEmitter: def get_store_index_map(self, inverse: bool = False) -> IndexMap: warp_size, local_size_c = self.WARP_SIZE, self.local_size_out index_map = IndexMap.from_func( - mma_32x8_to_shared_16x16_layout_fp32 - if self.accum_dtype == "float32" else mma_32x8_to_shared_16x16_layout_fp16, - index_dtype="int32") + mma_32x8_to_shared_16x16_layout_fp32 if self.accum_dtype == "float32" else mma_32x8_to_shared_16x16_layout_fp16, + index_dtype="int32", + ) if not inverse: return index_map inverse_index_map = index_map.inverse([warp_size, local_size_c]) return inverse_index_map - def extract_thread_binding( - self, - thread_id: PrimExpr, - is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: + def extract_thread_binding(self, thread_id: PrimExpr, is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: """ is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] @@ -187,11 +184,7 @@ class TensorCoreIntrinEmitter: ) return lane_id, warp_n, warp_m - def ldmatrix_a(self, - A_local_buf: Buffer, - A_shared_buf: Buffer | BufferRegion, - ki: PrimExpr, - rk: PrimExpr | None = 0): + def ldmatrix_a(self, A_local_buf: Buffer, A_shared_buf: Buffer | BufferRegion, ki: PrimExpr, rk: PrimExpr | None = 0): warp_row_tiles = self.warp_row_tiles warp_rows = self.warp_rows chunk = self.chunk @@ -231,11 +224,7 @@ class TensorCoreIntrinEmitter: return _warp_ldmatrix_a(A_local_buf, A_region, ki, thread_binding, rk) - def ldmatrix_b(self, - B_local_buf: Buffer, - B_shared_buf: Buffer | BufferRegion, - ki: PrimExpr, - rk: PrimExpr | None = 0): + def ldmatrix_b(self, B_local_buf: Buffer, B_shared_buf: Buffer | BufferRegion, ki: PrimExpr, rk: PrimExpr | None = 0): warp_col_tiles = self.warp_col_tiles warp_cols = self.warp_cols chunk = self.chunk @@ -274,20 +263,14 @@ class TensorCoreIntrinEmitter: for j in T.vectorized(local_size_b): if b_transposed: mi, mk = mma_load_layout(tx, j) - B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi, - B_base1 + wk + mk] + B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi, B_base1 + wk + mk] else: mk, mi = mma_load_layout(tx, j) - B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk, - B_base1 + wi + mi] + B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk, B_base1 + wi + mi] return _warp_ldmatrix_b(B_local_buf, B_region, ki, thread_binding, rk) - def mma(self, - A_local_buf: Buffer, - B_local_buf: Buffer, - C_local_buf: Buffer, - k_inner: PrimExpr | None = 0): + def mma(self, A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr | None = 0): warp_rows = self.warp_rows warp_cols = self.warp_cols local_size_a = self.local_size_a @@ -326,9 +309,7 @@ class TensorCoreIntrinEmitter: return _warp_mma(A_local_buf, B_local_buf, C_local_buf) - def make_mma_load_layout(self, - local_buf: Buffer, - matrix: Literal["A", "B"] = "A") -> T.Fragment: + def make_mma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A") -> T.Fragment: """ Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with `inverse_mma_store_layout` to @@ -351,6 +332,7 @@ class TensorCoreIntrinEmitter: If `local_buf` is not detected to be a fragment buffer. """ from tilelang.utils import is_fragment + assert matrix in ["A", "B"], "matrix should be either A or B" matrix_is_a: bool = matrix == "A" matrix_is_b: bool = matrix == "B" @@ -383,11 +365,9 @@ class TensorCoreIntrinEmitter: # so the b matrix expected a transposed basic layout transform_func: Callable = None if matrix_is_a: - transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( - j, i) + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) elif matrix_is_b: - transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_rs_b( - i, j) + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_rs_b(i, j) else: raise ValueError(f"Unsupported matrix {matrix}") @@ -413,9 +393,8 @@ class TensorCoreIntrinEmitter: return lane_id, local_id base_fragment = T.Fragment( - [micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s], - forward_fn=forward, - replicate=2) + [micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s], forward_fn=forward, replicate=2 + ) warp_rows, warp_cols = self.warp_rows, self.warp_cols chunk = self.chunk @@ -426,31 +405,19 @@ class TensorCoreIntrinEmitter: replicate = block_col_warps if matrix_is_a else block_row_warps if is_sr_axis_order: - warp_fragment = base_fragment.repeat([warp_s, warp_r], - repeat_on_thread=False, - lower_dim_first=False) + warp_fragment = base_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False) if matrix_is_a: - block_fragment = warp_fragment.repeat([block_s, 1], - repeat_on_thread=True, - lower_dim_first=True).replicate(replicate) + block_fragment = warp_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) elif matrix_is_b: - block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], - repeat_on_thread=True, - lower_dim_first=True) + block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True) else: raise ValueError(f"Unsupported matrix type {matrix}") else: - warp_fragment = base_fragment.repeat([warp_r, warp_s], - repeat_on_thread=False, - lower_dim_first=True) + warp_fragment = base_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True) if matrix_is_a: - block_fragment = warp_fragment.repeat([1, block_s], - repeat_on_thread=True, - lower_dim_first=True).replicate(replicate) + block_fragment = warp_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) elif matrix_is_b: - block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], - repeat_on_thread=True, - lower_dim_first=True) + block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True) else: raise ValueError(f"Unsupported matrix type {matrix}") diff --git a/tilelang/intrinsics/mma_sp_layout.py b/tilelang/intrinsics/mma_sp_layout.py index bae86bf45c2a99dd15f3f789aea3bb71f4eb0900..58034e7fdba90cf6bc8408db6e3ad8c10ce0661a 100644 --- a/tilelang/intrinsics/mma_sp_layout.py +++ b/tilelang/intrinsics/mma_sp_layout.py @@ -72,56 +72,47 @@ def get_logical_id_32bit(thread_id: int) -> int: return (thread_id // 4) * 2 + (thread_id % 4) % 2 -def metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(thread_id: int, - local_id: int) -> tuple[int, int]: +def metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(thread_id: int, local_id: int) -> tuple[int, int]: logical_id = get_logical_id_32bit(thread_id) row = logical_id // 4 + local_id * 8 col = logical_id % 4 return row, col -def metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(thread_id: int, - local_id: int) -> tuple[int, int]: +def metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(thread_id: int, local_id: int) -> tuple[int, int]: logical_id = get_logical_id_32bit(thread_id) row = logical_id // 2 + local_id * 8 col = logical_id % 2 return row, col -def metadata_8bit_load_32x4_to_shared_16x4_layout_16bit(thread_id: int, - local_id: int) -> tuple[int, int]: - return metadata_8bit_load_32x4_to_shared_16x4_layout_32bit( - thread_id, local_id) # same mapping for 16bit and 32bit +def metadata_8bit_load_32x4_to_shared_16x4_layout_16bit(thread_id: int, local_id: int) -> tuple[int, int]: + return metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(thread_id, local_id) # same mapping for 16bit and 32bit -def metadata_16bit_load_32x2_to_shared_16x2_layout_16bit(thread_id: int, - local_id: int) -> tuple[int, int]: - return metadata_16bit_load_32x2_to_shared_16x2_layout_32bit( - thread_id, local_id) # same mapping for 16bit and 32bit +def metadata_16bit_load_32x2_to_shared_16x2_layout_16bit(thread_id: int, local_id: int) -> tuple[int, int]: + return metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(thread_id, local_id) # same mapping for 16bit and 32bit def get_logical_id_8bit(thread_id: int) -> int: return thread_id -def metadata_8bit_load_32x4_to_shared_16x4_layout_8bit(thread_id: int, - local_id: int) -> tuple[int, int]: +def metadata_8bit_load_32x4_to_shared_16x4_layout_8bit(thread_id: int, local_id: int) -> tuple[int, int]: logical_id = get_logical_id_8bit(thread_id) row = logical_id // 2 + local_id * 8 col = (logical_id % 4) // 2 * 4 + local_id return row, col -def metadata_16bit_load_32x2_to_shared_16x4_layout_8bit(thread_id: int, - local_id: int) -> tuple[int, int]: +def metadata_16bit_load_32x2_to_shared_16x4_layout_8bit(thread_id: int, local_id: int) -> tuple[int, int]: logical_id = get_logical_id_8bit(thread_id) row = logical_id // 2 + local_id * 8 col = (logical_id % 4) // 2 * 2 + local_id return row, col -def metadata_32bit_load_32x1_to_shared_16x2_layout_8bit(thread_id: int, - local_id: int) -> tuple[int, int]: +def metadata_32bit_load_32x1_to_shared_16x2_layout_8bit(thread_id: int, local_id: int) -> tuple[int, int]: # local_id is always 0 logical_id = get_logical_id_8bit(thread_id) row = logical_id // 4 + (logical_id % 2) * 8 diff --git a/tilelang/intrinsics/mma_sp_macro_generator.py b/tilelang/intrinsics/mma_sp_macro_generator.py index 629d95d997db09caaf238a72c75e3e6d8e636f13..ea7aa8992683611c29920e925234a52ac52ad9e8 100644 --- a/tilelang/intrinsics/mma_sp_macro_generator.py +++ b/tilelang/intrinsics/mma_sp_macro_generator.py @@ -190,8 +190,7 @@ class SparseTensorCoreIntrinEmitter: def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): self.local_size_a = (m_dim * k_dim) // warp_size // self.SPARSE_FACTOR - self.local_size_e = ( - m_dim * k_dim) // self.e_factor // warp_size * self.E_REPLICATE_FACTOR[self.a_dtype] + self.local_size_e = (m_dim * k_dim) // self.e_factor // warp_size * self.E_REPLICATE_FACTOR[self.a_dtype] self.local_size_b = (n_dim * k_dim) // warp_size self.local_size_out = (m_dim * n_dim) // warp_size @@ -257,10 +256,7 @@ class SparseTensorCoreIntrinEmitter: inverse_index_map = index_map.inverse([warp_size, local_size_c]) return inverse_index_map - def extract_thread_binding( - self, - thread_id: PrimExpr, - is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: + def extract_thread_binding(self, thread_id: PrimExpr, is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: """ is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] @@ -330,8 +326,7 @@ class SparseTensorCoreIntrinEmitter: for i in T.serial(warp_rows): # Assign A_shared_buf_elem - wi, wk = warp_m * warp_row_tiles + i * micro_size_x, ( - rk * warp_k + ki * micro_size_k) // self.SPARSE_FACTOR + wi, wk = warp_m * warp_row_tiles + i * micro_size_x, (rk * warp_k + ki * micro_size_k) // self.SPARSE_FACTOR A_shared_buf_elem = A_shared_buf[wk, wi] if a_transposed else A_shared_buf[wi, wk] if ldmatrix_available: @@ -348,10 +343,9 @@ class SparseTensorCoreIntrinEmitter: else: for j in T.serial(local_size_a): mi, mk = mma_load_layout(tx, j) - A_local_buf[i * local_size_a + - j] = A_shared_buf[wk + mk, wi + - mi] if a_transposed else A_shared_buf[wi + mi, - wk + mk] + A_local_buf[i * local_size_a + j] = ( + A_shared_buf[wk + mk, wi + mi] if a_transposed else A_shared_buf[wi + mi, wk + mk] + ) return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) @@ -412,14 +406,10 @@ class SparseTensorCoreIntrinEmitter: tx, _, warp_m = self.extract_thread_binding(thread_binding) for i in T.serial(warp_rows): # Assign E_shared_buf_elem - wi, wk = warp_m * warp_row_tiles + i * micro_size_x, ( - rk * warp_k + ki * micro_size_k) // self.e_factor + wi, wk = warp_m * warp_row_tiles + i * micro_size_x, (rk * warp_k + ki * micro_size_k) // self.e_factor for j in T.serial(local_size_e): mi, mk = mma_load_layout(tx, j) - E_local_buf[i * local_size_e + - j] = E_shared_buf[wk + mk, - wi + mi] if trans else E_shared_buf[wi + mi, - wk + mk] + E_local_buf[i * local_size_e + j] = E_shared_buf[wk + mk, wi + mi] if trans else E_shared_buf[wi + mi, wk + mk] return _warp_ldmatrix_e(E_local_buf, E_shared_buf, ki, thread_binding, rk) @@ -433,7 +423,7 @@ class SparseTensorCoreIntrinEmitter: b_dtype = self.b_dtype b_transposed = self.b_transposed thread_binding = self.get_thread_binding() - replicate_b = (self.n_dim == 16) + replicate_b = self.n_dim == 16 # ldmatrix cannot be used for int8 + trans case. ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed) @@ -470,8 +460,7 @@ class SparseTensorCoreIntrinEmitter: ) if ldmatrix_available: - B_shared_buf_elem = B_shared_buf[wi, wk] if b_transposed else B_shared_buf[wk, - wi] + B_shared_buf_elem = B_shared_buf[wi, wk] if b_transposed else B_shared_buf[wk, wi] if replicate_b: T.ptx_ldmatrix( @@ -493,9 +482,7 @@ class SparseTensorCoreIntrinEmitter: B_local_buf.data, i * local_size_b + lift(local_size_b) // 2, T.address_of(B_shared_buf_elem), - get_ldmatrix_offset_b("B", tx, - lift(local_size_b) // 2, stride, b_dtype, - b_transposed), + get_ldmatrix_offset_b("B", tx, lift(local_size_b) // 2, stride, b_dtype, b_transposed), ) else: T.ptx_ldmatrix( @@ -514,19 +501,13 @@ class SparseTensorCoreIntrinEmitter: # must be transposed. for j in T.serial(local_size_b): mi, mk = mma_load_layout(tx, j) - B_local_buf[i * local_size_b + - j] = B_shared_buf[wi + mi, wk + - mk] if b_transposed else B_shared_buf[wk + mk, - wi + mi] + B_local_buf[i * local_size_b + j] = ( + B_shared_buf[wi + mi, wk + mk] if b_transposed else B_shared_buf[wk + mk, wi + mi] + ) return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) - def mma_sp(self, - A_local_buf: Buffer, - E_local_buf: Buffer, - B_local_buf: Buffer, - C_local_buf: Buffer, - k_inner: PrimExpr = 0): + def mma_sp(self, A_local_buf: Buffer, E_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr = 0): warp_rows = self.warp_rows warp_cols = self.warp_cols local_size_a = self.local_size_a @@ -538,7 +519,7 @@ class SparseTensorCoreIntrinEmitter: accum_dtype = self.accum_dtype accum_dtype_abbrv = self.accum_dtype_abbrv mma_prefix = self.mma_prefix - replicate_b = (self.n_dim == 16) + replicate_b = self.n_dim == 16 a_is_fragment = is_fragment(A_local_buf) e_is_fragment = is_fragment(E_local_buf) @@ -584,8 +565,7 @@ class SparseTensorCoreIntrinEmitter: B_local_buf.data, b_local_stride + j * local_size_b + lift(local_size_b) // 2, C_local_buf.data, - i * warp_cols * local_size_out + j * local_size_out + - lift(local_size_out) // 2, + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, E_local_buf.data, # metadata e_local_stride + i * local_size_e, # metadata offset self.SPARSE_SELECTOR, # sparse_selector @@ -623,14 +603,13 @@ class SparseTensorCoreIntrinEmitter: local_id = local_id_o * 2 + local_id_i row, col = T.meta_var(mma_store_index_map(tx, local_id)) if C_buf_dims == 2: - C_buf[(warp_m * warp_rows + i) * M_DIM + row, - (warp_n * warp_cols + j) * n_dim + - col] = C_local_buf[i * (warp_cols * local_size_out) + - j * local_size_out + local_id] + C_buf[(warp_m * warp_rows + i) * M_DIM + row, (warp_n * warp_cols + j) * n_dim + col] = C_local_buf[ + i * (warp_cols * local_size_out) + j * local_size_out + local_id + ] else: - C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, - col] = C_local_buf[i * (warp_cols * local_size_out) + - j * local_size_out + local_id] + C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, col] = C_local_buf[ + i * (warp_cols * local_size_out) + j * local_size_out + local_id + ] @T.macro def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding): @@ -643,15 +622,15 @@ class SparseTensorCoreIntrinEmitter: C_buf[ (pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, (pid_n * BLOCK_N + warp_n * warp_cols + j) * n_dim + col, - ] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + - local_id] + ] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + local_id] - return (_warp_stmatrix_global(C_local_buf, C_buf, thread_binding) - if is_global else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding)) + return ( + _warp_stmatrix_global(C_local_buf, C_buf, thread_binding) + if is_global + else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding) + ) - def make_mma_load_layout(self, - local_buf: Buffer, - matrix: Literal["A", "B"] = "A") -> T.Fragment: + def make_mma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A") -> T.Fragment: """ Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with `inverse_mma_store_layout` to @@ -674,6 +653,7 @@ class SparseTensorCoreIntrinEmitter: If `local_buf` is not detected to be a fragment buffer. """ from tilelang.utils import is_fragment + assert matrix in ["A", "B"], "matrix should be either A or B" matrix_is_a: bool = matrix == "A" matrix_is_b: bool = matrix == "B" @@ -710,11 +690,9 @@ class SparseTensorCoreIntrinEmitter: # so the b matrix expected a transposed basic layout transform_func: Callable = None if matrix_is_a: - transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( - j, i) + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) elif matrix_is_b: - transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b( - j, i) + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i) else: raise ValueError(f"Unsupported matrix {matrix}") @@ -747,7 +725,8 @@ class SparseTensorCoreIntrinEmitter: return local_id base_fragment = T.Fragment( - [micro_size_s, micro_size_r // 2 if matrix_is_a else micro_size_r] if is_sr_axis_order + [micro_size_s, micro_size_r // 2 if matrix_is_a else micro_size_r] + if is_sr_axis_order else [micro_size_r // 2 if matrix_is_a else micro_size_r, micro_size_s], forward_thread_fn=forward_thread, forward_index_fn=forward_index, @@ -762,31 +741,19 @@ class SparseTensorCoreIntrinEmitter: replicate = block_col_warps if matrix_is_a else block_row_warps if is_sr_axis_order: - warp_fragment = base_fragment.repeat([warp_s, warp_r], - repeat_on_thread=False, - lower_dim_first=False) + warp_fragment = base_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False) if matrix_is_a: - block_fragment = warp_fragment.repeat([block_s, 1], - repeat_on_thread=True, - lower_dim_first=True).replicate(replicate) + block_fragment = warp_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) elif matrix_is_b: - block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], - repeat_on_thread=True, - lower_dim_first=True) + block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=True) else: raise ValueError(f"Unsupported matrix type {matrix}") else: - warp_fragment = base_fragment.repeat([warp_r, warp_s], - repeat_on_thread=False, - lower_dim_first=True) + warp_fragment = base_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True) if matrix_is_a: - block_fragment = warp_fragment.repeat([1, block_s], - repeat_on_thread=True, - lower_dim_first=True).replicate(replicate) + block_fragment = warp_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True).replicate(replicate) elif matrix_is_b: - block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], - repeat_on_thread=True, - lower_dim_first=True) + block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], repeat_on_thread=True, lower_dim_first=True) else: raise ValueError(f"Unsupported matrix type {matrix}") diff --git a/tilelang/intrinsics/tcgen05_macro_generator.py b/tilelang/intrinsics/tcgen05_macro_generator.py index 966f4dc494fcb6825d78212ea736a1929f013ee2..26208d6ce0c1492f6743a701f4236b19f014fbc9 100644 --- a/tilelang/intrinsics/tcgen05_macro_generator.py +++ b/tilelang/intrinsics/tcgen05_macro_generator.py @@ -88,9 +88,22 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): is_m_first: bool = False, thread_var: Var | None = None, ): - super().__init__(a_dtype, b_dtype, accum_dtype, a_transposed, b_transposed, block_row_warps, - block_col_warps, warp_row_tiles, warp_col_tiles, chunk, reduce_k, - num_elems_per_byte, is_m_first, thread_var) + super().__init__( + a_dtype, + b_dtype, + accum_dtype, + a_transposed, + b_transposed, + block_row_warps, + block_col_warps, + warp_row_tiles, + warp_col_tiles, + chunk, + reduce_k, + num_elems_per_byte, + is_m_first, + thread_var, + ) def _assign_a_shared_layout(self, layout: Layout): self.a_shared_layout = layout @@ -137,13 +150,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): else: raise ValueError(f"Unsupported swizzle mode: {layout}") - def tcgen05mma(self, - A_buf: Buffer, - B_buf: Buffer, - C_local_buf: Buffer, - mbar, - clear_accum: PrimExpr = False): - + def tcgen05mma(self, A_buf: Buffer, B_buf: Buffer, C_local_buf: Buffer, mbar, clear_accum: PrimExpr = False): if is_tensor_memory(A_buf): return self.tcgen05mma_rs(A_buf, B_buf, C_local_buf, clear_accum) @@ -164,22 +171,20 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): elems_in_bits = DataType(self.a_dtype).bits elems_in_bytes = elems_in_bits // 8 a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes - b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none( - ) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes + b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes accum_dtype_in_bits = DataType(accum_dtype).bits meta = self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim) if len(meta) != 5: raise ValueError( f"Unsupported TCGEN5MMA configuration for desc generation: M={m_dim}, N={n_dim}, " - f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}") + f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}" + ) atom_m, atom_n, atom_k, enable_ws, enable_2cta = (int(x) for x in meta) # by default, we utilize non-swizzle layout offset - a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * - elems_in_bytes) - a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * - elems_in_bytes) + a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * elems_in_bytes) + a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * elems_in_bytes) if not a_swizzle_mode.is_none(): # swizzle mode doesn't require LBO/SBO to be 1 @@ -202,11 +207,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): else: a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems - b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * - elems_in_bytes) - b_stride_byte_offset = (8 * k_dim * - elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else - (8 * 8 * elems_in_bytes)) + b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes)) if not b_swizzle_mode.is_none(): # swizzle mode doesn't require LBO/SBO to be 1 # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset @@ -312,21 +314,26 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): for ki in T.unroll(0, (k_dim // micro_size_k)): scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1)) A_elem_offset = ( - ki % ak_atom_size - ) * micro_size_k + i * atom_m * a_swizzle_atom_elems + ( - ki // ak_atom_size - ) * m_dim * a_swizzle_atom_elems if a_is_k_major else i * atom_m * k_dim + ki * a_swizzle_atom_elems * micro_size_k + (ki % ak_atom_size) * micro_size_k + + i * atom_m * a_swizzle_atom_elems + + (ki // ak_atom_size) * m_dim * a_swizzle_atom_elems + if a_is_k_major + else i * atom_m * k_dim + ki * a_swizzle_atom_elems * micro_size_k + ) - B_elem_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + ( - ki % bk_atom_size - ) * micro_size_k + j * atom_n * b_swizzle_atom_elems if b_is_k_major else ( - ki * b_swizzle_atom_elems * micro_size_k + j * atom_n * - (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1)) + B_elem_offset = ( + (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + + (ki % bk_atom_size) * micro_size_k + + j * atom_n * b_swizzle_atom_elems + if b_is_k_major + else ( + ki * b_swizzle_atom_elems * micro_size_k + j * atom_n * (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1) + ) + ) A_byte_offset = A_elem_offset * elems_in_bytes B_byte_offset = B_elem_offset * elems_in_bytes - C_offset = (i * n_dim + j * tmem_col_step - ) * accum_dtype_in_bits // 32 # 32 bits per tmem bank + C_offset = (i * n_dim + j * tmem_col_step) * accum_dtype_in_bits // 32 # 32 bits per tmem bank T.ptx_tcgen05_mma_ss( a_dtype_abbrv, @@ -373,8 +380,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): """ assert is_tensor_memory(tmem_buf), "tmem_buf must reside in tensor memory (shared.tmem)" if len(tmem_buf.shape) != 2: - raise ValueError( - f"TCGEN5MMA expects a 2-D tensor-memory buffer, got shape {tmem_buf.shape}") + raise ValueError(f"TCGEN5MMA expects a 2-D tensor-memory buffer, got shape {tmem_buf.shape}") m = int(tmem_buf.shape[0]) n = int(tmem_buf.shape[1]) @@ -382,14 +388,13 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): meta = self.get_tcgen5_mma_meta(m, n, k) if len(meta) != 5: - raise ValueError(f"Unsupported TCGEN5MMA configuration: M={m}, N={n}, K={k}, " - f"A dtype={self.a_dtype}, accum dtype={self.accum_dtype}") + raise ValueError( + f"Unsupported TCGEN5MMA configuration: M={m}, N={n}, K={k}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}" + ) atom_m, atom_n, _, _, _ = (int(x) for x in meta) if m % atom_m != 0 or n % atom_n != 0: - raise ValueError( - f"Invalid TCGEN5MMA store layout for shape ({m}, {n}) with atoms ({atom_m}, {atom_n})" - ) + raise ValueError(f"Invalid TCGEN5MMA store layout for shape ({m}, {n}) with atoms ({atom_m}, {atom_n})") def forward(i: PrimExpr, j: PrimExpr): atom_idx = (i // atom_m) + (j // atom_n) * (m // atom_m) @@ -422,11 +427,11 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): return Layout([m, n], forward) def get_tcgen5_mma_meta(self, m: int, n: int, k: int): - return _ffi_api.get_tcgen5_mma_meta( - int(m), int(n), int(k), DataType(self.a_dtype), DataType(self.accum_dtype)) + return _ffi_api.get_tcgen5_mma_meta(int(m), int(n), int(k), DataType(self.a_dtype), DataType(self.accum_dtype)) - def get_tcgen5_instr_desc(self, atom_m: int, atom_n: int, atom_k: int, a_is_k_major: bool, - b_is_k_major: bool, scale_in_a: int, scale_in_b: int) -> PrimExpr: + def get_tcgen5_instr_desc( + self, atom_m: int, atom_n: int, atom_k: int, a_is_k_major: bool, b_is_k_major: bool, scale_in_a: int, scale_in_b: int + ) -> PrimExpr: desc = _ffi_api.get_tcgen5_instr_desc( atom_m, atom_n, diff --git a/tilelang/intrinsics/utils.py b/tilelang/intrinsics/utils.py index 7fc9bab134052bf6ac9dca197e9b1269732e1402..fb24a4add2f157eb296f21b4dad185d83ae6e243 100644 --- a/tilelang/intrinsics/utils.py +++ b/tilelang/intrinsics/utils.py @@ -10,7 +10,7 @@ from .mma_layout import ( mma_store_32x8_to_shared_16x16_layout, mma_store_32x2_to_shared_8x8_layout_fp64, ) -from .mfma_layout import (thread_id_shared_access_64x4_to_16x16_layout_C_n_m) +from .mfma_layout import thread_id_shared_access_64x4_to_16x16_layout_C_n_m from .mma_layout import get_swizzle_layout # noqa: F401 from .mma_layout import make_mma_swizzle_layout # noqa: F401 diff --git a/tilelang/intrinsics/wgmma_macro_generator.py b/tilelang/intrinsics/wgmma_macro_generator.py index 51a90fba10c60d92c5145d4ef2fd318bca26e911..483b6e7315bd7ba89cb2553f66079d0e16d7bd6a 100644 --- a/tilelang/intrinsics/wgmma_macro_generator.py +++ b/tilelang/intrinsics/wgmma_macro_generator.py @@ -15,9 +15,11 @@ from tilelang.layout import ( make_linear_layout, ) from tvm.runtime import convert -from tilelang.intrinsics.mma_layout import (shared_16x8_to_mma_32x4_layout_sr_a, - shared_16x16_to_mma_32x8_layout_sr_a, - shared_16x32_to_mma_32x16_layout_sr_a) +from tilelang.intrinsics.mma_layout import ( + shared_16x8_to_mma_32x4_layout_sr_a, + shared_16x16_to_mma_32x8_layout_sr_a, + shared_16x32_to_mma_32x16_layout_sr_a, +) lift = convert @@ -96,9 +98,22 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): is_m_first: bool | None = False, thread_var: Var | None = None, ): - super().__init__(a_dtype, b_dtype, accum_dtype, a_transposed, b_transposed, block_row_warps, - block_col_warps, warp_row_tiles, warp_col_tiles, chunk, reduce_k, - num_elems_per_byte, is_m_first, thread_var) + super().__init__( + a_dtype, + b_dtype, + accum_dtype, + a_transposed, + b_transposed, + block_row_warps, + block_col_warps, + warp_row_tiles, + warp_col_tiles, + chunk, + reduce_k, + num_elems_per_byte, + is_m_first, + thread_var, + ) self._initialize_wgmma_prefix(self.n_dim) def _assign_a_shared_layout(self, layout: Layout): @@ -112,12 +127,12 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): def _initialize_wgmma_prefix(self, n_dim: int = 16): inst_m, inst_n = 64, gcd(self.warp_col_tiles, 256) assert inst_n % 8 == 0, ( - f"inst_n must be a multiple of 8, got {inst_n} " - f"(block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})") + f"inst_n must be a multiple of 8, got {inst_n} (block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})" + ) # Validate inst_n: Hopper WGMMA supports n in [8, 256] and multiple of 8 assert 8 <= inst_n <= 256, ( - f"inst_n must be within [8, 256], got {inst_n} " - f"(block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})") + f"inst_n must be within [8, 256], got {inst_n} (block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})" + ) # 256 bits per instruction inst_k = 256 // DataType(self.a_dtype).bits self.wgmma_inst_m = inst_m @@ -160,13 +175,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): else: raise ValueError(f"Unsupported swizzle mode: {layout}") - def wgmma(self, - A_region: BufferRegion, - B_region: BufferRegion, - C_region: BufferRegion, - clear_accum: PrimExpr = False, - wg_wait: int = 0): - + def wgmma( + self, A_region: BufferRegion, B_region: BufferRegion, C_region: BufferRegion, clear_accum: PrimExpr = False, wg_wait: int = 0 + ): if is_fragment(A_region): return self.wgmma_rs(A_region, B_region, C_region, clear_accum, wg_wait) @@ -195,16 +206,13 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): elems_in_bytes = elems_in_bits // 8 a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes - b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none( - ) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes + b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes accum_bits = DataType(accum_dtype).bits accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32 # by default, we utilize non-swizzle layout offset - a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * - elems_in_bytes) - a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * - elems_in_bytes) + a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * elems_in_bytes) + a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * elems_in_bytes) if not a_swizzle_mode.is_none(): # swizzle mode doesn't require LBO/SBO to be 1 @@ -220,19 +228,15 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): if a_m_axis_atoms <= 1: a_leading_byte_offset = 0 else: - a_leading_byte_offset = 8 * a_swizzle_mode.swizzle_atom_size() * ( - a_swizzle_mode.swizzle_byte_size() // elems_in_bytes) + a_leading_byte_offset = 8 * a_swizzle_mode.swizzle_atom_size() * (a_swizzle_mode.swizzle_byte_size() // elems_in_bytes) if a_m_axis_atoms <= 1: a_stride_byte_offset = 8 * elems_in_bytes * m_dim else: a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems - b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * - elems_in_bytes) - b_stride_byte_offset = (8 * k_dim * - elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else - (8 * 8 * elems_in_bytes)) + b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes)) if not b_swizzle_mode.is_none(): # swizzle mode doesn't require LBO/SBO to be 1 # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset @@ -275,12 +279,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): desc_a = T.alloc_wgmma_desc() desc_b = T.alloc_wgmma_desc() - T.initialize_wgmma_descriptor(desc_a, A_ptr, a_swizzle_mode, - int(a_leading_byte_offset >> 4), - int(a_stride_byte_offset >> 4)) - T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode, - int(b_leading_byte_offset >> 4), - int(b_stride_byte_offset >> 4)) + T.initialize_wgmma_descriptor(desc_a, A_ptr, a_swizzle_mode, int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4)) + T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) T.warpgroup_fence_operand(C_buf, num_regs=accum_regs) T.warpgroup_arrive() @@ -291,21 +291,41 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): warp_i = (warp_m // 4) * num_inst_m + i warp_j = warp_n * num_inst_n + j A_offset = ( - ki % ak_atom_size - ) * micro_size_k + warp_i * 64 * a_swizzle_atom_elems + ( - ki // ak_atom_size - ) * m_dim * a_swizzle_atom_elems if a_is_k_major else warp_i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k - B_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + ( - ki % bk_atom_size - ) * micro_size_k + warp_j * wgmma_inst_n * b_swizzle_atom_elems if b_is_k_major else ( - ki * b_swizzle_atom_elems * micro_size_k + warp_j * wgmma_inst_n * - (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1)) + (ki % ak_atom_size) * micro_size_k + + warp_i * 64 * a_swizzle_atom_elems + + (ki // ak_atom_size) * m_dim * a_swizzle_atom_elems + if a_is_k_major + else warp_i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k + ) + B_offset = ( + (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + + (ki % bk_atom_size) * micro_size_k + + warp_j * wgmma_inst_n * b_swizzle_atom_elems + if b_is_k_major + else ( + ki * b_swizzle_atom_elems * micro_size_k + + warp_j * wgmma_inst_n * (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1) + ) + ) C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit - T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major, b_is_k_major, - a_dtype_abbrv, b_dtype_abbrv, accum_dtype_abbrv, desc_a.data, - (A_offset * elems_in_bytes) >> 4, desc_b.data, - (B_offset * elems_in_bytes) >> 4, C_buf.data, C_offset, - scale_out, scale_in_a, scale_in_b) + T.ptx_wgmma_ss( + accum_dtype, + wgmma_prefix, + a_is_k_major, + b_is_k_major, + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + desc_a.data, + (A_offset * elems_in_bytes) >> 4, + desc_b.data, + (B_offset * elems_in_bytes) >> 4, + C_buf.data, + C_offset, + scale_out, + scale_in_a, + scale_in_b, + ) T.warpgroup_commit_batch() if wg_wait >= 0: @@ -314,12 +334,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): return _warp_mma(A_ptr, B_ptr, C_buf) - def wgmma_rs(self, - A_region: BufferRegion, - B_region: BufferRegion, - C_region: BufferRegion, - clear_accum: PrimExpr = False, - wg_wait: int = 0): + def wgmma_rs( + self, A_region: BufferRegion, B_region: BufferRegion, C_region: BufferRegion, clear_accum: PrimExpr = False, wg_wait: int = 0 + ): local_size_a = self.local_size_a local_size_out = self.local_size_out a_dtype_abbrv = self.a_dtype_abbrv @@ -344,14 +361,10 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): b_is_k_major = self.b_transposed b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout) - b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none( - ) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes - - b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * - elems_in_bytes) - b_stride_byte_offset = (8 * k_dim * - elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else - (8 * 8 * elems_in_bytes)) + b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes + + b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes)) if not b_swizzle_mode.is_none(): # swizzle mode doesn't require LBO/SBO to be 1 # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset @@ -390,9 +403,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) desc_b = T.alloc_wgmma_desc() - T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode, - int(b_leading_byte_offset >> 4), - int(b_stride_byte_offset >> 4)) + T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) T.warpgroup_fence_operand(A_buf, num_regs=a_regs) T.warpgroup_fence_operand(C_buf, num_regs=accum_regs) T.warpgroup_arrive() @@ -405,11 +416,15 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): A_offset = ki * warp_rows * local_size_a + i * local_size_a B_offset = ( - ki // bk_atom_size - ) * n_dim * b_swizzle_atom_elems + warp_j * wgmma_inst_n * b_swizzle_atom_elems + ( - ki % bk_atom_size) * micro_size_k if b_is_k_major else ( - ki * b_swizzle_atom_elems * micro_size_k + warp_j * wgmma_inst_n * - (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1)) + (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + + warp_j * wgmma_inst_n * b_swizzle_atom_elems + + (ki % bk_atom_size) * micro_size_k + if b_is_k_major + else ( + ki * b_swizzle_atom_elems * micro_size_k + + warp_j * wgmma_inst_n * (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1) + ) + ) C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit T.ptx_wgmma_rs( accum_dtype, @@ -460,6 +475,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): If `local_buf` is not detected to be a fragment buffer. """ from tilelang.utils import is_fragment + assert matrix in ["A"], "matrix should be A for WGMMA" dtype = self.a_dtype dtype_bits = DataType(dtype).bits @@ -488,8 +504,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): # the layout of mma.sync is row.col. # so the b matrix expected a transposed basic layout transform_func: Callable = None - transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( - j, i) + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}" @@ -531,20 +546,12 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): replicate = block_col_warps if is_sr_axis_order: - warp_fragment = base_fragment.repeat([block_s, 1], - repeat_on_thread=True, - lower_dim_first=False).replicate(replicate) - block_fragment = warp_fragment.repeat([warp_s, warp_r], - repeat_on_thread=False, - lower_dim_first=False) + warp_fragment = base_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=False).replicate(replicate) + block_fragment = warp_fragment.repeat([warp_s, warp_r], repeat_on_thread=False, lower_dim_first=False) else: # rs condition, transposed_a matrix - warp_fragment = base_fragment.repeat([1, block_s], - repeat_on_thread=True, - lower_dim_first=False).replicate(replicate) - block_fragment = warp_fragment.repeat([warp_r, warp_s], - repeat_on_thread=False, - lower_dim_first=True) + warp_fragment = base_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=False).replicate(replicate) + block_fragment = warp_fragment.repeat([warp_r, warp_s], repeat_on_thread=False, lower_dim_first=True) return block_fragment diff --git a/tilelang/ir.py b/tilelang/ir.py index 08d4e96cd437385cbaf81e62f317b3c15e8314c3..b4a7de5ebb22fd43d7f8e23813966e72f9d8ca2a 100644 --- a/tilelang/ir.py +++ b/tilelang/ir.py @@ -7,23 +7,19 @@ from tilelang import _ffi_api @tvm_ffi.register_object("tl.Fill") -class Fill(Node, Scriptable): - ... +class Fill(Node, Scriptable): ... @tvm_ffi.register_object("tl.AtomicAdd") -class AtomicAdd(Node, Scriptable): - ... +class AtomicAdd(Node, Scriptable): ... @tvm_ffi.register_object("tl.Copy") -class Copy(Node, Scriptable): - ... +class Copy(Node, Scriptable): ... @tvm_ffi.register_object("tl.Conv2DIm2Col") -class Conv2DIm2ColOp(Node, Scriptable): - ... +class Conv2DIm2ColOp(Node, Scriptable): ... @tvm_ffi.register_object("tl.GemmWarpPolicy") @@ -32,10 +28,8 @@ class GemmWarpPolicy(Node, Scriptable): m_warp: int n_warp: int - def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target, - is_wgmma: bool): - _ffi_api.GemmWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target, - is_wgmma) + def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target, is_wgmma: bool): + _ffi_api.GemmWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target, is_wgmma) return self.m_warp, self.n_warp @@ -45,48 +39,38 @@ class GemmSPWarpPolicy(Node, Scriptable): m_warp: int n_warp: int - def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target, - is_wgmma: bool, bits: int): - _ffi_api.GemmSPWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target, - is_wgmma, bits) + def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target, is_wgmma: bool, bits: int): + _ffi_api.GemmSPWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target, is_wgmma, bits) return self.m_warp, self.n_warp @tvm_ffi.register_object("tl.Gemm") -class Gemm(Node, Scriptable): - ... +class Gemm(Node, Scriptable): ... @tvm_ffi.register_object("tl.GemmSP") -class GemmSP(Node, Scriptable): - ... +class GemmSP(Node, Scriptable): ... @tvm_ffi.register_object("tl.FinalizeReducerOp") -class FinalizeReducerOp(Node, Scriptable): - ... +class FinalizeReducerOp(Node, Scriptable): ... @tvm_ffi.register_object("tl.ParallelOp") -class ParallelOp(Node, Scriptable): - ... +class ParallelOp(Node, Scriptable): ... @tvm_ffi.register_object("tl.ReduceOp") -class ReduceOp(Node, Scriptable): - ... +class ReduceOp(Node, Scriptable): ... @tvm_ffi.register_object("tl.CumSumOp") -class CumSumOp(Node, Scriptable): - ... +class CumSumOp(Node, Scriptable): ... @tvm_ffi.register_object("tl.RegionOp") -class RegionOp(Node, Scriptable): - ... +class RegionOp(Node, Scriptable): ... @tvm_ffi.register_object("tl.ReduceType") -class ReduceType(Node, Scriptable): - ... +class ReduceType(Node, Scriptable): ... diff --git a/tilelang/jit/__init__.py b/tilelang/jit/__init__.py index 09cbac5ef00fd80a3b52e609eacb6c0474f18e0b..9a5920d7fda269dee0d0fc67d8b78b4f60d9df77 100644 --- a/tilelang/jit/__init__.py +++ b/tilelang/jit/__init__.py @@ -3,6 +3,7 @@ This module provides an auto-tuning infrastructure for TileLang (tl) programs. It includes functionality to JIT-compile TileLang programs into a runnable kernel adapter using TVM. """ + from __future__ import annotations from dataclasses import dataclass @@ -39,17 +40,16 @@ from tqdm.auto import tqdm logger = getLogger(__name__) -_P = ParamSpec('_P') -_KP = ParamSpec('_KP') -_T = TypeVar('_T') -_Ret = TypeVar('_Ret') +_P = ParamSpec("_P") +_KP = ParamSpec("_KP") +_T = TypeVar("_T") +_Ret = TypeVar("_Ret") def compile( func: PrimFunc[_KP, _T] = None, out_idx: list[int] | int | None = None, - execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", - "torch"] = "auto", + execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto", target: str | Target = "auto", target_host: str | Target | None = None, verbose: bool = False, @@ -83,11 +83,9 @@ def compile( if isinstance(compile_flags, str): compile_flags = [compile_flags] - if hasattr(func, 'out_idx_override'): + if hasattr(func, "out_idx_override"): if func.out_idx_override is not None and out_idx is not None: - raise ValueError( - "Out index conflict: out_idx is specified and prim_func have returned `T.empty` tensors" - ) + raise ValueError("Out index conflict: out_idx is specified and prim_func have returned `T.empty` tensors") out_idx = func.out_idx_override or out_idx # This path is not a performance critical path, so we can afford to convert the target. @@ -96,6 +94,7 @@ def compile( # Resolve execution backend (handles aliases, auto, validation per target) requested_backend = execution_backend from tilelang.jit.execution_backend import resolve_execution_backend, allowed_backends_for_target + execution_backend = resolve_execution_backend(requested_backend, target) if verbose: allowed_now = allowed_backends_for_target(target, include_unavailable=False) @@ -119,17 +118,18 @@ def compile( ) -def par_compile(funcs: Iterable[PrimFunc[_KP, _T]], - out_idx: list[int] | int | None = None, - execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", - "torch"] = "auto", - target: str | Target = "auto", - target_host: str | Target | None = None, - verbose: bool = False, - pass_configs: dict[str, Any] | None = None, - compile_flags: list[str] | str | None = None, - num_workers: int = None, - ignore_error: bool = False) -> list[JITKernel[_KP, _T]]: +def par_compile( + funcs: Iterable[PrimFunc[_KP, _T]], + out_idx: list[int] | int | None = None, + execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto", + target: str | Target = "auto", + target_host: str | Target | None = None, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | str | None = None, + num_workers: int = None, + ignore_error: bool = False, +) -> list[JITKernel[_KP, _T]]: """ Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels. Parameters @@ -151,7 +151,7 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]], Additional keyword arguments to pass to the Compiler PassContext. Refer to `tilelang.transform.PassConfigKey` for supported options. """ - with concurrent.futures.ThreadPoolExecutor(num_workers, 'tl-par-comp') as executor: + with concurrent.futures.ThreadPoolExecutor(num_workers, "tl-par-comp") as executor: futures = [] future_map = {} for i, func in enumerate(funcs): @@ -170,9 +170,9 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]], futures.append(future) results = [... for _ in futures] for future in tqdm( - concurrent.futures.as_completed(futures), - total=len(futures), - desc="Parallel Compiling", + concurrent.futures.as_completed(futures), + total=len(futures), + desc="Parallel Compiling", ): idx = future_map[future] if ignore_error: @@ -189,7 +189,7 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]], @dataclass class JITImpl(Generic[_P, _KP, _T, _Ret]): - ''' + """ Detailed Just-In-Time wrapper for TileLang programs. This dataclass encapsulates the configuration and runtime helpers used by the @@ -256,7 +256,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): PrimFunc and the resulting set is compiled in parallel via the module-level `par_compile` helper. Returns a list of JITKernel objects in the same order as the provided configs. - ''' + """ out_idx: list[int] | int | None execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] @@ -302,10 +302,9 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): assert isinstance(tir, PrimFunc), f"target function must be a PrimFunc but got {type(tir)}" return tir - def par_compile(self, - configs: Iterable[dict[str, Any] | tuple[str, Any]], - num_workers: int = None, - ignore_error: bool = False) -> list[JITKernel[_KP, _T]]: + def par_compile( + self, configs: Iterable[dict[str, Any] | tuple[str, Any]], num_workers: int = None, ignore_error: bool = False + ) -> list[JITKernel[_KP, _T]]: """ Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels. Parameters @@ -328,7 +327,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): """ configs = list(configs) funcs = [] - for cfg in tqdm(configs, desc='Elaborating'): + for cfg in tqdm(configs, desc="Elaborating"): if isinstance(cfg, tuple): funcs.append(self.get_tir(*cfg)) elif isinstance(cfg, dict): @@ -345,7 +344,8 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): pass_configs=self.pass_configs, compile_flags=self.compile_flags, num_workers=num_workers, - ignore_error=ignore_error) + ignore_error=ignore_error, + ) def compile(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret: func = self.get_tir(*args, **kwargs) @@ -362,25 +362,25 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): if self.debug_root_path: if isinstance(self.func, PrimFunc): - func_name = self.func.attrs['global_symbol'] + func_name = self.func.attrs["global_symbol"] else: - func_name = getattr(self.func, '__name__', 'jit_kernel') - kernel_file = f'tilelang_jit_kernel_{func_name}.c' - program_file = f'tilelang_jit_program_{func_name}.py' + func_name = getattr(self.func, "__name__", "jit_kernel") + kernel_file = f"tilelang_jit_kernel_{func_name}.c" + program_file = f"tilelang_jit_program_{func_name}.py" makedirs(self.debug_root_path, exist_ok=True) - with open(path.join(self.debug_root_path, kernel_file), 'w') as f: + with open(path.join(self.debug_root_path, kernel_file), "w") as f: print(kernel_result.get_kernel_source(), file=f) - with open(path.join(self.debug_root_path, program_file), 'w') as f: + with open(path.join(self.debug_root_path, program_file), "w") as f: print(func.script(), file=f) return kernel_result def parse_cache_key(self, *args: _P.args, **kwargs: _P.kwargs): if isinstance(self.func, PrimFuncCreater): - tune_params = kwargs.pop('__tune_params', {}) + tune_params = kwargs.pop("__tune_params", {}) return self.func.func_annot.parse_key(*args, **kwargs, **tune_params) else: - tune_params = kwargs.pop('__tune_params', {}) + tune_params = kwargs.pop("__tune_params", {}) key_args_tuple = args key_kwargs_tuple = tuple(sorted(kwargs.items())) tuned_key_kwargs_tuple = tuple(sorted(tune_params.items())) @@ -389,34 +389,31 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): def convert_kernel_args(self, *args: _P.args, **kwargs: _P.kwargs): if isinstance(self.func, PrimFuncCreater): - tune_params = kwargs.pop('__tune_params', {}) + tune_params = kwargs.pop("__tune_params", {}) return self.func.func_annot.convert_to_kernel_args(*args, **kwargs, **tune_params) else: - raise NotImplementedError( - "convert_arg_to_kernel_args is only implemented for PrimFuncCreater.") + raise NotImplementedError("convert_arg_to_kernel_args is only implemented for PrimFuncCreater.") def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret: # Separate out the tuning parameters from the user's kwargs # Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache - return_compile_arguments = kwargs.pop('__return_compile_arguments', False) + return_compile_arguments = kwargs.pop("__return_compile_arguments", False) if return_compile_arguments: - logger.warning( - "`__return_compile_arguments` is deprecated and will be removed in future versions." - ) + logger.warning("`__return_compile_arguments` is deprecated and will be removed in future versions.") compile_args = { - 'out_idx': self.out_idx, - 'execution_backend': self.execution_backend, - 'target': self.target, - 'target_host': self.target_host, - 'verbose': self.verbose, - 'pass_configs': self.pass_configs, - 'compile_flags': self.compile_flags, + "out_idx": self.out_idx, + "execution_backend": self.execution_backend, + "target": self.target, + "target_host": self.target_host, + "verbose": self.verbose, + "pass_configs": self.pass_configs, + "compile_flags": self.compile_flags, } return compile_args key = self.parse_cache_key(*args, **kwargs) - tune_params = kwargs.pop('__tune_params', {}) + tune_params = kwargs.pop("__tune_params", {}) kernel = self._kernel_cache.get(key, None) if kernel is None: @@ -434,8 +431,7 @@ ExecutionBackend = Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvr @overload -def jit(func: Callable[_P, PrimFunc[_KP, _T]]) -> JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]: - ... +def jit(func: Callable[_P, PrimFunc[_KP, _T]]) -> JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]: ... @overload @@ -448,22 +444,22 @@ def jit( verbose: bool = False, pass_configs: dict[str, Any] | None = None, debug_root_path: str | None = None, - compile_flags: list[str] | str | None = None -) -> Callable[[Callable[_P, PrimFunc[_KP, _T]]], JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]]: - ... + compile_flags: list[str] | str | None = None, +) -> Callable[[Callable[_P, PrimFunc[_KP, _T]]], JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]]: ... def jit( # This is the new public interface - func: Callable[_P, _T] | PrimFunc | None = None, - *, # Indicates subsequent arguments are keyword-only - out_idx: Any = None, - target: str | Target = "auto", - target_host: str | Target = None, - execution_backend: ExecutionBackend = "auto", - verbose: bool = False, - pass_configs: dict[str, Any] | None = None, - debug_root_path: str | None = None, - compile_flags: list[str] | str | None = None): + func: Callable[_P, _T] | PrimFunc | None = None, + *, # Indicates subsequent arguments are keyword-only + out_idx: Any = None, + target: str | Target = "auto", + target_host: str | Target = None, + execution_backend: ExecutionBackend = "auto", + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + debug_root_path: str | None = None, + compile_flags: list[str] | str | None = None, +): """ Just-In-Time (JIT) compiler decorator for TileLang functions. @@ -516,7 +512,8 @@ def jit( # This is the new public interface compile_flags=compile_flags, func_source=inspect.getsource(orig_func), signature=inspect.signature(orig_func), - lazy_jit=False) + lazy_jit=False, + ) if func is not None: return decorator(func) @@ -525,8 +522,7 @@ def jit( # This is the new public interface @overload -def lazy_jit(func: Callable[_KP, _T]) -> JITImpl[_KP, _KP, _T, _T]: - ... +def lazy_jit(func: Callable[_KP, _T]) -> JITImpl[_KP, _KP, _T, _T]: ... @overload @@ -539,9 +535,8 @@ def lazy_jit( verbose: bool = False, pass_configs: dict[str, Any] | None = None, debug_root_path: str | None = None, - compile_flags: list[str] | str | None = None -) -> Callable[[Callable[_KP, _T]], JITImpl[_KP, _KP, _T, _T]]: - ... + compile_flags: list[str] | str | None = None, +) -> Callable[[Callable[_KP, _T]], JITImpl[_KP, _KP, _T, _T]]: ... def lazy_jit( @@ -555,7 +550,6 @@ def lazy_jit( debug_root_path: str | None = None, compile_flags: list[str] | str | None = None, ): - if isinstance(compile_flags, str): compile_flags = [compile_flags] @@ -567,7 +561,8 @@ def lazy_jit( verbose=verbose, pass_configs=pass_configs, debug_root_path=debug_root_path, - compile_flags=compile_flags) + compile_flags=compile_flags, + ) def decorator(func: Callable[_P, _T]): pf: PrimFunc[_P, _T] | PrimFuncCreater[_P, _T] = prim_func(func, generator=True) @@ -576,10 +571,7 @@ def lazy_jit( # return compile(pf, **compile_args) # else: return JITImpl( - func=pf, - **compile_args, - func_source=inspect.getsource(pf.orig_func), - signature=inspect.signature(pf.orig_func), - lazy_jit=True) + func=pf, **compile_args, func_source=inspect.getsource(pf.orig_func), signature=inspect.signature(pf.orig_func), lazy_jit=True + ) return decorator(func) if func is not None else decorator diff --git a/tilelang/jit/adapter/base.py b/tilelang/jit/adapter/base.py index 6bd69cff4134c0c8d1a652d9288ca84ec29957c6..3669f9e35c6f0d667d6389adf23a1caf0109b865 100644 --- a/tilelang/jit/adapter/base.py +++ b/tilelang/jit/adapter/base.py @@ -1,4 +1,5 @@ """The profiler and convert to torch utils""" + from __future__ import annotations from abc import ABC, abstractmethod @@ -8,7 +9,6 @@ import torch class BaseKernelAdapter(ABC): - func: Callable | None = None def __init__(self, mod, params: list[KernelParam], result_idx: list[int]) -> None: @@ -24,18 +24,14 @@ class BaseKernelAdapter(ABC): result_idx = [] elif isinstance(result_idx, int): if result_idx > len(params) or result_idx < -len(params): - raise ValueError( - f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}" - ) + raise ValueError(f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}") if result_idx < 0: result_idx = len(params) + result_idx result_idx = [result_idx] elif isinstance(result_idx, list): for i, idx in enumerate(result_idx): if idx >= len(params) or idx < -len(params): - raise ValueError( - f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}" - ) + raise ValueError(f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}") if idx < 0: result_idx[i] = len(params) + idx else: diff --git a/tilelang/jit/adapter/ctypes/adapter.py b/tilelang/jit/adapter/ctypes/adapter.py index e267730582de65db6480f64dbd82742dd2998c28..92af8262e0c611e4532a6bebf05377970999ac28 100644 --- a/tilelang/jit/adapter/ctypes/adapter.py +++ b/tilelang/jit/adapter/ctypes/adapter.py @@ -1,4 +1,5 @@ """The profiler and convert to torch utils""" + from __future__ import annotations import torch from ..base import BaseKernelAdapter @@ -41,18 +42,20 @@ class CtypesKernelAdapter(BaseKernelAdapter): param_dtypes: list[torch.dtype] | None = None # Cache for parameter dtypes param_shapes: list[list] | None = None # Cache for parameter shapes - def __init__(self, - params: list[TensorType], - result_idx: list[int], - target: str, - func_or_mod: tir.PrimFunc | tvm.IRModule, - host_mod: tvm.IRModule | None = None, - device_mod: tvm.IRModule | None = None, - host_kernel_source: str | None = None, - device_kernel_source: str | None = None, - verbose: bool = False, - pass_configs: dict[str, Any] | None = None, - compile_flags: list[str] | None = None): + def __init__( + self, + params: list[TensorType], + result_idx: list[int], + target: str, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_mod: tvm.IRModule | None = None, + device_mod: tvm.IRModule | None = None, + host_kernel_source: str | None = None, + device_kernel_source: str | None = None, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): """Initialize the adapter with the given TIR function or module. Args: @@ -109,17 +112,19 @@ class CtypesKernelAdapter(BaseKernelAdapter): self._post_init() @classmethod - def from_database(cls, - params: list[TensorType], - result_idx: list[int], - target: str, - func_or_mod: tir.PrimFunc | tvm.IRModule, - host_kernel_source: str, - device_kernel_source: str, - kernel_lib_path: str, - verbose: bool = False, - pass_configs: dict[str, Any] | None = None, - compile_flags: list[str] | None = None): + def from_database( + cls, + params: list[TensorType], + result_idx: list[int], + target: str, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_kernel_source: str, + device_kernel_source: str, + kernel_lib_path: str, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) @@ -175,15 +180,13 @@ class CtypesKernelAdapter(BaseKernelAdapter): if param in buffer_map: buffer = buffer_map[param] for j, shape in enumerate(buffer.shape): - if (isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and - (shape not in params)): + if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and (shape not in params): dynamic_symbolic_map[shape] = (0, i, j) for i, param in enumerate(params): if param in buffer_map: buffer = buffer_map[param] for j, stride in enumerate(buffer.strides): - if (isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and - (stride not in params)): + if isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and (stride not in params): dynamic_symbolic_map[stride] = (1, i, j) return dynamic_symbolic_map @@ -192,9 +195,7 @@ class CtypesKernelAdapter(BaseKernelAdapter): Converts PyTorch tensor pointers to C void pointers for ctypes interface. """ - ctypes_args = [ - ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args - ] + ctypes_args = [ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args] ctypes_args.append(ctypes.c_void_p(stream)) self.lib.call(*ctypes_args) @@ -288,7 +289,7 @@ class CtypesKernelAdapter(BaseKernelAdapter): @property def is_dynamic(self): """Indicates whether the kernel handles dynamic shapes.""" - return (self.dynamic_symbolic_map is not None and len(self.dynamic_symbolic_map) > 0) + return self.dynamic_symbolic_map is not None and len(self.dynamic_symbolic_map) > 0 def get_kernel_source(self, kernel_only: bool = False): """Returns the source code of the compiled kernel.""" diff --git a/tilelang/jit/adapter/cython/adapter.py b/tilelang/jit/adapter/cython/adapter.py index fe8fe5bd91c776deca8b718c0de0b4964df87f24..c456e4dbaa38f6bcd341198f27ca56364ae1fb68 100644 --- a/tilelang/jit/adapter/cython/adapter.py +++ b/tilelang/jit/adapter/cython/adapter.py @@ -1,4 +1,5 @@ """The profiler and convert to torch utils""" + from __future__ import annotations import ctypes import logging @@ -70,17 +71,19 @@ class CythonKernelAdapter(BaseKernelAdapter): # Pass configs for the compiler pass_configs: dict[str, Any] | None = None - def __init__(self, - params: list[KernelParam], - result_idx: list[int], - target: str | Target, - func_or_mod: tir.PrimFunc | tvm.IRModule, - host_mod: tvm.IRModule | None = None, - device_mod: tvm.IRModule | None = None, - device_kernel_source: str | None = None, - verbose: bool = False, - pass_configs: dict[str, Any] | None = None, - compile_flags: list[str] | None = None): + def __init__( + self, + params: list[KernelParam], + result_idx: list[int], + target: str | Target, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_mod: tvm.IRModule | None = None, + device_mod: tvm.IRModule | None = None, + device_kernel_source: str | None = None, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): """Initialize the adapter with the given TIR function or module. Args: @@ -130,7 +133,7 @@ class CythonKernelAdapter(BaseKernelAdapter): self.lib.get_last_error.restype = ctypes.c_char_p result = self.lib.init() if result != 0: - error_msg = self.lib.get_last_error().decode('utf-8') + error_msg = self.lib.get_last_error().decode("utf-8") error_msg += f"\n{self.lib_code}" raise RuntimeError(f"Initialization failed: {error_msg}") @@ -145,17 +148,19 @@ class CythonKernelAdapter(BaseKernelAdapter): self._post_init() @classmethod - def from_database(cls, - params: list[TensorType], - result_idx: list[int], - target: str, - func_or_mod: tir.PrimFunc | tvm.IRModule, - host_kernel_source: str, - device_kernel_source: str, - kernel_lib_path: str, - verbose: bool = False, - pass_configs: dict[str, Any] | None = None, - compile_flags: list[str] | None = None): + def from_database( + cls, + params: list[TensorType], + result_idx: list[int], + target: str, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_kernel_source: str, + device_kernel_source: str, + kernel_lib_path: str, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) @@ -190,11 +195,10 @@ class CythonKernelAdapter(BaseKernelAdapter): adapter.lib.get_last_error.restype = ctypes.c_char_p result = adapter.lib.init() if result != 0: - error_msg = adapter.lib.get_last_error().decode('utf-8') + error_msg = adapter.lib.get_last_error().decode("utf-8") raise RuntimeError(f"Initialization failed: {error_msg}") - adapter.cython_wrapper = CythonKernelWrapper(adapter.result_idx, adapter.params, - adapter.lib) + adapter.cython_wrapper = CythonKernelWrapper(adapter.result_idx, adapter.params, adapter.lib) adapter.cython_wrapper.set_dynamic_symbolic_map(adapter.dynamic_symbolic_map) adapter.cython_wrapper.set_buffer_dtype_map(adapter.buffer_dtype_map) adapter.cython_wrapper.set_static_shape_map(adapter.static_shape_map) @@ -221,15 +225,13 @@ class CythonKernelAdapter(BaseKernelAdapter): if param in buffer_map: buffer = buffer_map[param] for j, shape in enumerate(buffer.shape): - if (isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and - (shape not in params)): + if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and (shape not in params): dynamic_symbolic_map[shape] = (0, i, j) for i, param in enumerate(params): if param in buffer_map: buffer = buffer_map[param] for j, stride in enumerate(buffer.strides): - if (isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and - (stride not in params)): + if isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and (stride not in params): dynamic_symbolic_map[stride] = (1, i, j) return dynamic_symbolic_map @@ -259,14 +261,13 @@ class CythonKernelAdapter(BaseKernelAdapter): params = func.params ptr_map = {} for i, param in enumerate(params): - if param.dtype == 'handle': + if param.dtype == "handle": ptr_map[i] = param.name return ptr_map - def _process_static_buffer_infos(self) -> \ - tuple[dict[tir.Var, tuple[int, list[tuple[int, int]]]], - dict[tir.Var, tuple[int, list[tuple[int, int]]]], - list[tuple[tir.Var]]]: + def _process_static_buffer_infos( + self, + ) -> tuple[dict[tir.Var, tuple[int, list[tuple[int, int]]]], dict[tir.Var, tuple[int, list[tuple[int, int]]]], list[tuple[tir.Var]]]: """Extract information about static shapes from the TIR function. Maps buffer variables to their corresponding static shapes. @@ -332,9 +333,7 @@ class CythonKernelAdapter(BaseKernelAdapter): Converts PyTorch tensor pointers to C void pointers for ctypes interface. """ - ctypes_args = [ - ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args - ] + ctypes_args = [ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args] ctypes_args.append(ctypes.c_void_p(stream)) self.lib.call(*ctypes_args) @@ -349,9 +348,7 @@ class CythonKernelAdapter(BaseKernelAdapter): skip_tensor_validation: Whether to skip tensor attributes validation which includes shape, dtype, device, etc. """ - return self.cython_wrapper.forward([*args], - stream=stream, - skip_tensor_validation=skip_tensor_validation) + return self.cython_wrapper.forward([*args], stream=stream, skip_tensor_validation=skip_tensor_validation) return lambda_forward diff --git a/tilelang/jit/adapter/libgen.py b/tilelang/jit/adapter/libgen.py index 208370b05174c4141eb43e477167db8dc2d9e8b4..d67f5b403ee010f3c8b52b5d3c5c802f89e88c3a 100644 --- a/tilelang/jit/adapter/libgen.py +++ b/tilelang/jit/adapter/libgen.py @@ -55,6 +55,7 @@ class LibraryGenerator: verbose = self.verbose if is_cuda_target(target): from tilelang.env import CUTLASS_INCLUDE_DIR + src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) # noqa: SIM115 target_arch = get_target_arch(get_target_compute_version(target)) libpath = src.name.replace(".cu", ".so") @@ -65,15 +66,12 @@ class LibraryGenerator: "TL_ENABLE_FAST_MATH", "0.1.7", ) - enable_fast_math = not self.pass_configs.get(PassConfigKey.TL_DISABLE_FAST_MATH, - True) + enable_fast_math = not self.pass_configs.get(PassConfigKey.TL_DISABLE_FAST_MATH, True) else: enable_fast_math = self.pass_configs.get(PassConfigKey.TL_ENABLE_FAST_MATH, False) - ptxas_usage_level = self.pass_configs.get(PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL, - None) - verbose_ptxas_output = self.pass_configs.get( - PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT, False) + ptxas_usage_level = self.pass_configs.get(PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL, None) + verbose_ptxas_output = self.pass_configs.get(PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT, False) command = [ get_nvcc_compiler(), @@ -102,6 +100,7 @@ class LibraryGenerator: elif is_hip_target(target): from tilelang.env import COMPOSABLE_KERNEL_INCLUDE_DIR + src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) # noqa: SIM115 libpath = src.name.replace(".cpp", ".so") rocm_path = find_rocm_path() @@ -119,6 +118,7 @@ class LibraryGenerator: ] elif is_cpu_target(target): from tilelang.contrib.cc import get_cplus_compiler + src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) # noqa: SIM115 libpath = src.name.replace(".cpp", ".so") @@ -134,9 +134,7 @@ class LibraryGenerator: ] if self.compile_flags: - command += [ - item for flag in self.compile_flags for item in flag.split() if item not in command - ] + command += [item for flag in self.compile_flags for item in flag.split() if item not in command] command += ["-o", libpath] @@ -151,8 +149,7 @@ class LibraryGenerator: raise RuntimeError(f"Compile kernel failed because of {e}") from e if ret.returncode != 0: - raise RuntimeError(f"Compilation Failed! {command}" - f"\n {self.lib_code}") + raise RuntimeError(f"Compilation Failed! {command}\n {self.lib_code}") self.srcpath = src.name self.libpath = libpath diff --git a/tilelang/jit/adapter/nvrtc/__init__.py b/tilelang/jit/adapter/nvrtc/__init__.py index faa08c1940c4f60cbf1d32f2d0a439aac8f39458..c8abe8d7789faff0ebcebb2e306aaaba436c4a5f 100644 --- a/tilelang/jit/adapter/nvrtc/__init__.py +++ b/tilelang/jit/adapter/nvrtc/__init__.py @@ -5,22 +5,22 @@ This module provides runtime compilation support using NVIDIA's NVRTC API. import logging -__all__ = [ - 'NVRTCKernelAdapter', 'TLNVRTCSourceWrapper', 'NVRTCLibraryGenerator', 'is_nvrtc_available', - 'check_nvrtc_available' -] +__all__ = ["NVRTCKernelAdapter", "TLNVRTCSourceWrapper", "NVRTCLibraryGenerator", "is_nvrtc_available", "check_nvrtc_available"] logger = logging.getLogger(__name__) # Check if cuda-python is available is_nvrtc_available = False -NVRTC_UNAVAILABLE_MESSAGE = ("cuda-python is not available, NVRTC backend cannot be used. " - "Please install cuda-python via `pip install cuda-python` " - "if you want to use the NVRTC backend.") +NVRTC_UNAVAILABLE_MESSAGE = ( + "cuda-python is not available, NVRTC backend cannot be used. " + "Please install cuda-python via `pip install cuda-python` " + "if you want to use the NVRTC backend." +) try: import cuda.bindings.driver as cuda # noqa: F401 import cuda.bindings.nvrtc as nvrtc # noqa: F401 + is_nvrtc_available = True except ImportError as e: logger.debug(f"cuda-python import failed: {e}") diff --git a/tilelang/jit/adapter/nvrtc/adapter.py b/tilelang/jit/adapter/nvrtc/adapter.py index 4a465d33bfc7250dcff396c51fff8356329f64b1..d222f33a53205a3bf32e070954cd22b93a88e31f 100644 --- a/tilelang/jit/adapter/nvrtc/adapter.py +++ b/tilelang/jit/adapter/nvrtc/adapter.py @@ -27,18 +27,19 @@ class NVRTCKernelAdapter(BaseKernelAdapter): pymodule = None kernels = {} - def __init__(self, - params: list[KernelParam], - result_idx: list[int], - target: str | Target, - func_or_mod: tir.PrimFunc | tvm.IRModule, - host_mod: tvm.IRModule | None = None, - device_mod: tvm.IRModule | None = None, - device_kernel_source: str | None = None, - verbose: bool = False, - pass_configs: dict[str, Any] | None = None, - compile_flags: list[str] | None = None): - + def __init__( + self, + params: list[KernelParam], + result_idx: list[int], + target: str | Target, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_mod: tvm.IRModule | None = None, + device_mod: tvm.IRModule | None = None, + device_kernel_source: str | None = None, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): check_nvrtc_available() self.params = params @@ -92,17 +93,19 @@ class NVRTCKernelAdapter(BaseKernelAdapter): self._post_init() @classmethod - def from_database(cls, - params: list[KernelParam], - result_idx: list[int], - target: str, - func_or_mod: tir.PrimFunc | tvm.IRModule, - host_kernel_source: str, - device_kernel_source: str, - kernel_lib_path: str, - verbose: bool = False, - pass_configs: dict[str, Any] | None = None, - compile_flags: list[str] | None = None): + def from_database( + cls, + params: list[KernelParam], + result_idx: list[int], + target: str, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_kernel_source: str, + device_kernel_source: str, + kernel_lib_path: str, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) @@ -183,8 +186,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter): return self.host_func def _forward_from_prebuild_lib(self, *args, stream: int | None = None): - """Low-level function to call the compiled CUDA kernel. - """ + """Low-level function to call the compiled CUDA kernel.""" return self.pymodule.call(self.kernels, *args, stream=stream) def _wrap_forward_from_prebuild_lib(self, *ins: list[torch.Tensor], stream: int | None = None): diff --git a/tilelang/jit/adapter/nvrtc/libgen.py b/tilelang/jit/adapter/nvrtc/libgen.py index 50a587a527ff6b4603a436dcd7c1f3c7b27a0f06..406cc44d97f8eee513c8c427961abf908289a665 100644 --- a/tilelang/jit/adapter/nvrtc/libgen.py +++ b/tilelang/jit/adapter/nvrtc/libgen.py @@ -13,6 +13,7 @@ Key responsibilities: - Load compiled cubin and extract kernel handles - Manage library lifecycle (load/unload) """ + from __future__ import annotations import importlib import logging @@ -56,6 +57,7 @@ class NVRTCLibraryGenerator(LibraryGenerator): culib: CUDA library handle (CUlibrary) pymodule: Imported Python module containing call() function """ + host_func: str | None = None culib: cuda.CUlibrary | None = None pymodule: ModuleType | None = None @@ -131,10 +133,10 @@ class NVRTCLibraryGenerator(LibraryGenerator): ctx = cuda.cuCtxGetCurrent()[1] if cuda.cuCtxGetApiVersion(ctx)[0] != cuda.CUresult.CUDA_SUCCESS: import torch + torch.cuda.synchronize() - result, self.culib = cuda.cuLibraryLoadFromFile( - bytes(lib_path, "utf-8"), [], [], 0, [], [], 0) + result, self.culib = cuda.cuLibraryLoadFromFile(bytes(lib_path, "utf-8"), [], [], 0, [], [], 0) if result != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError(f"Failed to load library: {lib_path}, error: {result}") @@ -164,7 +166,8 @@ class NVRTCLibraryGenerator(LibraryGenerator): target = self.target verbose = self.verbose if is_cuda_target(target): - from tilelang.env import (CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH) + from tilelang.env import CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH + src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) libpath = src.name.replace(".cu", ".cubin") @@ -195,13 +198,9 @@ class NVRTCLibraryGenerator(LibraryGenerator): f"-D__CUDACC_VER_MAJOR__={__CUDACC_VER_MAJOR__}", ] if self.compile_flags: - options += [ - item for flag in self.compile_flags for item in flag.split() - if item not in options - ] + options += [item for flag in self.compile_flags for item in flag.split() if item not in options] - cubin_bytes = compile_cuda( - self.lib_code, target_format="cubin", options=options, verbose=verbose) + cubin_bytes = compile_cuda(self.lib_code, target_format="cubin", options=options, verbose=verbose) with open(libpath, "wb") as f: f.write(cubin_bytes) @@ -212,8 +211,7 @@ class NVRTCLibraryGenerator(LibraryGenerator): self.libpath = libpath self.pypath = src.name.replace(".cu", ".py") if self.host_func is None: - raise RuntimeError( - "Host function is not set, please call update_host_func() first.") + raise RuntimeError("Host function is not set, please call update_host_func() first.") with open(self.pypath, "w") as f: f.write(self.host_func) else: diff --git a/tilelang/jit/adapter/nvrtc/wrapper.py b/tilelang/jit/adapter/nvrtc/wrapper.py index 7e00050c7e85b50c2f6dd54d5b661d21f5cb6642..3df2b3bfa54c507ab3f8d432c9617bd19cc02288 100644 --- a/tilelang/jit/adapter/nvrtc/wrapper.py +++ b/tilelang/jit/adapter/nvrtc/wrapper.py @@ -12,6 +12,7 @@ Key design: - Dict-based deduplication ensures TMA descriptors created only once - Generates pure Python using cuda.bindings.driver for zero C++ dependency """ + from __future__ import annotations from typing import Any, ClassVar @@ -21,8 +22,7 @@ from tvm.tir.stmt_functor import post_order_visit from tilelang import tvm as tvm from tilelang.jit.adapter.wrapper import TLCUDASourceWrapper -from tilelang.jit.adapter.utils import (match_declare_kernel, pythonic_expr, - parse_function_call_args, parse_tma_descriptor_args) +from tilelang.jit.adapter.utils import match_declare_kernel, pythonic_expr, parse_function_call_args, parse_tma_descriptor_args PREDEF_HOST_FUNC_PY = """ from cuda.bindings.driver import ( @@ -235,13 +235,15 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): _generated_host_func: str | None = None - def __init__(self, - scheduled_ir_module: IRModule, - source: str, - target: Target, - device_mod: IRModule | None = None, - host_mod: IRModule | None = None, - pass_configs: dict[str, Any] | None = None): + def __init__( + self, + scheduled_ir_module: IRModule, + source: str, + target: Target, + device_mod: IRModule | None = None, + host_mod: IRModule | None = None, + pass_configs: dict[str, Any] | None = None, + ): """Initialize NVRTC wrapper with compiled IR modules. Args: @@ -303,15 +305,16 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): for param in self.prim_func.params: if param in self.prim_func.buffer_map: buffer = self.prim_func.buffer_map[param] - function_args.append({ - "name": buffer.data.name, - "type": "ctypes.c_void_p", - }) + function_args.append( + { + "name": buffer.data.name, + "type": "ctypes.c_void_p", + } + ) elif isinstance(param, tvm.tir.Var): function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)}) else: - raise ValueError( - f"Parameter {param} is not in the buffer map of the primary function.") + raise ValueError(f"Parameter {param} is not in the buffer map of the primary function.") # Add dynamic symbols as integer arguments for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set: if dyn_sym not in [arg["name"] for arg in function_args]: @@ -359,9 +362,9 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): return (f"{name}.data_ptr()", arg_type) return (name, arg_type) - call_args = parse_function_call_args(declaration, function_args, function_params, - desc_name_map, desc_name_var_map, - transform_nvrtc_arg) + call_args = parse_function_call_args( + declaration, function_args, function_params, desc_name_map, desc_name_var_map, transform_nvrtc_arg + ) for arg_name, arg_type in call_args: if arg_type == "ctypes.c_void_p": @@ -369,26 +372,28 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): break # Store kernel info for second pass - kernel_info_list.append({ - 'function_name': function_name, - 'block_info': block_info, - 'grid_info': grid_info, - 'dynamic_smem_buf': dynamic_smem_buf, - 'call_args': call_args, - 'device_index': device_index, - }) + kernel_info_list.append( + { + "function_name": function_name, + "block_info": block_info, + "grid_info": grid_info, + "dynamic_smem_buf": dynamic_smem_buf, + "call_args": call_args, + "device_index": device_index, + } + ) # Generate TMA descriptor initialization code once for all kernels kernel_launch_code += self.generate_tma_descriptor_args(desc_name_map, desc_name_var_map) # Second pass: generate kernel launch code for each kernel for kernel_info in kernel_info_list: - function_name = kernel_info['function_name'] - block_info = kernel_info['block_info'] - grid_info = kernel_info['grid_info'] - dynamic_smem_buf = kernel_info['dynamic_smem_buf'] - call_args = kernel_info['call_args'] - device_index = kernel_info['device_index'] + function_name = kernel_info["function_name"] + block_info = kernel_info["block_info"] + grid_info = kernel_info["grid_info"] + dynamic_smem_buf = kernel_info["dynamic_smem_buf"] + call_args = kernel_info["call_args"] + device_index = kernel_info["device_index"] arg_names = ", ".join([arg[0] for arg in call_args]) arg_types = ", ".join([arg[1] for arg in call_args]) @@ -399,23 +404,26 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): kernel_launch_code += init_l2_persistent_map # Generate kernel launch code - kernel_launch_code += KERNEL_LAUNCH_FUNC_PY.format(function_name, - self._pythonic_expr(grid_info[0]), - self._pythonic_expr(grid_info[1]), - self._pythonic_expr(grid_info[2]), - self._pythonic_expr(block_info[0]), - self._pythonic_expr(block_info[1]), - self._pythonic_expr(block_info[2]), - smem_str, arg_names, arg_types, - device_index) + kernel_launch_code += KERNEL_LAUNCH_FUNC_PY.format( + function_name, + self._pythonic_expr(grid_info[0]), + self._pythonic_expr(grid_info[1]), + self._pythonic_expr(grid_info[2]), + self._pythonic_expr(block_info[0]), + self._pythonic_expr(block_info[1]), + self._pythonic_expr(block_info[2]), + smem_str, + arg_names, + arg_types, + device_index, + ) # Reset L2 persistent map after all kernel execution if has_l2_persistent_map: kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE_PY # Wrap the kernel dispatch logic in an external C function - host_func = PREDEF_HOST_FUNC_PY.format( - repr(list(function_informations.keys())), def_args, kernel_launch_code) + host_func = PREDEF_HOST_FUNC_PY.format(repr(list(function_informations.keys())), def_args, kernel_launch_code) return host_func def generate_l2_persistent_map(self, function_name: str) -> str: @@ -434,23 +442,21 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): if function_name not in self.l2_persistent_map: return "" init_l2_persistent_map = "" - for buffer_name, (hit_ratio, - size_in_bytes) in self.l2_persistent_map[function_name].items(): + for buffer_name, (hit_ratio, size_in_bytes) in self.l2_persistent_map[function_name].items(): # Get persisting_l2_cache_max_size from tilelang.carver.arch.driver import get_persisting_l2_cache_max_size + persisting_l2_cache_max_size = get_persisting_l2_cache_max_size() try: num_bytes = min(size_in_bytes, persisting_l2_cache_max_size) except TypeError: # as size_in_bytes may be a symbolic expression num_bytes = persisting_l2_cache_max_size - init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC_PY.format( - buffer_name, float(hit_ratio), self._pythonic_expr(num_bytes)) + init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC_PY.format(buffer_name, float(hit_ratio), self._pythonic_expr(num_bytes)) return init_l2_persistent_map - def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], - desc_name_var_map: dict[str, tvm.tir.Var]) -> str: + def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], desc_name_var_map: dict[str, tvm.tir.Var]) -> str: """Generate Python code to initialize TMA descriptors. TMA (Tensor Memory Accelerator) descriptors are opaque CUDA objects @@ -470,28 +476,43 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): return tma_descriptor_init # Parse TMA descriptor arguments using the common utility - parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map, - desc_name_var_map, self._pythonic_expr) + parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map, desc_name_var_map, self._pythonic_expr) # Generate Python code from parsed parameters for params in parsed_params: if not params.is_img2col: tma_descriptor_init += TMA_DESC_INIT_FUNC_PY.format( - params.handle_name, params.dtype, params.tensor_rank, params.global_address, + params.handle_name, + params.dtype, + params.tensor_rank, + params.global_address, ", ".join(map(lambda x: f"cuuint64_t({x})", params.global_dim)), ", ".join(map(lambda x: f"cuuint64_t({x})", params.global_stride)), ", ".join(map(lambda x: f"cuuint32_t({x})", params.box_dim)), ", ".join(map(lambda x: f"cuuint32_t({x})", params.element_strides)), - params.interleave, params.swizzle, params.l2_promotion, params.oob_fill) + params.interleave, + params.swizzle, + params.l2_promotion, + params.oob_fill, + ) else: tma_descriptor_init += TMA_IM2COL_DESC_INIT_FUNC_PY.format( - params.handle_name, params.dtype, params.tensor_rank, params.global_address, + params.handle_name, + params.dtype, + params.tensor_rank, + params.global_address, ", ".join(map(lambda x: f"cuuint64_t({x})", params.global_dim)), ", ".join(map(lambda x: f"cuuint64_t({x})", params.global_stride)), - ", ".join(map(lambda x: f"cuuint32_t({x})", - params.element_strides)), ", ".join(params.lower_corner), - ", ".join(params.upper_corner), params.smem_box_channel, params.smem_box_pixel, - params.interleave, params.swizzle, params.l2_promotion, params.oob_fill) + ", ".join(map(lambda x: f"cuuint32_t({x})", params.element_strides)), + ", ".join(params.lower_corner), + ", ".join(params.upper_corner), + params.smem_box_channel, + params.smem_box_pixel, + params.interleave, + params.swizzle, + params.l2_promotion, + params.oob_fill, + ) return tma_descriptor_init @@ -527,17 +548,14 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): def visitor(node, fn=function_name, param_cnt=kernel_params_cnt): nonlocal function_params if isinstance(node, tvm.tir.Call): - if not (hasattr(node, "op") and - node.op == tvm.ir.Op.get("tir.tvm_call_packed")): + if not (hasattr(node, "op") and node.op == tvm.ir.Op.get("tir.tvm_call_packed")): return args = node.args if not args or args[0] != fn: return if len(args) < 1 + param_cnt: - raise AssertionError( - "tvm_call_packed should have at least 1 argument and match device function parameters" - ) - function_params = args[1:1 + param_cnt] + raise AssertionError("tvm_call_packed should have at least 1 argument and match device function parameters") + function_params = args[1 : 1 + param_cnt] post_order_visit(self.host_func.body, visitor) assert function_params is not None, "function_params should not be None" diff --git a/tilelang/jit/adapter/torch/__init__.py b/tilelang/jit/adapter/torch/__init__.py index 2390e3e7c210cdc0ee92307524f168cdf2de36ba..f688993d0e563f1b14a7c4ecbe91ce4a9bd344b3 100644 --- a/tilelang/jit/adapter/torch/__init__.py +++ b/tilelang/jit/adapter/torch/__init__.py @@ -1,3 +1,3 @@ from .metal import MetalKernelAdapter -__all__ = ['MetalKernelAdapter'] +__all__ = ["MetalKernelAdapter"] diff --git a/tilelang/jit/adapter/torch/metal.py b/tilelang/jit/adapter/torch/metal.py index 0b1bc009813ea61eab5fa904412f17ce7291ede5..4690cf59bda7cd1907c130d5a8d1446466e097b5 100644 --- a/tilelang/jit/adapter/torch/metal.py +++ b/tilelang/jit/adapter/torch/metal.py @@ -12,7 +12,6 @@ from tilelang.engine.param import KernelParam class MetalKernelAdapter(BaseKernelAdapter): - def __init__( self, params: list[KernelParam], @@ -28,10 +27,10 @@ class MetalKernelAdapter(BaseKernelAdapter): ): self.kernel_global_source = kernel_global_source if isinstance(func_or_mod, tir.PrimFunc): - func_name = func_or_mod.attrs['global_symbol'] + func_name = func_or_mod.attrs["global_symbol"] else: func_name = func_or_mod.__name__ - self.kernel_name = func_name + '_kernel' + self.kernel_name = func_name + "_kernel" self.verbose = verbose self.block_info = [1, 1, 1] @@ -39,7 +38,7 @@ class MetalKernelAdapter(BaseKernelAdapter): for var, func in device_mod.functions.items(): assert var.name_hint == self.kernel_name - thread_extent = func.attrs['thread_extent'] + thread_extent = func.attrs["thread_extent"] for tag, extent in thread_extent.items(): if "threadIdx" in tag: self.block_info["xyz".index(tag[-1])] = extent @@ -47,7 +46,7 @@ class MetalKernelAdapter(BaseKernelAdapter): self.grid_info["xyz".index(tag[-1])] = extent break else: - raise AssertionError(f'no kernel with name {func_name}') + raise AssertionError(f"no kernel with name {func_name}") # print(self.block_info, self.grid_info) super().__init__(func_or_mod, result_idx=result_idx, params=params) @@ -55,15 +54,12 @@ class MetalKernelAdapter(BaseKernelAdapter): _kernel = None def _convert_torch_func(self) -> Callable: - if self._kernel is None: - _kernel = getattr(torch.mps.compile_shader(self.kernel_global_source), self.kernel_name) _threads = [x * y for (x, y) in zip(self.block_info, self.grid_info)] @wraps(_kernel) def launcher(*args: torch.Tensor): - return _kernel( *args, threads=_threads, diff --git a/tilelang/jit/adapter/tvm_ffi.py b/tilelang/jit/adapter/tvm_ffi.py index 96b4c85e98592d177f8580c5b549bb68df9f6dd5..8b868645d60235656dea3a3ed896a501c80f4984 100644 --- a/tilelang/jit/adapter/tvm_ffi.py +++ b/tilelang/jit/adapter/tvm_ffi.py @@ -5,6 +5,7 @@ via light-weight callables so that, when the wrapped function is invoked, the execution observes the same stream context as the active Torch code. On non-CUDA builds, the stream/device fall back to 0/CPU semantics. """ + from __future__ import annotations from typing import Callable, Any @@ -31,6 +32,7 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): - The stream pointer returned is a raw CUDA stream handle compatible with TVM's device API; on CPU or when CUDA is unavailable, we return 0. """ + # Class attributes to store compiled kernel information target: str | Target = "cuda" ir_module: tvm.IRModule | None = None @@ -51,19 +53,21 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): dynamic_symbolic_map: dict[tir.Var, tuple[int, int, int]] | None = None # Stream/device functors are inherited from BaseKernelAdapter - def __init__(self, - params: list[KernelParam], - result_idx: list[int], - target: str | Target, - func_or_mod: tir.PrimFunc | tvm.IRModule, - host_mod: tvm.IRModule | None = None, - device_mod: tvm.IRModule | None = None, - rt_mod: tvm.runtime.Module | None = None, - host_kernel_source: str | None = None, - device_kernel_source: str | None = None, - verbose: bool = False, - pass_configs: dict[str, Any] | None = None, - compile_flags: list[str] | None = None): + def __init__( + self, + params: list[KernelParam], + result_idx: list[int], + target: str | Target, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_mod: tvm.IRModule | None = None, + device_mod: tvm.IRModule | None = None, + rt_mod: tvm.runtime.Module | None = None, + host_kernel_source: str | None = None, + device_kernel_source: str | None = None, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): """Initialize the adapter with the given TIR function or module. Args: @@ -113,15 +117,13 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): if param in buffer_map: buffer = buffer_map[param] for j, shape in enumerate(buffer.shape): - if (isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and - (shape not in params)): + if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and (shape not in params): dynamic_symbolic_map[shape] = (0, i, j) for i, param in enumerate(params): if param in buffer_map: buffer = buffer_map[param] for j, stride in enumerate(buffer.strides): - if (isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and - (stride not in params)): + if isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and (stride not in params): dynamic_symbolic_map[stride] = (1, i, j) return dynamic_symbolic_map @@ -197,8 +199,7 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): # Validate input count strictly expected_inputs = len(self.params) - len(self.result_idx) if len(inputs) != expected_inputs: - raise ValueError( - f"Kernel expected {expected_inputs} inputs, but {len(inputs)} are provided.") + raise ValueError(f"Kernel expected {expected_inputs} inputs, but {len(inputs)} are provided.") # Resolve the device used for outputs. Prefer the first tensor input's device # if available, otherwise use PyTorch's current device. @@ -217,17 +218,14 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): for s in param_shapes[i]: if isinstance(s, tir.Var): for key in dynamic_symbolic_map: - if (str(s) == str(key)): - ref_id, ref_tensor_idx, ref_shape_idx = dynamic_symbolic_map[ - key] + if str(s) == str(key): + ref_id, ref_tensor_idx, ref_shape_idx = dynamic_symbolic_map[key] if ref_id == 2: shape.append(inputs[ref_tensor_idx]) elif ref_id == 0: - shape.append( - tensor_list[ref_tensor_idx].shape[ref_shape_idx]) + shape.append(tensor_list[ref_tensor_idx].shape[ref_shape_idx]) elif ref_id == 1: - shape.append( - tensor_list[ref_tensor_idx].stride()[ref_shape_idx]) + shape.append(tensor_list[ref_tensor_idx].stride()[ref_shape_idx]) else: # Already converted to Python int during initialization shape.append(s) @@ -235,11 +233,11 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): out_device = current_device_functor() if len(shape) == 0: - param_name = self.params[i].name if hasattr(self.params[i], - 'name') else f'parameter_{i}' + param_name = self.params[i].name if hasattr(self.params[i], "name") else f"parameter_{i}" raise ValueError( f"Cannot create output tensor (name={param_name}) - 0-dimensional tensors are not supported. " - f"Expected shape: {shape}") + f"Expected shape: {shape}" + ) tensor = torch.empty(*shape, dtype=dtype, device=out_device) else: tensor = inputs[ins_idx] @@ -256,17 +254,19 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): return func @classmethod - def from_database(cls, - params: list[TensorType], - result_idx: list[int], - target: str, - func_or_mod: tir.PrimFunc | tvm.IRModule, - host_kernel_source: str, - device_kernel_source: str, - kernel_lib_path: str, - verbose: bool = False, - pass_configs: dict[str, Any] | None = None, - compile_flags: list[str] | None = None): + def from_database( + cls, + params: list[TensorType], + result_idx: list[int], + target: str, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_kernel_source: str, + device_kernel_source: str, + kernel_lib_path: str, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) diff --git a/tilelang/jit/adapter/utils.py b/tilelang/jit/adapter/utils.py index 94e590d3fc2e1b71ff310410eee7fb7cbb0fb9c3..15801ffa751bce51a6c10591c8e358c87e2ebc53 100644 --- a/tilelang/jit/adapter/utils.py +++ b/tilelang/jit/adapter/utils.py @@ -70,7 +70,6 @@ def get_annotated_mod( target_host: str | Target | None = None, model_type: Literal["device", "host", "all"] = "all", ) -> IRModule | tuple[IRModule, IRModule]: - # Validate model_type early if model_type not in {"device", "host", "all"}: raise ValueError(f"Invalid model type: {model_type}") @@ -95,21 +94,15 @@ def get_annotated_mod( # Define dispatch dictionary for different model types dispatch = { - "device": - lambda m: tir.transform.Filter(_is_device_call)(m), - "host": - lambda m: tir.transform.Filter(_is_host_call)(m), - "all": - lambda m: (tir.transform.Filter(_is_device_call)(m), tir.transform.Filter(_is_host_call) - (m)), + "device": lambda m: tir.transform.Filter(_is_device_call)(m), + "host": lambda m: tir.transform.Filter(_is_host_call)(m), + "all": lambda m: (tir.transform.Filter(_is_device_call)(m), tir.transform.Filter(_is_host_call)(m)), } return dispatch[model_type](mod) -def pythonic_expr(expr: tvm.tir.PrimExpr, - dtype_map: dict[str, str] | None = None, - ignore_cast: bool = False) -> str: +def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = None, ignore_cast: bool = False) -> str: """ Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence. @@ -168,9 +161,23 @@ def pythonic_expr(expr: tvm.tir.PrimExpr, s = f"({type_str}){value_str}" p = PRECEDENCE.get(type(node), ATOMIC_PRECEDENCE) elif isinstance( - node, - (tvm.tir.Mul, tvm.tir.FloorDiv, tvm.tir.Add, tvm.tir.Sub, tvm.tir.FloorMod, tvm.tir.LT, - tvm.tir.LE, tvm.tir.GT, tvm.tir.GE, tvm.tir.EQ, tvm.tir.NE, tvm.tir.And, tvm.tir.Or)): + node, + ( + tvm.tir.Mul, + tvm.tir.FloorDiv, + tvm.tir.Add, + tvm.tir.Sub, + tvm.tir.FloorMod, + tvm.tir.LT, + tvm.tir.LE, + tvm.tir.GT, + tvm.tir.GE, + tvm.tir.EQ, + tvm.tir.NE, + tvm.tir.And, + tvm.tir.Or, + ), + ): op_map = { tvm.tir.Mul: "*", tvm.tir.FloorDiv: "/", @@ -222,10 +229,7 @@ def pythonic_expr(expr: tvm.tir.PrimExpr, return next(iter(node_to_result_map[expr]), "") -def maybe_desc_name(name: str, - matches: list[str], - i: int, - desc_name_map: dict[str, str] | None = None) -> bool: +def maybe_desc_name(name: str, matches: list[str], i: int, desc_name_map: dict[str, str] | None = None) -> bool: """ Check if a parameter name corresponds to a TMA descriptor. @@ -290,8 +294,7 @@ def parse_function_call_args( else: call_args.append(match) if desc_name_var_map is not None and function_params is not None: - assert len(call_args) <= len(function_params), \ - f"Too many arguments: {len(call_args)} > {len(function_params)}" + assert len(call_args) <= len(function_params), f"Too many arguments: {len(call_args)} > {len(function_params)}" desc_name_var_map[match] = function_params[len(call_args) - 1] return call_args @@ -300,12 +303,7 @@ def parse_function_call_args( class TMADescriptorParams: """Parsed TMA descriptor parameters.""" - def __init__(self, - handle_name: str, - dtype: str, - tensor_rank: int, - global_address: Any, - is_img2col: bool = False): + def __init__(self, handle_name: str, dtype: str, tensor_rank: int, global_address: Any, is_img2col: bool = False): self.handle_name = handle_name self.dtype = dtype self.tensor_rank = tensor_rank @@ -355,22 +353,19 @@ def parse_tma_descriptor_args( results = [] for handle_name, _ in desc_name_map.items(): - assert handle_name in desc_name_var_map, \ - f"Handle name {handle_name} not found in desc_name_var_map" + assert handle_name in desc_name_var_map, f"Handle name {handle_name} not found in desc_name_var_map" desc_var = desc_name_var_map[handle_name] - assert desc_var in tma_descriptor_args, \ - f"TMA descriptor {desc_var} not found in {tma_descriptor_args}" + assert desc_var in tma_descriptor_args, f"TMA descriptor {desc_var} not found in {tma_descriptor_args}" args = tma_descriptor_args[desc_var] # Skip __tvm_tensormap_create_tiled and second element (like CUDA version) if len(args) < 3: - raise ValueError( - f"TMA descriptor args too short: {len(args)} elements, expected at least 3") + raise ValueError(f"TMA descriptor args too short: {len(args)} elements, expected at least 3") tma_create_str, _, dtype, tensor_rank, global_address, *remaining_args = args - is_img2col = (tma_create_str.value == "__tvm_tensormap_create_im2col") + is_img2col = tma_create_str.value == "__tvm_tensormap_create_im2col" # Convert basic fields dtype = pythonic_expr_func(dtype) @@ -386,60 +381,45 @@ def parse_tma_descriptor_args( # Tiled mode expected_args_len = 4 * tensor_rank + 4 if len(remaining_args) < expected_args_len: - raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, " - f"expected {expected_args_len} for tensor_rank {tensor_rank}") + raise ValueError( + f"Insufficient remaining args: got {len(remaining_args)}, expected {expected_args_len} for tensor_rank {tensor_rank}" + ) # Extract dimensions and strides params.global_dim = [pythonic_expr_func(i) for i in remaining_args[:tensor_rank]] - params.global_stride = [ - pythonic_expr_func(i) for i in remaining_args[tensor_rank:2 * tensor_rank] - ] - params.box_dim = [ - pythonic_expr_func(i) for i in remaining_args[2 * tensor_rank:3 * tensor_rank] - ] - params.element_strides = [ - pythonic_expr_func(i) for i in remaining_args[3 * tensor_rank:4 * tensor_rank] - ] + params.global_stride = [pythonic_expr_func(i) for i in remaining_args[tensor_rank : 2 * tensor_rank]] + params.box_dim = [pythonic_expr_func(i) for i in remaining_args[2 * tensor_rank : 3 * tensor_rank]] + params.element_strides = [pythonic_expr_func(i) for i in remaining_args[3 * tensor_rank : 4 * tensor_rank]] # Extract remaining parameters try: - interleave, swizzle, l2_promotion, oob_fill = remaining_args[4 * tensor_rank:4 * - tensor_rank + 4] + interleave, swizzle, l2_promotion, oob_fill = remaining_args[4 * tensor_rank : 4 * tensor_rank + 4] params.interleave = pythonic_expr_func(interleave) params.swizzle = pythonic_expr_func(swizzle) params.l2_promotion = pythonic_expr_func(l2_promotion) params.oob_fill = pythonic_expr_func(oob_fill) except ValueError as e: - raise ValueError( - "Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)" - ) from e + raise ValueError("Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)") from e else: # Im2col mode expected_args_len = 5 * tensor_rank + 2 if len(remaining_args) < expected_args_len: - raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, " - f"expected {expected_args_len} for tensor_rank {tensor_rank}") + raise ValueError( + f"Insufficient remaining args: got {len(remaining_args)}, expected {expected_args_len} for tensor_rank {tensor_rank}" + ) # Extract dimensions and strides params.global_dim = [pythonic_expr_func(i) for i in remaining_args[:tensor_rank]] - params.global_stride = [ - pythonic_expr_func(i) for i in remaining_args[tensor_rank:2 * tensor_rank] - ] - params.element_strides = [ - pythonic_expr_func(i) for i in remaining_args[2 * tensor_rank:3 * tensor_rank] - ] - params.lower_corner = [ - pythonic_expr_func(i) for i in remaining_args[3 * tensor_rank:4 * tensor_rank - 2] - ] - params.upper_corner = [ - pythonic_expr_func(i) - for i in remaining_args[4 * tensor_rank - 2:5 * tensor_rank - 4] - ] + params.global_stride = [pythonic_expr_func(i) for i in remaining_args[tensor_rank : 2 * tensor_rank]] + params.element_strides = [pythonic_expr_func(i) for i in remaining_args[2 * tensor_rank : 3 * tensor_rank]] + params.lower_corner = [pythonic_expr_func(i) for i in remaining_args[3 * tensor_rank : 4 * tensor_rank - 2]] + params.upper_corner = [pythonic_expr_func(i) for i in remaining_args[4 * tensor_rank - 2 : 5 * tensor_rank - 4]] # Extract remaining parameters try: - smem_box_pixel, smem_box_channel, interleave, swizzle, l2_promotion, oob_fill = \ - remaining_args[5 * tensor_rank - 4:5 * tensor_rank + 2] + smem_box_pixel, smem_box_channel, interleave, swizzle, l2_promotion, oob_fill = remaining_args[ + 5 * tensor_rank - 4 : 5 * tensor_rank + 2 + ] params.smem_box_pixel = pythonic_expr_func(smem_box_pixel) params.smem_box_channel = pythonic_expr_func(smem_box_channel) params.interleave = pythonic_expr_func(interleave) diff --git a/tilelang/jit/adapter/wrapper.py b/tilelang/jit/adapter/wrapper.py index 756079763ff2278c34dbcc96c49f028dade6383e..c028a58efcd34b26954760f64015d6032b542d38 100644 --- a/tilelang/jit/adapter/wrapper.py +++ b/tilelang/jit/adapter/wrapper.py @@ -4,9 +4,18 @@ from tilelang import tvm as tvm from typing import Any from tvm import IRModule from tvm.target import Target -from .utils import (is_metal_target, match_declare_kernel, match_declare_kernel_cpu, is_cuda_target, - is_hip_target, is_cpu_target, get_annotated_mod, pythonic_expr, - parse_function_call_args, parse_tma_descriptor_args) +from .utils import ( + is_metal_target, + match_declare_kernel, + match_declare_kernel_cpu, + is_cuda_target, + is_hip_target, + is_cpu_target, + get_annotated_mod, + pythonic_expr, + parse_function_call_args, + parse_tma_descriptor_args, +) import re import logging import textwrap @@ -129,7 +138,6 @@ TMA_IM2COL_DESC_INIT_FUNC = """ class BaseWrapper(ABC): - @abstractmethod def wrap(self, *args, **kwargs): raise NotImplementedError @@ -163,13 +171,15 @@ class TLCUDASourceWrapper: host_mod: IRModule | None = None pass_configs: dict[str, Any] | None = None - def __init__(self, - scheduled_ir_module: IRModule, - source: str, - target: Target, - device_mod: IRModule | None = None, - host_mod: IRModule | None = None, - pass_configs: dict[str, Any] | None = None): + def __init__( + self, + scheduled_ir_module: IRModule, + source: str, + target: Target, + device_mod: IRModule | None = None, + host_mod: IRModule | None = None, + pass_configs: dict[str, Any] | None = None, + ): self.mod = scheduled_ir_module self.target = target self.source = source @@ -211,15 +221,16 @@ class TLCUDASourceWrapper: for param in self.prim_func.params: if param in self.prim_func.buffer_map: buffer = self.prim_func.buffer_map[param] - function_args.append({ - "name": buffer.data.name, - "type": self._lookup_type(buffer.dtype) + "* __restrict__", - }) + function_args.append( + { + "name": buffer.data.name, + "type": self._lookup_type(buffer.dtype) + "* __restrict__", + } + ) elif isinstance(param, tvm.tir.Var): function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)}) else: - raise ValueError( - f"Parameter {param} is not in the buffer map of the primary function.") + raise ValueError(f"Parameter {param} is not in the buffer map of the primary function.") # Add dynamic symbols as integer arguments for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set: if dyn_sym not in [arg["name"] for arg in function_args]: @@ -256,38 +267,40 @@ class TLCUDASourceWrapper: # Identify the start of the function body to insert arguments index = code.index("{", index) - block_str = f"dim3({self._pythonic_expr(block_info[0])}, {self._pythonic_expr(block_info[1])}, {self._pythonic_expr(block_info[2])})" - grid_str = f"dim3({self._pythonic_expr(grid_info[0])}, {self._pythonic_expr(grid_info[1])}, {self._pythonic_expr(grid_info[2])})" + block_str = ( + f"dim3({self._pythonic_expr(block_info[0])}, {self._pythonic_expr(block_info[1])}, {self._pythonic_expr(block_info[2])})" + ) + grid_str = ( + f"dim3({self._pythonic_expr(grid_info[0])}, {self._pythonic_expr(grid_info[1])}, {self._pythonic_expr(grid_info[2])})" + ) smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf init_l2_persistent_map = self.generate_l2_persistent_map(function_name) kernel_launch_code += init_l2_persistent_map if self.use_cooperative_groups[function_name]: - args_list = parse_function_call_args(declaration, function_args, function_params, - desc_name_map, desc_name_var_map) - assert len(function_params) == len( - args_list - ), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments" + args_list = parse_function_call_args(declaration, function_args, function_params, desc_name_map, desc_name_var_map) + assert len(function_params) == len(args_list), ( + f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments" + ) args_array = [f"(void*)&{arg}" for arg in args_list] call_args = f"\tvoid* {function_name}_args[] = {{{', '.join(args_array)}}};\n" kernel_launch_code += call_args # Using cudaLaunchCooperativeKernel to launch the kernel kernel_launch_code += "\tTILELANG_CHECK(cudaLaunchCooperativeKernel((void*){}, {}, {}, {}, {}, stream));\n".format( - function_name, grid_str, block_str, function_name + "_args", smem_str) + function_name, grid_str, block_str, function_name + "_args", smem_str + ) else: - args_list = parse_function_call_args(declaration, function_args, function_params, - desc_name_map, desc_name_var_map) - assert len(function_params) == len( - args_list - ), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments" + args_list = parse_function_call_args(declaration, function_args, function_params, desc_name_map, desc_name_var_map) + assert len(function_params) == len(args_list), ( + f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments" + ) call_args = ", ".join(args_list) kernel_launch_code += f"\t{function_name}<<<{grid_str}, {block_str}, {smem_str}, stream>>>({call_args});\n" - kernel_launch_code += f"\tTILELANG_CHECK_LAST_ERROR(\"{function_name}\");\n" + kernel_launch_code += f'\tTILELANG_CHECK_LAST_ERROR("{function_name}");\n' if has_l2_persistent_map: kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE - init_tma_descriptor_args = self.generate_tma_descriptor_args(desc_name_map, - desc_name_var_map) + init_tma_descriptor_args = self.generate_tma_descriptor_args(desc_name_map, desc_name_var_map) kernel_launch_code = init_tma_descriptor_args + kernel_launch_code # Wrap the kernel dispatch logic in an external C function @@ -298,46 +311,63 @@ class TLCUDASourceWrapper: if function_name not in self.l2_persistent_map: return "" init_l2_persistent_map = "" - for buffer_name, (hit_ratio, - size_in_bytes) in self.l2_persistent_map[function_name].items(): + for buffer_name, (hit_ratio, size_in_bytes) in self.l2_persistent_map[function_name].items(): # get persisting_l2_cache_max_size from tilelang.carver.arch.driver import get_persisting_l2_cache_max_size + persisting_l2_cache_max_size = get_persisting_l2_cache_max_size() try: num_bytes = min(size_in_bytes, persisting_l2_cache_max_size) except Exception: # as size_in_bytes maybe a symbolic expression num_bytes = persisting_l2_cache_max_size - init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC.format( - buffer_name, float(hit_ratio), self._pythonic_expr(num_bytes)) + init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC.format(buffer_name, float(hit_ratio), self._pythonic_expr(num_bytes)) return init_l2_persistent_map - def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], - desc_name_var_map: dict[str, tvm.tir.Var]) -> str: + def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], desc_name_var_map: dict[str, tvm.tir.Var]) -> str: tma_descripter_init = "" if self.tma_descriptor_args is None: return tma_descripter_init # Parse TMA descriptor arguments using the common utility - parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map, - desc_name_var_map, self._pythonic_expr) + parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map, desc_name_var_map, self._pythonic_expr) # Generate C++ code from parsed parameters for params in parsed_params: if not params.is_img2col: tma_descripter_init += TMA_DESC_INIT_FUNC.format( - params.handle_name, params.dtype, params.tensor_rank, params.global_address, - ",".join(params.global_dim), ",".join(params.global_stride), - ",".join(params.box_dim), ",".join(params.element_strides), params.interleave, - params.swizzle, params.l2_promotion, params.oob_fill) + params.handle_name, + params.dtype, + params.tensor_rank, + params.global_address, + ",".join(params.global_dim), + ",".join(params.global_stride), + ",".join(params.box_dim), + ",".join(params.element_strides), + params.interleave, + params.swizzle, + params.l2_promotion, + params.oob_fill, + ) else: tma_descripter_init += TMA_IM2COL_DESC_INIT_FUNC.format( - params.handle_name, params.dtype, params.tensor_rank, params.global_address, - ",".join(params.global_dim), ",".join(params.global_stride), - ",".join(params.element_strides), ",".join(params.lower_corner), - ",".join(params.upper_corner), params.smem_box_channel, params.smem_box_pixel, - params.interleave, params.swizzle, params.l2_promotion, params.oob_fill) + params.handle_name, + params.dtype, + params.tensor_rank, + params.global_address, + ",".join(params.global_dim), + ",".join(params.global_stride), + ",".join(params.element_strides), + ",".join(params.lower_corner), + ",".join(params.upper_corner), + params.smem_box_channel, + params.smem_box_pixel, + params.interleave, + params.swizzle, + params.l2_promotion, + params.oob_fill, + ) return tma_descripter_init @@ -347,9 +377,8 @@ class TLCUDASourceWrapper: device_mod, host_mod = get_annotated_mod(self.mod, self.target) self.device_mod = device_mod self.host_mod = host_mod - assert (len(self.device_mod.functions) - >= 1), "Device module should have at least one function." - assert (len(self.host_mod.functions) == 1), "Only support one function in host module." + assert len(self.device_mod.functions) >= 1, "Device module should have at least one function." + assert len(self.host_mod.functions) == 1, "Only support one function in host module." block_info_map = {} grid_info_map = {} @@ -438,8 +467,7 @@ class TLCUDASourceWrapper: for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items(): if dynamic_smem_buf is not None: # Format the cudaFuncSetAttribute call for dynamic shared memory - call_str += PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY.format( - function_name, dynamic_smem_buf) + call_str += PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY.format(function_name, dynamic_smem_buf) # Format the initialization function using the call_str init_funcs = PREDEF_INIT_FUNC.format(call_str) return init_funcs @@ -466,17 +494,14 @@ class TLCUDASourceWrapper: def visitor(node, fn=function_name, param_cnt=kernel_params_cnt): nonlocal function_params if isinstance(node, tvm.tir.Call): - if not (hasattr(node, "op") and - node.op == tvm.ir.Op.get("tir.tvm_call_packed")): + if not (hasattr(node, "op") and node.op == tvm.ir.Op.get("tir.tvm_call_packed")): return args = node.args if not args or args[0] != fn: return if len(args) < 1 + param_cnt: - raise AssertionError( - "tvm_call_packed should have at least 1 argument and match device function parameters" - ) - function_params = args[1:1 + param_cnt] + raise AssertionError("tvm_call_packed should have at least 1 argument and match device function parameters") + function_params = args[1 : 1 + param_cnt] post_order_visit(self.host_func.body, visitor) assert function_params is not None, "function_params should not be None" @@ -564,13 +589,15 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper): "uchar": "uint8_t", } - def __init__(self, - scheduled_ir_module: IRModule, - source: str, - target: Target, - device_mod: IRModule | None = None, - host_mod: IRModule | None = None, - pass_configs: dict[str, Any] | None = None): + def __init__( + self, + scheduled_ir_module: IRModule, + source: str, + target: Target, + device_mod: IRModule | None = None, + host_mod: IRModule | None = None, + pass_configs: dict[str, Any] | None = None, + ): super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs) def get_init_func(self): @@ -580,8 +607,7 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper): for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items(): if dynamic_smem_buf is not None: # Format the cudaFuncSetAttribute call for dynamic shared memory - call_str += PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY_HIP.format( - function_name, dynamic_smem_buf) + call_str += PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY_HIP.format(function_name, dynamic_smem_buf) # Format the initialization function using the call_str init_funcs = PREDEF_INIT_FUNC.format(call_str) return init_funcs @@ -623,13 +649,15 @@ class TLCPUSourceWrapper: host_mod: IRModule | None = None pass_configs: dict[str, Any] | None = None - def __init__(self, - scheduled_ir_module: IRModule, - source: str, - target: Target, - device_mod: IRModule | None = None, - host_mod: IRModule | None = None, - pass_configs: dict[str, Any] | None = None): + def __init__( + self, + scheduled_ir_module: IRModule, + source: str, + target: Target, + device_mod: IRModule | None = None, + host_mod: IRModule | None = None, + pass_configs: dict[str, Any] | None = None, + ): self.mod = scheduled_ir_module self.target = target self.source = source @@ -658,15 +686,16 @@ class TLCPUSourceWrapper: for param in self.prim_func.params: if param in self.prim_func.buffer_map: buffer = self.prim_func.buffer_map[param] - function_args.append({ - "name": buffer.name, - "type": self._lookup_type(buffer.dtype) + "*", - }) + function_args.append( + { + "name": buffer.name, + "type": self._lookup_type(buffer.dtype) + "*", + } + ) elif isinstance(param, tvm.tir.Var): function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)}) else: - raise ValueError( - f"Parameter {param} is not in the buffer map of the primary function.") + raise ValueError(f"Parameter {param} is not in the buffer map of the primary function.") # Add dynamic symbols as integer arguments for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set: function_args.append({"name": dyn_sym, "type": self._lookup_type(dyn_sym_dtype)}) @@ -686,7 +715,6 @@ class TLCPUSourceWrapper: _call_str = """""" for function_name, _ in function_informations.items(): - # Find the location of the global kernel function in the code index = match_declare_kernel_cpu(code, function_name + "(") @@ -706,8 +734,8 @@ class TLCPUSourceWrapper: def parse_source_information(self): with tvm.transform.PassContext(opt_level=3, config=self.pass_configs): device_mod, host_mod = get_annotated_mod(self.mod, self.target) - assert (len(device_mod.functions) >= 1), "Device module should have at least one function." - assert (len(host_mod.functions) == 1), "Only support one function in host module." + assert len(device_mod.functions) >= 1, "Device module should have at least one function." + assert len(host_mod.functions) == 1, "Only support one function in host module." function_names = [] for g_var, _ in device_mod.functions.items(): @@ -767,14 +795,15 @@ class TLCPUSourceWrapper: class TLMetalSourceWrapper: - - def __init__(self, - scheduled_ir_module: IRModule, - source: str, - target: Target, - device_mod: IRModule | None = None, - host_mod: IRModule | None = None, - pass_configs: dict[str, Any] | None = None): + def __init__( + self, + scheduled_ir_module: IRModule, + source: str, + target: Target, + device_mod: IRModule | None = None, + host_mod: IRModule | None = None, + pass_configs: dict[str, Any] | None = None, + ): self.mod = scheduled_ir_module self.target = target self.source = source @@ -792,6 +821,7 @@ class TLWrapper(BaseWrapper): """ A wrapper class for the TileLang backend. """ + device_mod: IRModule | None = None host_mod: IRModule | None = None pass_configs: dict[str, Any] | None = None @@ -836,12 +866,12 @@ class TLWrapper(BaseWrapper): target=self.target, device_mod=self.device_mod, host_mod=self.host_mod, - pass_configs=self.pass_configs) + pass_configs=self.pass_configs, + ) return wrapper.lib_code class TLPyWrapper(TLWrapper): - def __init__(self, target: Target): super().__init__(target) @@ -849,6 +879,7 @@ class TLPyWrapper(TLWrapper): # assert self.scheduled_ir_module is not None, "Please assign optimized module first." if is_cuda_target(self.target): from tilelang.jit.adapter.nvrtc import TLNVRTCSourceWrapper + wrapper_class = TLNVRTCSourceWrapper else: raise ValueError(f"Unsupported target for NVRTC backend: {self.target}") @@ -858,5 +889,6 @@ class TLPyWrapper(TLWrapper): target=self.target, device_mod=self.device_mod, host_mod=self.host_mod, - pass_configs=self.pass_configs) + pass_configs=self.pass_configs, + ) return wrapper.host_func, wrapper.function_names diff --git a/tilelang/jit/execution_backend.py b/tilelang/jit/execution_backend.py index fe6000028357463027966664b044d1f4b59c053d..492e8cb0f645f3411b958345d6c2f75c8c3fdfc4 100644 --- a/tilelang/jit/execution_backend.py +++ b/tilelang/jit/execution_backend.py @@ -46,6 +46,7 @@ def allowed_backends_for_target(target: Target, *, include_unavailable: bool = T # Drop NVRTC if not importable try: from tilelang.jit.adapter.nvrtc import is_nvrtc_available # lazy + if not is_nvrtc_available and "nvrtc" in allowed: allowed = [b for b in allowed if b != "nvrtc"] except Exception: @@ -89,12 +90,14 @@ def resolve_execution_backend(requested: str | None, target: Target) -> str: if req not in allowed_all: raise ValueError( f"Invalid execution backend '{requested}' for target '{_target_kind(target)}'. " - f"Allowed: {_format_options(allowed_all)}. Tip: use execution_backend='auto'.") + f"Allowed: {_format_options(allowed_all)}. Tip: use execution_backend='auto'." + ) # Promote to availability-aware set for nicer errors (e.g., nvrtc not installed) if req not in allowed_avail: raise ValueError( f"Execution backend '{requested}' requires extra dependencies and is not available now. " - f"Try one of: {_format_options(allowed_avail)}.") + f"Try one of: {_format_options(allowed_avail)}." + ) return req diff --git a/tilelang/jit/kernel.py b/tilelang/jit/kernel.py index 22cecf990337500b26a3388ed6ae95f8eb3ce447..c05ef9e5ae9131b20097c25b216e1dab71a41777 100644 --- a/tilelang/jit/kernel.py +++ b/tilelang/jit/kernel.py @@ -1,5 +1,6 @@ from __future__ import annotations from typing import Any, Callable, Generic, Literal, TypeVar + # Python 3.9 compatibility for ParamSpec try: from typing import ParamSpec @@ -14,8 +15,7 @@ import tilelang from tilelang import tvm from tilelang import env from tilelang.engine.param import CompiledArtifact, KernelParam -from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter, - TVMFFIKernelAdapter, MetalKernelAdapter) +from tilelang.jit.adapter import BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter, TVMFFIKernelAdapter, MetalKernelAdapter from tilelang.profiler import Profiler, TensorSupplyType from tilelang.utils.target import determine_target from tilelang.contrib import nvcc as tl_nvcc @@ -24,8 +24,8 @@ import os logger = logging.getLogger(__name__) -_P = ParamSpec('_P') -_T = TypeVar('_T') +_P = ParamSpec("_P") +_T = TypeVar("_T") class JITKernel(Generic[_P, _T]): @@ -41,6 +41,7 @@ class JITKernel(Generic[_P, _T]): torch_function : Callable The compiled function that can be invoked as a PyTorch-compatible function. """ + prim_func: PrimFunc = None artifact: CompiledArtifact = None adapter: BaseKernelAdapter = None @@ -111,9 +112,7 @@ class JITKernel(Generic[_P, _T]): if execution_backend == "cython": from tilelang.contrib.cc import get_cplus_compiler - assert ( - get_cplus_compiler() is not None - ), "Cython backend requires a C++ compiler, please install or use other backends." + assert get_cplus_compiler() is not None, "Cython backend requires a C++ compiler, please install or use other backends." if from_database: return @@ -200,8 +199,7 @@ class JITKernel(Generic[_P, _T]): """ return self.torch_function(*args, **kwds) - def _compile_and_create_adapter(self, tilelang_func: PrimFunc, - out_idx: list[int]) -> BaseKernelAdapter: + def _compile_and_create_adapter(self, tilelang_func: PrimFunc, out_idx: list[int]) -> BaseKernelAdapter: """ Compiles the given TileLang PrimFunc using TVM and creates a kernel adapter. @@ -233,7 +231,8 @@ class JITKernel(Generic[_P, _T]): target=target, target_host=target_host, enable_host_codegen=enable_host_codegen, - enable_device_compile=enable_device_compile) + enable_device_compile=enable_device_compile, + ) self.artifact = artifact @@ -241,7 +240,7 @@ class JITKernel(Generic[_P, _T]): if execution_backend == "tvm_ffi": # Use TVMFFIKernelAdapter for interoperability with PyTorch via DLPack. # But we need to ensure that the runtime is enabled and the runtime module is not None. - assert (artifact.rt_mod is not None), "tvm_ffi backend requires a runtime module." + assert artifact.rt_mod is not None, "tvm_ffi backend requires a runtime module." adapter = TVMFFIKernelAdapter( params=artifact.params, result_idx=out_idx, @@ -283,6 +282,7 @@ class JITKernel(Generic[_P, _T]): ) elif execution_backend == "nvrtc": from tilelang.jit.adapter import NVRTCKernelAdapter + adapter = NVRTCKernelAdapter( params=artifact.params, result_idx=out_idx, @@ -315,16 +315,18 @@ class JITKernel(Generic[_P, _T]): return adapter - def _create_adapter_from_database(self, - params: list[KernelParam], - result_idx: list[int] | int, - target: str | Target, - func_or_mod: PrimFunc | tvm.runtime.Module, - host_kernel_source: str, - device_kernel_source: str, - kernel_lib_path: str, - pass_configs: dict[str, Any] | None = None, - compile_flags: list[str] | None = None) -> BaseKernelAdapter: + def _create_adapter_from_database( + self, + params: list[KernelParam], + result_idx: list[int] | int, + target: str | Target, + func_or_mod: PrimFunc | tvm.runtime.Module, + host_kernel_source: str, + device_kernel_source: str, + kernel_lib_path: str, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ) -> BaseKernelAdapter: target = self.target execution_backend = self.execution_backend @@ -366,6 +368,7 @@ class JITKernel(Generic[_P, _T]): ) elif execution_backend == "nvrtc": from tilelang.jit.adapter import NVRTCKernelAdapter + adapter = NVRTCKernelAdapter.from_database( params=params, result_idx=result_idx, @@ -402,8 +405,7 @@ class JITKernel(Generic[_P, _T]): """ return cls(func=tilelang_func, **kwargs) - def get_profiler(self, - tensor_supply_type: TensorSupplyType = TensorSupplyType.Auto) -> Profiler: + def get_profiler(self, tensor_supply_type: TensorSupplyType = TensorSupplyType.Auto) -> Profiler: """ Creates a profiler to benchmark the compiled runtime module. @@ -417,8 +419,7 @@ class JITKernel(Generic[_P, _T]): Profiler A Profiler instance for benchmarking the runtime module. """ - return Profiler(self.params, self.out_idx, - tensor_supply_type).with_default_adapter(self.adapter) + return Profiler(self.params, self.out_idx, tensor_supply_type).with_default_adapter(self.adapter) def get_kernel_source(self, kernel_only: bool = True) -> str: """ @@ -507,21 +508,19 @@ class JITKernel(Generic[_P, _T]): dir_path = os.path.dirname(kernel_path) if dir_path: os.makedirs(dir_path, exist_ok=True) - with open(kernel_path, 'w') as f: + with open(kernel_path, "w") as f: f.write(self.get_kernel_source()) if host_path is not None: dir_path = os.path.dirname(host_path) if dir_path: os.makedirs(dir_path, exist_ok=True) - with open(host_path, 'w') as f: + with open(host_path, "w") as f: f.write(self.get_host_source()) except Exception as e: logger.error(f"Failed to export sources: {e}") # Backward compatibility alias (deprecated) - def print_source_code(self, - which: Literal["kernel", "host", "both"] = "kernel", - file: str | None = None) -> None: + def print_source_code(self, which: Literal["kernel", "host", "both"] = "kernel", file: str | None = None) -> None: """ Deprecated: use show_source() or export_sources() instead. @@ -541,16 +540,14 @@ class JITKernel(Generic[_P, _T]): >>> # Old API (still works but deprecated) >>> jit_kernel.print_source_code(file="/tmp/kernel.cu") """ - logger.warning( - "print_source_code is deprecated; use show_source() or export_sources() instead.") + logger.warning("print_source_code is deprecated; use show_source() or export_sources() instead.") if file is not None: # Historical behavior wrote only kernel source when file provided self.export_sources(kernel_path=file) else: self.show_source(which=which) - def update_tuner_result(self, latency: float, config: dict[str, Any], - ref_latency: float) -> JITKernel: + def update_tuner_result(self, latency: float, config: dict[str, Any], ref_latency: float) -> JITKernel: """ Updates the tuning results for this kernel. @@ -651,8 +648,7 @@ class JITKernel(Generic[_P, _T]): verbose = self.verbose # Ensure target is set so nvcc picks correct arch via Target.current() with self.target: - return tl_nvcc.get_ptx_from_source( - code, compile_flags=self.compile_flags, verbose=verbose) + return tl_nvcc.get_ptx_from_source(code, compile_flags=self.compile_flags, verbose=verbose) def show_ptx(self) -> None: """ @@ -714,8 +710,7 @@ class JITKernel(Generic[_P, _T]): if verbose is None: verbose = self.verbose with self.target: - return tl_nvcc.get_sass_from_source( - code, compile_flags=self.compile_flags, verbose=verbose) + return tl_nvcc.get_sass_from_source(code, compile_flags=self.compile_flags, verbose=verbose) def show_sass(self) -> None: """ diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index c91ac3cb2fa169e537c52f97419bcf1d88d8ce2c..0f3d5fb13868d1b18997d2fe4c1a97b9a1f1de83 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -1,4 +1,5 @@ """The language interface for tl programs.""" + from __future__ import annotations # from .parser import * @@ -102,7 +103,10 @@ from .utils import index_to_coordinates # noqa: F401 from .symbolics import dynamic, symbolic # noqa: F401 from .annotations import ( # noqa: F401 - use_swizzle, annotate_layout, annotate_safe_value, annotate_l2_hit_ratio, + use_swizzle, + annotate_layout, + annotate_safe_value, + annotate_l2_hit_ratio, ) diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index 73377822bfa4f098b1a789650dcd4300c103ab5a..b26f0b8fef55ef5e875068547c099af45356518e 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -13,8 +13,10 @@ Available allocation functions: Each function takes shape and dtype parameters and returns a TVM buffer object with the appropriate memory scope. """ + from __future__ import annotations from typing import TypeVar, overload, Literal, Callable + # Python 3.9 compatibility for advanced typing features (PEP 646) try: from typing import TypeVarTuple, Unpack # type: ignore[attr-defined] @@ -30,13 +32,11 @@ from .v2.dtypes import dtype as tl_dtype from .v2.builder import OutTensor from .v2.annot import Tensor, SharedBuffer, LocalBuffer, FragmentBuffer -_Shapes = TypeVarTuple('_Shapes') -_DType = TypeVar('_DType') +_Shapes = TypeVarTuple("_Shapes") +_DType = TypeVar("_DType") -def alloc_shared(shape: tuple[Unpack[_Shapes]], - dtype: _DType, - scope="shared.dyn") -> SharedBuffer[Callable[[Unpack[_Shapes]]], _DType]: +def alloc_shared(shape: tuple[Unpack[_Shapes]], dtype: _DType, scope="shared.dyn") -> SharedBuffer[Callable[[Unpack[_Shapes]]], _DType]: """Allocate a shared memory buffer for inter-thread communication. Args: @@ -54,9 +54,7 @@ def alloc_shared(shape: tuple[Unpack[_Shapes]], return T.alloc_buffer(shape, dtype, scope=scope) -def alloc_local(shape: tuple[Unpack[_Shapes]], - dtype: _DType, - scope="local") -> LocalBuffer[Callable[[Unpack[_Shapes]]], _DType]: +def alloc_local(shape: tuple[Unpack[_Shapes]], dtype: _DType, scope="local") -> LocalBuffer[Callable[[Unpack[_Shapes]]], _DType]: """Allocate a local memory buffer for thread-private storage. Args: @@ -70,9 +68,9 @@ def alloc_local(shape: tuple[Unpack[_Shapes]], return T.alloc_buffer(shape, dtype, scope=scope) -def alloc_fragment(shape: tuple[Unpack[_Shapes]], - dtype: _DType, - scope="local.fragment") -> FragmentBuffer[Callable[[Unpack[_Shapes]]], _DType]: +def alloc_fragment( + shape: tuple[Unpack[_Shapes]], dtype: _DType, scope="local.fragment" +) -> FragmentBuffer[Callable[[Unpack[_Shapes]]], _DType]: """Allocate a fragment memory buffer for specialized operations. Args: @@ -87,16 +85,11 @@ def alloc_fragment(shape: tuple[Unpack[_Shapes]], @overload -def alloc_var(dtype: str, init: PrimExpr | int | float, scope: str = 'local.var') -> Buffer: - ... +def alloc_var(dtype: str, init: PrimExpr | int | float, scope: str = "local.var") -> Buffer: ... @overload -def alloc_var(dtype: str, - scope: str = 'local.var', - *, - init: PrimExpr | int | float | None = None) -> Buffer: - ... +def alloc_var(dtype: str, scope: str = "local.var", *, init: PrimExpr | int | float | None = None) -> Buffer: ... def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None): @@ -142,8 +135,7 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None): raise TypeError("Scope must be provided as a string in alloc_var.") parsed_scope = parsed_scope_arg elif len(args) > 2: - raise TypeError( - f"alloc_var expected at most 3 positional arguments but got {len(args) + 1}.") + raise TypeError(f"alloc_var expected at most 3 positional arguments but got {len(args) + 1}.") if not isinstance(parsed_scope, str): raise TypeError("Scope must be a string in alloc_var.") @@ -274,13 +266,10 @@ def alloc_tcgen05_instr_desc(dtype: str = "uint32"): @overload -def empty(shape: tuple[Unpack[_Shapes]], - dtype: str = 'float32') -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: - ... +def empty(shape: tuple[Unpack[_Shapes]], dtype: str = "float32") -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: ... -def empty(*shape: Unpack[_Shapes], - dtype: str = 'float32') -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: +def empty(*shape: Unpack[_Shapes], dtype: str = "float32") -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: if len(shape) == 1 and isinstance(shape[0], (tuple, list)): return OutTensor(shape[0], dtype) elif len(shape) == 2 and isinstance(shape[0], (tuple, list)) and isinstance(shape[1], str): @@ -288,4 +277,4 @@ def empty(*shape: Unpack[_Shapes], elif all([isinstance(x, (int, PrimExpr)) for x in shape]): return OutTensor(shape, dtype) else: - raise RuntimeError(f'Invalid shape {shape}') + raise RuntimeError(f"Invalid shape {shape}") diff --git a/tilelang/language/annotations.py b/tilelang/language/annotations.py index 2ce71cb96d7faff34e2de808a59c61ce89a0583f..09cfa58b0e5f42be5a57481d7405b49c9a6b9b6d 100644 --- a/tilelang/language/annotations.py +++ b/tilelang/language/annotations.py @@ -1,4 +1,5 @@ """Annotation helpers exposed on the TileLang language surface.""" + from typing import Callable from tilelang.layout import Layout diff --git a/tilelang/language/ast/__init__.py b/tilelang/language/ast/__init__.py index 9d77454429dc76ba6cc843864c20e2a11b2f9132..6ab6249b14484ca534cdc3ee3cc5bdb6ce6c2767 100644 --- a/tilelang/language/ast/__init__.py +++ b/tilelang/language/ast/__init__.py @@ -17,6 +17,7 @@ # This file is modified from the original version, # which is part of the TVM project (https://tvm.apache.org/). """Package tvm.script.ir_builder.tir""" + from .ir import * # noqa: F401 from .ir import boolean as bool # noqa: F401 from .ir import buffer as Buffer # noqa: F401 diff --git a/tilelang/language/ast/_ffi_api.py b/tilelang/language/ast/_ffi_api.py index 518d57ea8b535123db78e14b61d280d83bf4368d..5cc74762a7089946a4c055d4623e506dce7985e7 100644 --- a/tilelang/language/ast/_ffi_api.py +++ b/tilelang/language/ast/_ffi_api.py @@ -17,6 +17,7 @@ # This file is modified from the original version, # which is part of the TVM project (https://tvm.apache.org/). """FFI APIs""" + import tvm.ffi tvm.ffi._init_api("script.ir_builder.tir", __name__) # pylint: disable=protected-access diff --git a/tilelang/language/ast/ir.py b/tilelang/language/ast/ir.py index 41b658d7cdc967e5c8554cac90f4fb41fc9b2d22..0352514341843d3473eea597d4af00200fc3e3af 100644 --- a/tilelang/language/ast/ir.py +++ b/tilelang/language/ast/ir.py @@ -558,7 +558,8 @@ class axis: # pylint: disable=invalid-name The iteration variable. """ return _ffi_api.AxisSpatial( # type: ignore[attr-defined] # pylint: disable=no-member - _as_range(dom), binding, dtype) + _as_range(dom), binding, dtype + ) @staticmethod def reduce( @@ -585,7 +586,8 @@ class axis: # pylint: disable=invalid-name The iteration variable. """ return _ffi_api.AxisReduce( # type: ignore[attr-defined] # pylint: disable=no-member - _as_range(dom), binding, dtype) + _as_range(dom), binding, dtype + ) @staticmethod def scan( @@ -612,7 +614,8 @@ class axis: # pylint: disable=invalid-name The iteration variable. """ return _ffi_api.AxisScan( # type: ignore[attr-defined] # pylint: disable=no-member - _as_range(dom), binding, dtype) + _as_range(dom), binding, dtype + ) @staticmethod def opaque( @@ -639,7 +642,8 @@ class axis: # pylint: disable=invalid-name The iteration variable. """ return _ffi_api.AxisOpaque( # type: ignore[attr-defined] # pylint: disable=no-member - _as_range(dom), binding, dtype) + _as_range(dom), binding, dtype + ) @staticmethod def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> Union[List[Var], Var]: @@ -662,17 +666,15 @@ class axis: # pylint: disable=invalid-name The iteration variables. """ iter_vars = _ffi_api.AxisRemap( # type: ignore[attr-defined] # pylint: disable=no-member - kinds, bindings, dtype) + kinds, bindings, dtype + ) return iter_vars[0] if len(iter_vars) == 1 else iter_vars S = spatial # pylint: disable=invalid-name R = reduce # pylint: disable=invalid-name -def serial(start: PrimExpr, - stop: PrimExpr = None, - *, - annotations: Dict[str, Any] = None) -> frame.ForFrame: +def serial(start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None) -> frame.ForFrame: """The serial For statement. Parameters @@ -700,10 +702,7 @@ def serial(start: PrimExpr, return _ffi_api.Serial(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member -def parallel(start: PrimExpr, - stop: PrimExpr = None, - *, - annotations: Dict[str, Any] = None) -> frame.ForFrame: +def parallel(start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None) -> frame.ForFrame: """The parallel For statement. Parameters @@ -731,10 +730,7 @@ def parallel(start: PrimExpr, return _ffi_api.Parallel(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member -def vectorized(start: PrimExpr, - stop: PrimExpr = None, - *, - annotations: Dict[str, Any] = None) -> frame.ForFrame: +def vectorized(start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None) -> frame.ForFrame: """The vectorized For statement. Parameters @@ -762,10 +758,7 @@ def vectorized(start: PrimExpr, return _ffi_api.Vectorized(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member -def unroll(start: PrimExpr, - stop: PrimExpr = None, - *, - annotations: Dict[str, Any] = None) -> frame.ForFrame: +def unroll(start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None) -> frame.ForFrame: """The unrolled For statement. Parameters @@ -837,7 +830,8 @@ def thread_binding( else: start = 0 return _ffi_api.ThreadBinding( # type: ignore[attr-defined] # pylint: disable=no-member - start, stop, thread, annotations) + start, stop, thread, annotations + ) def grid(*extents: PrimExpr) -> frame.ForFrame: @@ -878,10 +872,10 @@ def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame: # pylint: d def LetStmt( # pylint: disable=invalid-name - value: PrimExpr, - type_annotation: Optional[Type] = None, # pylint: disable=redefined-outer-name - *, - var: Optional[Var] = None, # pylint: disable=redefined-outer-name + value: PrimExpr, + type_annotation: Optional[Type] = None, # pylint: disable=redefined-outer-name + *, + var: Optional[Var] = None, # pylint: disable=redefined-outer-name ) -> frame.LetFrame: """Create a LetStmt binding @@ -909,8 +903,8 @@ def LetStmt( # pylint: disable=invalid-name def Let( # pylint: disable=invalid-name - expr: PrimExpr, - where: Dict[Var, PrimExpr], # pylint: disable=redefined-outer-name + expr: PrimExpr, + where: Dict[Var, PrimExpr], # pylint: disable=redefined-outer-name ) -> PrimExpr: """Create a Let expression binding""" assert len(where) == 1, "T.Let only allows `where` to have exactly one element" @@ -980,7 +974,8 @@ def realize( The result RealizeFrame. """ return _ffi_api.Realize( # type: ignore[attr-defined] # pylint: disable=no-member - buffer_slice, storage_scope, condition) + buffer_slice, storage_scope, condition + ) def allocate( @@ -1012,7 +1007,8 @@ def allocate( if isinstance(condition, bool): condition = IntImm("bool", condition) return _ffi_api.Allocate( # type: ignore[attr-defined] # pylint: disable=no-member - extents, dtype, scope, condition, annotations) + extents, dtype, scope, condition, annotations + ) def allocate_const( @@ -1048,7 +1044,8 @@ def allocate_const( np_data = np_data.reshape(extents) return _ffi_api.AllocateConst( # type: ignore[attr-defined] # pylint: disable=no-member - ndarray.array(np_data), dtype, extents, annotations) + ndarray.array(np_data), dtype, extents, annotations + ) def attr(node: Any, attr_key: str, value: Union[PrimExpr, str]) -> frame.AttrFrame: @@ -1297,7 +1294,8 @@ def buffer_store( if isinstance(value, bool) and buffer.dtype == "bool": value = IntImm("bool", value) return _ffi_api.BufferStore( # type: ignore[attr-defined] # pylint: disable=no-member - buffer, value, expr_indices) + buffer, value, expr_indices + ) def prefetch( @@ -1464,10 +1462,7 @@ def boolean(expr: Optional[PrimExpr] = None, is_size_var: bool = False) -> PrimE return _ffi_api.Boolean(expr, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member -def handle(dtype: Optional[str] = None, - storage_scope: str = "global", - *, - is_size_var: bool = False) -> Var: +def handle(dtype: Optional[str] = None, storage_scope: str = "global", *, is_size_var: bool = False) -> Var: """Create a TIR var that represents a pointer. Parameters @@ -1667,7 +1662,7 @@ def comm_reducer(combiner: Callable, identity: List[PrimExpr]) -> CommReducer: res = combiner(*args) if not isinstance(res, tuple): res = (res,) - return CommReducer(args[:num_args // 2], args[num_args // 2:], res, identity) + return CommReducer(args[: num_args // 2], args[num_args // 2 :], res, identity) def index_map( @@ -1700,16 +1695,15 @@ def target( The target. """ if not isinstance(target_config, (str, dict)): - raise ValueError( - f"T.target expected a config dict or string, but got {type(target_config)}") + raise ValueError(f"T.target expected a config dict or string, but got {type(target_config)}") if host is not None and not isinstance(host, (str, dict, Target)): - raise ValueError("T.target expected the host to be " - "a config dict, string, or T.target, " - f"but got {type(host)}") + raise ValueError(f"T.target expected the host to be a config dict, string, or T.target, but got {type(host)}") if isinstance(target_config, dict) and "host" in target_config and host is not None: - raise ValueError("T.target expects to either receive the host " - "as part of the target's config dictionary, " - "or as a separate argument, but not both.") + raise ValueError( + "T.target expects to either receive the host " + "as part of the target's config dictionary, " + "or as a separate argument, but not both." + ) return Target(target_config, host) @@ -1742,7 +1736,6 @@ class meta_var: # pylint: disable=invalid-name self.value = value def __iter__(self): - def f(): for i in self.value: yield meta_var(i) @@ -1754,7 +1747,6 @@ class meta_var: # pylint: disable=invalid-name def _op_wrapper(func): - @functools.wraps(func) def wrapped(*args, **kwargs): if "dtype" in kwargs: @@ -1874,7 +1866,6 @@ vscale = _op_wrapper(_tir_op.vscale) def _dtype_forward(func): - @functools.wraps(func) def wrapped(*args, **kwargs): if "dtype" in kwargs: diff --git a/tilelang/language/atomic.py b/tilelang/language/atomic.py index 07e45bbc83c7c06036e81a717be0c647dd904519..89a3af25fc871e21b2ca5bbbd1b9719c6551dfe2 100644 --- a/tilelang/language/atomic.py +++ b/tilelang/language/atomic.py @@ -1,6 +1,7 @@ # Copyright (c) Tile-AI Corporation. # Licensed under the MIT License. """Atomic operations for tilelang.""" + from __future__ import annotations import tilelang.language as T @@ -18,10 +19,7 @@ _MEMORY_ORDER_ID_MAP = { } -def atomic_max(dst: Buffer, - value: PrimExpr, - memory_order: str | None = None, - return_prev: bool = False) -> PrimExpr: +def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None, return_prev: bool = False) -> PrimExpr: """ Perform an atomic maximum on the value stored at dst with an optional memory-order. @@ -64,10 +62,7 @@ def atomic_max(dst: Buffer, return T.call_extern(return_type, func_name, dst, value, _MEMORY_ORDER_ID_MAP[memory_order]) -def atomic_min(dst: Buffer, - value: PrimExpr, - memory_order: str | None = None, - return_prev: bool = False) -> PrimExpr: +def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None, return_prev: bool = False) -> PrimExpr: """ Atomically update the value at dst to the minimum of its current value and value. @@ -112,11 +107,7 @@ def atomic_min(dst: Buffer, return T.call_extern(return_type, func_name, dst, value, _MEMORY_ORDER_ID_MAP[memory_order]) -def atomic_add(dst: Buffer, - value: PrimExpr, - memory_order: str | None = None, - return_prev: bool = False, - use_tma: bool = False) -> PrimExpr: +def atomic_add(dst: Buffer, value: PrimExpr, memory_order: str | None = None, return_prev: bool = False, use_tma: bool = False) -> PrimExpr: """ Atomically add `value` into `dst`, returning a handle to the operation. @@ -191,8 +182,7 @@ def atomic_add(dst: Buffer, if memory_order is None: return T.call_extern(return_type, func_name, dst, value) else: - return T.call_extern(return_type, func_name, dst, value, - _MEMORY_ORDER_ID_MAP[memory_order]) + return T.call_extern(return_type, func_name, dst, value, _MEMORY_ORDER_ID_MAP[memory_order]) if isinstance(dst, Buffer) and isinstance(value, Buffer): ir.assert_structural_equal(dst.shape, value.shape) @@ -208,14 +198,12 @@ def atomic_add(dst: Buffer, # Note: tile-region-based atomic operations don't support return_prev yet # This would need to be implemented in the tile runtime if return_prev: - raise NotImplementedError( - "return_prev is not supported for tile-region-based atomic operations") + raise NotImplementedError("return_prev is not supported for tile-region-based atomic operations") if memory_order is None: return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma, 0) else: - return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma, - _MEMORY_ORDER_ID_MAP[memory_order]) + return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma, _MEMORY_ORDER_ID_MAP[memory_order]) def atomic_addx2(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> PrimExpr: diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index 0bc12fcd8da605a46528d457c293f913f25fd0ab..60739e6110aebe3f15bcdb86ac0f41a0404f22d4 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -1,4 +1,5 @@ """The language interface for tl programs.""" + from __future__ import annotations from tilelang import tvm as tvm @@ -179,38 +180,32 @@ def set_max_nreg(reg_count: int, is_inc: int): def inc_max_nreg(reg_count: int): - """Increment the maximum number of registers to use. - """ + """Increment the maximum number of registers to use.""" return set_max_nreg(reg_count, 1) def dec_max_nreg(reg_count: int): - """Decrement the maximum number of registers to use. - """ + """Decrement the maximum number of registers to use.""" return set_max_nreg(reg_count, 0) def annotate_producer_reg_dealloc(reg_count: int = 24): - """Annotate the producer reg dealloc. - """ + """Annotate the producer reg dealloc.""" return dec_max_nreg(reg_count) def annotate_consumer_reg_alloc(reg_count: int = 240): - """Annotate the consumer reg alloc. - """ + """Annotate the consumer reg alloc.""" return inc_max_nreg(reg_count) def no_set_max_nreg(): - """Disable the maximum register limit setting. - """ + """Disable the maximum register limit setting.""" return tir.call_intrin("handle", tir.op.Op.get("tl.no_set_max_nreg")) def disable_warp_group_reg_alloc(): - """Disable the warp group reg alloc. - """ + """Disable the warp group reg alloc.""" return no_set_max_nreg() @@ -325,7 +320,9 @@ def warpgroup_wait(num_mma: int): return tir.call_intrin("handle", tir.op.Op.get("tl.warpgroup_wait"), num_mma) -def get_lane_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr: +def get_lane_idx( + warp_size: int | PrimExpr | None = None, +) -> PrimExpr: """Return the logical lane index of the calling thread within a warp. Parameters @@ -350,7 +347,9 @@ def get_lane_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr: return tir.call_intrin("int32", tir.op.Op.get("tl.get_lane_idx"), warp_size_expr) -def get_warp_idx_sync(warp_size: int | PrimExpr | None = None,) -> PrimExpr: +def get_warp_idx_sync( + warp_size: int | PrimExpr | None = None, +) -> PrimExpr: """Return the canonical warp index, assuming the warp's threads are converged. Parameters @@ -374,7 +373,9 @@ def get_warp_idx_sync(warp_size: int | PrimExpr | None = None,) -> PrimExpr: return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx_sync"), warp_size_expr) -def get_warp_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr: +def get_warp_idx( + warp_size: int | PrimExpr | None = None, +) -> PrimExpr: """Return the canonical warp index without synchronizing the warp. Parameters @@ -429,8 +430,7 @@ def get_warp_group_idx( args.append(warp_size_expr) if warps_per_group_expr is not None: if warp_size_expr is None: - raise ValueError("get_warp_group_idx expects `warp_size` when specifying " - "`warps_per_group`.") + raise ValueError("get_warp_group_idx expects `warp_size` when specifying `warps_per_group`.") args.append(warps_per_group_expr) return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_group_idx"), *args) @@ -459,10 +459,9 @@ def shuffle_elect(thread_extent: int) -> PrimExpr: return tir.call_intrin("bool", tir.op.Op.get("tl.tl_shuffle_elect"), thread_extent) -def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr, - offset: int | PrimExpr = 0, - num_regs: int | PrimExpr | None = None, - dtype: str | None = None): +def warpgroup_fence_operand( + buffer_or_ptr: tir.Buffer | PrimExpr, offset: int | PrimExpr = 0, num_regs: int | PrimExpr | None = None, dtype: str | None = None +): """Insert a warpgroup fence for the destination accumulator registers. This prevents NVCC from sinking uses of accumulator fragments past the corresponding @@ -517,7 +516,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr, data_ptr, convert(offset), convert(num_regs), - )) + ) + ) if isinstance(buffer_or_ptr, tir.Buffer): data_ptr = buffer_or_ptr.data @@ -531,8 +531,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr, if isinstance(dim, tir.IntImm): total_elems *= int(dim) else: - raise ValueError( - "warpgroup_fence_operand requires num_regs when buffer shape is symbolic.") + raise ValueError("warpgroup_fence_operand requires num_regs when buffer shape is symbolic.") bits_per_elem = DataType(dtype).bits num_regs = (total_elems * bits_per_elem + 31) // 32 elif isinstance(buffer_or_ptr, BufferRegion): @@ -569,9 +568,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr, bits_per_elem = DataType(dtype).bits num_regs = (total_elems * bits_per_elem + 31) // 32 else: - raise ValueError( - "warpgroup_fence_operand requires num_regs when BufferRegion extent is symbolic." - ) + raise ValueError("warpgroup_fence_operand requires num_regs when BufferRegion extent is symbolic.") return evaluate( tir.call_intrin( "handle", @@ -580,7 +577,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr, data_ptr, convert(offset), convert(num_regs), - )) + ) + ) else: data_ptr = buffer_or_ptr # Try to infer dtype from common pointer expressions when not provided @@ -603,9 +601,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr, except Exception: inferred = None if inferred is None: - raise ValueError( - "dtype must be provided when passing a pointer expression and cannot be inferred." - ) + raise ValueError("dtype must be provided when passing a pointer expression and cannot be inferred.") dtype = inferred if num_regs is None: raise ValueError("num_regs must be provided when passing a pointer expression.") @@ -618,7 +614,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr, data_ptr, convert(offset), convert(num_regs), - )) + ) + ) def wait_wgmma(id: int): @@ -673,7 +670,7 @@ def shfl_xor(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call if _IS_HIP_AVAILABLE: return tir.call_extern(value.dtype, "__shfl_xor", value, offset) else: - return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset) + return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xFFFFFFFF, value, offset) def shfl_down(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call): @@ -686,7 +683,7 @@ def shfl_down(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Cal if _IS_HIP_AVAILABLE: return tir.call_extern(value.dtype, "__shfl_down", value, offset) else: - return tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset) + return tir.call_extern(value.dtype, "__shfl_down_sync", 0xFFFFFFFF, value, offset) def shfl_up(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call): @@ -699,12 +696,11 @@ def shfl_up(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call) if _IS_HIP_AVAILABLE: return tir.call_extern(value.dtype, "__shfl_up", value, offset) else: - return tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset) + return tir.call_extern(value.dtype, "__shfl_up_sync", 0xFFFFFFFF, value, offset) def sync_threads(barrier_id: int = None, arrive_count: int = None): - """Synchronize all threads in a block. - """ + """Synchronize all threads in a block.""" args = [] if barrier_id is not None: args.append(barrier_id) @@ -714,8 +710,7 @@ def sync_threads(barrier_id: int = None, arrive_count: int = None): def sync_global(): - """Synchronize all threads in the entire grid. - """ + """Synchronize all threads in the entire grid.""" tx, ty, tz = get_thread_bindings() ex, ey, ez = get_block_extents() print(tx, ty, tz, ex, ey, ez) @@ -724,8 +719,7 @@ def sync_global(): def sync_grid(): - """Synchronize all threads in a grid. - """ + """Synchronize all threads in a grid.""" return tir.call_intrin("handle", tir.op.Op.get("tl.sync_grid")) @@ -741,12 +735,10 @@ def initialize_wgmma_descriptor( if not isinstance(descriptor, (BufferLoad, tir.Buffer)): raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") - if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or - descriptor.shape[0] != 1): + if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1): raise ValueError("Descriptor must be a 1D buffer of size 1.") - descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad( - descriptor, [0]) + descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(descriptor, [0]) return evaluate( tir.call_intrin( @@ -757,7 +749,8 @@ def initialize_wgmma_descriptor( layout_type_, int(leading_byte_offset), int(stride_byte_offset), - )) + ) + ) def initialize_tcgen05_descriptor( @@ -774,12 +767,10 @@ def initialize_tcgen05_descriptor( if not isinstance(descriptor, (BufferLoad, tir.Buffer)): raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") - if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or - descriptor.shape[0] != 1): + if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1): raise ValueError("Descriptor must be a 1D buffer of size 1.") - descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad( - descriptor, [0]) + descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(descriptor, [0]) return evaluate( tir.call_intrin( @@ -792,7 +783,8 @@ def initialize_tcgen05_descriptor( int(base_offset), tir.IntImm("int32", 1 if leading_is_absolute else 0), int(swizzle_mode), - )) + ) + ) def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimExpr: @@ -809,27 +801,21 @@ def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimEx if not isinstance(descriptor, (BufferLoad, tir.Buffer)): raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") - if isinstance(descriptor, tir.Buffer) and len( - descriptor.shape) != 1 or descriptor.shape[0] != 1: + if isinstance(descriptor, tir.Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1: raise ValueError("Descriptor must be a 1D buffer of size 1.") - descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad( - descriptor, [0]) + descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(descriptor, [0]) - return evaluate( - tir.call_intrin("handle", tir.op.Op.get("tl.increase_descriptor_offset"), descriptor, - offset)) + return evaluate(tir.call_intrin("handle", tir.op.Op.get("tl.increase_descriptor_offset"), descriptor, offset)) def loop_break(): - """Break out of the innermost loop. - """ + """Break out of the innermost loop.""" return tir.call_intrin("handle", tir.op.Op.get("tl.loop_break")) def cp_async_barrier_noinc(barrier_id: int | PrimExpr | tir.Call): - """Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc. - """ + """Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.""" return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id) diff --git a/tilelang/language/copy.py b/tilelang/language/copy.py index cabc4a3e4106c2bd32a616da06466a31d06ae3c1..b80a24e7c4eabc446b3083fbeca51d0c5167330f 100644 --- a/tilelang/language/copy.py +++ b/tilelang/language/copy.py @@ -1,4 +1,5 @@ """The language interface for tl programs.""" + from __future__ import annotations from typing import Literal from tilelang import language as T @@ -10,11 +11,13 @@ from tilelang.utils.language import ( from tvm import ir, tir -def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, - dst: tir.Buffer | tir.BufferLoad, - coalesced_width: int | None = None, - disable_tma: bool = False, - eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None): +def copy( + src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, + dst: tir.Buffer | tir.BufferLoad, + coalesced_width: int | None = None, + disable_tma: bool = False, + eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None, +): """Copy data between memory regions. Args: @@ -65,8 +68,7 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, src_extent = get_extent(src) dst_extent = get_extent(dst) # Combine the nested if statements into a single if statement as suggested by SIM102 - if (src_extent is None and dst_extent is None and isinstance(src, tir.BufferLoad) and - isinstance(dst, tir.BufferLoad)): + if src_extent is None and dst_extent is None and isinstance(src, tir.BufferLoad) and isinstance(dst, tir.BufferLoad): # check if the case is like this: # copy(buffer_a[i], buffer_b[i]) where both are BufferLoad nodes # In this case, lower it to a simple BufferStore: buffer_b[i] = buffer_a[i] @@ -90,19 +92,20 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, eviction_policy = 0 else: eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy] - return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.copy"), src, dst, coalesced_width, - disable_tma, eviction_policy) - - -def c2d_im2col(img: tir.Buffer, - col: tir.Buffer, - nhw_step: tir.PrimExpr, - c_step: tir.PrimExpr, - kernel: int, - stride: int, - dilation: int, - pad: int, - eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None): + return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.copy"), src, dst, coalesced_width, disable_tma, eviction_policy) + + +def c2d_im2col( + img: tir.Buffer, + col: tir.Buffer, + nhw_step: tir.PrimExpr, + c_step: tir.PrimExpr, + kernel: int, + stride: int, + dilation: int, + pad: int, + eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None, +): """Perform im2col transformation for 2D convolution. Args: @@ -124,5 +127,16 @@ def c2d_im2col(img: tir.Buffer, eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy] img_region = to_buffer_region(img, access_type="r") col_region = to_buffer_region(col, access_type="w") - return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.c2d_im2col"), img_region, col_region, - nhw_step, c_step, kernel, stride, dilation, pad, eviction_policy) + return tir.call_intrin( + "handle", + tir.op.Op.get("tl.tileop.c2d_im2col"), + img_region, + col_region, + nhw_step, + c_step, + kernel, + stride, + dilation, + pad, + eviction_policy, + ) diff --git a/tilelang/language/customize.py b/tilelang/language/customize.py index 720c9e99134af49783ec49e9b8d907e840e5f120..e2f4b1c8a75c3839d2a8f8fd49c6bcbbc13c3f85 100644 --- a/tilelang/language/customize.py +++ b/tilelang/language/customize.py @@ -1,8 +1,9 @@ """The language interface for tl programs.""" + from __future__ import annotations import tilelang.language as T from tvm.tir import PrimExpr, Buffer, op -from tilelang.utils.language import (bits_product, prim_expr_equal) +from tilelang.utils.language import bits_product, prim_expr_equal from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store # noqa: F401 @@ -46,9 +47,9 @@ def reshape(src: Buffer, shape: list[PrimExpr]) -> Buffer: Returns: Buffer: A new buffer view with the specified shape """ - assert prim_expr_equal( - bits_product(shape, src.dtype), bits_product(src.shape, src.dtype) - ), f"T.reshape/view shape check failed. src {src} src.shape: {src.shape}, src.dtype: {src.dtype}, target shape: {shape}, target dtype: {src.dtype}" + assert prim_expr_equal(bits_product(shape, src.dtype), bits_product(src.shape, src.dtype)), ( + f"T.reshape/view shape check failed. src {src} src.shape: {src.shape}, src.dtype: {src.dtype}, target shape: {shape}, target dtype: {src.dtype}" + ) return T.Tensor(shape, src.dtype, src.data) @@ -61,8 +62,7 @@ def view(src: Buffer, shape: list[PrimExpr] | None = None, dtype: str | None = N shape = src.shape if dtype is None: dtype = src.dtype - assert prim_expr_equal(bits_product(shape, dtype), - bits_product(src.shape, src.dtype)), "T.reshape/view shape check failed." + assert prim_expr_equal(bits_product(shape, dtype), bits_product(src.shape, src.dtype)), "T.reshape/view shape check failed." return T.Tensor(shape, dtype, src.data) diff --git a/tilelang/language/experimental/gemm_sp.py b/tilelang/language/experimental/gemm_sp.py index 4a20f3fb6d77fb218b91368b75d2fa83d8cf0850..5adac9265d1671b488007eff132cffee952d577b 100644 --- a/tilelang/language/experimental/gemm_sp.py +++ b/tilelang/language/experimental/gemm_sp.py @@ -1,4 +1,5 @@ """The language interface for tl programs.""" + from __future__ import annotations from tilelang.primitives.gemm.base import GemmWarpPolicy import tilelang.language as T @@ -11,7 +12,8 @@ from tilelang.utils.language import ( prim_expr_equal, ) from tilelang.language.utils import ( - buffer_region_to_tile_region,) + buffer_region_to_tile_region, +) def gemm_sp( @@ -169,18 +171,19 @@ def gemm_sp_v2( assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor" if len(A_shape) > 2: for i in range(len(A_shape) - 2): - assert A_shape[i] == 1, \ + assert A_shape[i] == 1, ( "current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" + ) if len(B_shape) > 2: for i in range(len(B_shape) - 2): - assert B_shape[i] == 1, \ + assert B_shape[i] == 1, ( "current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" + ) M, N = C_shape K = 2 * (A_shape[-2] if transpose_A else A_shape[-1]) K_B = B_shape[-1] if transpose_B else B_shape[-2] - assert prim_expr_equal( - K, K_B), f"T.gemm_sp K shape check failed: K_A (wo sparse) = {K}, K_B = {K_B}" + assert prim_expr_equal(K, K_B), f"T.gemm_sp K shape check failed: K_A (wo sparse) = {K}, K_B = {K_B}" stride_a = A_stride[-2] stride_b = B_stride[-2] diff --git a/tilelang/language/fill.py b/tilelang/language/fill.py index b23733377a5a8cf893780ec490dec6b3a44ef3a5..af301c264edb2bb053a3e250404d14daf3062bc4 100644 --- a/tilelang/language/fill.py +++ b/tilelang/language/fill.py @@ -1,4 +1,5 @@ """The language interface for tl programs.""" + from __future__ import annotations from tvm import tir from tilelang.language import has_let_value, get_let_value @@ -32,8 +33,7 @@ def fill(buffer: tir.Buffer | tir.BufferRegion | tir.BufferLoad, value: tir.Prim extents = [tir.IntImm("int32", 1) for _ in buffer.indices] else: extents = [] - return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.fill"), - to_buffer_region(buffer, access_type="w", extents=extents), value) + return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.fill"), to_buffer_region(buffer, access_type="w", extents=extents), value) def clear(buffer: tir.Buffer | tir.Var): @@ -55,8 +55,7 @@ def clear(buffer: tir.Buffer | tir.Var): elif isinstance(buffer_region, tir.BufferLoad): region = get_buffer_region_from_load(buffer_region) if region is None: - raise ValueError( - f"Invalid buffer region: {buffer_region}, type: {type(buffer_region)}") + raise ValueError(f"Invalid buffer region: {buffer_region}, type: {type(buffer_region)}") return fill(region, 0) else: raise ValueError(f"Invalid buffer region: {buffer_region}, type: {type(buffer_region)}") diff --git a/tilelang/language/frame.py b/tilelang/language/frame.py index db649952ac870578e08ee6499fd726866572502a..7e60f46ee98da3e3283744b39a575068e7dfed09 100644 --- a/tilelang/language/frame.py +++ b/tilelang/language/frame.py @@ -1,4 +1,5 @@ """Override the LetFrame to print a message when entering the frame.""" + from __future__ import annotations from tvm.ffi import register_object as _register_object from tvm.tir import Var, PrimExpr, BufferLoad, BufferRegion @@ -29,7 +30,7 @@ class FrameStack: item: The frame object to push onto the stack """ self._stack.append(item) - if hasattr(item, 'var') and hasattr(item, 'value'): + if hasattr(item, "var") and hasattr(item, "value"): self._var_value_map[item.var] = item.value def pop(self): @@ -43,7 +44,7 @@ class FrameStack: """ if self._stack: item = self._stack.pop() - if hasattr(item, 'var'): + if hasattr(item, "var"): self._var_value_map.pop(item.var, None) return item raise IndexError(f"{self.__class__.__name__} is empty") @@ -129,8 +130,7 @@ class LetFrame(TIRFrame): is_block_load = True break if is_block_load: - self.value = BufferRegion(self.value.buffer, - [Range(x.base, x.lanes) for x in indices]) + self.value = BufferRegion(self.value.buffer, [Range(x.base, x.lanes) for x in indices]) _get_let_stack().push(self) return self.var diff --git a/tilelang/language/gemm.py b/tilelang/language/gemm.py index db8e04abaa2c6ba3323cee518a9567074c8bd147..56f6805f0094b69ca688cec3fc85dbf054577678 100644 --- a/tilelang/language/gemm.py +++ b/tilelang/language/gemm.py @@ -1,4 +1,5 @@ """The language interface for tl programs.""" + from __future__ import annotations from tilelang.primitives.gemm.base import GemmWarpPolicy import tilelang.language as T @@ -11,7 +12,8 @@ from tilelang.utils.language import ( prim_expr_equal, ) from tilelang.language.utils import ( - buffer_region_to_tile_region,) + buffer_region_to_tile_region, +) from tilelang.env import env as _env @@ -68,12 +70,14 @@ def _gemm_impl( assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor" if len(A_shape) > 2: for i in range(len(A_shape) - 2): - assert A_shape[i] == 1, \ + assert A_shape[i] == 1, ( "current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" + ) if len(B_shape) > 2: for i in range(len(B_shape) - 2): - assert B_shape[i] == 1, \ + assert B_shape[i] == 1, ( "current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" + ) M, N = C_shape K = A_shape[-2] if transpose_A else A_shape[-1] @@ -96,9 +100,29 @@ def _gemm_impl( A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape]) B_arg = buffer_region_to_tile_region(B_region, "r", [r for r in B_shape]) C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape]) - return tir.call_intrin("handle", tir.op.Op.get(op_key), A_arg, B_arg, C_arg, transpose_A, - transpose_B, M, N, K, policy, clear_accum, stride_a, stride_b, offset_a, - offset_b, k_pack, wg_wait, mbar, C_coords[0], C_coords[1]) + return tir.call_intrin( + "handle", + tir.op.Op.get(op_key), + A_arg, + B_arg, + C_arg, + transpose_A, + transpose_B, + M, + N, + K, + policy, + clear_accum, + stride_a, + stride_b, + offset_a, + offset_b, + k_pack, + wg_wait, + mbar, + C_coords[0], + C_coords[1], + ) # Public wrappers diff --git a/tilelang/language/kernel.py b/tilelang/language/kernel.py index 5e819da70eadc2580704110d352fd6a87d25e2c3..625531b38cb3d6b107f07031aeeae32aa5992ffe 100644 --- a/tilelang/language/kernel.py +++ b/tilelang/language/kernel.py @@ -1,4 +1,5 @@ """The language interface for tl programs.""" + from __future__ import annotations from collections import deque from tvm import tir @@ -107,8 +108,7 @@ class KernelLaunchFrame(TIRFrame): _get_current_stack().push(self) last_block_frame = self.frames[-1] - assert isinstance(last_block_frame, - BlockFrame), f"Last frame must be a block frame, got {last_block_frame}" + assert isinstance(last_block_frame, BlockFrame), f"Last frame must be a block frame, got {last_block_frame}" maybe_cpu = last_block_frame.annotations.get("tilelang.is_cpu_kernel_frame", False) @@ -303,56 +303,48 @@ def Kernel( def get_thread_binding(dim: int = 0) -> Var: - """Returns the thread binding for the given dimension. - """ + """Returns the thread binding for the given dimension.""" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" return KernelLaunchFrame.Current().get_thread_binding(dim) def get_thread_bindings() -> list[Var]: - """Returns all three thread bindings. - """ + """Returns all three thread bindings.""" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" return KernelLaunchFrame.Current().get_thread_bindings() def get_block_binding(dim: int = 0) -> Var: - """Returns the block binding for the given dimension. - """ + """Returns the block binding for the given dimension.""" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" return KernelLaunchFrame.Current().get_block_binding(dim) def get_block_bindings() -> list[Var]: - """Returns all three block bindings. - """ + """Returns all three block bindings.""" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" return KernelLaunchFrame.Current().get_block_bindings() def get_thread_extent(dim: int = 0) -> int: - """Returns the thread extent for the given dimension. - """ + """Returns the thread extent for the given dimension.""" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" return KernelLaunchFrame.Current().get_thread_extent(dim) def get_thread_extents() -> list[int]: - """Returns all three thread extents. - """ + """Returns all three thread extents.""" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" return KernelLaunchFrame.Current().get_thread_extents() def get_block_extent(dim: int = 0) -> int: - """Returns the block extent for the given dimension. - """ + """Returns the block extent for the given dimension.""" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" return KernelLaunchFrame.Current().get_block_extent(dim) def get_block_extents() -> list[int]: - """Returns all three block extents. - """ + """Returns all three block extents.""" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" return KernelLaunchFrame.Current().get_block_extents() diff --git a/tilelang/language/logical.py b/tilelang/language/logical.py index a09088e68ce2d125ee226f55f52c8f7b8089df53..fb4b88a6ee2e10614ddb48359db55a92dd10233f 100644 --- a/tilelang/language/logical.py +++ b/tilelang/language/logical.py @@ -1,4 +1,5 @@ """The language interface for tl programs.""" + from __future__ import annotations from tilelang import language as T @@ -36,8 +37,7 @@ def any_of(buffer: T.Tensor | BufferRegion): ) new_region.append(r.min) buffer_load = BufferLoad(buffer, new_region) - return T.call_intrin(return_type, tir.op.Op.get("tl.any_of"), T.address_of(buffer_load), - extent) + return T.call_intrin(return_type, tir.op.Op.get("tl.any_of"), T.address_of(buffer_load), extent) else: raise ValueError(f"Invalid buffer type: {type(buffer)}") @@ -71,7 +71,6 @@ def all_of(buffer: T.Tensor | BufferRegion): ) new_region.append(r.min) buffer_load = BufferLoad(buffer, new_region) - return T.call_intrin(return_type, tir.op.Op.get("tl.all_of"), T.address_of(buffer_load), - extent) + return T.call_intrin(return_type, tir.op.Op.get("tl.all_of"), T.address_of(buffer_load), extent) else: raise ValueError(f"Invalid buffer type: {type(buffer)}") diff --git a/tilelang/language/loop.py b/tilelang/language/loop.py index 3478b6cc17ffa0e4e83d39fb48042bc95969fe92..f28f097cb90c4ba509c2e9aa808dc48846983df2 100644 --- a/tilelang/language/loop.py +++ b/tilelang/language/loop.py @@ -1,4 +1,5 @@ """The language interface for tl programs.""" + from __future__ import annotations from typing import Any from tvm import tir @@ -94,11 +95,9 @@ def Pipelined( return _ffi_api.Pipelined(start, stop, num_stages, order, stage, sync, group) -def serial(start: tir.PrimExpr, - stop: tir.PrimExpr | None = None, - step: tir.PrimExpr | None = None, - *, - annotations: dict[str, Any] | None = None) -> frame.ForFrame: +def serial( + start: tir.PrimExpr, stop: tir.PrimExpr | None = None, step: tir.PrimExpr | None = None, *, annotations: dict[str, Any] | None = None +) -> frame.ForFrame: step_is_one = False step_is_one |= isinstance(step, int) and step == 1 step_is_one |= isinstance(step, IntImm) and step.value == 1 @@ -111,13 +110,15 @@ def serial(start: tir.PrimExpr, return SerialForWithStep(start, stop, step, annotations=annotations) -def unroll(start: tir.PrimExpr, - stop: tir.PrimExpr | None = None, - step: tir.PrimExpr | None = None, - *, - explicit: bool = False, - unroll_factor: int | None = None, - annotations: dict[str, Any] | None = None) -> frame.ForFrame: +def unroll( + start: tir.PrimExpr, + stop: tir.PrimExpr | None = None, + step: tir.PrimExpr | None = None, + *, + explicit: bool = False, + unroll_factor: int | None = None, + annotations: dict[str, Any] | None = None, +) -> frame.ForFrame: """The unrolled For statement. Parameters diff --git a/tilelang/language/math_intrinsics.py b/tilelang/language/math_intrinsics.py index 39cab27adc5967322ddb6270caf2e95bfae7cc0f..7a6104c74be5cff9ecd70becb3a68f4655055f0e 100644 --- a/tilelang/language/math_intrinsics.py +++ b/tilelang/language/math_intrinsics.py @@ -3,7 +3,7 @@ from tvm import tir def _validate_rounding_mode(rounding_mode): """Validate that the rounding mode is one of the supported IEEE modes""" - valid_modes = {'rn', 'rz', 'ru', 'rd'} + valid_modes = {"rn", "rz", "ru", "rd"} if isinstance(rounding_mode, str) and rounding_mode in valid_modes: return raise ValueError(f"Invalid rounding mode '{rounding_mode}'. Must be one of: {valid_modes}") diff --git a/tilelang/language/overrides/parser.py b/tilelang/language/overrides/parser.py index af42098a2b03ca8ac24e4d1ac1af64b15ce47f4c..0b2fcc44f965c08207283e35966607ef69405c25 100644 --- a/tilelang/language/overrides/parser.py +++ b/tilelang/language/overrides/parser.py @@ -1,4 +1,5 @@ """TVMScript parser overrides tailored for TileLang.""" + from functools import partial from tvm.script.ir_builder import tir as T @@ -58,8 +59,12 @@ def tilelang_visit_assign(self, node: doc.Assign) -> None: # pylint: disable=un lhs.ctx = load_ctx lhs_value = self.eval_expr(lhs) lhs.ctx = store_ctx - if (isinstance(lhs_value, BufferLoad) and lhs_value.buffer.scope() == "local.var" and - len(lhs_value.indices) == 1 and lhs_value.indices[0] == 0): + if ( + isinstance(lhs_value, BufferLoad) + and lhs_value.buffer.scope() == "local.var" + and len(lhs_value.indices) == 1 + and lhs_value.indices[0] == 0 + ): T.buffer_store(lhs_value.buffer, rhs, indices=[0]) continue @@ -106,8 +111,12 @@ def tilelang_visit_aug_assign(self, node: doc.AugAssign) -> None: # pylint: dis lhs.ctx = load_ctx lhs_value = self.eval_expr(lhs) lhs.ctx = store_ctx - if (isinstance(lhs_value, BufferLoad) and lhs_value.buffer.scope() == "local.var" and - len(lhs_value.indices) == 1 and lhs_value.indices[0] == 0): + if ( + isinstance(lhs_value, BufferLoad) + and lhs_value.buffer.scope() == "local.var" + and len(lhs_value.indices) == 1 + and lhs_value.indices[0] == 0 + ): T.buffer_store(lhs_value.buffer, rhs, indices=[0]) return @@ -131,8 +140,12 @@ def tilelang_visit_ann_assign(self, node: doc.AnnAssign) -> None: # pylint: dis lhs.ctx = load_ctx lhs_value = self.eval_expr(lhs) lhs.ctx = store_ctx - if (isinstance(lhs_value, BufferLoad) and lhs_value.buffer.scope() == "local.var" and - len(lhs_value.indices) == 1 and lhs_value.indices[0] == 0): + if ( + isinstance(lhs_value, BufferLoad) + and lhs_value.buffer.scope() == "local.var" + and len(lhs_value.indices) == 1 + and lhs_value.indices[0] == 0 + ): T.buffer_store(lhs_value.buffer, rhs, indices=[0]) return diff --git a/tilelang/language/parser/entry.py b/tilelang/language/parser/entry.py index aa98cf569932d811bf4a5434a8cfe28983b7f73e..5f2aaab7b01da6f863157cbf68f2523ce73231b5 100644 --- a/tilelang/language/parser/entry.py +++ b/tilelang/language/parser/entry.py @@ -18,6 +18,7 @@ # which is part of the TVM project (https://tvm.apache.org/). # ruff: noqa """The entry point of TVM parser for tir.""" + import inspect from typing import Callable, Optional, Union @@ -29,9 +30,7 @@ from tvm.script.parser._core import parse, scan_macro, utils from tvm.script.parser.core.parser import Parser, ScriptMacro -def prim_func(func: Optional[Callable] = None, - private: bool = False, - check_well_formed=True) -> Union[PrimFunc, Callable]: +def prim_func(func: Optional[Callable] = None, private: bool = False, check_well_formed=True) -> Union[PrimFunc, Callable]: """The parsing method for tir prim func, by using `@prim_func` as decorator. Parameters @@ -149,8 +148,7 @@ def macro(*args, hygienic: bool = True) -> Callable: if len(args) == 1 and inspect.isfunction(args[0]): return _decorator(args[0]) - raise ValueError( - "Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])") + raise ValueError("Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])") class BufferProxy: diff --git a/tilelang/language/parser/operation.py b/tilelang/language/parser/operation.py index b2138acf3acd70df345f5e5f27b82728eeb41db6..473da43275a6252588288f3bafda1572d314b978 100644 --- a/tilelang/language/parser/operation.py +++ b/tilelang/language/parser/operation.py @@ -17,6 +17,7 @@ # This file is modified from the original version, # which is part of the TVM project (https://tvm.apache.org/). """The tir expression operation registration""" + from tvm import tir from tvm.ffi.runtime_ctypes import DataType, DataTypeCode from tvm.tir import IntImm @@ -55,11 +56,9 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name return dtype[0:index] def _auto_broadcast(a, b, op): - if isinstance(a, int): if hasattr(b, "dtype"): - if (DataType(b.dtype).type_code == DataTypeCode.INT or - DataType(b.dtype).type_code == DataTypeCode.UINT): + if DataType(b.dtype).type_code == DataTypeCode.INT or DataType(b.dtype).type_code == DataTypeCode.UINT: a = IntImm(_get_type_str(b.dtype), a) elif DataType(b.dtype).type_code == DataTypeCode.FLOAT: a = FloatImm(_get_type_str(b.dtype), a) @@ -75,8 +74,7 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name assert isinstance(a, tir.PrimExpr), "Operand should be a PrimExpr." if isinstance(b, int): - if (DataType(a.dtype).type_code == DataTypeCode.INT or - DataType(a.dtype).type_code == DataTypeCode.UINT): + if DataType(a.dtype).type_code == DataTypeCode.INT or DataType(a.dtype).type_code == DataTypeCode.UINT: b = IntImm(_get_type_str(a.dtype), b) elif DataType(a.dtype).type_code == DataTypeCode.FLOAT: b = FloatImm(_get_type_str(a.dtype), b) @@ -85,10 +83,10 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name if DataType(a.dtype).lanes == DataType(b.dtype).lanes: return op(a, b) - elif (DataType(a.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes): + elif DataType(a.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes: broadcast_a = tir.Broadcast(a, DataType(b.dtype).lanes) return op(broadcast_a, b) - elif (DataType(b.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes): + elif DataType(b.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes: broadcast_b = tir.Broadcast(b, DataType(a.dtype).lanes) return op(a, broadcast_b) else: diff --git a/tilelang/language/parser/parser.py b/tilelang/language/parser/parser.py index 3aa720d4e6efa785ce6aad45759e4bad4de980e2..4cac0ad74f56de23c3edd2145fae444bef281423 100644 --- a/tilelang/language/parser/parser.py +++ b/tilelang/language/parser/parser.py @@ -146,8 +146,7 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) - res = value.__enter__() IRBuilder.name(var_name, res) return res - elif isinstance(value, (Buffer, IterVar)) or (isinstance(value, Var) and - not self.var_table.exist(value)): + elif isinstance(value, (Buffer, IterVar)) or (isinstance(value, Var) and not self.var_table.exist(value)): IRBuilder.name(var_name, value) return value else: @@ -191,8 +190,7 @@ def visit_for(self: Parser, node: doc.For) -> None: if not isinstance(for_frame, T.frame.ForFrame): self.report_error( node.iter, - "Expect the for loop to be one of the following: " - "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding", + "Expect the for loop to be one of the following: range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding", ) with self.var_table.with_frame(): with for_frame as iters: @@ -361,8 +359,7 @@ def visit_with(self: Parser, node: doc.With) -> None: for item in node.items: frame = self.eval_expr(item.context_expr) if not isinstance(frame, Frame): - self.report_error(item.context_expr, - "Invalid context expression in the with-statement.") + self.report_error(item.context_expr, "Invalid context expression in the with-statement.") rhs = stack.enter_context(frame) if item.optional_vars is not None: self.eval_assign(target=item.optional_vars, source=rhs, bind_value=bind_with_value) @@ -505,8 +502,7 @@ def visit_if(self: Parser, node: doc.If) -> None: with self.var_table.with_frame(): self.visit_body(node.orelse) else: - self.report_error(node.test, - f"If condition must be a boolean expression, but got {predicate}") + self.report_error(node.test, f"If condition must be a boolean expression, but got {predicate}") @dispatch.register(token="tir", type_name="Assert") diff --git a/tilelang/language/print.py b/tilelang/language/print.py index 08e18f426019ebd691fd3bc18b2aed358880ebbb..bbaa119ed55d7adbe0637ffff56d617e4d616454 100644 --- a/tilelang/language/print.py +++ b/tilelang/language/print.py @@ -26,9 +26,7 @@ def print_var(var: tir.PrimExpr, msg: str = "") -> tir.PrimExpr: @macro -def print_var_with_condition(condition: tir.PrimExpr, - var: tir.PrimExpr, - msg: str = "") -> tir.PrimExpr: +def print_var_with_condition(condition: tir.PrimExpr, var: tir.PrimExpr, msg: str = "") -> tir.PrimExpr: """ Conditionally prints a TIR primitive expression (PrimExpr) if a given condition is True. @@ -44,10 +42,7 @@ def print_var_with_condition(condition: tir.PrimExpr, @macro -def print_global_buffer_with_condition(condition: tir.PrimExpr, - buffer: tir.Buffer, - elems: int, - msg: str = "") -> tir.PrimExpr: +def print_global_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> tir.PrimExpr: """ Conditionally prints the values of a flattened TIR buffer if the condition is True. """ @@ -55,17 +50,13 @@ def print_global_buffer_with_condition(condition: tir.PrimExpr, # Iterate through the buffer elements and print each one. for i in serial(elems): coords = index_to_coordinates(i, buffer.shape) - tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, - buffer[coords]) + tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords]) else: tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords]) @macro -def print_shared_buffer_with_condition(condition: tir.PrimExpr, - buffer: tir.Buffer, - elems: int, - msg: str = "") -> tir.PrimExpr: +def print_shared_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> tir.PrimExpr: """ Conditionally prints the values of a flattened TIR buffer if the condition is True. @@ -81,15 +72,11 @@ def print_shared_buffer_with_condition(condition: tir.PrimExpr, # Iterate through the buffer elements and print each one. for i in serial(elems): coords = index_to_coordinates(i, buffer.shape) - tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, - buffer[coords]) + tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords]) @macro -def print_fragment_buffer_with_condition(condition: tir.PrimExpr, - buffer: tir.Buffer, - elems: int, - msg: str = "") -> tir.PrimExpr: +def print_fragment_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> tir.PrimExpr: """ Conditionally prints the values of a flattened TIR buffer if the condition is True. @@ -111,10 +98,7 @@ def print_fragment_buffer_with_condition(condition: tir.PrimExpr, @macro -def print_local_buffer_with_condition(condition: tir.PrimExpr, - buffer: tir.Buffer, - elems: int, - msg: str = "") -> tir.PrimExpr: +def print_local_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> tir.PrimExpr: """ Conditionally prints the values of a flattened TIR buffer if the condition is True. @@ -130,8 +114,7 @@ def print_local_buffer_with_condition(condition: tir.PrimExpr, # Iterate through the buffer elements and print each one. for i in serial(elems): coords = index_to_coordinates(i, buffer.shape) - tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, - buffer[coords]) + tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords]) from tilelang.utils.target import check_cuda_availability @@ -201,7 +184,7 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> elems *= dim # Ensure only the first thread (tx=0, ty=0, tz=0) executes the print. - condition = (tx == main_lane and ty == 0 and tz == 0) + condition = tx == main_lane and ty == 0 and tz == 0 if not msg: msg = f"buffer<{buffer.name}, {buffer.dtype}>" return print_fragment_buffer_with_condition(condition, buffer, elems, msg) @@ -212,7 +195,7 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> elems *= dim # Ensure only the first thread (tx=0, ty=0, tz=0) executes the print. - condition = (tx == main_lane and ty == 0 and tz == 0) + condition = tx == main_lane and ty == 0 and tz == 0 if not msg: msg = f"buffer<{buffer.name}, {buffer.dtype}>" return print_shared_buffer_with_condition(condition, buffer, elems, msg) @@ -234,5 +217,4 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> else: # Unsupported object type. - raise ValueError( - f"Unexpected type: {type(obj)}. Supported types are tir.Buffer and tir.PrimExpr.") + raise ValueError(f"Unexpected type: {type(obj)}. Supported types are tir.Buffer and tir.PrimExpr.") diff --git a/tilelang/language/proxy.py b/tilelang/language/proxy.py index 9e209a1b27d5973471b934233041236c209e42fd..7807a46697b58b785d7e314bbf2397fbe9b24c90 100644 --- a/tilelang/language/proxy.py +++ b/tilelang/language/proxy.py @@ -1,4 +1,5 @@ """The language interface for tl programs.""" + from __future__ import annotations from typing import Any, SupportsIndex, TYPE_CHECKING, Generic, TypeVar @@ -51,11 +52,9 @@ class BufferProxy: return self(keys) return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member - def from_ptr(self, - pointer_var: Var, - shape: tuple[PrimExpr, ...], - dtype: str = "float32", - strides: tuple[PrimExpr, ...] = None) -> Buffer: + def from_ptr( + self, pointer_var: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32", strides: tuple[PrimExpr, ...] = None + ) -> Buffer: """Create a buffer from a pointer, shape, and data type. Args: @@ -76,6 +75,7 @@ class BaseTensorProxy: customizable default values for scope, alignment, and offset factors. It implements the core functionality for creating TIR buffers with specific memory configurations. """ + default_scope = "global" default_align = 0 default_offset_factor = 0 @@ -118,11 +118,9 @@ class BaseTensorProxy: keys = (keys,) return self(*keys) - def from_ptr(self, - pointer_var: Var, - shape: tuple[PrimExpr, ...], - dtype: str = "float32", - strides: tuple[PrimExpr, ...] = None) -> tir.Buffer: + def from_ptr( + self, pointer_var: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32", strides: tuple[PrimExpr, ...] = None + ) -> tir.Buffer: """Create a buffer from a pointer, shape, and data type. Args: @@ -151,19 +149,10 @@ class TensorProxy(BaseTensorProxy): strides.append(s) return tuple(reversed(strides)) - def __call__(self, - shape: tuple[Any] | PrimExpr | int, - dtype: str = "float32", - data=None, - scope=None) -> tir.Buffer: + def __call__(self, shape: tuple[Any] | PrimExpr | int, dtype: str = "float32", data=None, scope=None) -> tir.Buffer: if isinstance(shape, (int, PrimExpr)): shape = (shape,) - return super().__call__( - shape, - dtype=dtype, - strides=TensorProxy._construct_strides(shape), - data=data, - scope=scope) + return super().__call__(shape, dtype=dtype, strides=TensorProxy._construct_strides(shape), data=data, scope=scope) class StridedTensorProxy(BaseTensorProxy): @@ -172,11 +161,7 @@ class StridedTensorProxy(BaseTensorProxy): This class implements the default tensor proxy with global memory scope, with the stride information required. """ - def __call__(self, - shape: tuple[Any], - strides: tuple[Any], - dtype: str = "float32", - scope=None) -> tir.Buffer: + def __call__(self, shape: tuple[Any], strides: tuple[Any], dtype: str = "float32", scope=None) -> tir.Buffer: if len(shape) != len(strides): raise ValueError("Invalid shape/strides' dimensions") return super().__call__(shape, dtype=dtype, strides=strides, scope=scope) @@ -188,6 +173,7 @@ class FragmentBufferProxy(BaseTensorProxy): This class represents tensor proxies specifically for local fragment memory, typically used in GPU tensor core operations. """ + default_scope = "local.fragment" @@ -197,6 +183,7 @@ class SharedBufferProxy(BaseTensorProxy): This class represents tensor proxies for dynamic shared memory, commonly used in GPU shared memory operations. """ + default_scope = "shared.dyn" @@ -206,6 +193,7 @@ class LocalBufferProxy(BaseTensorProxy): This class represents tensor proxies for local memory scope, typically used for temporary computations in GPU kernels. """ + default_scope = "local" @@ -216,15 +204,12 @@ Buffer = BufferProxy() # pylint: disable=invalid-name if TYPE_CHECKING: class BaseTensor: - def __class_getitem__(cls, key): return cls - def __getitem__(self, key) -> Any: - ... + def __getitem__(self, key) -> Any: ... - def __setitem__(self, key, value) -> None: - ... + def __setitem__(self, key, value) -> None: ... def __init__( self, @@ -238,36 +223,26 @@ if TYPE_CHECKING: offset_factor=None, buffer_type="", axis_separators=None, - ): - ... + ): ... @classmethod - def from_ptr(cls, - pointer_var: Var, - shape: Sequence[PrimExpr, ...], - dtype: str = "float32", - strides: tuple[PrimExpr, ...] = None) -> Self: - ... + def from_ptr( + cls, pointer_var: Var, shape: Sequence[PrimExpr, ...], dtype: str = "float32", strides: tuple[PrimExpr, ...] = None + ) -> Self: ... - class Tensor(BaseTensor): - ... + class Tensor(BaseTensor): ... - class StridedTensor(BaseTensor): - ... + class StridedTensor(BaseTensor): ... - class FragmentBuffer(BaseTensor): - ... + class FragmentBuffer(BaseTensor): ... - class SharedBuffer(BaseTensor): - ... + class SharedBuffer(BaseTensor): ... - class LocalBuffer(BaseTensor): - ... + class LocalBuffer(BaseTensor): ... - _T = TypeVar('_T') + _T = TypeVar("_T") - class Ref(Generic[_T], tir.Var): - ... + class Ref(Generic[_T], tir.Var): ... else: Tensor = TensorProxy() # pylint: disable=invalid-name StridedTensor = StridedTensorProxy() # pylint: disable=invalid-name @@ -275,14 +250,10 @@ else: SharedBuffer = SharedBufferProxy() # pylint: disable=invalid-name LocalBuffer = LocalBufferProxy() # pylint: disable=invalid-name - class Ref: - ... + class Ref: ... -def ptr(dtype: str | None = None, - storage_scope: str = "global", - *, - is_size_var: bool = False) -> Var: +def ptr(dtype: str | None = None, storage_scope: str = "global", *, is_size_var: bool = False) -> Var: """Create a TIR var that represents a pointer. Parameters @@ -304,8 +275,5 @@ def ptr(dtype: str | None = None, return handle(dtype=dtype, storage_scope=storage_scope, is_size_var=is_size_var) -def make_tensor(ptr: Var, - shape: tuple[PrimExpr, ...], - dtype: str = "float32", - strides: tuple[PrimExpr, ...] = None) -> tir.Buffer: +def make_tensor(ptr: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32", strides: tuple[PrimExpr, ...] = None) -> tir.Buffer: return Tensor.from_ptr(ptr, shape, dtype, strides) diff --git a/tilelang/language/reduce.py b/tilelang/language/reduce.py index fb84b6d784f96a688ae730b84d8ae8103bd1eef5..9bb3b179b4f708dfb87c22923f40a0bcbcebb49b 100644 --- a/tilelang/language/reduce.py +++ b/tilelang/language/reduce.py @@ -1,4 +1,5 @@ """The language interface for tl programs.""" + from __future__ import annotations from tvm import tir from tilelang.language import copy, macro, alloc_shared, alloc_fragment @@ -30,15 +31,13 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea tir.Call: Handle to the reduction operation """ # input shape: [X, d, Y], expected output shape: [X, Y] or [X, 1, Y] - expected_shapes = [ - buffer.shape[:dim] + buffer.shape[dim + 1:], - buffer.shape[:dim] + [1] + buffer.shape[dim + 1:] - ] + expected_shapes = [buffer.shape[:dim] + buffer.shape[dim + 1 :], buffer.shape[:dim] + [1] + buffer.shape[dim + 1 :]] if list(out.shape) not in expected_shapes: - expected_shapes_str = ' or '.join(map(str, expected_shapes)) + expected_shapes_str = " or ".join(map(str, expected_shapes)) raise ValueError( f"Invalid reduce output shape, buffer shape is {buffer.shape}, dim is {dim}, " - f"output shape is {out.shape}, expected shapes are {expected_shapes_str}") + f"output shape is {out.shape}, expected shapes are {expected_shapes_str}" + ) @macro def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool): diff --git a/tilelang/language/tir/entry.py b/tilelang/language/tir/entry.py index 22702ae432cc6e3c05963796c3647e6021b9183b..82ae7d70f59d569acb97a00695232a455cdccedd 100644 --- a/tilelang/language/tir/entry.py +++ b/tilelang/language/tir/entry.py @@ -7,9 +7,7 @@ from tvm.tir.function import PrimFunc from tvm.script.parser._core import parse, scan_macro, utils -def prim_func(func: Callable | None = None, - private: bool = False, - check_well_formed: bool = False) -> PrimFunc | Callable: +def prim_func(func: Callable | None = None, private: bool = False, check_well_formed: bool = False) -> PrimFunc | Callable: """The parsing method for tir prim func, by using `@prim_func` as decorator. Parameters @@ -113,8 +111,7 @@ def macro(*args, hygienic: bool = True) -> Callable: if len(args) == 1 and inspect.isfunction(args[0]): return _decorator(args[0]) - raise ValueError( - "Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])") + raise ValueError("Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])") setattr(macro, "dispatch_token", "tir") # noqa: B010 diff --git a/tilelang/language/tir/ir.py b/tilelang/language/tir/ir.py index 74cb32f7aa4f33bc1eb8d5751031af66dbad4b3f..a8367933fcce3a186a887876d85ea200c29f7e68 100644 --- a/tilelang/language/tir/ir.py +++ b/tilelang/language/tir/ir.py @@ -6,10 +6,7 @@ import tilelang.language.tir.op as _tir_op import functools -def serial(start: PrimExpr, - stop: PrimExpr = None, - *, - annotations: dict[str, Any] = None) -> frame.ForFrame: +def serial(start: PrimExpr, stop: PrimExpr = None, *, annotations: dict[str, Any] = None) -> frame.ForFrame: """The serial For statement. Parameters @@ -31,10 +28,7 @@ def serial(start: PrimExpr, return _ir.serial(start=start, stop=stop, annotations=annotations) -def parallel(start: PrimExpr, - stop: PrimExpr = None, - *, - annotations: dict[str, Any] = None) -> frame.ForFrame: +def parallel(start: PrimExpr, stop: PrimExpr = None, *, annotations: dict[str, Any] = None) -> frame.ForFrame: """The parallel For statement. Parameters @@ -56,10 +50,7 @@ def parallel(start: PrimExpr, return _ir.parallel(start=start, stop=stop, annotations=annotations) -def vectorized(start: PrimExpr, - stop: PrimExpr = None, - *, - annotations: dict[str, Any] = None) -> frame.ForFrame: +def vectorized(start: PrimExpr, stop: PrimExpr = None, *, annotations: dict[str, Any] = None) -> frame.ForFrame: """The vectorized For statement. Parameters @@ -81,10 +72,7 @@ def vectorized(start: PrimExpr, return _ir.vectorized(start=start, stop=stop, annotations=annotations) -def unroll(start: PrimExpr, - stop: PrimExpr = None, - *, - annotations: dict[str, Any] = None) -> frame.ForFrame: +def unroll(start: PrimExpr, stop: PrimExpr = None, *, annotations: dict[str, Any] = None) -> frame.ForFrame: """The unrolled For statement. Parameters @@ -161,7 +149,6 @@ def grid(*extents: PrimExpr) -> frame.ForFrame: def _dtype_forward(func): - @functools.wraps(func) def wrapped(*args, **kwargs): if "dtype" in kwargs: @@ -172,7 +159,6 @@ def _dtype_forward(func): def _op_wrapper(func): - @functools.wraps(func) def wrapped(*args, **kwargs): if "dtype" in kwargs: diff --git a/tilelang/language/tir/ir.pyi b/tilelang/language/tir/ir.pyi index fe25b58f85573f2cc2941470eba7acdcfaf3dd07..7723f13782bb3e40f5ee4ba3b1242ff6360eb0e6 100644 --- a/tilelang/language/tir/ir.pyi +++ b/tilelang/language/tir/ir.pyi @@ -1,22 +1,22 @@ from typing import TypeVar, Literal from tvm.tir.expr import Span, PrimExpr, BufferLoad, Var, IntImm -_T = TypeVar('_T') +_T = TypeVar("_T") -def abs(x: _T, span: Span | None=None) -> _T: ... +def abs(x: _T, span: Span | None = None) -> _T: ... def acos(x: _T) -> _T: ... def acosh(x: _T) -> _T: ... -def address_of(buffer_load: BufferLoad, span: Span | None=None) -> PrimExpr: ... +def address_of(buffer_load: BufferLoad, span: Span | None = None) -> PrimExpr: ... def asin(x: _T) -> _T: ... def asinh(x: _T) -> _T: ... def atan(x: _T) -> _T: ... def atan2(x1: _T, x2: _T) -> _T: ... def atanh(x: _T) -> _T: ... -def bitwise_and(x: _T, y: _T, span: Span | None=None) -> _T: ... -def bitwise_not(x: _T, span: Span | None=None) -> _T: ... -def bitwise_or(x: _T, y: _T, span: Span | None=None) -> _T: ... -def bitwise_xor(x: _T, y: _T, span: Span | None=None) -> _T: ... -def ceil(x: _T, span: Span | None=None) -> _T: ... +def bitwise_and(x: _T, y: _T, span: Span | None = None) -> _T: ... +def bitwise_not(x: _T, span: Span | None = None) -> _T: ... +def bitwise_or(x: _T, y: _T, span: Span | None = None) -> _T: ... +def bitwise_xor(x: _T, y: _T, span: Span | None = None) -> _T: ... +def ceil(x: _T, span: Span | None = None) -> _T: ... def clz(x: _T) -> _T: ... def copysign(x1: _T, x2: _T) -> _T: ... def cos(x: _T) -> _T: ... @@ -25,35 +25,37 @@ def erf(x: _T) -> _T: ... def exp(x: _T) -> _T: ... def exp2(x: _T) -> _T: ... def exp10(x: _T) -> _T: ... -def floor(x: _T, span: Span | None=None) -> _T: ... -def ceildiv(lhs: _T, rhs: _T, span: Span | None=None) -> _T: ... -def floordiv(a: _T, b: _T, span: Span | None=None) -> _T: ... -def floormod(a: _T, b: _T, span: Span | None=None) -> _T: ... +def floor(x: _T, span: Span | None = None) -> _T: ... +def ceildiv(lhs: _T, rhs: _T, span: Span | None = None) -> _T: ... +def floordiv(a: _T, b: _T, span: Span | None = None) -> _T: ... +def floormod(a: _T, b: _T, span: Span | None = None) -> _T: ... def fmod(x: _T, y: _T) -> _T: ... def hypot(x1: _T, x2: _T) -> _T: ... -def if_then_else(cond: PrimExpr, t: _T, f: _T, span: Span | None=None) -> _T: ... -def infinity(dtype: _T, span: Span | None=None) -> _T: ... -def isfinite(x: _T, span: Span | None=None) -> _T: ... -def isinf(x: _T, span: Span | None=None) -> _T: ... -def isnan(x: _T, span: Span | None=None) -> _T: ... -def isnullptr(x: _T, span: Span | None=None) -> _T: ... +def if_then_else(cond: PrimExpr, t: _T, f: _T, span: Span | None = None) -> _T: ... +def infinity(dtype: _T, span: Span | None = None) -> _T: ... +def isfinite(x: _T, span: Span | None = None) -> _T: ... +def isinf(x: _T, span: Span | None = None) -> _T: ... +def isnan(x: _T, span: Span | None = None) -> _T: ... +def isnullptr(x: _T, span: Span | None = None) -> _T: ... def ldexp(x1: _T, x2: _T) -> _T: ... -def likely(cond: _T, span: Span | None=None) -> _T: ... +def likely(cond: _T, span: Span | None = None) -> _T: ... def log(x: _T) -> _T: ... def log1p(x: _T) -> _T: ... def log2(x: _T) -> _T: ... def log10(x: _T) -> _T: ... -def lookup_param(param_name: str, span: Span | None=None) -> PrimExpr: ... -def max_value(dtype: str, span: Span | None=None) -> PrimExpr: ... -def min_value(dtype: str, span: Span | None=None) -> PrimExpr: ... -def nearbyint(x: _T, span: Span | None=None) -> _T: ... +def lookup_param(param_name: str, span: Span | None = None) -> PrimExpr: ... +def max_value(dtype: str, span: Span | None = None) -> PrimExpr: ... +def min_value(dtype: str, span: Span | None = None) -> PrimExpr: ... +def nearbyint(x: _T, span: Span | None = None) -> _T: ... def nextafter(x1: _T, x2: _T) -> _T: ... def popcount(x: _T) -> _T: ... -def pow(x: _T, y: _T, span: Span | None=None) -> _T: ... +def pow(x: _T, y: _T, span: Span | None = None) -> _T: ... def q_multiply_shift(x: _T, y: _T, q: _T, s: _T) -> _T: ... -def q_multiply_shift_per_axis(x: _T, y: _T, ls: _T, rs: _T, q: IntImm, is_lshift_required: IntImm, is_rshift_required: IntImm) -> PrimExpr: ... +def q_multiply_shift_per_axis( + x: _T, y: _T, ls: _T, rs: _T, q: IntImm, is_lshift_required: IntImm, is_rshift_required: IntImm +) -> PrimExpr: ... def ret(val: _T) -> _T: ... -def round(x: _T, span: Span | None=None) -> _T: ... +def round(x: _T, span: Span | None = None) -> _T: ... def rsqrt(x: _T) -> _T: ... def shift_left(x: _T, y: _T, span=None) -> _T: ... def shift_right(x: _T, y: _T, span=None) -> _T: ... @@ -63,14 +65,16 @@ def sinh(x: _T) -> _T: ... def sqrt(x: _T) -> _T: ... def tan(x: _T) -> _T: ... def tanh(x: _T) -> _T: ... -def trunc(x: _T, span: Span | None=None) -> _T: ... -def truncdiv(a: _T, b: _T, span: Span | None=None) -> _T: ... -def truncmod(a: _T, b: _T, span: Span | None=None) -> _T: ... +def trunc(x: _T, span: Span | None = None) -> _T: ... +def truncdiv(a: _T, b: _T, span: Span | None = None) -> _T: ... +def truncmod(a: _T, b: _T, span: Span | None = None) -> _T: ... def tvm_access_ptr(ptype: PrimExpr, data, offset: int, extent: int, rw_mask: int) -> PrimExpr: ... def tvm_throw_last_error() -> _T: ... def tvm_stack_alloca(dtype_str: str, num: int) -> PrimExpr: ... def tvm_stack_make_shape(*args) -> _T: ... -def tvm_stack_make_array(data: PrimExpr, shape: PrimExpr, strides: PrimExpr, ndim: PrimExpr, arr_dtype: PrimExpr, elem_offset) -> PrimExpr: ... +def tvm_stack_make_array( + data: PrimExpr, shape: PrimExpr, strides: PrimExpr, ndim: PrimExpr, arr_dtype: PrimExpr, elem_offset +) -> PrimExpr: ... def tvm_check_return(expected: int, return_unexpected: int, nested_call: PrimExpr) -> PrimExpr: ... def call_packed(*args, span=None) -> _T: ... def call_cpacked(*args, span=None) -> _T: ... @@ -80,11 +84,47 @@ def tvm_tuple(*value) -> _T: ... def tvm_struct_set(arr, index: int, field: int, value: PrimExpr) -> PrimExpr: ... def tvm_thread_invariant(cond: _T) -> _T: ... def tvm_thread_allreduce(*freduce_args) -> _T: ... -def tvm_load_matrix_sync(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, buffer_ptr: PrimExpr, stride: PrimExpr, layout: Literal['row_major', 'column_major']) -> PrimExpr: ... -def tvm_mma_sync(fragment_d: Var, index_d: PrimExpr, fragment_a: Var, index_a: PrimExpr, fragment_b: Var, index_b: PrimExpr, fragment_c: Var, index_c: PrimExpr) -> PrimExpr: ... -def tvm_bmma_sync(fragment_d: Var, index_d: PrimExpr, fragment_a: Var, index_a: PrimExpr, fragment_b: Var, index_b: PrimExpr, fragment_c: Var, index_c: PrimExpr) -> PrimExpr: ... +def tvm_load_matrix_sync( + fragment: Var, + m: IntImm, + n: IntImm, + k: IntImm, + index: PrimExpr, + buffer_ptr: PrimExpr, + stride: PrimExpr, + layout: Literal["row_major", "column_major"], +) -> PrimExpr: ... +def tvm_mma_sync( + fragment_d: Var, + index_d: PrimExpr, + fragment_a: Var, + index_a: PrimExpr, + fragment_b: Var, + index_b: PrimExpr, + fragment_c: Var, + index_c: PrimExpr, +) -> PrimExpr: ... +def tvm_bmma_sync( + fragment_d: Var, + index_d: PrimExpr, + fragment_a: Var, + index_a: PrimExpr, + fragment_b: Var, + index_b: PrimExpr, + fragment_c: Var, + index_c: PrimExpr, +) -> PrimExpr: ... def tvm_fill_fragment(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, value: PrimExpr) -> PrimExpr: ... -def tvm_store_matrix_sync(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, buffer_ptr: PrimExpr, stride: PrimExpr, layout: Literal['row_major', 'column_major']) -> PrimExpr: ... +def tvm_store_matrix_sync( + fragment: Var, + m: IntImm, + n: IntImm, + k: IntImm, + index: PrimExpr, + buffer_ptr: PrimExpr, + stride: PrimExpr, + layout: Literal["row_major", "column_major"], +) -> PrimExpr: ... def ptx_wait_group(num: int) -> PrimExpr: ... def ptx_commit_group() -> _T: ... def ptx_cp_async_barrier(barrier_id: int) -> PrimExpr: ... @@ -93,7 +133,7 @@ def ptx_arrive_barrier(barrier_id: int) -> PrimExpr: ... def ptx_arrive_barrier_expect_tx(barrier_id: int, byte_count: int) -> PrimExpr: ... def ptx_wait_barrier(barrier_id: int) -> PrimExpr: ... def create_barriers(barrier_count: int) -> PrimExpr: ... -def assume(cond: _T=None) -> _T: ... +def assume(cond: _T = None) -> _T: ... def undef() -> _T: ... def TVMBackendAllocWorkspace(device_type: int, device_id: int, nbytes: int, dtype_code_hint: int, dtype_bits_hint: int) -> PrimExpr: ... def TVMBackendFreeWorkspace(device_type: int, device_id: int, ptr: Var) -> PrimExpr: ... diff --git a/tilelang/language/tir/op.py b/tilelang/language/tir/op.py index a9ce6a536424b60cd0b741a9f7b68427fd470e40..6cf78418479a897c67c3211294624d35e9b26695 100644 --- a/tilelang/language/tir/op.py +++ b/tilelang/language/tir/op.py @@ -724,8 +724,7 @@ def tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout): return _tvm_op.tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout) -def tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, - index_c): +def tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c): """TVM intrinsic for tensor core mma_sync operators Parameters @@ -759,12 +758,10 @@ def tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, call : PrimExpr The call expression. """ - return _tvm_op.tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, - fragment_c, index_c) + return _tvm_op.tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c) -def tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, - index_c): +def tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c): """TVM intrinsic for tensor core bmma_sync operators Parameters @@ -798,8 +795,7 @@ def tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, call : PrimExpr The call expression. """ - return _tvm_op.tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, - fragment_c, index_c) + return _tvm_op.tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c) def tvm_fill_fragment(fragment, m, n, k, index, value): @@ -1121,7 +1117,6 @@ def ptx_wgmma_rs( scale_in_a, scale_in_b, ): - return call_intrin( dtype, _tvm_op.Op.get("tl.ptx_wgmma_rs"), @@ -1345,8 +1340,7 @@ def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, sme call : PrimExpr The call expression. """ - return _tvm_op.ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, - smem_offset) + return _tvm_op.ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset) def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes): @@ -1381,8 +1375,7 @@ def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, by return _tvm_op.ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes) -def ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, - barrier_id): +def ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_id): """TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk @@ -1414,8 +1407,7 @@ def ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offse call : PrimExpr The call expression. """ - return _tvm_op.ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, - bytes, barrier_id) + return _tvm_op.ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_id) def ptx_commit_group(): @@ -2951,8 +2943,7 @@ def q_multiply_shift_per_axis( z : PrimExpr The result. """ - return _tvm_op.q_multiply_shift_per_axis(x, y, ls, rs, q, is_lshift_required, - is_rshift_required) + return _tvm_op.q_multiply_shift_per_axis(x, y, ls, rs, q, is_lshift_required, is_rshift_required) def shift_left(x, y, span=None): @@ -3302,8 +3293,7 @@ def TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dt call : PrimExpr The call expression. """ - return _tvm_op.TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, - dtype_bits_hint) + return _tvm_op.TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint) def TVMBackendFreeWorkspace(device_type, device_id, ptr): diff --git a/tilelang/language/utils.py b/tilelang/language/utils.py index 136bc0bac42489ed3fb42c1caaed4e223fa49f08..7d6829419b5a21f0bbc3cde1759066fbd66f59fc 100644 --- a/tilelang/language/utils.py +++ b/tilelang/language/utils.py @@ -14,23 +14,18 @@ def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: list """Convert a BufferLoad to a tl.region call with explicit extents.""" indices = list(load.indices) if len(indices) > len(extents): - extents = [tir.IntImm("int32", 1) for _ in range(len(indices) - len(extents)) - ] + list(extents) + extents = [tir.IntImm("int32", 1) for _ in range(len(indices) - len(extents))] + list(extents) assert len(indices) == len(extents), f"indices = {indices}, extents = {extents}" return region(load, access_type, *extents) -def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str, - extents: list[tir.PrimExpr]): +def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str, extents: list[tir.PrimExpr]): """Clamp extents and return a tl.region call.""" mins = [r.min for r in buffer_region.region] region_extents = [r.extent for r in buffer_region.region] - assert len(region_extents) >= len(extents), ( - f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}" - ) + assert len(region_extents) >= len(extents), f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}" clamped_extents = [ - tir.min(region_extents[i], extents[i]) if i < len(extents) else region_extents[i] - for i in range(len(region_extents)) + tir.min(region_extents[i], extents[i]) if i < len(extents) else region_extents[i] for i in range(len(region_extents)) ] return region(tir.BufferLoad(buffer_region.buffer, mins), access_type, *clamped_extents) diff --git a/tilelang/language/v2/annot.py b/tilelang/language/v2/annot.py index b61d9d11ccea40b676ab8525565c104e47f4eb63..bac92142ce1c21cd17a2f59a1ffc26c5d91b3d8f 100644 --- a/tilelang/language/v2/annot.py +++ b/tilelang/language/v2/annot.py @@ -5,6 +5,7 @@ from tvm import tir from tvm.ir.expr import PrimExpr from tvm.script.ir_builder.tir import buffer from typing import Any, Callable, Literal, TypeVar, Generic, TYPE_CHECKING + # Python 3.9 compatibility for advanced typing features try: from typing import ParamSpec, TypeVarTuple, Unpack, Self # type: ignore[attr-defined] @@ -37,16 +38,16 @@ from tvm.script.ir_builder import IRBuilder import torch import inspect -_Shapes = TypeVarTuple('_Shapes') -_Shape = ParamSpec('_Shape') -_Stride = ParamSpec('_Stride') -_DType = TypeVar('_DType') +_Shapes = TypeVarTuple("_Shapes") +_Shape = ParamSpec("_Shape") +_Stride = ParamSpec("_Stride") +_DType = TypeVar("_DType") -Scope = Literal['global', 'shared.dyn', 'local', 'local.fragment'] +Scope = Literal["global", "shared.dyn", "local", "local.fragment"] class Annot(ABC): - ''' + """ Base class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel @@ -54,12 +55,12 @@ class Annot(ABC): 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation - ''' + """ def is_kernel_arg(self) -> bool: - ''' + """ Determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) - ''' + """ return False @abstractmethod @@ -68,29 +69,29 @@ class Annot(ABC): @abstractmethod def get_key_parser(self) -> Callable[[str, Any], tuple[Any, ...]]: - ''' + """ Return a parser function that converts the argument value into a hash key for jit caching - ''' + """ @abstractmethod def create_prim_func_arg(self, name: str, value: Any, vt: ArgVarTable) -> tir.Var | tir.Buffer: - ''' + """ Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation - ''' + """ def promote(self) -> TIRAnnot | None: - ''' + """ Try to promote the annotation into a FixedAnnot if possible Return None if not promotable - ''' + """ return None @dataclass class ArgVarTable: - ''' + """ ArgVarTable is used to manage the mapping from argument names to tir.Var objects - ''' + """ var_tab: dict[str, tir.Var] = field(default_factory=dict) tmp_name_idx: int = 0 @@ -103,50 +104,49 @@ class ArgVarTable: return self.var_tab[name] def create_tmp_name(self) -> str: - name = f'varg_{self.tmp_name_idx}' + name = f"varg_{self.tmp_name_idx}" self.tmp_name_idx += 1 return name @dataclass class Value(Annot): - kind: Literal['static', 'dynamic'] = 'dynamic' + kind: Literal["static", "dynamic"] = "dynamic" name: str | None = None dtype: dt.dtype | None = dt.int32 value: int | tir.Var | None = None creator: Callable[[], Any] | None = None def is_kernel_arg(self) -> bool: - return self.kind == 'dynamic' + return self.kind == "dynamic" @classmethod def from_value(cls, value: Any, prefer_name: str = None) -> Value: if isinstance(value, int): # handle A: T.Tensor[[1024, 1024], ...] - return Value(kind='static', name=prefer_name, dtype=dt.int32, value=value) + return Value(kind="static", name=prefer_name, dtype=dt.int32, value=value) elif isinstance(value, float): - return Value(kind='static', name=prefer_name, dtype=dt.float32, value=value) + return Value(kind="static", name=prefer_name, dtype=dt.float32, value=value) elif isinstance(value, dt.dtype): # handle A: T.float32 - return Value(kind='dynamic', name=prefer_name, dtype=value, value=None) + return Value(kind="dynamic", name=prefer_name, dtype=value, value=None) elif isinstance(value, Value): # handle A: T.dyn return value elif isinstance(value, TypeVar): - return Value(kind='static', name=value.__name__, value=None) + return Value(kind="static", name=value.__name__, value=None) elif isinstance(value, (tir.Var, PrimExpr)): # handle A: T.Tensor[[M, N, K], ...] # or primexpr annotation like A: T.Tensor[[M, N * 4 +1]] name = value.name if isinstance(value, tir.Var) else prefer_name - return Value(kind='dynamic', name=name, dtype=value.dtype, value=value) - elif value is Any or value is None or value is dt.dtype or isinstance( - value, (type,) + _GenericAliasTypes): + return Value(kind="dynamic", name=name, dtype=value.dtype, value=value) + elif value is Any or value is None or value is dt.dtype or isinstance(value, (type,) + _GenericAliasTypes): # A # no annotation # A: Any # A: _T # A: dt.dtype # A: tuple[...] - return Value(kind='static', name=prefer_name, value=None) + return Value(kind="static", name=prefer_name, value=None) else: raise TypeError(f"Unsupported Value annotation: {value!r}, type: {type(value)}") @@ -154,7 +154,7 @@ class Value(Annot): return Value(kind=self.kind, name=self.name or name, dtype=self.dtype, value=self.value) def get_key_parser(self): - if self.kind == 'static': + if self.kind == "static": if self.value is not None: expected_value = self.value @@ -172,7 +172,7 @@ class Value(Annot): return self.get_key_parser()(target) def create_prim_func_arg(self, name: str, value: Any, vt: ArgVarTable, create_arg: bool = True): - if self.kind == 'static': + if self.kind == "static": if self.value: assert self.value == value, f"static value mismatch for {name}: expected {self.value}, got {value}" return value @@ -187,18 +187,18 @@ class Value(Annot): return tb_tir.arg(name, arg) if create_arg else arg def __repr__(self): - if self.kind == 'static': + if self.kind == "static": if self.value is not None: return repr(self.value) else: - return (str(self.name) or '$unnamed') + '$' + return (str(self.name) or "$unnamed") + "$" else: if self.value is not None: return repr(self.value) elif self.creator is not None: return repr(self.creator()) else: - return (str(self.name) or '$unnamed') + '$dyn' + return (str(self.name) or "$unnamed") + "$dyn" def _canonicalize_dtype(val: Any) -> dt.dtype | None: @@ -226,7 +226,7 @@ def _shape_with_name(shape: Sequence[Value], base_name: str) -> list[Value]: return None res = [] for i, dim in enumerate(shape): - dim = dim.with_name(f'{base_name}_{i}') + dim = dim.with_name(f"{base_name}_{i}") res.append(dim) return res @@ -236,7 +236,7 @@ def _try_convert_static_shape(shape: Sequence[Value]): return None res = [] for s in shape: - if s.kind == 'static' and s.value is not None or s.kind == 'dynamic' and s.value is not None: + if s.kind == "static" and s.value is not None or s.kind == "dynamic" and s.value is not None: res.append(s.value) if len(res) == len(shape): return res @@ -253,7 +253,7 @@ class BufferAnnot(Annot): @property def scope(self): - return 'global' + return "global" def __call__( self, @@ -290,8 +290,8 @@ class BufferAnnot(Annot): return self.__class__(shape, strides=self.strides, dtype=dtype) def with_name(self, name: str): - shape = _shape_with_name(self.shape, base_name=f'{name}_shape') - strides = _shape_with_name(self.strides, base_name=f'{name}_stride') + shape = _shape_with_name(self.shape, base_name=f"{name}_shape") + strides = _shape_with_name(self.strides, base_name=f"{name}_stride") return self.__class__(shape, strides, self.dtype) def get_key_parser(self): @@ -299,14 +299,14 @@ class BufferAnnot(Annot): if self.shape is not None: raw_shapes = False shape_len = len(self.shape) - static_shape_idx = [i for i, dim in enumerate(self.shape) if dim.kind == 'static'] + static_shape_idx = [i for i, dim in enumerate(self.shape) if dim.kind == "static"] # static_fixed_shape_idx = [i for i, dim in enumerate(self.shape) if dim.kind == 'static' and dim.value is not None] # static_fixed_shape_values = [dim.value for dim in self.shape if dim.kind == 'static' and dim.value is not None] raw_strides = True if self.strides is not None: raw_strides = False strides_len = len(self.strides) - strides_shape_idx = [i for i, dim in enumerate(self.strides) if dim.kind == 'static'] + strides_shape_idx = [i for i, dim in enumerate(self.strides) if dim.kind == "static"] # static_fixed_strides_idx = [i for i, dim in enumerate(self.strides) if dim.kind == 'static' and dim.value is not None] # static_fixed_strides_values = [dim.value for dim in self.strides if dim.kind == 'static' and dim.value is not None] raw_dtype = True @@ -340,9 +340,7 @@ class BufferAnnot(Annot): if not raw_dtype: dtype = dt.dtype(dtype) if dtype != expected_dtype: - raise TypeError( - f"Tensor dtype mismatch for argument `{name}`, expected {expected_dtype}, got {dtype}" - ) + raise TypeError(f"Tensor dtype mismatch for argument `{name}`, expected {expected_dtype}, got {dtype}") return shape, strides, dtype return key_parser @@ -384,7 +382,6 @@ class BufferAnnot(Annot): class TensorAnnot(BufferAnnot): - @staticmethod def _construct_strides(shape: tuple[Any]): s, strides = 1, [1] @@ -419,7 +416,8 @@ class TensorAnnot(BufferAnnot): align=align, offset_factor=offset_factor, buffer_type=buffer_type, - axis_separators=axis_separators) + axis_separators=axis_separators, + ) def promote(self): shape = _try_convert_static_shape(self.shape) @@ -430,7 +428,6 @@ class TensorAnnot(BufferAnnot): class StridedTensorAnnot(BufferAnnot): - def __call__( self, shape, @@ -466,30 +463,27 @@ class StridedTensorAnnot(BufferAnnot): class FragmentBufferAnnot(BufferAnnot): - @property def scope(self): - return 'local.fragment' + return "local.fragment" class SharedBufferAnnot(BufferAnnot): - @property def scope(self): - return 'shared.dyn' + return "shared.dyn" class LocalBufferAnnot(BufferAnnot): - @property def scope(self): - return 'local' + return "local" class DynAnnot(Value): - ''' + """ Dynamic variable annotation represents a tvm tir.Var argument - ''' + """ def __call__(self, dtype: AnyDType = dt.float32, name: str | None = None) -> DynAnnot: return tir.Var(name, dtype) @@ -499,16 +493,16 @@ class DynAnnot(Value): params = (params,) dtype = None if len(params) == 1: - name, = params + (name,) = params if len(params) == 2: dtype, name = params dtype = _canonicalize_dtype(dtype) or dt.int32 - return DynAnnot(kind='dynamic', dtype=dtype, name=name) + return DynAnnot(kind="dynamic", dtype=dtype, name=name) @dataclass class DTypeAnnot(Annot): - ''' + """ Data type annotation ensures automatically conversion from AnyDType to dtype >>> def foo(A: T.dtype): print(A) >>> foo(torch.float32) @@ -517,7 +511,8 @@ class DTypeAnnot(Annot): dtype('float32') >>> foo('float32') dtype('float32') - ''' + """ + name: str | None = None def is_kernel_arg(self) -> bool: @@ -533,15 +528,16 @@ class DTypeAnnot(Annot): return dt.dtype(value) def __repr__(self): - return self.name + '$dtype' + return self.name + "$dtype" @dataclass class TIRAnnot(Annot): - ''' + """ TIR annotation is used to directly pass tir.Buffer or tir.Var as kernel arguments >>> def foo(A: T.Buffer((128,), T.float32)): ... - ''' + """ + data: tir.Buffer | tir.Var def is_kernel_arg(self) -> bool: @@ -564,7 +560,6 @@ class TIRAnnot(Annot): if TYPE_CHECKING: class Buffer(Generic[_Shape, _DType]): - def __init__( shape: tuple[Unpack[_Shapes]], dtype: _DType = "float32", @@ -576,26 +571,20 @@ if TYPE_CHECKING: offset_factor=0, buffer_type="", axis_separators=None, - ) -> Buffer[Callable[[Unpack[_Shapes]]], _DType]: - ... + ) -> Buffer[Callable[[Unpack[_Shapes]]], _DType]: ... @property - def shape(self: Buffer[Callable[[Unpack[_Shapes]]], _DType]) -> tuple[Unpack[_Shapes]]: - ... + def shape(self: Buffer[Callable[[Unpack[_Shapes]]], _DType]) -> tuple[Unpack[_Shapes]]: ... @property - def dtype(self: Buffer[Callable[[Unpack[_Shapes]]], _DType]) -> dt.dtype[_DType]: - ... + def dtype(self: Buffer[Callable[[Unpack[_Shapes]]], _DType]) -> dt.dtype[_DType]: ... @property - def strides(self) -> tuple[tir.PrimExpr]: - ... + def strides(self) -> tuple[tir.PrimExpr]: ... - def scope(self) -> Scope: - ... + def scope(self) -> Scope: ... class Tensor(Generic[_Shape, _DType], Buffer[_Shape, _DType]): - def __new__( shape: tuple[Unpack[_Shapes]], dtype: _DType = "float32", @@ -607,11 +596,9 @@ if TYPE_CHECKING: offset_factor=0, buffer_type="", axis_separators=None, - ) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: - ... + ) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: ... class StridedTensor(Generic[_Shape, _Stride, _DType], Buffer[_Shape, _DType]): - def __new__( shape: tuple[Unpack[_Shapes]], strides=None, @@ -623,8 +610,7 @@ if TYPE_CHECKING: offset_factor=0, buffer_type="", axis_separators=None, - ) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: - ... + ) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: ... class FragmentBuffer(Generic[_Shape, _DType], Buffer[_Shape, _DType]): pass @@ -636,16 +622,12 @@ if TYPE_CHECKING: pass class dyn(tir.Var): - - def __new__(cls, dtype: _DType = "float32", name: str | None = None) -> dyn[_DType]: - ... + def __new__(cls, dtype: _DType = "float32", name: str | None = None) -> dyn[_DType]: ... @property - def dtype(self: dyn[_DType]) -> dt.dtype[_DType]: - ... + def dtype(self: dyn[_DType]) -> dt.dtype[_DType]: ... else: - Buffer = BufferAnnot() Tensor = TensorAnnot() StridedTensor = StridedTensorAnnot() @@ -670,7 +652,7 @@ class FuncAnnot: ker_arg_names = [] for param in sig.parameters.values(): name = param.name - annot = func_annots.get(name, Value('static', name)) + annot = func_annots.get(name, Value("static", name)) if not isinstance(annot, Annot): if not isinstance(annot, type) and callable(annot): annot = annot() @@ -679,7 +661,7 @@ class FuncAnnot: elif isinstance(annot, (tir.Buffer, tir.Var)): annot = TIRAnnot(data=annot) else: - annot = Value(kind='static', name=name) + annot = Value(kind="static", name=name) annot = annot.promote() or annot annots[name] = annot.with_name(name) if annot.is_kernel_arg(): @@ -689,9 +671,9 @@ class FuncAnnot: return FuncAnnot(sig, arg_names, annots, arg_parser, ker_arg_names) def parse_key(self, *args, **kws): - ''' + """ Parse arguments and generates the cache key for jit caching - ''' + """ args = {name: arg for name, arg in zip(self.arg_names, args)} arg_dict = dict(**args, **kws) parsed = [] @@ -706,15 +688,15 @@ class FuncAnnot: return [arg_dict[name] for name in self.ker_arg_names] def create_argument(self, name: str, value: Any, vt: ArgVarTable): - ''' + """ Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation - ''' + """ return self.annots[name].create_prim_func_arg(name, value, vt) def is_all_static(self): - ''' + """ Check if all arguments are static (i.e., can be fully determined at compile time) - ''' + """ return all(isinstance(annot, TIRAnnot) for annot in self.annots.values()) def get_all_static_args(self): diff --git a/tilelang/language/v2/ast.py b/tilelang/language/v2/ast.py index c6dfecf1ebc6b1e9d872d77aa1652e922878c6b9..26c1851ebcee147c8a4ad841e004964f9511b9b3 100644 --- a/tilelang/language/v2/ast.py +++ b/tilelang/language/v2/ast.py @@ -4,16 +4,18 @@ from dataclasses import dataclass from typing import Callable, Generic, Any, Literal, TypeVar from contextlib import AbstractContextManager from collections.abc import Iterable + # Python 3.9 compatibility for ParamSpec try: from typing import ParamSpec except ImportError: # Python < 3.10 from typing_extensions import ParamSpec import inspect + # from .utils import get_ast, get_compiled_object from . import utils -_span_attrs = ['lineno', 'col_offset', 'end_lineno', 'end_col_offset'] +_span_attrs = ["lineno", "col_offset", "end_lineno", "end_col_offset"] def ast_has_span(ast: ast.AST) -> bool: @@ -34,7 +36,6 @@ def ast_set_span(ast: ast.AST, span: tuple[int, int, int, int]): class QuoteVisitor(ast.NodeTransformer): - def __init__(self, names: dict[str, ast.AST], passes: list[Any] | None = None, span=None): self.names = names self.passes = passes or [] @@ -76,9 +77,8 @@ def quote_expr(expr: str, **kws) -> ast.expr: return res.value -Operator = Literal['Add', 'Sub', 'Mult', 'MatMult', 'Div', 'Mod', 'Pow', 'LShift', 'RShift', - 'BitOr', 'BitXor', 'BitAnd', 'FloorDiv'] -BoolOp = Literal['And', 'Or', 'Not'] +Operator = Literal["Add", "Sub", "Mult", "MatMult", "Div", "Mod", "Pow", "LShift", "RShift", "BitOr", "BitXor", "BitAnd", "FloorDiv"] +BoolOp = Literal["And", "Or", "Not"] def get_operator_name(operator: ast.operator) -> Operator: @@ -89,84 +89,83 @@ def get_boolop_name(boolop: ast.boolop) -> BoolOp: return boolop.__class__.__name__ -_T = TypeVar('_T') +_T = TypeVar("_T") def eval_op(op: Operator, left: Any, right: Any) -> Any: - if op == 'Add': + if op == "Add": return left + right - if op == 'Sub': + if op == "Sub": return left - right - if op == 'Mult': + if op == "Mult": return left * right - if op == 'MatMult': + if op == "MatMult": return left @ right - if op == 'Div': + if op == "Div": return left / right - if op == 'Mod': + if op == "Mod": return left % right - if op == 'Pow': + if op == "Pow": return left**right - if op == 'LShift': + if op == "LShift": return left << right - if op == 'RShift': + if op == "RShift": return left >> right - if op == 'BitOr': + if op == "BitOr": return left | right - if op == 'BitXor': + if op == "BitXor": return left ^ right - if op == 'BitAnd': + if op == "BitAnd": return left & right - if op == 'FloorDiv': + if op == "FloorDiv": return left // right - raise ValueError(f'Unknown operator: {op}') + raise ValueError(f"Unknown operator: {op}") def eval_aug_assign(op: Operator, left: Any, sl: slice, right: Any) -> Any: - if op == 'Add': + if op == "Add": left[sl] += right return left - if op == 'Sub': + if op == "Sub": left[sl] -= right return left - if op == 'Mult': + if op == "Mult": left[sl] *= right return left - if op == 'MatMult': + if op == "MatMult": left[sl] @= right return left - if op == 'Div': + if op == "Div": left[sl] /= right return left - if op == 'Mod': + if op == "Mod": left[sl] %= right return left - if op == 'Pow': + if op == "Pow": left[sl] **= right return left - if op == 'LShift': + if op == "LShift": left[sl] <<= right return left - if op == 'RShift': + if op == "RShift": left[sl] >>= right return left - if op == 'BitOr': + if op == "BitOr": left[sl] |= right return left - if op == 'BitXor': + if op == "BitXor": left[sl] ^= right return left - if op == 'BitAnd': + if op == "BitAnd": left[sl] &= right return left - if op == 'FloorDiv': + if op == "FloorDiv": left[sl] //= right return left - raise ValueError(f'Unknown operator: {op}') + raise ValueError(f"Unknown operator: {op}") -class _empty: - ... +class _empty: ... class BaseBuilder: @@ -218,13 +217,13 @@ class BaseBuilder: eval_aug_assign(op, target, sl, aug_value) def boolop(self, op: BoolOp, left: Any, right: Callable[[], Any] | None = None) -> Any: - if op == 'And': + if op == "And": return left and right() - if op == 'Or': + if op == "Or": return left or right() - if op == 'Not': + if op == "Not": return not left - raise ValueError(f'Unknown boolop: {op}') + raise ValueError(f"Unknown boolop: {op}") def ifexp(self, cond: Any, then: Callable[[], Any], otherwise: Callable[[], Any]) -> Any: return then() if cond else otherwise() @@ -249,7 +248,6 @@ class BaseBuilder: class DSLMutator(ast.NodeTransformer): - def __init__(self, closure_names: list[str]): self.tmp_counter = 0 self.closure_names = closure_names @@ -264,19 +262,13 @@ class DSLMutator(ast.NodeTransformer): br = self.get_tmp() if len(node.orelse) == 0: return quote( - f"for {br} in __tb.ctx_if(cond):\n" - f" for _ in __tb.ctx_then({br}):\n" - " pass\n", + f"for {br} in __tb.ctx_if(cond):\n for _ in __tb.ctx_then({br}):\n pass\n", cond=node.test, passes=[node.body], span=node, ) return quote( - f"for {br} in __tb.ctx_if(cond):\n" - f" for _ in __tb.ctx_then({br}):\n" - f" pass\n" - f" for _ in __tb.ctx_else({br}):\n" - f" pass\n", + f"for {br} in __tb.ctx_if(cond):\n for _ in __tb.ctx_then({br}):\n pass\n for _ in __tb.ctx_else({br}):\n pass\n", cond=node.test, passes=[node.body, node.orelse], span=node, @@ -290,7 +282,7 @@ class DSLMutator(ast.NodeTransformer): if isinstance(target, ast.Name): return f"'{target.id}'" elif isinstance(target, ast.Tuple): - return ("(" + ",".join([self._parse_names(elt) for elt in target.elts]) + ",)") + return "(" + ",".join([self._parse_names(elt) for elt in target.elts]) + ",)" else: s = ast.unparse(target) raise NotImplementedError(f"Unsupported for target `{s}`") @@ -303,8 +295,7 @@ class DSLMutator(ast.NodeTransformer): ast_set_span(var, ast_get_span(node.target)) stmts = self._emit_assign_target(node.target, var) return quote( - f"for {tmp} in __tb.ctx_for(range):\n" - " pass\n", + f"for {tmp} in __tb.ctx_for(range):\n pass\n", target=node.target, range=node.iter, passes=[stmts + node.body], @@ -319,24 +310,15 @@ class DSLMutator(ast.NodeTransformer): node = self.generic_visit(node) return quote("if __tb.ctx_break(): break", span=node) - def _emit_assign_target(self, - target: ast.expr, - rval: ast.expr, - annot: ast.expr = None) -> list[ast.AST]: + def _emit_assign_target(self, target: ast.expr, rval: ast.expr, annot: ast.expr = None) -> list[ast.AST]: if isinstance(target, ast.Name): if annot is None: - return quote( - f"name = __tb.bind('{target.id}', value)", name=target, value=rval, span=target) + return quote(f"name = __tb.bind('{target.id}', value)", name=target, value=rval, span=target) else: - return quote( - f'name = __tb.bind("{target.id}", value, annot)', - name=target, - value=rval, - annot=annot, - span=target) + return quote(f'name = __tb.bind("{target.id}", value, annot)', name=target, value=rval, annot=annot, span=target) elif isinstance(target, ast.Attribute): s = ast.unparse(target) - raise NotImplementedError(f'Attribute assignment not supported yet, `{s}`') + raise NotImplementedError(f"Attribute assignment not supported yet, `{s}`") elif isinstance(target, ast.Subscript): if annot is None: return quote( @@ -356,7 +338,6 @@ class DSLMutator(ast.NodeTransformer): span=target, ) else: - # flatten nested tuple into a list of (tmp_name, target) unpacked = [] @@ -374,11 +355,9 @@ class DSLMutator(ast.NodeTransformer): return res else: s = ast.unparse(target) - raise NotImplementedError(f'Attribute assignment not supported yet, `{s}`') + raise NotImplementedError(f"Attribute assignment not supported yet, `{s}`") - unpack_stmt = ast.Assign( - targets=[_visit_target(target)], - value=quote_expr('__tb.unwrap_value(rval)', rval=rval, span=rval)) + unpack_stmt = ast.Assign(targets=[_visit_target(target)], value=quote_expr("__tb.unwrap_value(rval)", rval=rval, span=rval)) ast_set_span(unpack_stmt, ast_get_span(target)) stmts = [unpack_stmt] bind_lvals = [] @@ -386,8 +365,7 @@ class DSLMutator(ast.NodeTransformer): def flush_binds(): if bind_lvals: - stmts.append( - quote1(f'{", ".join(bind_lvals)}, = {", ".join(bind_rvals)},', span=target)) + stmts.append(quote1(f"{', '.join(bind_lvals)}, = {', '.join(bind_rvals)},", span=target)) bind_lvals.clear() bind_rvals.clear() @@ -417,15 +395,10 @@ class DSLMutator(ast.NodeTransformer): bind_rvals.append(f'__tb.bind("{target.id}", {tmp})') elif isinstance(target, ast.Subscript): flush_binds() - stmts.append( - quote1( - f'__tb.assign_slice(lval, slice, {tmp})', - lval=target.value, - slice=target.slice, - span=target)) + stmts.append(quote1(f"__tb.assign_slice(lval, slice, {tmp})", lval=target.value, slice=target.slice, span=target)) else: s = ast.unparse(target) - raise NotImplementedError(f'Unsupported target: {s}') + raise NotImplementedError(f"Unsupported target: {s}") flush_binds() return stmts @@ -450,11 +423,7 @@ class DSLMutator(ast.NodeTransformer): target, rval = node.target, node.value op = get_operator_name(node.op) if isinstance(target, ast.Name): - return quote( - f"name = __tb.aug_assign('{op}', {target.id}, value)", - name=target, - value=rval, - span=node) + return quote(f"name = __tb.aug_assign('{op}', {target.id}, value)", name=target, value=rval, span=node) elif isinstance(target, ast.Subscript): return quote( f"__tb.aug_assign_slice('{op}', lval, slice, value)", @@ -468,16 +437,12 @@ class DSLMutator(ast.NodeTransformer): def visit_AnnAssign(self, node: ast.AnnAssign): node = self.generic_visit(node) - rval = node.value or quote_expr('__tb.empty', span=node, annot=node) + rval = node.value or quote_expr("__tb.empty", span=node, annot=node) return self._emit_assign_target(node.target, rval, annot=node.annotation) def visit_While(self, node): node = self.generic_visit(node) - return quote1( - "for _ in __tb.ctx_while(lambda: cond):\n pass", - cond=node.test, - passes=[node.body], - span=node) + return quote1("for _ in __tb.ctx_while(lambda: cond):\n pass", cond=node.test, passes=[node.body], span=node) def visit_FunctionDef(self, node: ast.FunctionDef): node = self.generic_visit(node) @@ -536,18 +501,14 @@ class DSLMutator(ast.NodeTransformer): left = comp last = split[-1] for i in reversed(range(len(split) - 1)): - last = quote_expr( - "__tb.boolop('And', left, lambda: right)", left=split[i], right=last, span=node) + last = quote_expr("__tb.boolop('And', left, lambda: right)", left=split[i], right=last, span=node) return last def visit_IfExp(self, node: ast.IfExp) -> ast.Expr: node = self.generic_visit(node) return quote_expr( - '__tb.ifexp(cond, lambda: then, lambda: otherwise)', - cond=node.test, - then=node.body, - otherwise=node.orelse, - span=node) + "__tb.ifexp(cond, lambda: then, lambda: otherwise)", cond=node.test, then=node.body, otherwise=node.orelse, span=node + ) def visit_Return(self, node: ast.Return): node = self.generic_visit(node) @@ -569,7 +530,7 @@ class DSLMutator(ast.NodeTransformer): return node -_P = ParamSpec('_P') +_P = ParamSpec("_P") @dataclass @@ -626,7 +587,7 @@ def mutate(func: Callable[_P, _T]) -> IRGenerator[_P, _T]: make_closure = utils.get_compiled_object( tree, - 'make_closure', + "make_closure", filename, func.__globals__, # use the original globalns ) diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index 436756df83827e694467f3999de84fa1beed6bac..645a1ad920d8e1230896e3fda3e8a6cb83144e9f 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Callable, Any, Generic, TypeVar, ForwardRef, U from collections.abc import Sequence from .annot import FuncAnnot, ArgVarTable, Annot import pprint + # Python 3.9 compatibility for ParamSpec and Self try: from typing import ParamSpec, Self @@ -32,9 +33,9 @@ logger = logging.getLogger(__name__) def unwrap_expr(expr) -> PrimExpr | int | float: - ''' + """ unwrap expr and convert it into PrimExpr like - ''' + """ if isinstance(expr, tir.meta_var): expr = expr.value elif isinstance(expr, Ref): @@ -47,9 +48,9 @@ def unwrap_expr(expr) -> PrimExpr | int | float: def unwrap_cond(expr): - ''' + """ unwrap expr and convert to bool condition - ''' + """ expr = unwrap_expr(expr) if isinstance(expr, (IntImm, FloatImm, StringImm)): return bool(expr.value) @@ -61,10 +62,10 @@ def unwrap_cond(expr): return bool(expr) else: logger.warning( - f"Python expression `{expr}` is used as condition in TileLang, \n" - "this is treated as a constant expression. ", + f"Python expression `{expr}` is used as condition in TileLang, \nthis is treated as a constant expression. ", stack_info=True, - stacklevel=3) + stacklevel=3, + ) return bool(expr) @@ -72,44 +73,35 @@ thread_local_storage = threading.local() class Frame: - ''' + """ Frame are virtual context managers used in frontend only They do not have any runtime representation in the generated TIR. - ''' + """ - def __enter__(self): - ... + def __enter__(self): ... - def __exit__(self, exc_type, exc_value, traceback): - ... + def __exit__(self, exc_type, exc_value, traceback): ... -class MacroFrame(Frame): - ... +class MacroFrame(Frame): ... -class ExitedMacroFrame(Frame): - ... +class ExitedMacroFrame(Frame): ... -class BoolOpFrame(Frame): - ... +class BoolOpFrame(Frame): ... -class ConstIfFrame(Frame): - ... +class ConstIfFrame(Frame): ... -class BlockFrame(Frame): - ... +class BlockFrame(Frame): ... -class ContinueFrame(Frame): - ... +class ContinueFrame(Frame): ... -class BreakFrame(Frame): - ... +class BreakFrame(Frame): ... @dataclass @@ -145,8 +137,7 @@ class Ref: return self.bufload -class UnrollForWithStep(SerialForWithStep): - ... +class UnrollForWithStep(SerialForWithStep): ... # Python 3.9 compatibility: avoid PEP 604 unions at runtime @@ -172,11 +163,10 @@ TIR_VAR_SCOPE_FRAME = ( def is_var(v: Any) -> bool: - return isinstance(v, Buffer) and v.scope() == 'local.var' + return isinstance(v, Buffer) and v.scope() == "local.var" class Builder(BaseBuilder): - def __init__(self, func_annot: FuncAnnot = None): self.frames: list[AnyFrame] = [] self.ir_builder = IRBuilder() @@ -189,7 +179,7 @@ class Builder(BaseBuilder): @classmethod def current(cls) -> Self: - builder = getattr(thread_local_storage, 'builder', None) + builder = getattr(thread_local_storage, "builder", None) return builder @contextmanager @@ -199,14 +189,15 @@ class Builder(BaseBuilder): tir.func_name(name) yield if len(self.out_idx) != self.out_tensor_cnt: - raise RuntimeError('Not all tensor allocated from `T.empty` are returned') + raise RuntimeError("Not all tensor allocated from `T.empty` are returned") @contextmanager def macro(self, name=None, annotations=None): if self.find_frame_idx(BoolOpFrame) is not None: raise RuntimeError( f"Macro `{name}` is used inside boolean expressions, " - "please use `if` to replace `M and M`, `M or M`, `M if xxx else M` constructs") + "please use `if` to replace `M and M`, `M or M`, `M if xxx else M` constructs" + ) save = self.name_inside_frame, self.macro_arg_annot self.name_inside_frame = {} self.macro_arg_annot = annotations or {} @@ -244,10 +235,7 @@ class Builder(BaseBuilder): def check_continue_break(self): idx = self.find_frame_idx(ContinueOrBreak) if idx is not None: - logger.warning( - 'Writing code after continue/break may cause undefined behavior in tilelang.', - stack_info=True, - stacklevel=3) + logger.warning("Writing code after continue/break may cause undefined behavior in tilelang.", stack_info=True, stacklevel=3) @contextmanager def with_frame(self, frame: AbstractContextManager[Any] | None): @@ -256,8 +244,7 @@ class Builder(BaseBuilder): while len(self.frames) > pop_idx: self.frames.pop().__exit__(None, None, None) - class _has_if_frame: - ... + class _has_if_frame: ... def ctx_if(self, cond): self.check_continue_break() @@ -294,7 +281,7 @@ class Builder(BaseBuilder): elif isinstance(val, tir.frame.IRBuilderFrame): if isinstance(val, tir.frame.ForFrame): logger.warning( - 'Evaluating a for frame may cause undefined behavior in tilelang.', + "Evaluating a for frame may cause undefined behavior in tilelang.", stack_info=True, stacklevel=1, ) @@ -310,8 +297,7 @@ class Builder(BaseBuilder): elif isinstance(val, (Buffer, Var)): pass else: - logger.warning( - f"Unused return value: {val}({type(val)})", stack_info=True, stacklevel=2) + logger.warning(f"Unused return value: {val}({type(val)})", stack_info=True, stacklevel=2) def ctx_for(self, it): self.check_continue_break() @@ -321,15 +307,13 @@ class Builder(BaseBuilder): if isinstance(it.step, (int, IntImm)): step_value = it.step if isinstance(it.step, int) else it.step.value if step_value == 0: - raise ValueError('Invalid stepped serial: step must be non-zero') + raise ValueError("Invalid stepped serial: step must be non-zero") if step_value > 0: real_stop = tir.ceildiv(it.stop - it.start, step_value) else: real_stop = tir.ceildiv(it.start - it.stop, -step_value) else: - logger.warning( - f'Using a non-constant step `{it.step}` in stepped serial may lead to undefined behavior in tilelang' - ) + logger.warning(f"Using a non-constant step `{it.step}` in stepped serial may lead to undefined behavior in tilelang") real_stop = tir.ceildiv(it.stop - it.start, it.step) if isinstance(it, UnrollForWithStep): real_frame = tir.unroll(real_stop, annotations=it.annotations) @@ -338,15 +322,17 @@ class Builder(BaseBuilder): else: raise TypeError( f"Invalid for loop, got {it}({type(it)}), expect one of the following: " - "range, T.serial, T.unroll, T.grid, T.parallel, T.vectorized, T.thread_binding") + "range, T.serial, T.unroll, T.grid, T.parallel, T.vectorized, T.thread_binding" + ) with self.with_frame(real_frame) as v: - IRBuilder.name('_tmp', v) + IRBuilder.name("_tmp", v) yield it.start + v * it.step else: if not isinstance(it, tir.frame.ForFrame): raise TypeError( f"Invalid for loop, got {it}({type(it)}), expect one of the following: " - "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding") + "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding" + ) with self.with_frame(it) as v: yield v @@ -369,15 +355,16 @@ class Builder(BaseBuilder): if not isinstance(cond_v_unwrap, PrimExpr): if cond_v_unwrap: raise RuntimeError( - f'Infinite while loop detected in TileLang\n' - f'Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n' + f"Infinite while loop detected in TileLang\n" + f"Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n" ) else: logger.warning( - 'While loop with constant false condition detected in Tilelang, the loop body will never be executed.\n', - f'Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n', + "While loop with constant false condition detected in Tilelang, the loop body will never be executed.\n", + f"Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n", stack_info=True, - stacklevel=2) + stacklevel=2, + ) with self.with_frame(tir.While(cond_v_unwrap)): yield None @@ -406,14 +393,14 @@ class Builder(BaseBuilder): # 2. Quick return for trivil types if isinstance(value, (tuple, list, tvm.ffi.Array, int, float, str)): return value - if isinstance(value, tir.IntImm) and value.dtype == 'int32': + if isinstance(value, tir.IntImm) and value.dtype == "int32": return value.value if isinstance(value, (Var, Buffer)): # Bind TVM Var/Buffer names and also record scope so reusing the same # Python name (e.g., loop vars like `i`) across different for-frames # works without triggering out-of-scope errors. IRBuilder.name(name, value) - if name != '_': + if name != "_": frame = self.find_frame_idx(TIR_VAR_SCOPE_FRAME) assert frame is not None, f"Variable `{name}` is not defined inside any control flow." self.name_inside_frame[name] = self.frames[frame] @@ -423,12 +410,12 @@ class Builder(BaseBuilder): res = self.bind_immutable(name, value) # 4. Check variable scope and shadowing - if name != '_': + if name != "_": frame = self.find_frame_idx(TIR_VAR_SCOPE_FRAME) assert frame is not None, f"Variable `{name}` is not defined inside any control flow." if name in self.name_inside_frame and self.name_inside_frame[name] in self.frames: logger.warning( - f'Variable `{name}` is declared twice, are you looking for a T.alloc_var?', + f"Variable `{name}` is declared twice, are you looking for a T.alloc_var?", stack_info=True, stacklevel=2, ) @@ -436,9 +423,9 @@ class Builder(BaseBuilder): return res def unwrap_value(self, value): - ''' + """ Unwrap some tilelang objects to get their inner value - ''' + """ value = unwrap_expr(value) # handle bx, by = tl.Kernel(128, 128), rval is frame if isinstance(value, tir.frame.IRBuilderFrame): @@ -447,11 +434,11 @@ class Builder(BaseBuilder): return value def bind_immutable(self, name, value): - ''' + """ Bind an immutable tilelang objects. The immutability means the result is usually not changed or re-assigned in a python block. - ''' - if name == '_': + """ + if name == "_": # use _tmp to make the generated tir more readable name = "_tmp" if isinstance(value, tir.meta_var): @@ -459,18 +446,20 @@ class Builder(BaseBuilder): elif isinstance(value, tir.frame.IRBuilderFrame): if isinstance(value, tir.frame.ForFrame): logger.warning( - 'Binding a for frame to variable may cause undefined behavior in tilelang.', + "Binding a for frame to variable may cause undefined behavior in tilelang.", stack_info=True, stacklevel=2, ) return self.enter_frame(value) elif isinstance(value, OutTensor): - arg = tir.arg(name, - tir.buffer( - shape=value.shape, - dtype=value.dtype, - strides=value.strides, - )) + arg = tir.arg( + name, + tir.buffer( + shape=value.shape, + dtype=value.dtype, + strides=value.strides, + ), + ) arg._out_idx = self.out_tensor_cnt self.out_tensor_cnt += 1 return arg @@ -490,8 +479,7 @@ class Builder(BaseBuilder): def assign_slice(self, lval: Any, sl: slice, value: Any, annot=BaseBuilder.empty): self.check_continue_break() if annot is not self.empty: - logger.warning( - "Type annotation in slice assignment has no effect", stack_info=True, stacklevel=2) + logger.warning("Type annotation in slice assignment has no effect", stack_info=True, stacklevel=2) if isinstance(lval, Buffer): tir.buffer_store(lval, value, sl) else: @@ -521,11 +509,11 @@ class Builder(BaseBuilder): left = unwrap_cond(left) if isinstance(left, PrimExpr): with self.with_frame(BoolOpFrame()): - if op == 'And': + if op == "And": return tir.And(left, right()) - if op == 'Or': + if op == "Or": return tir.Or(left, right()) - if op == 'Not': + if op == "Not": return tir.Not(left) raise RuntimeError(f"Unsupported boolean operator: {op}") else: @@ -557,7 +545,7 @@ class Builder(BaseBuilder): "You should allocate a var before the control flow, assign value inside the blocks, \n" "and return the var after the control flow. i.e.\n" "```\n" - "@T.macro\n" \ + "@T.macro\n" "def my_macro(cond):\n" " a = T.alloc_var(T.float16)\n" " if cond:\n" @@ -570,14 +558,12 @@ class Builder(BaseBuilder): if not isinstance(value, tuple): value = (value,) for v in value: - if not isinstance(v, Buffer) or not hasattr(v, '_out_idx'): - raise RuntimeError( - f'Only tensor allocated from `T.empty` can be returned in a prim_func, got {v}({type(v)})' - ) + if not isinstance(v, Buffer) or not hasattr(v, "_out_idx"): + raise RuntimeError(f"Only tensor allocated from `T.empty` can be returned in a prim_func, got {v}({type(v)})") # convert 0, 1, 2 => -3, -2, -1 as the out tensor index self.out_idx.append(v._out_idx - self.out_tensor_cnt) if len(self.out_idx) != self.out_tensor_cnt: - raise RuntimeError(f'Not all tensor from `T.empty` are returned, only got {value}') + raise RuntimeError(f"Not all tensor from `T.empty` are returned, only got {value}") return NotImplemented def ctx_with(self, ctx): @@ -591,7 +577,7 @@ class Builder(BaseBuilder): self.check_continue_break() cond = unwrap_cond(cond) if msg is None: - msg = 'Assertion failed' + msg = "Assertion failed" if isinstance(cond, PrimExpr): self.enter_frame(tir.Assert(cond, msg)) elif not cond: @@ -611,23 +597,18 @@ class Builder(BaseBuilder): annot_value = self.macro_arg_annot.get(name, None) if annot_value is Var or annot_value is Ref: if annot_value is Var: - logger.warning('Use `T.Var` as macro annotations is deprecated, please use `T.Ref`') + logger.warning("Use `T.Var` as macro annotations is deprecated, please use `T.Ref`") if isinstance(value, BufferLoad): if is_var(value.buffer): return value.buffer - idx = [self.bind('_', idx) for idx in value.indices] + idx = [self.bind("_", idx) for idx in value.indices] # indices = self.bind(f'_', value.indices) return Ref(BufferLoad(value.buffer, indices=idx)) if isinstance(value, BufferRegion): - region = [ - Range( - self.bind('_', x.begin), - end=self.bind('_', x.end) if x.end is not None else None) - for x in value.region - ] + region = [Range(self.bind("_", x.begin), end=self.bind("_", x.end) if x.end is not None else None) for x in value.region] return BufferRegion(value.buffer, region=region) raise ValueError( - f'To pass as reference, argument `{name}` is expected to be a variable or a buffer region, but got {value}({type(value)})' + f"To pass as reference, argument `{name}` is expected to be a variable or a buffer region, but got {value}({type(value)})" ) elif isinstance(value, (PrimExpr, int, float)): return self.bind(name, value) @@ -652,13 +633,14 @@ class Builder(BaseBuilder): def override(self, name: str): from tilelang.language import serial - if name == 'range': + + if name == "range": return serial - raise ValueError(f'Unknown override: {name}') + raise ValueError(f"Unknown override: {name}") -_P = ParamSpec('_P') -_T = TypeVar('_T') +_P = ParamSpec("_P") +_T = TypeVar("_T") @dataclass @@ -683,14 +665,8 @@ class PrimFuncCreater(Generic[_P, _T]): return res def __repr__(self): - fmt = pprint.pformat( - { - 'annot': self.func_annot.annots, - 'ir_gen': self.ir_gen, - 'orig_func': self.orig_func - }, - indent=2) - return f'{self.__class__.__name__}(\n{fmt}\n)' + fmt = pprint.pformat({"annot": self.func_annot.annots, "ir_gen": self.ir_gen, "orig_func": self.orig_func}, indent=2) + return f"{self.__class__.__name__}(\n{fmt}\n)" if TYPE_CHECKING: @@ -769,8 +745,7 @@ def macro(func: Callable[_P, _T] = None) -> Macro[_P, _T]: def impl(func: Callable[_P, _T]) -> Macro[_P, _T]: annotations = get_type_hints(func) - return Macro( - name=func.__name__, orig_func=func, ir_gen=mutate(func), annotations=annotations) + return Macro(name=func.__name__, orig_func=func, ir_gen=mutate(func), annotations=annotations) return impl(func) if func is not None else impl @@ -779,9 +754,9 @@ from typing import _eval_type def get_type_hints(func): - annot = getattr(func, '__annotations__', None) + annot = getattr(func, "__annotations__", None) if annot is None: - raise TypeError(f'Failed to get function type hints, {func} is not a function') + raise TypeError(f"Failed to get function type hints, {func} is not a function") hints = {} # Build eval namespaces from function globals plus captured closure variables # This lets annotations reference symbols like `n`, `h`, or dtype vars @@ -808,7 +783,7 @@ def get_type_hints(func): # ... # empty function, do not use `n` localns = utils.get_func_nonlocals(func) for name, value in annot.items(): - if name == 'return': + if name == "return": continue if isinstance(value, tvm.DataType): hints[name] = value @@ -821,7 +796,7 @@ def get_type_hints(func): # typing see: T.float32 is str('float32'), and there is no object named `flaot32` and give a NameError # here we manually interpret it to return T.float32 object try: - _, v = value.split('.', maxsplit=1) + _, v = value.split(".", maxsplit=1) except ValueError: v = value if v in dt._all_dtypes: @@ -837,9 +812,7 @@ def get_type_hints(func): return hints -def prim_func(func: Callable[_P, _T] = None, - *, - generator: bool = False) -> PrimFunc[_P, _T] | PrimFuncCreater[_P, _T]: +def prim_func(func: Callable[_P, _T] = None, *, generator: bool = False) -> PrimFunc[_P, _T] | PrimFuncCreater[_P, _T]: """ Decorator to create a primitive function (PrimFunc) for TileLang IR generation. This decorator transforms a Python function into a TileLang primitive function by analyzing @@ -903,7 +876,8 @@ def prim_func(func: Callable[_P, _T] = None, raise ValueError( f"Cannot create PrimFunc for `{func.__name__}`, some arguments are not compile-time known, \n" f"Annotations:\n{func_annot.annots}" - f"Unknown Args: {unknown_args}") + f"Unknown Args: {unknown_args}" + ) return prim_func_generator return impl(func) if func is not None else impl diff --git a/tilelang/language/v2/dtypes.py b/tilelang/language/v2/dtypes.py index 3b21587077b31e99720ec08dcdab6c9c75002c52..6ed56b48a0b006cd781ac819d7f843583f337ac2 100644 --- a/tilelang/language/v2/dtypes.py +++ b/tilelang/language/v2/dtypes.py @@ -6,14 +6,12 @@ from tvm import tir import tvm.script.ir_builder.tir._ffi_api as tb_ffi import numpy as np -_T = TypeVar('_T') +_T = TypeVar("_T") if TYPE_CHECKING: class dtype(Generic[_T]): - - def torch(self) -> torch.dtype: - ... + def torch(self) -> torch.dtype: ... else: dtype = tvm.DataType @@ -21,53 +19,53 @@ else: AnyDType = Union[ir.Type, str, type, torch.dtype, dtype] _PYTHON_DTYPE_TO_STR = { - bool: 'bool', - int: 'int32', - float: 'float32', + bool: "bool", + int: "int32", + float: "float32", } _NUMPY_DTYPE_TO_STR = { - np.bool_: 'bool', - np.short: 'int16', - np.int_: 'int64', - np.longlong: 'int64', - np.half: 'float16', - np.double: 'float64', - np.int8: 'int8', - np.int16: 'int16', - np.int32: 'int32', - np.int64: 'int64', - np.uint8: 'uint8', - np.uint16: 'uint16', - np.uint32: 'uint32', - np.uint64: 'uint64', - np.float16: 'float16', - np.float32: 'float32', - np.float64: 'float64', + np.bool_: "bool", + np.short: "int16", + np.int_: "int64", + np.longlong: "int64", + np.half: "float16", + np.double: "float64", + np.int8: "int8", + np.int16: "int16", + np.int32: "int32", + np.int64: "int64", + np.uint8: "uint8", + np.uint16: "uint16", + np.uint32: "uint32", + np.uint64: "uint64", + np.float16: "float16", + np.float32: "float32", + np.float64: "float64", } _NUMPY_DTYPE_TO_STR.update({np.dtype(k): v for k, v in _NUMPY_DTYPE_TO_STR.items()}) _TORCH_DTYPE_TO_STR = { - torch.bool: 'bool', - torch.short: 'int16', - torch.int: 'int32', - torch.long: 'int64', - torch.half: 'float16', - torch.float: 'float32', - torch.double: 'float64', - torch.int8: 'int8', - torch.int16: 'int16', - torch.int32: 'int32', - torch.int64: 'int64', - torch.uint8: 'uint8', - torch.uint16: 'uint16', - torch.uint32: 'uint32', - torch.uint64: 'uint64', - torch.float16: 'float16', - torch.float32: 'float32', - torch.float64: 'float64', - torch.bfloat16: 'bfloat16', + torch.bool: "bool", + torch.short: "int16", + torch.int: "int32", + torch.long: "int64", + torch.half: "float16", + torch.float: "float32", + torch.double: "float64", + torch.int8: "int8", + torch.int16: "int16", + torch.int32: "int32", + torch.int64: "int64", + torch.uint8: "uint8", + torch.uint16: "uint16", + torch.uint32: "uint32", + torch.uint64: "uint64", + torch.float16: "float16", + torch.float32: "float32", + torch.float64: "float64", + torch.bfloat16: "bfloat16", } # _STR_TO_TORCH_DTYPE = {v: k for k, v in _TORCH_DTYPE_TO_STR.items()} @@ -77,24 +75,24 @@ _TORCH_DTYPE_TO_STR = { _DTYPE_TO_STR = {**_PYTHON_DTYPE_TO_STR, **_NUMPY_DTYPE_TO_STR, **_TORCH_DTYPE_TO_STR} _STR_TO_TVM_DTYPE_CALL = { - 'bool': 'Boolean', - 'int8': 'Int8', - 'int32': 'Int32', - 'int64': 'Int64', - 'uint8': 'UInt8', - 'uint16': 'UInt16', - 'uint32': 'UInt32', - 'uint64': 'UInt64', - 'float16': 'Float16', - 'float32': 'Float32', - 'float64': 'Float64', - 'bfloat16': 'BFloat16', - 'float8_e4m3': 'Float8E4M3', - 'float8_e4m3fn': 'Float8E4M3FN', - 'float8_e4m3fnuz': 'Float8E4M3FNUZ', - 'float8_e5m2': 'Float8E5M2', - 'float8_e5m2fnuz': 'Float8E5M2FNUZ', - 'float8_e8m0fnu': 'Float8E8M0FNU' + "bool": "Boolean", + "int8": "Int8", + "int32": "Int32", + "int64": "Int64", + "uint8": "UInt8", + "uint16": "UInt16", + "uint32": "UInt32", + "uint64": "UInt64", + "float16": "Float16", + "float32": "Float32", + "float64": "Float64", + "bfloat16": "BFloat16", + "float8_e4m3": "Float8E4M3", + "float8_e4m3fn": "Float8E4M3FN", + "float8_e4m3fnuz": "Float8E4M3FNUZ", + "float8_e5m2": "Float8E5M2", + "float8_e5m2fnuz": "Float8E5M2FNUZ", + "float8_e8m0fnu": "Float8E8M0FNU", } int_ = int @@ -108,23 +106,24 @@ def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var call = getattr(tb_ffi, attr, None) return call(expr, is_size_var) # try to construct the ffi call - if self.startswith('uint'): - val = 'UInt' + self[4:] - elif self.startswith('int'): - val = 'Int' + self[3:] - elif self.startswith('float'): - val = 'Float' + self[5:] - elif self.startswith('bfloat'): - val = 'BFloat' + self[6:] + if self.startswith("uint"): + val = "UInt" + self[4:] + elif self.startswith("int"): + val = "Int" + self[3:] + elif self.startswith("float"): + val = "Float" + self[5:] + elif self.startswith("bfloat"): + val = "BFloat" + self[6:] else: - raise TypeError(f'Invalid type {self}') - if '_' in val: - first, second = val.split('_', maxsplit=1) + raise TypeError(f"Invalid type {self}") + if "_" in val: + first, second = val.split("_", maxsplit=1) val = first + second.upper() call = getattr(tb_ffi, val, None) if call is None: - raise TypeError(f"Convert to datatype `{self}` is not supported by tvm\n" - f"calling failed on `tvm.script.ir_builder.tir._ffi_api.{val}`") + raise TypeError( + f"Convert to datatype `{self}` is not supported by tvm\ncalling failed on `tvm.script.ir_builder.tir._ffi_api.{val}`" + ) return call(expr, is_size_var) @@ -152,7 +151,6 @@ def get_tvm_dtype(value: AnyDType) -> dtype: if TYPE_CHECKING: - # yapf: disable class bool(dtype): ... class short(dtype): ... @@ -319,336 +317,336 @@ if TYPE_CHECKING: # yapf: enable else: - bool = dtype('bool') - short = dtype('int16') - int = dtype('int32') - long = dtype('int64') - half = dtype('float16') - float = dtype('float32') - double = dtype('float64') - int8 = dtype('int8') - int16 = dtype('int16') - int32 = dtype('int32') - int64 = dtype('int64') - int8x2 = dtype('int8x2') - int16x2 = dtype('int16x2') - int32x2 = dtype('int32x2') - int64x2 = dtype('int64x2') - int8x4 = dtype('int8x4') - int16x4 = dtype('int16x4') - int32x4 = dtype('int32x4') - int64x4 = dtype('int64x4') - int8x8 = dtype('int8x8') - int16x8 = dtype('int16x8') - int32x8 = dtype('int32x8') - int64x8 = dtype('int64x8') - int8x16 = dtype('int8x16') - int16x16 = dtype('int16x16') - int32x16 = dtype('int32x16') - int64x16 = dtype('int64x16') - int8x32 = dtype('int8x32') - int16x32 = dtype('int16x32') - int32x32 = dtype('int32x32') - int64x32 = dtype('int64x32') - int8x64 = dtype('int8x64') - int16x64 = dtype('int16x64') - int32x64 = dtype('int32x64') - int64x64 = dtype('int64x64') - uint8 = dtype('uint8') - uint16 = dtype('uint16') - uint32 = dtype('uint32') - uint64 = dtype('uint64') - uint8x2 = dtype('uint8x2') - uint16x2 = dtype('uint16x2') - uint32x2 = dtype('uint32x2') - uint64x2 = dtype('uint64x2') - uint8x4 = dtype('uint8x4') - uint16x4 = dtype('uint16x4') - uint32x4 = dtype('uint32x4') - uint64x4 = dtype('uint64x4') - uint8x8 = dtype('uint8x8') - uint16x8 = dtype('uint16x8') - uint32x8 = dtype('uint32x8') - uint64x8 = dtype('uint64x8') - uint8x16 = dtype('uint8x16') - uint16x16 = dtype('uint16x16') - uint32x16 = dtype('uint32x16') - uint64x16 = dtype('uint64x16') - uint8x32 = dtype('uint8x32') - uint16x32 = dtype('uint16x32') - uint32x32 = dtype('uint32x32') - uint64x32 = dtype('uint64x32') - uint8x64 = dtype('uint8x64') - uint16x64 = dtype('uint16x64') - uint32x64 = dtype('uint32x64') - uint64x64 = dtype('uint64x64') - float16 = dtype('float16') - float32 = dtype('float32') - float64 = dtype('float64') - float16x2 = dtype('float16x2') - float32x2 = dtype('float32x2') - float64x2 = dtype('float64x2') - float16x4 = dtype('float16x4') - float32x4 = dtype('float32x4') - float64x4 = dtype('float64x4') - float16x8 = dtype('float16x8') - float32x8 = dtype('float32x8') - float64x8 = dtype('float64x8') - float16x16 = dtype('float16x16') - float32x16 = dtype('float32x16') - float64x16 = dtype('float64x16') - float16x32 = dtype('float16x32') - float32x32 = dtype('float32x32') - float64x32 = dtype('float64x32') - float16x64 = dtype('float16x64') - float32x64 = dtype('float32x64') - float64x64 = dtype('float64x64') - float8_e3m4 = dtype('float8_e3m4') - float8_e3m4x2 = dtype('float8_e3m4x2') - float8_e3m4x4 = dtype('float8_e3m4x4') - float8_e3m4x8 = dtype('float8_e3m4x8') - float8_e3m4x16 = dtype('float8_e3m4x16') - float8_e3m4x32 = dtype('float8_e3m4x32') - float8_e3m4x64 = dtype('float8_e3m4x64') - float8_e4m3 = dtype('float8_e4m3') - float8_e4m3x2 = dtype('float8_e4m3x2') - float8_e4m3x4 = dtype('float8_e4m3x4') - float8_e4m3x8 = dtype('float8_e4m3x8') - float8_e4m3x16 = dtype('float8_e4m3x16') - float8_e4m3x32 = dtype('float8_e4m3x32') - float8_e4m3x64 = dtype('float8_e4m3x64') - float8_e4m3b11fnuz = dtype('float8_e4m3b11fnuz') - float8_e4m3b11fnuzx2 = dtype('float8_e4m3b11fnuzx2') - float8_e4m3b11fnuzx4 = dtype('float8_e4m3b11fnuzx4') - float8_e4m3b11fnuzx8 = dtype('float8_e4m3b11fnuzx8') - float8_e4m3b11fnuzx16 = dtype('float8_e4m3b11fnuzx16') - float8_e4m3b11fnuzx32 = dtype('float8_e4m3b11fnuzx32') - float8_e4m3b11fnuzx64 = dtype('float8_e4m3b11fnuzx64') - float8_e4m3fn = dtype('float8_e4m3fn') - float8_e4m3fnx2 = dtype('float8_e4m3fnx2') - float8_e4m3fnx4 = dtype('float8_e4m3fnx4') - float8_e4m3fnx8 = dtype('float8_e4m3fnx8') - float8_e4m3fnx16 = dtype('float8_e4m3fnx16') - float8_e4m3fnx32 = dtype('float8_e4m3fnx32') - float8_e4m3fnx64 = dtype('float8_e4m3fnx64') - float8_e4m3fnuz = dtype('float8_e4m3fnuz') - float8_e4m3fnuzx2 = dtype('float8_e4m3fnuzx2') - float8_e4m3fnuzx4 = dtype('float8_e4m3fnuzx4') - float8_e4m3fnuzx8 = dtype('float8_e4m3fnuzx8') - float8_e4m3fnuzx16 = dtype('float8_e4m3fnuzx16') - float8_e4m3fnuzx32 = dtype('float8_e4m3fnuzx32') - float8_e4m3fnuzx64 = dtype('float8_e4m3fnuzx64') - float8_e5m2 = dtype('float8_e5m2') - float8_e5m2x2 = dtype('float8_e5m2x2') - float8_e5m2x4 = dtype('float8_e5m2x4') - float8_e5m2x8 = dtype('float8_e5m2x8') - float8_e5m2x16 = dtype('float8_e5m2x16') - float8_e5m2x32 = dtype('float8_e5m2x32') - float8_e5m2x64 = dtype('float8_e5m2x64') - float8_e5m2fnuz = dtype('float8_e5m2fnuz') - float8_e5m2fnuzx2 = dtype('float8_e5m2fnuzx2') - float8_e5m2fnuzx4 = dtype('float8_e5m2fnuzx4') - float8_e5m2fnuzx8 = dtype('float8_e5m2fnuzx8') - float8_e5m2fnuzx16 = dtype('float8_e5m2fnuzx16') - float8_e5m2fnuzx32 = dtype('float8_e5m2fnuzx32') - float8_e5m2fnuzx64 = dtype('float8_e5m2fnuzx64') - float8_e8m0fnu = dtype('float8_e8m0fnu') - float8_e8m0fnux2 = dtype('float8_e8m0fnux2') - float8_e8m0fnux4 = dtype('float8_e8m0fnux4') - float8_e8m0fnux8 = dtype('float8_e8m0fnux8') - float8_e8m0fnux16 = dtype('float8_e8m0fnux16') - float8_e8m0fnux32 = dtype('float8_e8m0fnux32') - float8_e8m0fnux64 = dtype('float8_e8m0fnux64') - float6_e2m3fn = dtype('float6_e2m3fn') - float6_e2m3fnx2 = dtype('float6_e2m3fnx2') - float6_e2m3fnx4 = dtype('float6_e2m3fnx4') - float6_e2m3fnx8 = dtype('float6_e2m3fnx8') - float6_e2m3fnx16 = dtype('float6_e2m3fnx16') - float6_e2m3fnx32 = dtype('float6_e2m3fnx32') - float6_e2m3fnx64 = dtype('float6_e2m3fnx64') - float6_e3m2fn = dtype('float6_e3m2fn') - float6_e3m2fnx2 = dtype('float6_e3m2fnx2') - float6_e3m2fnx4 = dtype('float6_e3m2fnx4') - float6_e3m2fnx8 = dtype('float6_e3m2fnx8') - float6_e3m2fnx16 = dtype('float6_e3m2fnx16') - float6_e3m2fnx32 = dtype('float6_e3m2fnx32') - float6_e3m2fnx64 = dtype('float6_e3m2fnx64') - float4_e2m1fn = dtype('float4_e2m1fn') - float4_e2m1fnx2 = dtype('float4_e2m1fnx2') - float4_e2m1fnx4 = dtype('float4_e2m1fnx4') - float4_e2m1fnx8 = dtype('float4_e2m1fnx8') - float4_e2m1fnx16 = dtype('float4_e2m1fnx16') - float4_e2m1fnx32 = dtype('float4_e2m1fnx32') - float4_e2m1fnx64 = dtype('float4_e2m1fnx64') - bfloat16 = dtype('bfloat16') + bool = dtype("bool") + short = dtype("int16") + int = dtype("int32") + long = dtype("int64") + half = dtype("float16") + float = dtype("float32") + double = dtype("float64") + int8 = dtype("int8") + int16 = dtype("int16") + int32 = dtype("int32") + int64 = dtype("int64") + int8x2 = dtype("int8x2") + int16x2 = dtype("int16x2") + int32x2 = dtype("int32x2") + int64x2 = dtype("int64x2") + int8x4 = dtype("int8x4") + int16x4 = dtype("int16x4") + int32x4 = dtype("int32x4") + int64x4 = dtype("int64x4") + int8x8 = dtype("int8x8") + int16x8 = dtype("int16x8") + int32x8 = dtype("int32x8") + int64x8 = dtype("int64x8") + int8x16 = dtype("int8x16") + int16x16 = dtype("int16x16") + int32x16 = dtype("int32x16") + int64x16 = dtype("int64x16") + int8x32 = dtype("int8x32") + int16x32 = dtype("int16x32") + int32x32 = dtype("int32x32") + int64x32 = dtype("int64x32") + int8x64 = dtype("int8x64") + int16x64 = dtype("int16x64") + int32x64 = dtype("int32x64") + int64x64 = dtype("int64x64") + uint8 = dtype("uint8") + uint16 = dtype("uint16") + uint32 = dtype("uint32") + uint64 = dtype("uint64") + uint8x2 = dtype("uint8x2") + uint16x2 = dtype("uint16x2") + uint32x2 = dtype("uint32x2") + uint64x2 = dtype("uint64x2") + uint8x4 = dtype("uint8x4") + uint16x4 = dtype("uint16x4") + uint32x4 = dtype("uint32x4") + uint64x4 = dtype("uint64x4") + uint8x8 = dtype("uint8x8") + uint16x8 = dtype("uint16x8") + uint32x8 = dtype("uint32x8") + uint64x8 = dtype("uint64x8") + uint8x16 = dtype("uint8x16") + uint16x16 = dtype("uint16x16") + uint32x16 = dtype("uint32x16") + uint64x16 = dtype("uint64x16") + uint8x32 = dtype("uint8x32") + uint16x32 = dtype("uint16x32") + uint32x32 = dtype("uint32x32") + uint64x32 = dtype("uint64x32") + uint8x64 = dtype("uint8x64") + uint16x64 = dtype("uint16x64") + uint32x64 = dtype("uint32x64") + uint64x64 = dtype("uint64x64") + float16 = dtype("float16") + float32 = dtype("float32") + float64 = dtype("float64") + float16x2 = dtype("float16x2") + float32x2 = dtype("float32x2") + float64x2 = dtype("float64x2") + float16x4 = dtype("float16x4") + float32x4 = dtype("float32x4") + float64x4 = dtype("float64x4") + float16x8 = dtype("float16x8") + float32x8 = dtype("float32x8") + float64x8 = dtype("float64x8") + float16x16 = dtype("float16x16") + float32x16 = dtype("float32x16") + float64x16 = dtype("float64x16") + float16x32 = dtype("float16x32") + float32x32 = dtype("float32x32") + float64x32 = dtype("float64x32") + float16x64 = dtype("float16x64") + float32x64 = dtype("float32x64") + float64x64 = dtype("float64x64") + float8_e3m4 = dtype("float8_e3m4") + float8_e3m4x2 = dtype("float8_e3m4x2") + float8_e3m4x4 = dtype("float8_e3m4x4") + float8_e3m4x8 = dtype("float8_e3m4x8") + float8_e3m4x16 = dtype("float8_e3m4x16") + float8_e3m4x32 = dtype("float8_e3m4x32") + float8_e3m4x64 = dtype("float8_e3m4x64") + float8_e4m3 = dtype("float8_e4m3") + float8_e4m3x2 = dtype("float8_e4m3x2") + float8_e4m3x4 = dtype("float8_e4m3x4") + float8_e4m3x8 = dtype("float8_e4m3x8") + float8_e4m3x16 = dtype("float8_e4m3x16") + float8_e4m3x32 = dtype("float8_e4m3x32") + float8_e4m3x64 = dtype("float8_e4m3x64") + float8_e4m3b11fnuz = dtype("float8_e4m3b11fnuz") + float8_e4m3b11fnuzx2 = dtype("float8_e4m3b11fnuzx2") + float8_e4m3b11fnuzx4 = dtype("float8_e4m3b11fnuzx4") + float8_e4m3b11fnuzx8 = dtype("float8_e4m3b11fnuzx8") + float8_e4m3b11fnuzx16 = dtype("float8_e4m3b11fnuzx16") + float8_e4m3b11fnuzx32 = dtype("float8_e4m3b11fnuzx32") + float8_e4m3b11fnuzx64 = dtype("float8_e4m3b11fnuzx64") + float8_e4m3fn = dtype("float8_e4m3fn") + float8_e4m3fnx2 = dtype("float8_e4m3fnx2") + float8_e4m3fnx4 = dtype("float8_e4m3fnx4") + float8_e4m3fnx8 = dtype("float8_e4m3fnx8") + float8_e4m3fnx16 = dtype("float8_e4m3fnx16") + float8_e4m3fnx32 = dtype("float8_e4m3fnx32") + float8_e4m3fnx64 = dtype("float8_e4m3fnx64") + float8_e4m3fnuz = dtype("float8_e4m3fnuz") + float8_e4m3fnuzx2 = dtype("float8_e4m3fnuzx2") + float8_e4m3fnuzx4 = dtype("float8_e4m3fnuzx4") + float8_e4m3fnuzx8 = dtype("float8_e4m3fnuzx8") + float8_e4m3fnuzx16 = dtype("float8_e4m3fnuzx16") + float8_e4m3fnuzx32 = dtype("float8_e4m3fnuzx32") + float8_e4m3fnuzx64 = dtype("float8_e4m3fnuzx64") + float8_e5m2 = dtype("float8_e5m2") + float8_e5m2x2 = dtype("float8_e5m2x2") + float8_e5m2x4 = dtype("float8_e5m2x4") + float8_e5m2x8 = dtype("float8_e5m2x8") + float8_e5m2x16 = dtype("float8_e5m2x16") + float8_e5m2x32 = dtype("float8_e5m2x32") + float8_e5m2x64 = dtype("float8_e5m2x64") + float8_e5m2fnuz = dtype("float8_e5m2fnuz") + float8_e5m2fnuzx2 = dtype("float8_e5m2fnuzx2") + float8_e5m2fnuzx4 = dtype("float8_e5m2fnuzx4") + float8_e5m2fnuzx8 = dtype("float8_e5m2fnuzx8") + float8_e5m2fnuzx16 = dtype("float8_e5m2fnuzx16") + float8_e5m2fnuzx32 = dtype("float8_e5m2fnuzx32") + float8_e5m2fnuzx64 = dtype("float8_e5m2fnuzx64") + float8_e8m0fnu = dtype("float8_e8m0fnu") + float8_e8m0fnux2 = dtype("float8_e8m0fnux2") + float8_e8m0fnux4 = dtype("float8_e8m0fnux4") + float8_e8m0fnux8 = dtype("float8_e8m0fnux8") + float8_e8m0fnux16 = dtype("float8_e8m0fnux16") + float8_e8m0fnux32 = dtype("float8_e8m0fnux32") + float8_e8m0fnux64 = dtype("float8_e8m0fnux64") + float6_e2m3fn = dtype("float6_e2m3fn") + float6_e2m3fnx2 = dtype("float6_e2m3fnx2") + float6_e2m3fnx4 = dtype("float6_e2m3fnx4") + float6_e2m3fnx8 = dtype("float6_e2m3fnx8") + float6_e2m3fnx16 = dtype("float6_e2m3fnx16") + float6_e2m3fnx32 = dtype("float6_e2m3fnx32") + float6_e2m3fnx64 = dtype("float6_e2m3fnx64") + float6_e3m2fn = dtype("float6_e3m2fn") + float6_e3m2fnx2 = dtype("float6_e3m2fnx2") + float6_e3m2fnx4 = dtype("float6_e3m2fnx4") + float6_e3m2fnx8 = dtype("float6_e3m2fnx8") + float6_e3m2fnx16 = dtype("float6_e3m2fnx16") + float6_e3m2fnx32 = dtype("float6_e3m2fnx32") + float6_e3m2fnx64 = dtype("float6_e3m2fnx64") + float4_e2m1fn = dtype("float4_e2m1fn") + float4_e2m1fnx2 = dtype("float4_e2m1fnx2") + float4_e2m1fnx4 = dtype("float4_e2m1fnx4") + float4_e2m1fnx8 = dtype("float4_e2m1fnx8") + float4_e2m1fnx16 = dtype("float4_e2m1fnx16") + float4_e2m1fnx32 = dtype("float4_e2m1fnx32") + float4_e2m1fnx64 = dtype("float4_e2m1fnx64") + bfloat16 = dtype("bfloat16") _all_dtypes = { - 'bool', - 'short', - 'int', - 'long', - 'half', - 'float', - 'double', - 'int8', - 'int16', - 'int32', - 'int64', - 'int8x2', - 'int16x2', - 'int32x2', - 'int64x2', - 'int8x4', - 'int16x4', - 'int32x4', - 'int64x4', - 'int8x8', - 'int16x8', - 'int32x8', - 'int64x8', - 'int8x16', - 'int16x16', - 'int32x16', - 'int64x16', - 'int8x32', - 'int16x32', - 'int32x32', - 'int64x32', - 'int8x64', - 'int16x64', - 'int32x64', - 'int64x64', - 'uint8', - 'uint16', - 'uint32', - 'uint64', - 'uint8x2', - 'uint16x2', - 'uint32x2', - 'uint64x2', - 'uint8x4', - 'uint16x4', - 'uint32x4', - 'uint64x4', - 'uint8x8', - 'uint16x8', - 'uint32x8', - 'uint64x8', - 'uint8x16', - 'uint16x16', - 'uint32x16', - 'uint64x16', - 'uint8x32', - 'uint16x32', - 'uint32x32', - 'uint64x32', - 'uint8x64', - 'uint16x64', - 'uint32x64', - 'uint64x64', - 'float16', - 'float32', - 'float64', - 'float16x2', - 'float32x2', - 'float64x2', - 'float16x4', - 'float32x4', - 'float64x4', - 'float16x8', - 'float32x8', - 'float64x8', - 'float16x16', - 'float32x16', - 'float64x16', - 'float16x32', - 'float32x32', - 'float64x32', - 'float16x64', - 'float32x64', - 'float64x64', - 'float8_e3m4', - 'float8_e3m4x2', - 'float8_e3m4x4', - 'float8_e3m4x8', - 'float8_e3m4x16', - 'float8_e3m4x32', - 'float8_e3m4x64', - 'float8_e4m3', - 'float8_e4m3x2', - 'float8_e4m3x4', - 'float8_e4m3x8', - 'float8_e4m3x16', - 'float8_e4m3x32', - 'float8_e4m3x64', - 'float8_e4m3b11fnuz', - 'float8_e4m3b11fnuzx2', - 'float8_e4m3b11fnuzx4', - 'float8_e4m3b11fnuzx8', - 'float8_e4m3b11fnuzx16', - 'float8_e4m3b11fnuzx32', - 'float8_e4m3b11fnuzx64', - 'float8_e4m3fn', - 'float8_e4m3fnx2', - 'float8_e4m3fnx4', - 'float8_e4m3fnx8', - 'float8_e4m3fnx16', - 'float8_e4m3fnx32', - 'float8_e4m3fnx64', - 'float8_e4m3fnuz', - 'float8_e4m3fnuzx2', - 'float8_e4m3fnuzx4', - 'float8_e4m3fnuzx8', - 'float8_e4m3fnuzx16', - 'float8_e4m3fnuzx32', - 'float8_e4m3fnuzx64', - 'float8_e5m2', - 'float8_e5m2x2', - 'float8_e5m2x4', - 'float8_e5m2x8', - 'float8_e5m2x16', - 'float8_e5m2x32', - 'float8_e5m2x64', - 'float8_e5m2fnuz', - 'float8_e5m2fnuzx2', - 'float8_e5m2fnuzx4', - 'float8_e5m2fnuzx8', - 'float8_e5m2fnuzx16', - 'float8_e5m2fnuzx32', - 'float8_e5m2fnuzx64', - 'float8_e8m0fnu', - 'float8_e8m0fnux2', - 'float8_e8m0fnux4', - 'float8_e8m0fnux8', - 'float8_e8m0fnux16', - 'float8_e8m0fnux32', - 'float8_e8m0fnux64', - 'float6_e2m3fn', - 'float6_e2m3fnx2', - 'float6_e2m3fnx4', - 'float6_e2m3fnx8', - 'float6_e2m3fnx16', - 'float6_e2m3fnx32', - 'float6_e2m3fnx64', - 'float6_e3m2fn', - 'float6_e3m2fnx2', - 'float6_e3m2fnx4', - 'float6_e3m2fnx8', - 'float6_e3m2fnx16', - 'float6_e3m2fnx32', - 'float6_e3m2fnx64', - 'float4_e2m1fn', - 'float4_e2m1fnx2', - 'float4_e2m1fnx4', - 'float4_e2m1fnx8', - 'float4_e2m1fnx16', - 'float4_e2m1fnx32', - 'float4_e2m1fnx64', - 'bfloat16', + "bool", + "short", + "int", + "long", + "half", + "float", + "double", + "int8", + "int16", + "int32", + "int64", + "int8x2", + "int16x2", + "int32x2", + "int64x2", + "int8x4", + "int16x4", + "int32x4", + "int64x4", + "int8x8", + "int16x8", + "int32x8", + "int64x8", + "int8x16", + "int16x16", + "int32x16", + "int64x16", + "int8x32", + "int16x32", + "int32x32", + "int64x32", + "int8x64", + "int16x64", + "int32x64", + "int64x64", + "uint8", + "uint16", + "uint32", + "uint64", + "uint8x2", + "uint16x2", + "uint32x2", + "uint64x2", + "uint8x4", + "uint16x4", + "uint32x4", + "uint64x4", + "uint8x8", + "uint16x8", + "uint32x8", + "uint64x8", + "uint8x16", + "uint16x16", + "uint32x16", + "uint64x16", + "uint8x32", + "uint16x32", + "uint32x32", + "uint64x32", + "uint8x64", + "uint16x64", + "uint32x64", + "uint64x64", + "float16", + "float32", + "float64", + "float16x2", + "float32x2", + "float64x2", + "float16x4", + "float32x4", + "float64x4", + "float16x8", + "float32x8", + "float64x8", + "float16x16", + "float32x16", + "float64x16", + "float16x32", + "float32x32", + "float64x32", + "float16x64", + "float32x64", + "float64x64", + "float8_e3m4", + "float8_e3m4x2", + "float8_e3m4x4", + "float8_e3m4x8", + "float8_e3m4x16", + "float8_e3m4x32", + "float8_e3m4x64", + "float8_e4m3", + "float8_e4m3x2", + "float8_e4m3x4", + "float8_e4m3x8", + "float8_e4m3x16", + "float8_e4m3x32", + "float8_e4m3x64", + "float8_e4m3b11fnuz", + "float8_e4m3b11fnuzx2", + "float8_e4m3b11fnuzx4", + "float8_e4m3b11fnuzx8", + "float8_e4m3b11fnuzx16", + "float8_e4m3b11fnuzx32", + "float8_e4m3b11fnuzx64", + "float8_e4m3fn", + "float8_e4m3fnx2", + "float8_e4m3fnx4", + "float8_e4m3fnx8", + "float8_e4m3fnx16", + "float8_e4m3fnx32", + "float8_e4m3fnx64", + "float8_e4m3fnuz", + "float8_e4m3fnuzx2", + "float8_e4m3fnuzx4", + "float8_e4m3fnuzx8", + "float8_e4m3fnuzx16", + "float8_e4m3fnuzx32", + "float8_e4m3fnuzx64", + "float8_e5m2", + "float8_e5m2x2", + "float8_e5m2x4", + "float8_e5m2x8", + "float8_e5m2x16", + "float8_e5m2x32", + "float8_e5m2x64", + "float8_e5m2fnuz", + "float8_e5m2fnuzx2", + "float8_e5m2fnuzx4", + "float8_e5m2fnuzx8", + "float8_e5m2fnuzx16", + "float8_e5m2fnuzx32", + "float8_e5m2fnuzx64", + "float8_e8m0fnu", + "float8_e8m0fnux2", + "float8_e8m0fnux4", + "float8_e8m0fnux8", + "float8_e8m0fnux16", + "float8_e8m0fnux32", + "float8_e8m0fnux64", + "float6_e2m3fn", + "float6_e2m3fnx2", + "float6_e2m3fnx4", + "float6_e2m3fnx8", + "float6_e2m3fnx16", + "float6_e2m3fnx32", + "float6_e2m3fnx64", + "float6_e3m2fn", + "float6_e3m2fnx2", + "float6_e3m2fnx4", + "float6_e3m2fnx8", + "float6_e3m2fnx16", + "float6_e3m2fnx32", + "float6_e3m2fnx64", + "float4_e2m1fn", + "float4_e2m1fnx2", + "float4_e2m1fnx4", + "float4_e2m1fnx8", + "float4_e2m1fnx16", + "float4_e2m1fnx32", + "float4_e2m1fnx64", + "bfloat16", } __all__ = list(_all_dtypes) + [ - 'dtype', - 'AnyDType', - 'get_tvm_dtype', + "dtype", + "AnyDType", + "get_tvm_dtype", ] diff --git a/tilelang/language/v2/utils.py b/tilelang/language/v2/utils.py index 022402dfac551f85bba18659f8ccb7d68680502b..207bd92ad3290fdad79a3c77ece32c9698163751 100644 --- a/tilelang/language/v2/utils.py +++ b/tilelang/language/v2/utils.py @@ -12,11 +12,12 @@ def disk_compile(source, name): cache_dir = env.TILELANG_CACHE_DIR if cache_dir is not None: import os + save_dir = os.path.join(cache_dir, "py-cache") os.makedirs(save_dir, exist_ok=True) - hash_sfx = sha256(source.encode('utf-8')).hexdigest()[:8] + hash_sfx = sha256(source.encode("utf-8")).hexdigest()[:8] path = os.path.join(save_dir, f"{name}.{hash_sfx}.py") - with open(path, 'w') as f: + with open(path, "w") as f: f.write(source) linecache.cache[path] = (len(source), None, source.splitlines(), path) return compile(source, path, "exec") @@ -59,29 +60,26 @@ def get_ast(func: Callable): filename = inspect.getsourcefile(func) or inspect.getfile(func) source = inspect.getsource(func) source = _remove_leading_ident(source) - source = '\n' * (start - 1) + source + source = "\n" * (start - 1) + source tree = ast.parse(source, filename=filename) return tree -CompileMethod = Literal['direct', 'disk'] +CompileMethod = Literal["direct", "disk"] -def get_compiled_object(source: str | ast.AST, - name: str, - filename: str = None, - globals: dict[str, Any] = None): +def get_compiled_object(source: str | ast.AST, name: str, filename: str = None, globals: dict[str, Any] = None): if isinstance(source, ast.AST): assert filename is not None, "filename must be provided when source is an AST" try: if isinstance(source, ast.AST): ast.fix_missing_locations(source) - compiled = compile(source, filename, 'exec') + compiled = compile(source, filename, "exec") else: compiled = disk_compile(source, name) except Exception as e: source_str = source if isinstance(source, str) else ast.unparse(source) - raise RuntimeError(f'Failed to compile source for {name}, Error: {e}:\n{source_str}') from e + raise RuntimeError(f"Failed to compile source for {name}, Error: {e}:\n{source_str}") from e locs = {} exec(compiled, globals, locs) return locs[name] @@ -95,7 +93,6 @@ def construct_strides(shape: tuple[Any, ...], allow_prim_expr: bool = True) -> t strides.append(stride) stride *= s if not allow_prim_expr and isinstance(stride, tir.PrimExpr): - raise ValueError( - "Cannot construct strides with PrimExpr when allow_prim_expr is False.") + raise ValueError("Cannot construct strides with PrimExpr when allow_prim_expr is False.") strides = tuple(reversed(strides)) return strides diff --git a/tilelang/language/warpgroup.py b/tilelang/language/warpgroup.py index bec768094545490e6e64a1dad3f5b8e9b98ecfb8..77cf6924583262aa4d34b708f1ade38a64347477 100644 --- a/tilelang/language/warpgroup.py +++ b/tilelang/language/warpgroup.py @@ -1,4 +1,5 @@ """The language interface for tl programs.""" + from tvm.script.ir_builder.tir.frame import TIRFrame from tvm.ffi import register_object from tilelang import _ffi_api diff --git a/tilelang/layout/fragment.py b/tilelang/layout/fragment.py index ff45f6d5a877614cf03581ebc6376f28bd7d15ad..256a7d5ee169c1545c38ad192131ed115810a2a8 100644 --- a/tilelang/layout/fragment.py +++ b/tilelang/layout/fragment.py @@ -1,4 +1,5 @@ """Wrapping Layouts.""" + # pylint: disable=invalid-name, unsupported-binary-operation import tvm import tvm_ffi @@ -20,12 +21,7 @@ class Fragment(Layout): # Disable the linter warning about not calling super().__init__() # because this object is created via TVM's FFI constructor mechanism. # pylint: disable=super-init-not-called - def __init__(self, - shape, - forward_fn=None, - forward_thread_fn=None, - replicate=1, - forward_index_fn=None): + def __init__(self, shape, forward_fn=None, forward_thread_fn=None, replicate=1, forward_index_fn=None): """ Initialize the Fragment with iteration variables and optional thread replication. @@ -119,10 +115,7 @@ class Fragment(Layout): """ return _ffi_api.Fragment_thread_size(self) - def repeat(self, - repeats, - repeat_on_thread: bool = False, - lower_dim_first: bool = True) -> 'Fragment': + def repeat(self, repeats, repeat_on_thread: bool = False, lower_dim_first: bool = True) -> "Fragment": """ Returns a new Fragment that repeats the iteration space a given number of times. @@ -142,7 +135,7 @@ class Fragment(Layout): """ return _ffi_api.Fragment_repeat(self, repeats, repeat_on_thread, lower_dim_first) - def replicate(self, replicate: int) -> 'Fragment': + def replicate(self, replicate: int) -> "Fragment": """ Replicate the Fragment across a new thread dimension. @@ -158,7 +151,7 @@ class Fragment(Layout): """ return _ffi_api.Fragment_replicate(self, replicate) - def condense_rep_var(self) -> 'Fragment': + def condense_rep_var(self) -> "Fragment": """ Condense or fold the replicate variable into the existing iteration space. This operation may be used to reduce dimensionality if the replicate variable @@ -190,8 +183,7 @@ class Fragment(Layout): # The thread dimension (IterVar) is accessed via the `thread` property forward_thread = self.thread # Construct an IndexMap to map the provided args into the final thread index - index_map = IndexMap( - initial_indices=forward_vars, final_indices=[forward_thread], inverse_index_map=None) + index_map = IndexMap(initial_indices=forward_vars, final_indices=[forward_thread], inverse_index_map=None) return index_map.map_indices(indices) def __repr__(self): @@ -206,7 +198,7 @@ class Fragment(Layout): return self._DebugOutput() # return f"Fragment<{self.get_input_shape()}->{self.get_output_shape()}, thread={self.thread}, index={self.index}>" - def is_equal(self, other: 'Fragment') -> bool: + def is_equal(self, other: "Fragment") -> bool: """ Check if the current fragment is equal to another fragment. """ diff --git a/tilelang/layout/gemm_sp.py b/tilelang/layout/gemm_sp.py index e5d190292888e0818e94bc21ac09dbf3adde20b9..e68c116746615f9fba52a87e6b9ba98e1cee6820 100644 --- a/tilelang/layout/gemm_sp.py +++ b/tilelang/layout/gemm_sp.py @@ -1,4 +1,5 @@ """Wrapping Layouts.""" + # pylint: disable=invalid-name, unsupported-binary-operation from __future__ import annotations import tvm @@ -114,8 +115,7 @@ def make_cutlass_metadata_layout_sm8x(buffer: tvm.tir.Buffer, mma_dtype: str): if mma_dtype in ["float16", "bfloat16"] and buffer.dtype not in ["uint16", "int16"]: raise ValueError(f"metadata should be 16 bit, got {buffer.dtype}") - if mma_dtype in ["float8_e4m3", "float8_e5m2", "int8", "uint8" - ] and buffer.dtype not in ["uint32", "int32"]: + if mma_dtype in ["float8_e4m3", "float8_e5m2", "int8", "uint8"] and buffer.dtype not in ["uint32", "int32"]: raise ValueError(f"metadata should be 32 bit, got {buffer.dtype}") m, k = buffer.shape @@ -134,10 +134,7 @@ def make_cutlass_metadata_layout_sm8x(buffer: tvm.tir.Buffer, mma_dtype: str): return T.Layout(buffer.shape, ColumnMajorInterleaved) -def make_cutlass_metadata_layout(buffer: tvm.tir.Buffer, - mma_dtype: str = "float16", - arch: str | None = None, - **extra_args): +def make_cutlass_metadata_layout(buffer: tvm.tir.Buffer, mma_dtype: str = "float16", arch: str | None = None, **extra_args): if arch is None: arch = nvcc.get_target_compute_version() diff --git a/tilelang/layout/layout.py b/tilelang/layout/layout.py index 87d2ee44bacfc9b5a1f7a45d87d8f38f8f2afa20..fbd39e8de748199a1d2de77d6fa6ae55f96b4438 100644 --- a/tilelang/layout/layout.py +++ b/tilelang/layout/layout.py @@ -1,4 +1,5 @@ """Wrapping Layouts.""" + # pylint: disable=invalid-name, unsupported-binary-operation import tvm_ffi from tvm.ir import Node, Range @@ -9,7 +10,6 @@ from tilelang import _ffi_api # Register the Layout class as a TVM object under the name "tl.Layout" @tvm_ffi.register_object("tl.Layout") class Layout(Node): - def __init__(self, shape, forward_fn): """ Initialize a Layout object. @@ -114,13 +114,13 @@ class Layout(Node): index_map = IndexMap( initial_indices=forward_vars, # The original iteration variables final_indices=forward_indexes, # The computed forward indices - inverse_index_map=None # No inverse mapping provided at this stage + inverse_index_map=None, # No inverse mapping provided at this stage ) # Map the provided indices using the constructed index mapping return index_map.map_indices(indices) - def inverse(self) -> 'Layout': + def inverse(self) -> "Layout": """ Compute the inverse of the current layout transformation. @@ -131,7 +131,7 @@ class Layout(Node): """ return _ffi_api.Layout_inverse(self) - def is_equal(self, other: 'Layout') -> bool: + def is_equal(self, other: "Layout") -> bool: """ Check if the current layout is equal to another layout. diff --git a/tilelang/layout/swizzle.py b/tilelang/layout/swizzle.py index 3a219c67c7569c74900d65ba3e7eb8b5e521bf43..e083d756db52aa130c78a1a31959b1914b3c71c4 100644 --- a/tilelang/layout/swizzle.py +++ b/tilelang/layout/swizzle.py @@ -1,4 +1,5 @@ """Wrapping Layouts.""" + # pylint: disable=invalid-name, unsupported-binary-operation from __future__ import annotations @@ -7,9 +8,7 @@ from tvm.tir import Buffer, BufferLoad, BufferRegion from tilelang import _ffi_api -def _get_buffer_info( - buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion -) -> tuple[Buffer, list[int], str]: +def _get_buffer_info(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> tuple[Buffer, list[int], str]: """ Extract buffer, shape, and dtype from Buffer, BufferLoad, or BufferRegion. @@ -25,12 +24,10 @@ def _get_buffer_info( buf = buffer_or_load_or_region.buffer return buf, buf.shape, buf.dtype else: - raise TypeError( - f"Expected Buffer, BufferLoad, or BufferRegion, got {type(buffer_or_load_or_region)}") + raise TypeError(f"Expected Buffer, BufferLoad, or BufferRegion, got {type(buffer_or_load_or_region)}") -def _get_stride_continuous( - buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> tuple[int, int]: +def _get_stride_continuous(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> tuple[int, int]: """ Get stride (last 2nd dimension) and continuous (last dimension) from Buffer, BufferLoad, or BufferRegion. @@ -62,9 +59,7 @@ def _get_element_size(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegi # Use a stable swizzled layout to ensure consistent memory access patterns. # Swizzling should be enabled or disabled based on whether TMA (Tensor Memory Access) is applied. -def make_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, - k_major: bool = True, - allow_pad: bool = True): +def make_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, k_major: bool = True, allow_pad: bool = True): stride, continuous = _get_stride_continuous(buffer) element_size = _get_element_size(buffer) return _ffi_api.make_swizzled_layout( @@ -77,9 +72,7 @@ def make_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, # for Volta Intrinsics -def make_volta_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, - is_a: bool = True, - k_inner: bool = True): +def make_volta_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, is_a: bool = True, k_inner: bool = True): stride, continuous = _get_stride_continuous(buffer) return _ffi_api.make_volta_swizzled_layout( stride, @@ -90,9 +83,7 @@ def make_volta_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, # for WGMMA Intrinsics -def make_wgmma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, - continuity: int = None, - k_major: bool = True): +def make_wgmma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, continuity: int = None, k_major: bool = True): stride, continuous = _get_stride_continuous(buffer) element_size = _get_element_size(buffer) if continuity is None: @@ -107,9 +98,7 @@ def make_wgmma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, # for TCGEN05MMA Intrinsics -def make_tcgen05mma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, - continuity: int = None, - k_major: bool = True): +def make_tcgen05mma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, continuity: int = None, k_major: bool = True): stride, continuous = _get_stride_continuous(buffer) element_size = _get_element_size(buffer) if continuity is None: diff --git a/tilelang/libinfo.py b/tilelang/libinfo.py index 5af8c84f45fde2faf4f84276197adeb99c1dd520..d82986b7534edc0d4cd0a92864e6da7d85818947 100644 --- a/tilelang/libinfo.py +++ b/tilelang/libinfo.py @@ -31,6 +31,5 @@ def find_lib_path(name: str, py_ext=False): if os.path.exists(lib_dll_path) and os.path.isfile(lib_dll_path): return lib_dll_path else: - message = (f"Cannot find libraries: {lib_name}\n" + "List of candidates:\n" + - "\n".join(TL_LIBS)) + message = f"Cannot find libraries: {lib_name}\n" + "List of candidates:\n" + "\n".join(TL_LIBS) raise RuntimeError(message) diff --git a/tilelang/primitives/__init__.py b/tilelang/primitives/__init__.py index 8eccc3e5ce55ff02090ade6c6755f782d13e393f..9d2a739a785b6e0e79f2afe17f2cd80814eef27c 100644 --- a/tilelang/primitives/__init__.py +++ b/tilelang/primitives/__init__.py @@ -1,3 +1,3 @@ -""" bootstrap the primitives module via tile language """ +"""bootstrap the primitives module via tile language""" from .gemm import gemm # noqa: F401 diff --git a/tilelang/primitives/gemm/__init__.py b/tilelang/primitives/gemm/__init__.py index 248437405c2ad265ab78ee72aabc93bbfbc3ed0f..7664a7b50b0423c36963579d54783e16b9ec978c 100644 --- a/tilelang/primitives/gemm/__init__.py +++ b/tilelang/primitives/gemm/__init__.py @@ -3,7 +3,8 @@ from tvm import tir from tilelang.utils import is_local, is_fragment, is_shared from tilelang.primitives.gemm.base import GemmWarpPolicy from tilelang.primitives.gemm.gemm_mma import ( - GemmPrimitiveMMA,) + GemmPrimitiveMMA, +) def gemm( @@ -20,12 +21,9 @@ def gemm( policy: GemmWarpPolicy = GemmWarpPolicy.Square, k_pack: int = 1, ): - assert is_local(A) or is_fragment(A) or is_shared(A), ( - f"Expected A to be a local, fragment, or shared buffer, but got {A.scope()}") - assert is_local(B) or is_fragment(B) or is_shared(B), ( - f"Expected B to be a local, fragment, or shared buffer, but got {B.scope()}") - assert is_local(C) or is_fragment(C), ( - f"Expected C to be a local, fragment, but got {C.scope()}") + assert is_local(A) or is_fragment(A) or is_shared(A), f"Expected A to be a local, fragment, or shared buffer, but got {A.scope()}" + assert is_local(B) or is_fragment(B) or is_shared(B), f"Expected B to be a local, fragment, or shared buffer, but got {B.scope()}" + assert is_local(C) or is_fragment(C), f"Expected C to be a local, fragment, but got {C.scope()}" # TODO(lei): Now we only support Nvidia GPUs # Must enhance the design to implement runtime lowering # for different targets (hip mfma for example) diff --git a/tilelang/primitives/gemm/base.py b/tilelang/primitives/gemm/base.py index 827ff78f9288a5d80c731956370d22a41acc75d6..b7fcdca92ce4fd814196c66136745bec44d0da51 100644 --- a/tilelang/primitives/gemm/base.py +++ b/tilelang/primitives/gemm/base.py @@ -131,7 +131,7 @@ class GemmWarpPolicy(IntEnum): # Try to find the best balanced partition best_m = 1 best_n = 1 - best_balance = float('inf') + best_balance = float("inf") # Try all possible combinations that satisfy the constraints for m in range(1, min(max_m_warps, num_warps) + 1): @@ -202,7 +202,7 @@ class GemmBaseParams: warp_row_tiles: int | None = None warp_col_tiles: int | None = None chunk: int | None = None - policy: GemmWarpPolicy = GemmWarpPolicy.Square, + policy: GemmWarpPolicy = (GemmWarpPolicy.Square,) k_pack: int = 1 def get_warp_size(self) -> int: @@ -267,17 +267,17 @@ class GemmBaseParams: # Determine whether block partition parameters need to be inferred require_infer = ( - block_row_warps is None or block_col_warps is None or warp_row_tiles is None or - warp_col_tiles is None or chunk is None) + block_row_warps is None or block_col_warps is None or warp_row_tiles is None or warp_col_tiles is None or chunk is None + ) A_shape, B_shape = A.shape, B.shape if require_infer: - assert (threads is not None), "threads must be provided for auto inference" + assert threads is not None, "threads must be provided for auto inference" # Auto-inference only supports 2D matrix multiplication - assert ( - len(A_shape) == 2 and len(B_shape) == 2 - ), f"Only support 2D matrix multiplication, got {len(A_shape)}D and {len(B_shape)}D" + assert len(A_shape) == 2 and len(B_shape) == 2, ( + f"Only support 2D matrix multiplication, got {len(A_shape)}D and {len(B_shape)}D" + ) # Analyze A/B shapes AM = A_shape[1] if transpose_A else A_shape[0] # M dimension @@ -291,8 +291,7 @@ class GemmBaseParams: num_warps = threads // warp_size # Infer block partition using a user-specified policy - block_row_warps, block_col_warps = policy.compute_warp_partition( - block_M, block_N, num_warps) + block_row_warps, block_col_warps = policy.compute_warp_partition(block_M, block_N, num_warps) warp_row_tiles = block_M // block_row_warps warp_col_tiles = block_N // block_col_warps chunk = int(AK) diff --git a/tilelang/primitives/gemm/gemm_mma.py b/tilelang/primitives/gemm/gemm_mma.py index 11e16838c07cd1612c381d8a729434ef590a551e..7ca3208be5214026d099184001a709929fc0730f 100644 --- a/tilelang/primitives/gemm/gemm_mma.py +++ b/tilelang/primitives/gemm/gemm_mma.py @@ -31,7 +31,6 @@ class GemmPrimitiveMMA(GemmBaseParams): C: tir.Buffer, mma_emitter: TensorCoreIntrinEmitter, ) -> tir.PrimExpr: - in_dtype = self.in_dtype warp_cols = mma_emitter.warp_cols local_size_b = mma_emitter.local_size_b @@ -53,21 +52,24 @@ class GemmPrimitiveMMA(GemmBaseParams): if a_is_fragment: # Annotate layout for A_local if it is a fragment. - T.annotate_layout({ - A_local: mma_emitter.make_mma_load_layout(A_local, "A"), - }) + T.annotate_layout( + { + A_local: mma_emitter.make_mma_load_layout(A_local, "A"), + } + ) if c_is_fragment: # Annotate layout for C_local if it is a fragment. - T.annotate_layout({ - C_local: mma_emitter.make_mma_store_layout(C_local), - }) + T.annotate_layout( + { + C_local: mma_emitter.make_mma_store_layout(C_local), + } + ) # Make default swizzle layout for shared memory # T.annotate_layout({ # B_shared: make_mma_swizzle_layout(B_shared), # }) for ki in T.serial(0, (block_K // micro_size_k)): - # Load B into fragment mma_emitter.ldmatrix_b( B_local, @@ -146,9 +148,11 @@ class GemmPrimitiveMMA(GemmBaseParams): if c_is_fragment: # Annotate layout for C_local if it is a fragment. - T.annotate_layout({ - C_local: mma_emitter.make_mma_store_layout(C_local), - }) + T.annotate_layout( + { + C_local: mma_emitter.make_mma_store_layout(C_local), + } + ) for ki in T.serial(0, (block_K // micro_size_k)): # Load A into fragment diff --git a/tilelang/profiler/__init__.py b/tilelang/profiler/__init__.py index 4750fa7d5d848b4837f8f12f2ba7caa68437c397..94d350153caffa38e00523e2b49c1224c9a4afed 100644 --- a/tilelang/profiler/__init__.py +++ b/tilelang/profiler/__init__.py @@ -1,4 +1,5 @@ """The profiler and convert to torch utils""" + from __future__ import annotations from typing import Callable, Any, Literal from functools import partial @@ -45,8 +46,7 @@ class Profiler: result_idx = [] elif isinstance(result_idx, int): if result_idx > len(params) or result_idx < -len(params): - raise ValueError( - f"result_idx should be an integer between {-len(params)} and {len(params) - 1}") + raise ValueError(f"result_idx should be an integer between {-len(params)} and {len(params) - 1}") if result_idx < 0: result_idx = len(params) + result_idx result_idx = [result_idx] @@ -113,8 +113,7 @@ class Profiler: ref_tensors = ins + ref_outs lib_tensors = ins + lib_outs - assert len(lib_tensors) == len( - ref_tensors), "len(lib_tensors) not equals to len(ref_tensors) !" + assert len(lib_tensors) == len(ref_tensors), "len(lib_tensors) not equals to len(ref_tensors) !" # torch.set_printoptions(edgeitems=torch.inf) for lhs, rhs in zip(lib_tensors, ref_tensors): # close_mask = torch.isclose(lhs, rhs, rtol=rtol, atol=atol) @@ -252,10 +251,9 @@ class Profiler: ) elif profiler == "tvm": assert func is not None, "func should not be None" - assert isinstance( - func, tvm.runtime.Module), f"func should be a TVM module, but got {type(func)}" + assert isinstance(func, tvm.runtime.Module), f"func should be a TVM module, but got {type(func)}" - ins = (self._get_inputs(with_output=True) if input_tensors is None else input_tensors) + ins = self._get_inputs(with_output=True) if input_tensors is None else input_tensors target = "cuda" with suppress(Exception): @@ -264,8 +262,7 @@ class Profiler: assert target in ["cuda", "hip"], f"Unknown target: {target}" device = tvm.cuda(0) if target == "cuda" else tvm.rocm(0) - time_evaluator = self.mod.time_evaluator( - self.mod.entry_name, device, number=rep, repeat=n_repeat) + time_evaluator = self.mod.time_evaluator(self.mod.entry_name, device, number=rep, repeat=n_repeat) # Transform Latency to ms return time_evaluator(*ins).mean * 1e3 else: diff --git a/tilelang/profiler/bench.py b/tilelang/profiler/bench.py index a851ceb3dca13e0e8d7749b18b0c9c9b1765f27e..bfcb5043debc093a1449992ac5592d03b8a768e9 100644 --- a/tilelang/profiler/bench.py +++ b/tilelang/profiler/bench.py @@ -1,4 +1,5 @@ """Profiler and benchmarking utilities for PyTorch functions.""" + from __future__ import annotations import os @@ -16,8 +17,8 @@ class suppress_stdout_stderr: def __enter__(self): # Open null device files - self.outnull_file = open(os.devnull, 'w') - self.errnull_file = open(os.devnull, 'w') + self.outnull_file = open(os.devnull, "w") + self.errnull_file = open(os.devnull, "w") # Save original file descriptors self.old_stdout_fileno_undup = sys.stdout.fileno() @@ -56,7 +57,7 @@ class suppress_stdout_stderr: IS_CUDA = torch.cuda.is_available() -device = 'cuda:0' if IS_CUDA else 'mps:0' +device = "cuda:0" if IS_CUDA else "mps:0" Event = torch.cuda.Event if IS_CUDA else torch.mps.Event @@ -93,8 +94,7 @@ def do_bench( Returns: Runtime in milliseconds (float) or list of quantile values if quantiles specified """ - assert return_mode in ["min", "max", "mean", "median"], \ - f"Invalid return_mode: {return_mode}" + assert return_mode in ["min", "max", "mean", "median"], f"Invalid return_mode: {return_mode}" # Initial function call and synchronization fn() diff --git a/tilelang/quantize/lop3.py b/tilelang/quantize/lop3.py index e4e7f7ee28d962c3b2f7c5e5c761a41b4c2cb334..e0788dab411754cac1909ab9c239d4f5de838187 100644 --- a/tilelang/quantize/lop3.py +++ b/tilelang/quantize/lop3.py @@ -1130,16 +1130,13 @@ def get_lop3_intrin_group( Dict[str, str] A dictionary mapping the names of the intrinsics to their corresponding implementations. """ - assert out_dtype in [ - "float16", "int8", "int4" - ], (f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'int8' or 'int4' .") + assert out_dtype in ["float16", "int8", "int4"], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'int8' or 'int4' ." dtype_mapping = {"float16": "f16", "int4": "i4", "int8": "i8", "int32": "i32"} target_dtype = dtype_mapping[out_dtype] if source_format not in ["int", "uint"]: - raise ValueError( - f"Invalid source_format. Expected 'int' or 'uint', but got {source_format}.") + raise ValueError(f"Invalid source_format. Expected 'int' or 'uint', but got {source_format}.") if with_zeros and source_format == "int": raise ValueError(f"Zeros are not supported for signed integers, but got {source_format}") diff --git a/tilelang/quantize/mxfp.py b/tilelang/quantize/mxfp.py index 80f3e0612b256c2f0d29f0c1ca455bf6cc70059e..e5c472cb1e303bc8da63567069ee0a6ad4ea88e5 100644 --- a/tilelang/quantize/mxfp.py +++ b/tilelang/quantize/mxfp.py @@ -80,13 +80,9 @@ def get_mxfp_intrin_group( AssertionError: if out_dtype, source_format, or storage_dtype are not supported. KeyError: if the constructed key does not match any available C source implementation. """ - assert out_dtype in ["float16", "bfloat16" - ], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'bfloat16'." - assert source_format in ["int", "uint" - ], f"Invalid source_format: {source_format}. Expected 'int' or 'uint'." - assert storage_dtype in [ - "int32", "int8", "uint8" - ], f"Invalid storage_dtype: {storage_dtype}. Expected 'int32' or 'int8' or 'uint8'." + assert out_dtype in ["float16", "bfloat16"], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'bfloat16'." + assert source_format in ["int", "uint"], f"Invalid source_format: {source_format}. Expected 'int' or 'uint'." + assert storage_dtype in ["int32", "int8", "uint8"], f"Invalid storage_dtype: {storage_dtype}. Expected 'int32' or 'int8' or 'uint8'." dtype_map = {"float16": "f16", "bfloat16": "bf16"} key = f"fp{source_bit}_to_{dtype_map[out_dtype]}" diff --git a/tilelang/quantize/utils.py b/tilelang/quantize/utils.py index 2447ca16724535b75e1cffd32c3ee966cb4d7397..2d092a0bab95b7c1a6e34fbbdff61ac76da340d2 100644 --- a/tilelang/quantize/utils.py +++ b/tilelang/quantize/utils.py @@ -1,6 +1,7 @@ def gen_quant4(k, n, groupsize=-1): import torch import torch.nn as nn + maxq = 2**4 w = torch.randn((k, n), dtype=torch.half, device="cpu") @@ -48,6 +49,7 @@ def gen_quant4(k, n, groupsize=-1): def general_compress(lowprecision_weight, source_bits=4, storage_dtype=None): import torch + if storage_dtype is None: storage_dtype = torch.int8 elems_per_byte = 8 // source_bits @@ -56,11 +58,11 @@ def general_compress(lowprecision_weight, source_bits=4, storage_dtype=None): int8_weight = torch.zeros( (*lowprecision_weight.shape[:-1], lowprecision_weight.shape[-1] // elems_per_byte), dtype=torch.int8, - device=lowprecision_weight.device) + device=lowprecision_weight.device, + ) for j in range(lowprecision_weight.shape[-1] // elems_per_byte): for k in range(elems_per_byte): - int8_weight[..., j] |= (lowprecision_weight[..., j * elems_per_byte + k] << - (source_bits * k)).to(torch.int8) + int8_weight[..., j] |= (lowprecision_weight[..., j * elems_per_byte + k] << (source_bits * k)).to(torch.int8) return int8_weight.to(storage_dtype) @@ -82,6 +84,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): interleave_weight(qweight, 4, "float16") """ import torch + assert target_dtype in ["float16", "int8"] # reinterpret the data type of qweight to int32 qweight = qweight.view(torch.int32) diff --git a/tilelang/testing/__init__.py b/tilelang/testing/__init__.py index 6a2031492fed1550c3c3cd9ce424b3b5349f8429..635fad365ce4d867f255531f01c0b9c2dfefe616 100644 --- a/tilelang/testing/__init__.py +++ b/tilelang/testing/__init__.py @@ -5,20 +5,19 @@ import random import torch import numpy as np from tilelang.contrib import nvcc -from tvm.testing.utils import (requires_cuda, requires_package, requires_llvm, requires_metal, - requires_rocm, _compose) +from tvm.testing.utils import requires_cuda, requires_package, requires_llvm, requires_metal, requires_rocm, _compose from tilelang.utils.tensor import torch_assert_close as torch_assert_close __all__ = [ - 'requires_package', - 'requires_cuda', - 'requires_metal', - 'requires_rocm', - 'requires_llvm', - 'main', - 'requires_cuda_compute_version', -] + [f'requires_cuda_compute_version_{op}' for op in ('ge', 'gt', 'le', 'lt', 'eq')] + "requires_package", + "requires_cuda", + "requires_metal", + "requires_rocm", + "requires_llvm", + "main", + "requires_cuda_compute_version", +] + [f"requires_cuda_compute_version_{op}" for op in ("ge", "gt", "le", "lt", "eq")] # pytest.main() wrapper to allow running single test file diff --git a/tilelang/tileop/gemm/__init__.py b/tilelang/tileop/gemm/__init__.py index 90960904f05ee3907555e87d07403d5b56e99ecc..4d2caf8c5872f9d6d4a7455a06871dc0bcaa1894 100644 --- a/tilelang/tileop/gemm/__init__.py +++ b/tilelang/tileop/gemm/__init__.py @@ -23,8 +23,7 @@ def gemm_py_infer_layout(gemm_py: GemmMMA, target: Target, thread_bounds: Range) @tvm_ffi.register_global_func("tl.gemm_py.lower") -def gemm_py_lower(gemm_py: GemmMMA, layout_map, target: Target, thread_bounds: Range, - thread_var: tir.Var): +def gemm_py_lower(gemm_py: GemmMMA, layout_map, target: Target, thread_bounds: Range, thread_var: tir.Var): thread_nums = thread_bounds.extent stmt = gemm_py.lower(layout_map, target, thread_nums, thread_var) return stmt diff --git a/tilelang/tileop/gemm/gemm_mfma.py b/tilelang/tileop/gemm/gemm_mfma.py index 862ec725b71ad87da8cb96f538b79f2670d48c30..d827d8a2a3fa9ad901dadf214e576f942619f0af 100644 --- a/tilelang/tileop/gemm/gemm_mfma.py +++ b/tilelang/tileop/gemm/gemm_mfma.py @@ -1,7 +1,8 @@ from .gemm_base import GemmBase from tilelang.layout import make_swizzled_layout from tilelang.intrinsics.mfma_macro_generator import ( - MatrixCoreIntrinEmitter,) + MatrixCoreIntrinEmitter, +) from tilelang.utils.language import is_shared, is_fragment, is_full_region from tilelang import tvm as tvm from tvm.target import Target @@ -11,10 +12,8 @@ from tilelang.transform.simplify import _Simplify class GemmMFMA(GemmBase): - def infer_layout(self, target: Target, thread_nums: int): - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - False) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mfma_emitter = MatrixCoreIntrinEmitter( @@ -56,12 +55,10 @@ class GemmMFMA(GemmBase): self.C: mfma_emitter.make_mfma_store_layout(self.C), } else: - raise ValueError( - f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - False) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mfma_emitter = MatrixCoreIntrinEmitter( @@ -153,7 +150,6 @@ class GemmMFMA(GemmBase): T.clear(C_buf) for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))): - # Load A into fragment mfma_emitter.ldmatrix_a( A_local, @@ -183,7 +179,6 @@ class GemmMFMA(GemmBase): if clear_accum: T.clear(C_buf) for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))): - # Load B into fragment mfma_emitter.ldmatrix_b( B_local, @@ -217,8 +212,7 @@ class GemmMFMA(GemmBase): # Must inline let statements to simplify the analysis return _Simplify(_gemm_rsr, inline_let=True) else: - raise ValueError( - f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") def is_gemm_ss(self) -> bool: return is_shared(self.A) and is_shared(self.B) diff --git a/tilelang/tileop/gemm/gemm_mma.py b/tilelang/tileop/gemm/gemm_mma.py index ce27409bb60ae0d855345810ce38ce06e701604c..b15173483813aa28f0a72d74260fba4b23dab3e7 100644 --- a/tilelang/tileop/gemm/gemm_mma.py +++ b/tilelang/tileop/gemm/gemm_mma.py @@ -1,7 +1,8 @@ from .gemm_base import GemmBase from tilelang.layout import make_swizzled_layout from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from tilelang.utils.language import is_shared, is_fragment, is_full_region from tilelang import tvm as tvm from tvm.target import Target @@ -11,10 +12,8 @@ from tilelang.transform.simplify import _Simplify class GemmMMA(GemmBase): - def infer_layout(self, target: Target, thread_nums: int): - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - False) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mma_emitter = TensorCoreIntrinEmitter( @@ -54,12 +53,10 @@ class GemmMMA(GemmBase): self.C: mma_emitter.make_mma_store_layout(self.C), } else: - raise ValueError( - f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - False) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mma_emitter = TensorCoreIntrinEmitter( @@ -177,7 +174,6 @@ class GemmMMA(GemmBase): if clear_accum: T.clear(C_buf) for ki in T.serial(0, (block_K // micro_size_k)): - # Load B into fragment mma_emitter.ldmatrix_b( B_local, @@ -211,8 +207,7 @@ class GemmMMA(GemmBase): # Must inline let statements to simplify the analysis return _Simplify(_gemm_rrr, inline_let=True) else: - raise ValueError( - f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") def is_gemm_ss(self) -> bool: return is_shared(self.A) and is_shared(self.B) diff --git a/tilelang/tileop/gemm/gemm_mma_sm70.py b/tilelang/tileop/gemm/gemm_mma_sm70.py index 12b729c27583421eb1a62618302b223cf1f3ab7f..52a4bf3262f0054be28158a0d3c0db7863512ddf 100644 --- a/tilelang/tileop/gemm/gemm_mma_sm70.py +++ b/tilelang/tileop/gemm/gemm_mma_sm70.py @@ -2,7 +2,8 @@ from .gemm_base import GemmBase from tilelang.layout import make_volta_swizzled_layout from tilelang.intrinsics.mma_sm70_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from tilelang.utils.language import is_shared, is_fragment, is_full_region from tilelang import tvm as tvm from tvm.target import Target @@ -12,10 +13,8 @@ from tilelang.transform.simplify import _Simplify class GemmMMASm70(GemmBase): - def infer_layout(self, target: Target, thread_nums: int): - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - False) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mma_emitter = TensorCoreIntrinEmitter( @@ -45,12 +44,10 @@ class GemmMMASm70(GemmBase): self.C: mma_emitter.make_mma_store_layout(self.C), } else: - raise ValueError( - f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - False) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mma_emitter = TensorCoreIntrinEmitter( @@ -140,7 +137,6 @@ class GemmMMASm70(GemmBase): T.clear(C_buf) for ki in T.serial(0, (block_K // micro_size_k)): - # Load B into fragment mma_emitter.ldmatrix_b( B_local, @@ -155,8 +151,7 @@ class GemmMMASm70(GemmBase): # Must inline let statements to simplify the analysis return _Simplify(_gemm_rsr, inline_let=True) else: - raise ValueError( - f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") def is_gemm_ss(self) -> bool: return is_shared(self.A) and is_shared(self.B) diff --git a/tilelang/tileop/gemm/gemm_tcgen05.py b/tilelang/tileop/gemm/gemm_tcgen05.py index 76f919e0f8d6870ba701e0a85fad17d9ff142683..f93a403ebb5be64af8c539718eb34eb70b712082 100644 --- a/tilelang/tileop/gemm/gemm_tcgen05.py +++ b/tilelang/tileop/gemm/gemm_tcgen05.py @@ -1,7 +1,8 @@ from .gemm_base import GemmBase from tilelang.layout import make_tcgen05mma_swizzled_layout from tilelang.intrinsics.tcgen05_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from tilelang import language as T from tilelang.transform.simplify import _Simplify from tvm import tir @@ -18,10 +19,8 @@ _FLOAT8_DTYPES = { class GemmTCGEN5(GemmBase): - def infer_layout(self, target: Target, thread_nums: int): - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - True) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, True) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mma_emitter = TensorCoreIntrinEmitter( @@ -40,27 +39,20 @@ class GemmTCGEN5(GemmBase): b_is_k_major = self.trans_B if self.is_gemm_ss(): - a_continuity = self.M if a_is_k_major else 4 * self.K // m_warp b_continuity = self.K if b_is_k_major else self.N // n_warp return { # WGMMA does not support padding - self.A: - make_tcgen05mma_swizzled_layout( - self.A, continuity=a_continuity, k_major=a_is_k_major), - self.B: - make_tcgen05mma_swizzled_layout( - self.B, continuity=b_continuity, k_major=b_is_k_major), - self.C: - mma_emitter.make_mma_store_layout(self.C), + self.A: make_tcgen05mma_swizzled_layout(self.A, continuity=a_continuity, k_major=a_is_k_major), + self.B: make_tcgen05mma_swizzled_layout(self.B, continuity=b_continuity, k_major=b_is_k_major), + self.C: mma_emitter.make_mma_store_layout(self.C), } # No special swizzle requirement; rely on existing layout. return {} def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - True) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, True) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mma_emitter = TensorCoreIntrinEmitter( @@ -82,11 +74,9 @@ class GemmTCGEN5(GemmBase): mma_emitter._assign_b_shared_layout(layout_map[self.B]) if not self.is_gemm_ss(): - raise ValueError(f"TCGEN5MMA currently only supports gemm_ss, got " - f"A scope {self.A.scope()}, B scope {self.B.scope()}") + raise ValueError(f"TCGEN5MMA currently only supports gemm_ss, got A scope {self.A.scope()}, B scope {self.B.scope()}") - atom_m, atom_n, atom_k, enable_ws, enable_2cta = mma_emitter.get_tcgen5_mma_meta( - self.M, self.N, self.K) + atom_m, atom_n, atom_k, enable_ws, enable_2cta = mma_emitter.get_tcgen5_mma_meta(self.M, self.N, self.K) if self.A.scope() not in {"shared", "shared.dyn", "shared.tmem"}: raise ValueError(f"Unsupported A scope for TCGEN5MMA: {self.A.scope()}") @@ -108,7 +98,7 @@ class GemmTCGEN5(GemmBase): raise ValueError("TCGEN5MMA expects 2D coordinates for C buffer access") accum_dtype = str(self.C.dtype) - if accum_dtype not in ["float32", 'float16']: + if accum_dtype not in ["float32", "float16"]: raise ValueError(f"Unsupported accumulator dtype for TCGEN5MMA: {accum_dtype}") A_shared = self.ARegion diff --git a/tilelang/tileop/gemm/gemm_wgmma.py b/tilelang/tileop/gemm/gemm_wgmma.py index 2325f45df65fb80e2ad3c29636cacb1f55aa8b6f..038aa2cd66692bb50386ccec669083c615882420 100644 --- a/tilelang/tileop/gemm/gemm_wgmma.py +++ b/tilelang/tileop/gemm/gemm_wgmma.py @@ -1,7 +1,8 @@ from .gemm_base import GemmBase from tilelang.layout import make_wgmma_swizzled_layout from tilelang.intrinsics.wgmma_macro_generator import ( - TensorCoreIntrinEmitter,) + TensorCoreIntrinEmitter, +) from tilelang.utils.language import is_shared, is_fragment from tilelang import tvm as tvm from tvm.target import Target @@ -11,10 +12,8 @@ from tilelang.transform.simplify import _Simplify class GemmWGMMA(GemmBase): - def infer_layout(self, target: Target, thread_nums: int): - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - True) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, True) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mma_emitter = TensorCoreIntrinEmitter( @@ -38,33 +37,22 @@ class GemmWGMMA(GemmBase): return { # WGMMA does not support padding - self.A: - make_wgmma_swizzled_layout( - self.A, continuity=a_continuity, k_major=a_is_k_major), - self.B: - make_wgmma_swizzled_layout( - self.B, continuity=b_continuity, k_major=b_is_k_major), - self.C: - mma_emitter.make_mma_store_layout(self.C), + self.A: make_wgmma_swizzled_layout(self.A, continuity=a_continuity, k_major=a_is_k_major), + self.B: make_wgmma_swizzled_layout(self.B, continuity=b_continuity, k_major=b_is_k_major), + self.C: mma_emitter.make_mma_store_layout(self.C), } elif self.is_gemm_rs(): b_continuity = self.N if b_is_k_major else 4 * self.K // n_warp return { - self.A: - mma_emitter.make_mma_load_layout(self.A, matrix="A"), - self.B: - make_wgmma_swizzled_layout( - self.B, continuity=b_continuity, k_major=b_is_k_major), - self.C: - mma_emitter.make_mma_store_layout(self.C), + self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"), + self.B: make_wgmma_swizzled_layout(self.B, continuity=b_continuity, k_major=b_is_k_major), + self.C: mma_emitter.make_mma_store_layout(self.C), } else: - raise ValueError( - f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}") + raise ValueError(f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}") def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - True) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, True) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) @@ -133,8 +121,7 @@ class GemmWGMMA(GemmBase): # Simplify to optimize the index computing # Must inline let statements to simplify the analysis return _Simplify(_gemm_rsr, inline_let=True) - raise ValueError( - f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}") + raise ValueError(f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}") def is_gemm_ss(self) -> bool: return is_shared(self.A) and is_shared(self.B) diff --git a/tilelang/tileop/gemm_sp/__init__.py b/tilelang/tileop/gemm_sp/__init__.py index fdac694cedfa73eae048c3e5b59c9b97bb701af6..c22bca8d22fb845c5526acabef9a61af5d7cd00c 100644 --- a/tilelang/tileop/gemm_sp/__init__.py +++ b/tilelang/tileop/gemm_sp/__init__.py @@ -1,7 +1,8 @@ from tilelang import tvm as tvm from tvm import tir from tilelang.utils.target import ( - target_is_cuda,) + target_is_cuda, +) from tvm.target import Target from tvm.ir.base import Node from tvm.ir import Range @@ -18,8 +19,7 @@ def gemm_sp_py_infer_layout(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds @tvm_ffi.register_global_func("tl.gemm_sp_py.lower") -def gemm_sp_py_lower(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range, - thread_var: tir.Var): +def gemm_sp_py_lower(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range, thread_var: tir.Var): thread_nums = thread_bounds.extent stmt = gemm_sp_py.lower(target, thread_nums, thread_var) return stmt diff --git a/tilelang/tileop/gemm_sp/gemm_sp_mma.py b/tilelang/tileop/gemm_sp/gemm_sp_mma.py index 50a40bb91a32789cca812512fd4d8d0317e77952..76a0d4a9ed8a3800e0a2017e7ee3a9b7995af49e 100644 --- a/tilelang/tileop/gemm_sp/gemm_sp_mma.py +++ b/tilelang/tileop/gemm_sp/gemm_sp_mma.py @@ -10,10 +10,8 @@ from tilelang.transform.simplify import _Simplify class GemmSPMMA(GemmSPBase): - def infer_layout(self, target: Target, thread_nums: int): - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - False) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mma_emitter = SparseTensorCoreIntrinEmitter( @@ -55,12 +53,10 @@ class GemmSPMMA(GemmSPBase): self.C: mma_emitter.make_mma_store_layout(self.C), } else: - raise ValueError( - f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - False) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mma_emitter = SparseTensorCoreIntrinEmitter( @@ -146,7 +142,6 @@ class GemmSPMMA(GemmSPBase): E_local = T.alloc_local((warp_rows * local_size_e), self.e_dtype) for ki in T.serial(0, (self.K // micro_size_k)): - # Load A into fragment mma_emitter.ldmatrix_a( A_local, @@ -231,8 +226,7 @@ class GemmSPMMA(GemmSPBase): # Must inline let statements to simplify the analysis return _Simplify(_gemm_rrr, inline_let=True) else: - raise ValueError( - f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") def is_gemm_ss(self) -> bool: return is_shared(self.A) and is_shared(self.B) diff --git a/tilelang/tools/Analyzer.py b/tilelang/tools/Analyzer.py index 205c647e3e54b2773c8a7f25e23aff6183ec9f9b..3af5222f29fb24cb2907d43da659a3efa28e7223 100644 --- a/tilelang/tools/Analyzer.py +++ b/tilelang/tools/Analyzer.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from tilelang import tvm from tvm.tir.stmt_functor import ir_transform import logging + # Configuration for different hardware architectures. # Each entry contains: (cores per SM, default clock (GHz), FLOPs per cycle, max SM count) ARCH_CONFIGS = {"80": (128, 1.41, 2, 108), "86": (128, 1.70, 2, 84), "89": (128, 2.52, 2, 128)} @@ -23,6 +24,7 @@ class AnalysisResult: tflops: Achieved TFLOPS (trillions of FLOPs per second). bandwidth_GBps: Achieved memory bandwidth in GB/s. """ + total_flops: int total_global_bytes: int estimated_time: float @@ -81,7 +83,7 @@ class Analyzer: # Account for loop and block dimensions loop_product = 1 for extent in self.loop_stack: - loop_product *= extent.value if hasattr(extent, 'value') else extent + loop_product *= extent.value if hasattr(extent, "value") else extent total_blocks = self.block_counts["blockIdx.x"] * self.block_counts["blockIdx.y"] total_bytes = bytes_transferred * loop_product * total_blocks self.total_global_bytes += total_bytes @@ -100,7 +102,7 @@ class Analyzer: # Account for loop and block dimensions loop_product = 1 for extent in self.loop_stack: - loop_product *= extent.value if hasattr(extent, 'value') else extent + loop_product *= extent.value if hasattr(extent, "value") else extent total_blocks = self.block_counts["blockIdx.x"] * self.block_counts["blockIdx.y"] self.total_flops += flops_per_call * loop_product * total_blocks @@ -127,8 +129,7 @@ class Analyzer: iter_var = stmt.node thread_tag = iter_var.thread_tag if thread_tag in self.block_counts: - extent = stmt.value.value if hasattr(stmt.value, - 'value') else stmt.value + extent = stmt.value.value if hasattr(stmt.value, "value") else stmt.value self.block_counts[thread_tag] = extent elif isinstance(stmt, tvm.tir.For): # Push loop extent onto the stack @@ -178,9 +179,7 @@ class Analyzer: """ arch_key = device.compute_capability[:2] if arch_key not in ARCH_CONFIGS: - logger.info( - f"Unsupported compute capability: {device.compute_capability}, theoretical peak tflops will be None" - ) + logger.info(f"Unsupported compute capability: {device.compute_capability}, theoretical peak tflops will be None") return None cores_per_sm, default_clock, flops_per_cycle, compute_max_core = ARCH_CONFIGS[arch_key] @@ -203,7 +202,8 @@ class Analyzer: total_global_bytes=self.total_global_bytes, estimated_time=estimated_time, expected_tflops=peak_tflops, - expected_bandwidth_GBps=bandwidth_GBps) + expected_bandwidth_GBps=bandwidth_GBps, + ) @classmethod def analysis(cls, fn, device): diff --git a/tilelang/tools/plot_layout.py b/tilelang/tools/plot_layout.py index 06e01f4895e3f9151defc409139ff9bd9420b16e..299c3e86b6325c84e43c575c6dc13729725b736b 100644 --- a/tilelang/tools/plot_layout.py +++ b/tilelang/tools/plot_layout.py @@ -2,12 +2,14 @@ from __future__ import annotations import tilelang.language as T -def plot_layout(layout: T.Fragment, - save_directory="./tmp", - name: str = "layout", - colormap: str = "RdPu", - verbose: bool = False, - formats: str | list[str] = "png") -> None: +def plot_layout( + layout: T.Fragment, + save_directory="./tmp", + name: str = "layout", + colormap: str = "RdPu", + verbose: bool = False, + formats: str | list[str] = "png", +) -> None: """ Plot the layout of a buffer. @@ -90,11 +92,13 @@ def plot_layout(layout: T.Fragment, # Warn if the number of threads is less than the warp size if num_threads < warp_size: import warnings + warnings.warn( f"Layout visualization has {num_threads} threads, which is less than the warp size ({warp_size}). " f"For the best viewing experience, it is recommended to have at least {warp_size} threads.", UserWarning, - stacklevel=2) + stacklevel=2, + ) spectral_camp = plt.get_cmap("hsv", warp_size * 6) for i in range(min(warp_size, num_threads)): @@ -118,12 +122,7 @@ def plot_layout(layout: T.Fragment, color = colors[thread_ids[0]] # Select color based on thread ID # Create a rectangle patch for visualization - rect = patches.Rectangle((j, i), - 1, - 1, - linewidth=0.5, - edgecolor='black', - facecolor=color) + rect = patches.Rectangle((j, i), 1, 1, linewidth=0.5, edgecolor="black", facecolor=color) ax.add_patch(rect) # Add the rectangle to the plot # Add text annotations inside the rectangles @@ -139,41 +138,19 @@ def plot_layout(layout: T.Fragment, thread_fontsize = min(font_size, font_size * (4 / len(thread_str))) # Add thread ID text with adjusted font size - ax.text( - j + 0.5, - i + 0.3, - thread_str, - ha='center', - va='center', - color='black', - fontsize=thread_fontsize) + ax.text(j + 0.5, i + 0.3, thread_str, ha="center", va="center", color="black", fontsize=thread_fontsize) # Add local ID text with original font size - ax.text( - j + 0.5, - i + 0.7, - f"L{local_id}", - ha='center', - va='center', - color='black', - fontsize=font_size) + ax.text(j + 0.5, i + 0.7, f"L{local_id}", ha="center", va="center", color="black", fontsize=font_size) # Add row labels to the left side of the plot for i in range(nrows): text = f"row {i}" - ax.text(-0.75, i + 0.5, text, ha='center', va='center', color='black', fontsize=font_size) + ax.text(-0.75, i + 0.5, text, ha="center", va="center", color="black", fontsize=font_size) # Add column labels at the top of the plot for j in range(ncols): text = f"col {j}" - ax.text( - j + 0.5, - -0.5, - text, - ha='center', - va='center', - color='black', - fontsize=font_size, - rotation=45) + ax.text(j + 0.5, -0.5, text, ha="center", va="center", color="black", fontsize=font_size, rotation=45) # Set the plot limits ax.set_xlim(0, ncols) @@ -189,17 +166,15 @@ def plot_layout(layout: T.Fragment, legend_x = 1.0 + (0.5 / fig_width) # Adjust x position based on figure width legend_y = 1.0 + (1.7 / fig_height) # Adjust y position based on figure height - legend_patches = [ - patches.Patch(color='black', label="T: Thread ID"), - patches.Patch(color='black', label="L: Local ID") - ] + legend_patches = [patches.Patch(color="black", label="T: Thread ID"), patches.Patch(color="black", label="L: Local ID")] ax.legend( handles=legend_patches, loc="upper right", fontsize=font_size - 4, frameon=False, bbox_to_anchor=(legend_x, legend_y), # Dynamic position - ncols=2) + ncols=2, + ) # Create the output directory if it does not exist tmp_directory = pathlib.Path(save_directory) @@ -211,28 +186,29 @@ def plot_layout(layout: T.Fragment, if isinstance(formats, str): formats_str = formats.strip().lower() - if formats_str == 'all': - formats_list = ['pdf', 'png', 'svg'] + if formats_str == "all": + formats_list = ["pdf", "png", "svg"] elif "," in formats_str: - formats_list = [f.strip() for f in formats_str.split(',')] + formats_list = [f.strip() for f in formats_str.split(",")] else: formats_list = [formats_str] else: - raise TypeError(f"Expected str, but got {type(formats).__name__}. " - f"Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf'.") + raise TypeError( + f"Expected str, but got {type(formats).__name__}. Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf'." + ) # Save the figure - if 'pdf' in formats_list: + if "pdf" in formats_list: pdf_path = tmp_directory / f"{name}.pdf" plt.savefig(pdf_path, bbox_inches="tight") print(f"Saved pdf format into {pdf_path}") - if 'png' in formats_list: + if "png" in formats_list: png_path = tmp_directory / f"{name}.png" plt.savefig(png_path, bbox_inches="tight", transparent=False, dpi=255) print(f"Saved png format into {png_path}") - if 'svg' in formats_list: + if "svg" in formats_list: svg_path = tmp_directory / f"{name}.svg" plt.savefig(svg_path, bbox_inches="tight", format="svg") print(f"Saved svg format into {svg_path}") diff --git a/tilelang/transform/__init__.py b/tilelang/transform/__init__.py index a86ffe21dbac0d5adfcced306e51895b414142b9..bb9202a31c16429b1e92d854c047d7d20c5fcb19 100644 --- a/tilelang/transform/__init__.py +++ b/tilelang/transform/__init__.py @@ -110,8 +110,7 @@ def LowerHopperIntrin(): fpass : tvm.transform.Pass The result pass """ - return (_ffi_api.LowerHopperIntrin() if hasattr(_ffi_api, "LowerHopperIntrin") else lambda f: f - ) # type: ignore + return _ffi_api.LowerHopperIntrin() if hasattr(_ffi_api, "LowerHopperIntrin") else lambda f: f # type: ignore def WarpSpecializedPipeline(): @@ -365,8 +364,7 @@ def FlattenBuffer(): def EliminateStorageSyncForMBarrier(): - """EliminateStorageSyncForMBarrier - """ + """EliminateStorageSyncForMBarrier""" return _ffi_api.EliminateStorageSyncForMBarrier() # type: ignore @@ -378,19 +376,16 @@ def MergeSharedMemoryAllocations(enable_aggressive_merge: bool = False, align_by fpass : tvm.transform.Pass The result pass """ - return _ffi_api.MergeSharedMemoryAllocations(enable_aggressive_merge, - align_bytes) # type: ignore + return _ffi_api.MergeSharedMemoryAllocations(enable_aggressive_merge, align_bytes) # type: ignore def LowerL2Persistent(): - """LowerL2Persistent - """ + """LowerL2Persistent""" return _ffi_api.LowerL2Persistent() # type: ignore def PersistThreadblock(): - """PersistThreadblock - """ + """PersistThreadblock""" return _ffi_api.PersistThreadblock() # type: ignore @@ -409,8 +404,7 @@ def AlignDynamicSharedMemoryAllocations(align_bytes: int = 16): def LowerSharedBarrier(): - """LowerSharedBarrier - """ + """LowerSharedBarrier""" return _ffi_api.LowerSharedBarrier() # type: ignore @@ -437,20 +431,17 @@ def StorageRewrite(): def LowerOpaqueBlock(): - """LowerOpaqueBlock - """ + """LowerOpaqueBlock""" return _ffi_api.LowerOpaqueBlock() # type: ignore def LowerThreadAllreduce(): - """LowerThreadAllreduce - """ + """LowerThreadAllreduce""" return _ffi_api.LowerThreadAllreduce() # type: ignore def LowerIntrin(): - """LowerIntrin - """ + """LowerIntrin""" return _ffi_api.LowerIntrin() # type: ignore @@ -468,8 +459,7 @@ def LowerDeviceKernelLaunch(): def LowerSharedTmem(): - """LowerSharedTmem - """ + """LowerSharedTmem""" return _ffi_api.LowerSharedTmem() # type: ignore diff --git a/tilelang/transform/add_bufstore_wrapper.py b/tilelang/transform/add_bufstore_wrapper.py index d8457f990002e9e69ecfbc61f65ab44ac9bd2af2..c1dd41e0dde17c7817416ea88fee9f63e17f1185 100644 --- a/tilelang/transform/add_bufstore_wrapper.py +++ b/tilelang/transform/add_bufstore_wrapper.py @@ -1,4 +1,4 @@ -from tvm.tir import (BufferStore, For, AttrStmt, ForKind, Var, PrimFunc, BufferLoad, Buffer, IntImm) +from tvm.tir import BufferStore, For, AttrStmt, ForKind, Var, PrimFunc, BufferLoad, Buffer, IntImm from tvm.tir.stmt_functor import ir_transform, post_order_visit from tvm.tir.transform import prim_func_pass @@ -97,7 +97,7 @@ def AddWrapperForSingleBufStore(): Returns: True if the loop is a tile operation (parallel or has num_stages annotation) """ - return loop.kind == ForKind.PARALLEL or 'num_stages' in loop.annotations + return loop.kind == ForKind.PARALLEL or "num_stages" in loop.annotations def pre_visit(statement): """ @@ -105,7 +105,7 @@ def AddWrapperForSingleBufStore(): """ nonlocal tile_operation_depth - if isinstance(statement, AttrStmt) and statement.attr_key == 'thread_extent': + if isinstance(statement, AttrStmt) and statement.attr_key == "thread_extent": thread_binding_vars.add(statement.node.var) elif isinstance(statement, For) and is_tile_operation_loop(statement): tile_operation_depth += 1 @@ -139,7 +139,8 @@ def AddWrapperForSingleBufStore(): if isinstance(index, IntImm) and index != 0: raise ValueError( f"Fragment buffer access with non-zero index [{index}] is not supported. " - "Only fragment[0] access is allowed.") + "Only fragment[0] access is allowed." + ) # Wrap fragment[0] access with T.Parallel loop return For(Var("_", "int32"), 0, 1, ForKind.PARALLEL, statement) diff --git a/tilelang/transform/pass_config.py b/tilelang/transform/pass_config.py index 92adcb42c080a81430d785850f82b03a6168bb2f..92a7313b4c68b8741e0f304ba4b49d76a96a9a52 100644 --- a/tilelang/transform/pass_config.py +++ b/tilelang/transform/pass_config.py @@ -5,6 +5,7 @@ from enum import Enum class PassConfigKey(str, Enum): """Pass configuration keys for TileLang compiler.""" + # TileLang specific configs TL_SIMPLIFY = "tl.Simplify" """Enable/disable TileLang simplification passes. Default: True""" diff --git a/tilelang/transform/simplify.py b/tilelang/transform/simplify.py index 7e0c5062b7a35c511d6bbe2ec77c8a55d1d28c22..c5e577d036a3b4553d66c19d1b04d79108249cbf 100644 --- a/tilelang/transform/simplify.py +++ b/tilelang/transform/simplify.py @@ -51,7 +51,6 @@ def _Simplify(stmt: PrimFunc | IRModule, inline_let: bool = False) -> PrimFunc | # Decorator to simplify the output of a function def simplify_prim_func(func: Callable) -> Callable: - def wrapper(*args, **kwargs): stmt: PrimFunc | IRModule = (func)(*args, **kwargs) return _Simplify(stmt) diff --git a/tilelang/utils/deprecated.py b/tilelang/utils/deprecated.py index 2aff08b59601049ac6fb53588ee3d0edd0d49bac..2944f292b308dd88ae2006433e96b09e0af5069e 100644 --- a/tilelang/utils/deprecated.py +++ b/tilelang/utils/deprecated.py @@ -1,11 +1,10 @@ def deprecated_warning(method_name: str, new_method_name: str, phaseout_version: str = None): - """A function to indicate that a method is deprecated - """ + """A function to indicate that a method is deprecated""" import warnings # pylint: disable=import-outside-toplevel, import-error warnings.warn( - f"{method_name} is deprecated, use {new_method_name} instead" + - (f" and will be removed in {phaseout_version}" if phaseout_version else ""), + f"{method_name} is deprecated, use {new_method_name} instead" + + (f" and will be removed in {phaseout_version}" if phaseout_version else ""), DeprecationWarning, stacklevel=2, ) @@ -30,7 +29,6 @@ def deprecated( import functools # pylint: disable=import-outside-toplevel def _deprecate(func): - @functools.wraps(func) def _wrapper(*args, **kwargs): deprecated_warning(method_name, new_method_name, phaseout_version) diff --git a/tilelang/utils/language.py b/tilelang/utils/language.py index 41da8ab0a2c72dfc84f4bb83863a8dcd6ae6ff90..584e9998d0743e1b4b8cf1aa815215829d1a520d 100644 --- a/tilelang/utils/language.py +++ b/tilelang/utils/language.py @@ -24,8 +24,7 @@ def _get_buffer(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> elif isinstance(buffer_or_load_or_region, (tir.BufferLoad, tir.BufferRegion)): return buffer_or_load_or_region.buffer else: - raise TypeError( - f"Expected Buffer, BufferLoad, or BufferRegion, got {type(buffer_or_load_or_region)}") + raise TypeError(f"Expected Buffer, BufferLoad, or BufferRegion, got {type(buffer_or_load_or_region)}") def is_global(buffer: Buffer | BufferLoad | BufferRegion) -> bool: @@ -153,14 +152,12 @@ def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc: """ if not isinstance(ir_module, IRModule): raise ValueError("Not supported type: ", type(ir_module)) - assert len(ir_module.get_global_vars()) == 1, ( - "The optimized module should only have one global variable for default schedule.") + assert len(ir_module.get_global_vars()) == 1, "The optimized module should only have one global variable for default schedule." func = list(ir_module.functions.values())[0] return func -def get_buffer_region_from_load(buffer_load: tir.BufferLoad, - extents: list[PrimExpr] | None = None) -> tir.BufferRegion | None: +def get_buffer_region_from_load(buffer_load: tir.BufferLoad, extents: list[PrimExpr] | None = None) -> tir.BufferRegion | None: """ Get the buffer region from a buffer load. @@ -193,9 +190,9 @@ def get_buffer_region_from_load(buffer_load: tir.BufferLoad, return None -def to_buffer_region(obj: Buffer | BufferLoad | BufferRegion | tir.Var, - access_type: str = "rw", - extents: list[PrimExpr] | None = None) -> PrimExpr | BufferRegion: +def to_buffer_region( + obj: Buffer | BufferLoad | BufferRegion | tir.Var, access_type: str = "rw", extents: list[PrimExpr] | None = None +) -> PrimExpr | BufferRegion: """ Convert to/from the tl.region representation. @@ -203,6 +200,7 @@ def to_buffer_region(obj: Buffer | BufferLoad | BufferRegion | tir.Var, - tl.region Call -> returns the decoded BufferRegion for analysis """ from tilelang.language.frame import has_let_value, get_let_value + if isinstance(obj, tir.Var) and has_let_value(obj): obj = get_let_value(obj) # Encode into tl.region call (when extents is provided), otherwise return BufferRegion for analysis @@ -279,8 +277,7 @@ def retrieve_stride(obj: Buffer | BufferRegion | BufferLoad) -> list: return strides -def retrive_ptr_from_buffer_region(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion, - access_type: str = "r") -> PrimExpr: +def retrive_ptr_from_buffer_region(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion, access_type: str = "r") -> PrimExpr: if isinstance(buffer_or_load_or_region, Buffer): return buffer_or_load_or_region.access_ptr(access_type) elif isinstance(buffer_or_load_or_region, BufferLoad): diff --git a/tilelang/utils/sparse.py b/tilelang/utils/sparse.py index a7b17ad93d036413a0bc5af4eaca1aab2c3b6c9e..26a8e345cd668d6690657205d177d9a477cf03ff 100644 --- a/tilelang/utils/sparse.py +++ b/tilelang/utils/sparse.py @@ -15,7 +15,7 @@ os.makedirs(_CACHE_DIR, exist_ok=True) def _get_cached_lib(): - name = 'compress_lib' + name = "compress_lib" if os.path.exists(os.path.join(_CACHE_DIR, f"{name}.so")): try: @@ -32,24 +32,22 @@ def _get_cached_lib(): name=name, sources=[compress_util], extra_cuda_cflags=[ - '-O2', - '-std=c++17', - '-lineinfo', - f'-I{env.CUTLASS_INCLUDE_DIR}', - f'-I{env.CUTLASS_INCLUDE_DIR}/../tools/util/include', - '-arch=sm_90', + "-O2", + "-std=c++17", + "-lineinfo", + f"-I{env.CUTLASS_INCLUDE_DIR}", + f"-I{env.CUTLASS_INCLUDE_DIR}/../tools/util/include", + "-arch=sm_90", ], build_directory=_CACHE_DIR, ) -def compress_sm90(A: torch.Tensor, block_k: int, - transposed: bool) -> tuple[torch.Tensor, torch.Tensor]: +def compress_sm90(A: torch.Tensor, block_k: int, transposed: bool) -> tuple[torch.Tensor, torch.Tensor]: if block_k > 128: block_k = 128 # Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146 - warnings.warn( - f"block_k {block_k} is too large, set to 128 for sm90 compression.", stacklevel=2) + warnings.warn(f"block_k {block_k} is too large, set to 128 for sm90 compression.", stacklevel=2) # Load the library (will use cache if available) compress_lib = _get_cached_lib() @@ -60,8 +58,9 @@ def compress_sm80(A: torch.Tensor, transposed: bool) -> tuple[torch.Tensor, torc try: from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor except ImportError as err: - raise ImportError("SparseSemiStructuredTensor is not available in this version of PyTorch. " - "Please install a compatible version.") from err + raise ImportError( + "SparseSemiStructuredTensor is not available in this version of PyTorch. Please install a compatible version." + ) from err orig_val = SparseSemiStructuredTensor._FORCE_CUTLASS try: SparseSemiStructuredTensor._FORCE_CUTLASS = True @@ -73,10 +72,7 @@ def compress_sm80(A: torch.Tensor, transposed: bool) -> tuple[torch.Tensor, torc SparseSemiStructuredTensor._FORCE_CUTLASS = orig_val -def compress(A: torch.Tensor, - transposed: bool, - arch: str | None = None, - **kwargs) -> tuple[torch.Tensor, torch.Tensor]: +def compress(A: torch.Tensor, transposed: bool, arch: str | None = None, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: """ Compress a tensor using the appropriate method based on the CUDA architecture. """ @@ -101,11 +97,10 @@ def compress(A: torch.Tensor, A_sp = A_sp.t().contiguous() return A_sp, E else: - raise ValueError(f"Unsupported CUDA compute version: {compute_version}. " - "Supported versions are sm_80 and sm_90.") + raise ValueError(f"Unsupported CUDA compute version: {compute_version}. Supported versions are sm_80 and sm_90.") -def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device='cuda', transposed: bool = False): +def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device="cuda", transposed: bool = False): """ Generate a random semi-sparse tensor. The generated tensor will have 2:4 sparsity along the K dimension. Args: @@ -127,13 +122,7 @@ def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device='cuda', transp return tensor.to(dtype) # dtype like float8 might not have randn kernel -def randint_semi_sparse(M: int, - K: int, - low: int, - high: int, - dtype=torch.int32, - device='cuda', - transposed: bool = False): +def randint_semi_sparse(M: int, K: int, low: int, high: int, dtype=torch.int32, device="cuda", transposed: bool = False): """ Generate a random semi-sparse integer tensor. The generated tensor will have 2:4 sparsity along the K dimension. Args: @@ -157,11 +146,7 @@ def randint_semi_sparse(M: int, return tensor -def arange_semi_sparse(M: int, - K: int, - dtype=torch.float16, - device='cuda', - transposed: bool = False): +def arange_semi_sparse(M: int, K: int, dtype=torch.float16, device="cuda", transposed: bool = False): """ Generate a semi-sparse tensor with values from 0 to M*K-1. The generated tensor will have 2:4 sparsity along the K dimension. Args: diff --git a/tilelang/utils/target.py b/tilelang/utils/target.py index 094c099fe37ca8d68c36478d426c2bd580e7acb9..4ead7efd06762a7fc9602804c5f8978c9e59d7bf 100644 --- a/tilelang/utils/target.py +++ b/tilelang/utils/target.py @@ -56,11 +56,10 @@ def check_metal_availability() -> bool: if not mac_release: return False # todo: check torch version? - return arch == 'arm64' + return arch == "arm64" -def determine_target(target: str | Target | Literal["auto"] = "auto", - return_object: bool = False) -> str | Target: +def determine_target(target: str | Target | Literal["auto"] = "auto", return_object: bool = False) -> str | Target: """ Determine the appropriate target for compilation (CUDA, HIP, or manual selection). diff --git a/tilelang/utils/tensor.py b/tilelang/utils/tensor.py index b2905fb1b42ea882249aca69557d7045549baea4..f1d4fc7304ed49c7b02d554e6e4713fb18cbc50d 100644 --- a/tilelang/utils/tensor.py +++ b/tilelang/utils/tensor.py @@ -1,4 +1,5 @@ """The profiler and convert to torch utils""" + from enum import Enum import torch from tvm import tir @@ -17,7 +18,7 @@ def is_float8_dtype(dtype: torch.dtype) -> bool: def fp8_remove_negative_zeros_(tensor: torch.Tensor): assert is_float8_dtype(tensor.dtype), "Input tensor must be of float8 dtype" bits = tensor.view(torch.uint8) - zeros_mask = (tensor == 0) + zeros_mask = tensor == 0 bits[zeros_mask] = 0x00 @@ -33,26 +34,21 @@ class TensorSupplyType(Enum): def map_torch_type(intype: str) -> torch.dtype: if intype == "float8_e4m3": - assert hasattr(torch, "float8_e4m3fn"), \ - "torch.float8_e4m3fn is not supported in this version of torch" \ - "Please upgrade torch >= 2.1.0" + assert hasattr(torch, "float8_e4m3fn"), "torch.float8_e4m3fn is not supported in this version of torchPlease upgrade torch >= 2.1.0" return torch.float8_e4m3fn elif intype == "float8_e5m2": - assert hasattr(torch, "float8_e5m2"), \ - "torch.float8_e5m2 is not supported in this version of torch" \ - "Please upgrade torch >= 2.1.0" + assert hasattr(torch, "float8_e5m2"), "torch.float8_e5m2 is not supported in this version of torchPlease upgrade torch >= 2.1.0" return torch.float8_e5m2 elif intype == "e4m3fnuz_float8": - assert hasattr(torch, "float8_e4m3fnuz"), \ - "torch.float8_e4m3fnuz is not supported in this version of torch" \ - "Please upgrade torch >= 2.2.0" + assert hasattr(torch, "float8_e4m3fnuz"), ( + "torch.float8_e4m3fnuz is not supported in this version of torchPlease upgrade torch >= 2.2.0" + ) return torch.float8_e4m3fnuz else: return getattr(torch, intype) def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer): - from tilelang.engine.param import KernelParam from .device import get_current_device @@ -63,7 +59,8 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer): if hasattr(param, "shape") and not param.shape: raise ValueError( f"TensorType must have a shape, but got {type(param)}, " - "likely you are trying to generate a random tensor with a dynamic symbolic shape.") + "likely you are trying to generate a random tensor with a dynamic symbolic shape." + ) # Check if with dynamic symbolic shape for shape in param.shape: @@ -81,8 +78,7 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer): if is_unsigned: return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype) elif is_float8: - return torch.randint( - low=-128, high=128, size=shape, device=device, dtype=torch.int8).to(dtype) + return torch.randint(low=-128, high=128, size=shape, device=device, dtype=torch.int8).to(dtype) elif is_boolean: return torch.randint(low=0, high=2, size=shape, device=device, dtype=dtype) elif dtype in {torch.float16, torch.float32, torch.bfloat16}: @@ -91,8 +87,8 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer): return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype) if dtype == torch.int8 and supply_type in [ - TensorSupplyType.Uniform, - TensorSupplyType.Normal, + TensorSupplyType.Uniform, + TensorSupplyType.Normal, ]: return torch.ones(*shape, device=device, dtype=dtype) @@ -103,18 +99,15 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer): if is_unsigned: return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype) elif is_float8: - return torch.randint( - low=-128, high=128, size=shape, device=device, dtype=torch.int8).to(dtype) + return torch.randint(low=-128, high=128, size=shape, device=device, dtype=torch.int8).to(dtype) elif is_boolean: return torch.randint(low=0, high=2, size=shape, device=device, dtype=dtype) else: return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype) elif supply_type == TensorSupplyType.Uniform: - return torch.empty( - *shape, device=device, dtype=torch.float32).uniform_(-1.0, 1.0).to(dtype) + return torch.empty(*shape, device=device, dtype=torch.float32).uniform_(-1.0, 1.0).to(dtype) elif supply_type == TensorSupplyType.Normal: - return torch.empty( - *shape, device=device, dtype=torch.float32).normal_(-1.0, 1.0).to(dtype) + return torch.empty(*shape, device=device, dtype=torch.float32).normal_(-1.0, 1.0).to(dtype) elif supply_type == TensorSupplyType.Randn: return torch.randn(*shape, device=device).to(dtype) elif supply_type == TensorSupplyType.Zero: @@ -150,9 +143,7 @@ def _compare_attributes( """ def raise_mismatch_error(attribute_name: str, actual_value, expected_value): - raise AssertionError( - f"The values for attribute '{attribute_name}' do not match: {actual_value} != {expected_value}." - ) + raise AssertionError(f"The values for attribute '{attribute_name}' do not match: {actual_value} != {expected_value}.") if actual.shape != expected.shape: raise_mismatch_error("shape", actual.shape, expected.shape) @@ -163,7 +154,7 @@ def _compare_attributes( if actual.layout != expected.layout: if check_layout: raise_mismatch_error("layout", actual.layout, expected.layout) - elif (actual.layout == torch.strided and check_stride and actual.stride() != expected.stride()): + elif actual.layout == torch.strided and check_stride and actual.stride() != expected.stride(): raise_mismatch_error("stride()", actual.stride(), expected.stride()) if check_device and actual.device != expected.device: raise_mismatch_error("device", actual.device, expected.device) @@ -171,8 +162,7 @@ def _compare_attributes( raise_mismatch_error("dtype", actual.dtype, expected.dtype) -def _equalize_attributes(actual: torch.Tensor, - expected: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +def _equalize_attributes(actual: torch.Tensor, expected: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Equalizes some attributes of two tensors for value comparison. If ``actual`` and ``expected`` are ... - ... not on the same :attr:`~torch.Tensor.device`, they are moved CPU memory. @@ -210,7 +200,7 @@ def _equalize_attributes(actual: torch.Tensor, if actual.layout != expected.layout: # These checks are needed, since Tensor.to_dense() fails on tensors that are already strided actual = actual.to_dense() if actual.layout != torch.strided else actual - expected = (expected.to_dense() if expected.layout != torch.strided else expected) + expected = expected.to_dense() if expected.layout != torch.strided else expected return actual, expected @@ -254,12 +244,8 @@ def torch_assert_close( """ _compare_attributes( - tensor_a, - tensor_b, - check_device=check_device, - check_dtype=check_dtype, - check_layout=check_layout, - check_stride=check_stride) + tensor_a, tensor_b, check_device=check_device, check_dtype=check_dtype, check_layout=check_layout, check_stride=check_stride + ) tensor_a, tensor_b = _equalize_attributes(tensor_a, tensor_b) mismatched = ~torch.isclose(tensor_a, tensor_b, rtol=rtol, atol=atol, equal_nan=equal_nan) @@ -276,8 +262,7 @@ def torch_assert_close( # Print debug information about the mismatch if verbose: - print(f"Number of mismatched elements: {num_mismatched} / {total_elements} " - f"(allowed: {max_allowed_mismatched})") + print(f"Number of mismatched elements: {num_mismatched} / {total_elements} (allowed: {max_allowed_mismatched})") # If there are mismatched elements, print the first mismatch if num_mismatched > 0: @@ -289,9 +274,9 @@ def torch_assert_close( b_val = tensor_b.reshape(-1)[flat_idx].item() abs_diff = abs(a_val - b_val) rel_diff = abs_diff / (abs(b_val) + 1e-12) - mismatch_info = (f"\nFirst mismatch at index {idx}: " - f"lhs={a_val:.6f}, rhs={b_val:.6f}, " - f"abs_diff={abs_diff:.6f}, rel_diff={rel_diff:.6f}") + mismatch_info = ( + f"\nFirst mismatch at index {idx}: lhs={a_val:.6f}, rhs={b_val:.6f}, abs_diff={abs_diff:.6f}, rel_diff={rel_diff:.6f}" + ) else: mismatch_info = "" @@ -304,6 +289,7 @@ def torch_assert_close( f"\nGreatest absolute difference: {diff.max().item()}, " f"Greatest relative difference: {(diff / (torch.abs(tensor_b) + 1e-12)).max().item()}" f"\n{base_name}: {tensor_a}" - f"\n{ref_name}: {tensor_b}") + f"\n{ref_name}: {tensor_b}" + ) else: return True diff --git a/version_provider.py b/version_provider.py index 3eb45aac90116c60747390199dd39961ea8d1172..c2ca929ae60d4c800a267aa4c51cbf55a450617c 100644 --- a/version_provider.py +++ b/version_provider.py @@ -8,29 +8,26 @@ from functools import lru_cache ROOT = Path(__file__).parent -base_version = (ROOT / 'VERSION').read_text().strip() +base_version = (ROOT / "VERSION").read_text().strip() # When installing a sdist, # the installed version needs to match the sdist version, # so pip will complain when we install `tilelang-0.1.6.post2+gitxxxx.tar.gz`. # To workaround that, when building sdist, # we do not add version label and use a file to store the git hash instead. -git_pin = ROOT / '.git_commit.txt' +git_pin = ROOT / ".git_commit.txt" def _read_cmake_bool(i: str | None, default=False): if i is None: return default - return i.lower() not in ('0', 'false', 'off', 'no', 'n', '') + return i.lower() not in ("0", "false", "off", "no", "n", "") @lru_cache(maxsize=1) def get_git_commit_id() -> str | None: """Get the current git commit hash by running git in the current file's directory.""" - r = subprocess.run(['git', 'rev-parse', 'HEAD'], - cwd=ROOT, - capture_output=True, - encoding='utf-8') + r = subprocess.run(["git", "rev-parse", "HEAD"], cwd=ROOT, capture_output=True, encoding="utf-8") if r.returncode == 0: _git = r.stdout.strip() git_pin.write_text(_git) @@ -41,51 +38,48 @@ def get_git_commit_id() -> str | None: return None -def dynamic_metadata( - field: str, - settings: dict[str, object] | None = None, -) -> str: - assert field == 'version' +def dynamic_metadata(field: str, settings: dict[str, object] | None = None) -> str: + assert field == "version" version = base_version # generate git version for sdist get_git_commit_id() - if not _read_cmake_bool(os.environ.get('NO_VERSION_LABEL')): + if not _read_cmake_bool(os.environ.get("NO_VERSION_LABEL")): exts = [] backend = None - if _read_cmake_bool(os.environ.get('NO_TOOLCHAIN_VERSION')): + if _read_cmake_bool(os.environ.get("NO_TOOLCHAIN_VERSION")): pass - elif platform.system() == 'Darwin': + elif platform.system() == "Darwin": # only on macosx_11_0_arm64, not necessary # backend = 'metal' pass - elif _read_cmake_bool(os.environ.get('USE_ROCM', '')): - backend = 'rocm' - elif 'USE_CUDA' in os.environ and not _read_cmake_bool(os.environ.get('USE_CUDA')): - backend = 'cpu' + elif _read_cmake_bool(os.environ.get("USE_ROCM", "")): + backend = "rocm" + elif "USE_CUDA" in os.environ and not _read_cmake_bool(os.environ.get("USE_CUDA")): + backend = "cpu" else: # cuda # Read nvcc version from env. # This is not exactly how it should be, # but works for now if building in a nvidia/cuda image. - if cuda_version := os.environ.get('CUDA_VERSION'): - major, minor, *_ = cuda_version.split('.') - backend = f'cu{major}{minor}' + if cuda_version := os.environ.get("CUDA_VERSION"): + major, minor, *_ = cuda_version.split(".") + backend = f"cu{major}{minor}" else: - backend = 'cuda' + backend = "cuda" if backend: exts.append(backend) - if _read_cmake_bool(os.environ.get('NO_GIT_VERSION')): + if _read_cmake_bool(os.environ.get("NO_GIT_VERSION")): pass elif git_hash := get_git_commit_id(): - exts.append(f'git{git_hash[:8]}') + exts.append(f"git{git_hash[:8]}") else: - exts.append('gitunknown') + exts.append("gitunknown") if exts: - version += '+' + '.'.join(exts) + version += "+" + ".".join(exts) return version