From eac96cd7a2741bf0fb343d2e857487b1832fc4ec Mon Sep 17 00:00:00 2001 From: Zhengju Tang <97930865+tzj-fxz@users.noreply.github.com> Date: Fri, 14 Nov 2025 20:51:16 +0800 Subject: [PATCH 001/139] [BugFix] Add autotune and exp2 for GDN kernel (#1258) * [BugFix] Add autotune and exp2 for GDN kernel * [Lint] * [Lint] --- examples/gdn/example_chunk_delta_h.py | 54 ++++++++++++++++++--------- 1 file changed, 36 insertions(+), 18 deletions(-) diff --git a/examples/gdn/example_chunk_delta_h.py b/examples/gdn/example_chunk_delta_h.py index 4d6b657f..61c2abd3 100644 --- a/examples/gdn/example_chunk_delta_h.py +++ b/examples/gdn/example_chunk_delta_h.py @@ -3,6 +3,7 @@ import sys # noqa: F401 import tilelang import tilelang.language as T +from tilelang.autotuner import autotune # Add your fla repository path to sys.path # Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae @@ -80,7 +81,25 @@ def prepare_output( return h, final_state, V_new -@tilelang.jit(out_idx=[-3, -2, -1]) +def get_configs(): + import itertools + block_DK = [32, 64, 128] + block_DV = [32, 64, 128] + threads = [128, 256] + num_stages = [1, 2, 3] + _configs = list(itertools.product(block_DK, block_DV, threads, num_stages)) + + configs = [{ + 'block_DK': c[0], + 'block_DV': c[1], + 'threads': c[2], + 'num_stages': c[3] + } for c in _configs] + return configs + + +@autotune(configs=get_configs(), warmup=3, rep=5) +@tilelang.jit(out_idx=[-3, -2, -1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) def tilelang_chunk_gated_delta_rule_fwd_h( # task config B, @@ -94,15 +113,15 @@ def tilelang_chunk_gated_delta_rule_fwd_h( gate_dtype, state_dtype, chunk_size, - use_g=True, - use_initial_state=True, - store_final_state=True, - save_new_value=True, + use_g, + use_initial_state, + store_final_state, + save_new_value, # kernel config block_DK=64, - block_DV=64, - threads=256, - num_stages=0, + block_DV=32, + threads=128, + num_stages=1, ): block_S = chunk_size BS = S // block_S @@ -193,11 +212,11 @@ def tilelang_chunk_gated_delta_rule_fwd_h( for i_s2, i_v in T.Parallel(block_S, block_DV): with T.If(G_last_local[0] - G_fragment[i_s2, i_v] <= 0): with T.Then(): - V_new_fragment[i_s2, i_v] = V_new_fragment[i_s2, i_v] * T.exp( - G_last_local[0] - G_fragment[i_s2, i_v]) + V_new_fragment[i_s2, i_v] = V_new_fragment[i_s2, i_v] * T.exp2( + (G_last_local[0] - G_fragment[i_s2, i_v]) * 1.442695) with T.Else(): V_new_fragment[i_s2, i_v] = 0 - G_last_local[0] = T.exp(G_last_local[0]) + G_last_local[0] = T.exp2(G_last_local[0] * 1.442695) for i_k, i_v in T.Parallel(DK, block_DV): b_h_fragment[i_k, i_v] *= G_last_local[0] @@ -281,8 +300,7 @@ def run_test( kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, gate_dtype, state_dtype, chunk_size, use_g, use_initial_state, store_final_state, - save_new_value, block_DK, block_DV, threads, - num_stages) + save_new_value) h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state) # (zhengju) If you want to print the generated cuda code, you can uncomment the following line # print("CUDA Code:\n", kernel.get_kernel_source()) @@ -352,13 +370,13 @@ def main(): state_dtype="float32", chunk_size=64, use_g=True, - use_initial_state=True, - store_final_state=True, - save_new_value=True, - block_DK=64, + use_initial_state=False, + store_final_state=False, + save_new_value=False, + block_DK=32, block_DV=32, threads=128, - num_stages=1, + num_stages=2, ) -- GitLab From 0af3fd7c70711f1c78da9bc087293826ecba451e Mon Sep 17 00:00:00 2001 From: Tong WU <109033598+Rachmanino@users.noreply.github.com> Date: Sat, 15 Nov 2025 09:36:16 +0800 Subject: [PATCH 002/139] [BugFix] Refactor attention kernel to handle OOB positions by filling with `-inf` instead of clearing accumulators. (#1222) * Refactor attention kernel to handle OOB positions by filling with `-inf` instead of clearing accumulators. * lint * pre-commit * Update imports in flash attention test file to use new backward and forward examples for better clarity and consistency. --- examples/flash_attention/example_gqa_bwd.py | 4 +++- examples/flash_attention/example_gqa_bwd_tma_reduce.py | 4 +++- .../flash_attention/example_gqa_bwd_wgmma_pipelined.py | 4 +++- examples/flash_attention/example_gqa_fwd_bshd.py | 4 +++- .../example_gqa_fwd_bshd_wgmma_pipelined.py | 4 +++- examples/flash_attention/example_mha_bwd_bhsd.py | 6 +++++- .../{example_mha_bwd.py => example_mha_bwd_bshd.py} | 8 ++++++-- ...pelined.py => example_mha_bwd_bshd_wgmma_pipelined.py} | 6 +++++- examples/flash_attention/example_mha_fwd_bhsd.py | 6 ++++-- .../example_mha_fwd_bhsd_wgmma_pipelined.py | 4 +++- examples/flash_attention/example_mha_fwd_bshd.py | 5 ++++- .../example_mha_fwd_bshd_wgmma_pipelined.py | 5 ++++- examples/flash_attention/test_example_flash_attention.py | 8 ++++---- 13 files changed, 50 insertions(+), 18 deletions(-) rename examples/flash_attention/{example_mha_bwd.py => example_mha_bwd_bshd.py} (97%) rename examples/flash_attention/{example_mha_bwd_wgmma_pipelined.py => example_mha_bwd_bshd_wgmma_pipelined.py} (97%) diff --git a/examples/flash_attention/example_gqa_bwd.py b/examples/flash_attention/example_gqa_bwd.py index 907a121d..dd9c8f7c 100644 --- a/examples/flash_attention/example_gqa_bwd.py +++ b/examples/flash_attention/example_gqa_bwd.py @@ -54,7 +54,9 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, + -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) T.copy(scores_max, scores_max_prev) diff --git a/examples/flash_attention/example_gqa_bwd_tma_reduce.py b/examples/flash_attention/example_gqa_bwd_tma_reduce.py index 615c2e19..2af06e4b 100644 --- a/examples/flash_attention/example_gqa_bwd_tma_reduce.py +++ b/examples/flash_attention/example_gqa_bwd_tma_reduce.py @@ -59,7 +59,9 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, T.Cast(accum_dtype, -1e30)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, + -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) T.copy(scores_max, scores_max_prev) diff --git a/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py index ed07e7d9..02421249 100644 --- a/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py +++ b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py @@ -54,7 +54,9 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, + -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) T.copy(scores_max, scores_max_prev) diff --git a/examples/flash_attention/example_gqa_fwd_bshd.py b/examples/flash_attention/example_gqa_fwd_bshd.py index 4d9d06a4..3d4bfe45 100644 --- a/examples/flash_attention/example_gqa_fwd_bshd.py +++ b/examples/flash_attention/example_gqa_fwd_bshd.py @@ -96,7 +96,9 @@ def flashattn(batch, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), + 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro diff --git a/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py index 1c1fc12d..21f5e9a9 100644 --- a/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py @@ -63,7 +63,9 @@ def flashattn( acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), + 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro diff --git a/examples/flash_attention/example_mha_bwd_bhsd.py b/examples/flash_attention/example_mha_bwd_bhsd.py index 1595ae76..8247b265 100644 --- a/examples/flash_attention/example_mha_bwd_bhsd.py +++ b/examples/flash_attention/example_mha_bwd_bhsd.py @@ -56,7 +56,9 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, + -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) T.copy(scores_max, scores_max_prev) @@ -213,6 +215,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + # We don't need to handle OOB positions for non-causal cases, + # since OOB values won't affect other positions here. T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) diff --git a/examples/flash_attention/example_mha_bwd.py b/examples/flash_attention/example_mha_bwd_bshd.py similarity index 97% rename from examples/flash_attention/example_mha_bwd.py rename to examples/flash_attention/example_mha_bwd_bshd.py index 543c2c0e..414061ff 100644 --- a/examples/flash_attention/example_mha_bwd.py +++ b/examples/flash_attention/example_mha_bwd_bshd.py @@ -52,7 +52,9 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, + -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) T.copy(scores_max, scores_max_prev) @@ -206,6 +208,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + # We don't need to handle OOB positions for non-causal cases, + # since OOB values won't affect other positions here. T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -340,7 +344,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=8, help='Batch size') parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') + parser.add_argument('--n_ctx', type=int, default=1048, help='Context size') parser.add_argument('--d_head', type=int, default=64, help='Head dimension') parser.add_argument('--causal', type=bool, default=False, help='Causal flag') args = parser.parse_args() diff --git a/examples/flash_attention/example_mha_bwd_wgmma_pipelined.py b/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py similarity index 97% rename from examples/flash_attention/example_mha_bwd_wgmma_pipelined.py rename to examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py index 7ad417ef..e10ef581 100644 --- a/examples/flash_attention/example_mha_bwd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py @@ -53,7 +53,9 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, + -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) T.copy(scores_max, scores_max_prev) @@ -193,6 +195,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + # We don't need to handle OOB positions for non-causal cases, + # since OOB values won't affect other positions here. T.wait_wgmma(0) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) diff --git a/examples/flash_attention/example_mha_fwd_bhsd.py b/examples/flash_attention/example_mha_fwd_bhsd.py index f07f7a61..e936cee3 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd.py +++ b/examples/flash_attention/example_mha_fwd_bhsd.py @@ -55,7 +55,9 @@ def flashattn(batch, k_idx = k * block_N + j acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + # We shall fill -inf for OOB positions + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_kv, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro @@ -226,7 +228,7 @@ if __name__ == "__main__": parser.add_argument('--seq_q', type=int, default=256, help='query sequence length') parser.add_argument('--seq_kv', type=int, default=256, help='key/value sequence length') parser.add_argument('--dim', type=int, default=64, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') + parser.add_argument('--is_causal', action='store_true', help='causal', default=False) parser.add_argument('--tune', action='store_true', help='tune configs') args = parser.parse_args() main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune) diff --git a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py index 26167b34..e1d0130a 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py @@ -55,7 +55,9 @@ def flashattn(batch, k_idx = k * block_N + j acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + # We shall fill -inf for OOB positions + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_kv, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro diff --git a/examples/flash_attention/example_mha_fwd_bshd.py b/examples/flash_attention/example_mha_fwd_bshd.py index 6a1f707e..a9268019 100644 --- a/examples/flash_attention/example_mha_fwd_bshd.py +++ b/examples/flash_attention/example_mha_fwd_bshd.py @@ -49,7 +49,10 @@ def flashattn(batch, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + # We shall fill -inf for OOB positions + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), + 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro diff --git a/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py index 3928db4c..d7023a20 100644 --- a/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py @@ -49,7 +49,10 @@ def flashattn(batch, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + # We shall fill -inf for OOB positions + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), + 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro diff --git a/examples/flash_attention/test_example_flash_attention.py b/examples/flash_attention/test_example_flash_attention.py index f4932aee..b184fc60 100644 --- a/examples/flash_attention/test_example_flash_attention.py +++ b/examples/flash_attention/test_example_flash_attention.py @@ -2,7 +2,7 @@ import tilelang.testing import example_gqa_bwd import example_gqa_bwd_wgmma_pipelined -import example_mha_bwd +import example_mha_bwd_bshd import example_mha_bwd_bhsd import example_mha_fwd_bhsd_wgmma_pipelined import example_gqa_fwd_bshd @@ -10,7 +10,7 @@ import example_mha_fwd_bshd import example_gqa_fwd_bshd_wgmma_pipelined import example_mha_fwd_bshd_wgmma_pipelined import example_mha_fwd_varlen -import example_mha_bwd_wgmma_pipelined +import example_mha_bwd_bshd_wgmma_pipelined import example_mha_fwd_bhsd import example_gqa_bwd_tma_reduce_varlen @@ -33,7 +33,7 @@ def test_example_gqa_bwd_wgmma_pipelined(): @tilelang.testing.requires_cuda def test_example_mha_bwd(): - example_mha_bwd.main( + example_mha_bwd_bshd.main( BATCH=1, H=16, N_CTX=512, @@ -56,7 +56,7 @@ def test_example_mha_bwd_bhsd(): @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_mha_bwd_wgmma_pipelined(): - example_mha_bwd_wgmma_pipelined.main(BATCH=1, H=32, N_CTX=256, D_HEAD=64, causal=False) + example_mha_bwd_bshd_wgmma_pipelined.main(BATCH=1, H=32, N_CTX=256, D_HEAD=64, causal=False) @tilelang.testing.requires_cuda -- GitLab From eb41574431608e2a96d3d8941f9c1e6d775f228e Mon Sep 17 00:00:00 2001 From: Gabriel Wu <13583761+lucifer1004@users.noreply.github.com> Date: Sat, 15 Nov 2025 11:43:03 +0800 Subject: [PATCH 003/139] [fix] NVRTC execution backend (#1256) * [fix] NVRTC execution backend * [fmt] run pre-commit * [fix] coderabbit reviews * [test] add cuda-python to test dep * [fix] coderabbit reviews * [fix] CUDA 13 compatibility * [fix] sm90 * [fix] CUDA 13 compatibility * [fix] pre-commit * [fix] always use cuda::std::__atomic_ref_impl * [fix] restore to external API * Revert "[fix] restore to external API" This reverts commit 49bd875638fb631d270015f408991d38fd1e9a5d. * [fmt] use space instead tabs for py codegen * [fix] im2col API * [fix] revert atomic.h * [fix] dynamic shape * [refactor] extract common utils * [feat] support L2 persistent map * [fix] l2 persistent map * [fix] pre-commit * [fix] restore _TYPE_MAP * [fix] pre-commit * [fix] avoid duplicate TMA descs * [docs] add docstring * [fix] coderabbit * [fix] coderabbit * [fix] coderabbit * [fix] coderabbit --- requirements-test-cuda.txt | 1 + src/tl_templates/cuda/instruction/mma.h | 2 + src/tl_templates/cuda/instruction/mma_sm70.h | 2 + src/tl_templates/cuda/instruction/wgmma.h | 2 + src/tl_templates/cuda/nvrtc_std.h | 53 ++ src/tl_templates/cuda/reduce.h | 3 + testing/python/jit/test_tilelang_jit_nvrtc.py | 585 ++++++++++++++++++ tilelang/jit/adapter/libgen.py | 102 --- tilelang/jit/adapter/nvrtc/__init__.py | 25 +- tilelang/jit/adapter/nvrtc/adapter.py | 7 +- tilelang/jit/adapter/nvrtc/libgen.py | 235 +++++++ tilelang/jit/adapter/nvrtc/wrapper.py | 563 +++++++++++++++++ tilelang/jit/adapter/utils.py | 251 +++++++- tilelang/jit/adapter/wrapper.py | 432 +------------ tilelang/jit/kernel.py | 4 +- tilelang/language/annotations.py | 3 +- 16 files changed, 1747 insertions(+), 523 deletions(-) create mode 100644 testing/python/jit/test_tilelang_jit_nvrtc.py create mode 100644 tilelang/jit/adapter/nvrtc/libgen.py create mode 100644 tilelang/jit/adapter/nvrtc/wrapper.py diff --git a/requirements-test-cuda.txt b/requirements-test-cuda.txt index 5413ad51..12232023 100644 --- a/requirements-test-cuda.txt +++ b/requirements-test-cuda.txt @@ -6,3 +6,4 @@ # CUDA specific requirements flash-attn==2.5.8 +cuda-python==12.9.4 diff --git a/src/tl_templates/cuda/instruction/mma.h b/src/tl_templates/cuda/instruction/mma.h index ed561285..869fa777 100644 --- a/src/tl_templates/cuda/instruction/mma.h +++ b/src/tl_templates/cuda/instruction/mma.h @@ -4,8 +4,10 @@ #include #include +#ifndef __CUDACC_RTC__ #include #include +#endif namespace tl { diff --git a/src/tl_templates/cuda/instruction/mma_sm70.h b/src/tl_templates/cuda/instruction/mma_sm70.h index 65674175..7a44b921 100644 --- a/src/tl_templates/cuda/instruction/mma_sm70.h +++ b/src/tl_templates/cuda/instruction/mma_sm70.h @@ -2,8 +2,10 @@ #include "../common.h" +#ifndef __CUDACC_RTC__ #include #include +#endif namespace tl { diff --git a/src/tl_templates/cuda/instruction/wgmma.h b/src/tl_templates/cuda/instruction/wgmma.h index b5ef59c2..3af2d79f 100644 --- a/src/tl_templates/cuda/instruction/wgmma.h +++ b/src/tl_templates/cuda/instruction/wgmma.h @@ -4,8 +4,10 @@ #include #include +#ifndef __CUDACC_RTC__ #include #include +#endif namespace tl { diff --git a/src/tl_templates/cuda/nvrtc_std.h b/src/tl_templates/cuda/nvrtc_std.h index 9930c220..1e6800e5 100644 --- a/src/tl_templates/cuda/nvrtc_std.h +++ b/src/tl_templates/cuda/nvrtc_std.h @@ -19,6 +19,11 @@ #ifdef __CUDACC_RTC__ +// Disable problematic CUDA standard library headers in NVRTC environment +// Vector types (float4, uchar, etc.) are built-in to NVRTC and don't need these +// headers +#define _LIBCUDACXX___TUPLE_VECTOR_TYPES_H // Prevent vector_types.h inclusion + using int8_t = signed char; using uint8_t = unsigned char; using int16_t = signed short; @@ -67,6 +72,24 @@ template struct is_same : true_type {}; template inline constexpr bool is_same_v = is_same::value; +template struct is_void : false_type {}; + +template <> struct is_void : true_type {}; +template <> struct is_void : true_type {}; +template <> struct is_void : true_type {}; +template <> struct is_void : true_type {}; + +template inline constexpr bool is_void_v = is_void::value; + +template struct is_pointer : false_type {}; + +template struct is_pointer : true_type {}; +template struct is_pointer : true_type {}; +template struct is_pointer : true_type {}; +template struct is_pointer : true_type {}; + +template inline constexpr bool is_pointer_v = is_pointer::value; + namespace index_sequence_impl { // Based on https://stackoverflow.com/a/32223343/11717224 @@ -118,6 +141,36 @@ template struct enable_if {}; template struct enable_if { using type = T; }; + +template struct remove_extent { + using type = T; +}; + +template struct remove_extent { + using type = T; +}; + +template struct remove_extent { + using type = T; +}; + +template using remove_extent_t = typename remove_extent::type; + +template +struct extent : integral_constant {}; + +template struct extent : integral_constant {}; + +template struct extent : extent {}; + +template +struct extent : integral_constant {}; + +template +struct extent : extent {}; + +template +inline constexpr size_t extent_v = extent::value; } // namespace std #endif \ No newline at end of file diff --git a/src/tl_templates/cuda/reduce.h b/src/tl_templates/cuda/reduce.h index 0009b9b9..a083c711 100644 --- a/src/tl_templates/cuda/reduce.h +++ b/src/tl_templates/cuda/reduce.h @@ -1,8 +1,11 @@ #pragma once #include "common.h" + +#ifndef __CUDACC_RTC__ #include #include +#endif namespace tl { diff --git a/testing/python/jit/test_tilelang_jit_nvrtc.py b/testing/python/jit/test_tilelang_jit_nvrtc.py new file mode 100644 index 00000000..c7076861 --- /dev/null +++ b/testing/python/jit/test_tilelang_jit_nvrtc.py @@ -0,0 +1,585 @@ +from tilelang import tvm as tvm +import tilelang.language as T +import tilelang.testing +import tilelang +import torch +from tilelang.utils.tensor import map_torch_type + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + stramp = "&*(XS)" + + @tvm.register_global_func("tilelang_callback_cuda_postproc", override=True) + def tilelang_callback_cuda_postproc(code, _): + code = f"// {stramp}\n" + code + return code + + matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="nvrtc") + + kernel_source = matmul_kernel.get_kernel_source() + + assert stramp in kernel_source, f"Expected {stramp} in the kernel source" + + +def test_gemm_f16f16f16_nn(): + run_gemm( + 512, + 1024, + 768, + False, + False, + "float16", + "float16", + "float16", + 128, + 256, + 32, + 2, + ) + + +def matmu_jit_kernel( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_jit_kernel( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmu_jit_kernel( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="nvrtc") + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + + A = torch.randn(M, K, dtype=in_dtype).cuda() + B = torch.randn(K, N, dtype=in_dtype).cuda() + + if trans_A: + A = A.T + if trans_B: + B = B.T + + def ref_program(A, B): + import torch + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(out_dtype) + return C + + ref_C = ref_program(A, B) + C = matmul_kernel(A, B) + + tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + + +def test_gemm_jit_kernel(): + run_gemm_jit_kernel( + 512, + 1024, + 768, + False, + False, + "float16", + "float16", + "float16", + 128, + 256, + 32, + 2, + ) + + +def run_nvrtc_kernel_do_bench(M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, execution_backend="nvrtc") + + profiler = matmul_kernel.get_profiler() + + nvrtc_latency = profiler.do_bench(func=matmul_kernel) + print(f"NVRTC Latency: {nvrtc_latency} ms") + + assert nvrtc_latency is not None + + tvm_latency = profiler.do_bench() + print(f"TVM Latency: {tvm_latency} ms") + + assert tvm_latency is not None + + +def test_nvrtc_kernel_do_bench(): + run_nvrtc_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, + 256, 32, 2) + + +def run_nvrtc_kernel_multi_stream(M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, execution_backend="nvrtc") + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() + tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() + + if trans_A: + tensor_a = tensor_a.T + if trans_B: + tensor_b = tensor_b.T + tensor_c = torch.randn(M, N, dtype=out_dtype).cuda() + + num_streams = 4 + for _ in range(num_streams): + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + matmul_kernel(tensor_a, tensor_b, tensor_c) + + +def test_nvrtc_kernel_multi_stream(): + run_nvrtc_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", + 128, 256, 32, 2) + + +def run_nvrtc_dynamic_shape(M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, execution_backend="nvrtc") + if isinstance(M, T.Var): + M = 1024 + if isinstance(N, T.Var): + N = 1024 + if isinstance(K, T.Var): + K = 768 + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + + tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() + tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() + + if trans_A: + tensor_a = tensor_a.T + if trans_B: + tensor_b = tensor_b.T + tensor_c = torch.randn(M, N, dtype=out_dtype).cuda() + + matmul_kernel(tensor_a, tensor_b, tensor_c) + + tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype) + tilelang.testing.torch_assert_close( + tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + + +def test_nvrtc_dynamic_shape(): + run_nvrtc_dynamic_shape( + T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + run_nvrtc_dynamic_shape( + T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, + 256, 32, 2) + + run_nvrtc_dynamic_shape( + T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", + "float16", 128, 256, 32, 2) + + +def check_hopper(): + if not torch.cuda.is_available(): + return False + props = torch.cuda.get_device_properties(0) + compute_capability = props.major, props.minor + return compute_capability == (9, 0) + + +def convolution_im2col(N, + C, + H, + W, + F, + K, + S, + D, + P, + block_M, + block_N, + block_K, + num_stages, + threads, + dtype="float16", + accum_dtype="float"): + KH, KW = K, K + OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 + OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 + + @T.prim_func + def main( + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), + ): + with T.Kernel( + T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), + threads=threads) as (bx, by): + data_shared = T.alloc_shared((block_M, block_K), dtype) + kernel_shared = T.alloc_shared((block_K, block_N), dtype) + out_local = T.alloc_fragment((block_M, block_N), accum_dtype) + out_shared = T.alloc_shared((block_M, block_N), dtype) + + kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) + out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) + + T.annotate_layout({ + out_shared: tilelang.layout.make_swizzled_layout(out_shared), + data_shared: tilelang.layout.make_swizzled_layout(data_shared), + kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), + }) + + T.clear(out_local) + for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): + T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P) + T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) + T.gemm(data_shared, kernel_shared, out_local) + + T.copy(out_local, out_shared) + T.copy(out_shared, out_flat[by * block_M, bx * block_N]) + + return main + + +def run_nvrtc_im2col_tma_desc(N, + C, + H, + W, + F, + K, + S, + D, + P, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=256): + """Test im2col TMA descriptor functionality in NVRTC backend.""" + program = convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, + num_threads) + + conv_kernel = tilelang.compile(program, out_idx=-1, execution_backend="nvrtc") + + a = torch.randn(N, H, W, C).cuda().half() + b = torch.randn(K, K, C, F).cuda().half() + + out_c = conv_kernel(a, b) + + # Reference implementation using torch.conv2d + def ref_program(A, B): + A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W + B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W + C = torch.conv2d(A, B, stride=S, padding=P, dilation=D) + C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C + return C + + ref_c = ref_program(a, b) + tilelang.testing.torch_assert_close( + out_c, ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + + +def test_nvrtc_im2col_tma_desc(): + """Test im2col TMA descriptor with NVRTC backend.""" + if not check_hopper(): + import pytest + pytest.skip("Test requires Hopper GPU (compute capability 9.0)") + + # Small test case for im2col TMA descriptor + run_nvrtc_im2col_tma_desc( + N=4, + C=64, + H=32, + W=32, + F=64, + K=3, + S=1, + D=1, + P=1, + block_M=64, + block_N=128, + block_K=32, + num_stages=3, + num_threads=256) + + +def test_nvrtc_l2_persistent_map(): + """Test L2 persistent cache annotation with elementwise add.""" + from tilelang.language import annotate_l2_hit_ratio + + M = 1024 + N = 1024 + + @tilelang.jit(out_idx=[-1], execution_backend="nvrtc") + def elementwise_add_with_l2_cache( + M, + N, + block_size=256, + dtype="float32", + ): + + @T.prim_func + def kernel( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(M * N // block_size, threads=block_size) as bx: + # Annotate L2 persistent cache for buffer B + # B will be accessed multiple times and benefit from L2 caching + annotate_l2_hit_ratio({B: 0.8}) + + for i in T.serial(block_size): + idx = bx * block_size + i + if idx < M * N: + row = idx // N + col = idx % N + C[row, col] = A[row, col] + B[row, col] + + return kernel + + # Compile the kernel + kernel = elementwise_add_with_l2_cache(M, N) + + # Create test tensors + a = torch.randn(M, N, dtype=torch.float32).cuda() + b = torch.randn(M, N, dtype=torch.float32).cuda() + + # Run kernel with out_idx=[-1], C is returned not passed in + c = kernel(a, b) + + # Verify correctness + ref_c = a + b + tilelang.testing.torch_assert_close(c, ref_c, atol=1e-5, rtol=1e-5) + + print("L2 persistent map test passed!") + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/jit/adapter/libgen.py b/tilelang/jit/adapter/libgen.py index 1e33ec04..208370b0 100644 --- a/tilelang/jit/adapter/libgen.py +++ b/tilelang/jit/adapter/libgen.py @@ -1,9 +1,7 @@ from __future__ import annotations import ctypes -import importlib import logging import os -import os.path as osp import subprocess import tempfile from typing import Any @@ -21,14 +19,6 @@ from .utils import is_cpu_target, is_cuda_target, is_hip_target logger = logging.getLogger(__name__) -try: - from tilelang.jit.adapter.nvrtc import is_nvrtc_available - if is_nvrtc_available: - import cuda.bindings.driver as cuda - from tilelang.contrib.nvrtc import compile_cuda -except ImportError: - is_nvrtc_available = False - class LibraryGenerator: srcpath: str | None = None @@ -183,95 +173,3 @@ class LibraryGenerator: def set_src_path(self, srcpath): self.srcpath = srcpath - - -class PyLibraryGenerator(LibraryGenerator): - host_func: str | None = None - culib = None - pymodule = None - - def __init__(self, target: Target, verbose: bool = False): - if not is_nvrtc_available: - raise ImportError("cuda-python is not available, nvrtc backend cannot be used. " - "Please install cuda-python via `pip install cuda-python` " - "if you want to use the nvrtc backend.") - super().__init__(target, verbose) - - @staticmethod - def import_from_file(module_name, file_path): - spec = importlib.util.spec_from_file_location(module_name, file_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module - - def update_host_func(self, host_func: str): - self.host_func = host_func - - def load_lib(self, lib_path: str | None = None): - if lib_path is None: - lib_path = self.libpath - - pypath = lib_path.replace(".cubin", ".py") - self.pymodule = self.import_from_file("kernel", pypath) - - # Ensure the context is valid - ctx = cuda.cuCtxGetCurrent()[1] - if cuda.cuCtxGetApiVersion(ctx)[0] != cuda.CUresult.CUDA_SUCCESS: - import torch - torch.cuda.synchronize() - - result, self.culib = cuda.cuLibraryLoadFromFile( - bytes(lib_path, "utf-8"), [], [], 0, [], [], 0) - assert result == cuda.CUresult.CUDA_SUCCESS, f"Failed to load library: {lib_path}" - - def compile_lib(self, timeout: float = None): - target = self.target - verbose = self.verbose - if is_cuda_target(target): - from tilelang.env import (CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH) - src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) # noqa: SIM115 - libpath = src.name.replace(".cu", ".cubin") - - project_root = osp.join(osp.dirname(__file__), "..", "..") - if CUTLASS_INCLUDE_DIR is None: - cutlass_path = osp.abspath(osp.join(project_root, "3rdparty/cutlass/include")) - else: - cutlass_path = CUTLASS_INCLUDE_DIR - - if TILELANG_TEMPLATE_PATH is None: - tl_template_path = osp.abspath(osp.join(project_root, "src")) - else: - tl_template_path = TILELANG_TEMPLATE_PATH - - cuda_home = CUDA_HOME if CUDA_HOME else "/usr/local/cuda" - - options = [f"-I{tl_template_path}", f"-I{cutlass_path}", f"-I{cuda_home}/include"] - if self.compile_flags: - options += [ - item for flag in self.compile_flags for item in flag.split() - if item not in options - ] - - cubin_bytes = compile_cuda( - self.lib_code, target_format="cubin", options=options, verbose=verbose) - with open(libpath, "wb") as f: - f.write(cubin_bytes) - - src.write(self.lib_code) - src.flush() - - self.srcpath = src.name - self.libpath = libpath - - pypath = src.name.replace(".cu", ".py") - with open(pypath, "w") as f: - f.write(self.host_func) - else: - raise ValueError(f"Unsupported target: {target}") - - def __del__(self): - if self.culib: - result = cuda.cuLibraryUnload(self.culib)[0] - if result != cuda.CUresult.CUDA_SUCCESS: - logger.warning(f"Failed to unload library: {self.libpath}") - self.culib = None diff --git a/tilelang/jit/adapter/nvrtc/__init__.py b/tilelang/jit/adapter/nvrtc/__init__.py index c9068faf..faa08c19 100644 --- a/tilelang/jit/adapter/nvrtc/__init__.py +++ b/tilelang/jit/adapter/nvrtc/__init__.py @@ -5,7 +5,10 @@ This module provides runtime compilation support using NVIDIA's NVRTC API. import logging -__all__ = ['NVRTCKernelAdapter', 'is_nvrtc_available', 'check_nvrtc_available'] +__all__ = [ + 'NVRTCKernelAdapter', 'TLNVRTCSourceWrapper', 'NVRTCLibraryGenerator', 'is_nvrtc_available', + 'check_nvrtc_available' +] logger = logging.getLogger(__name__) @@ -37,7 +40,9 @@ def check_nvrtc_available(): # Conditionally import the adapter if is_nvrtc_available: - from .adapter import NVRTCKernelAdapter # noqa: F401 + from .adapter import NVRTCKernelAdapter + from .wrapper import TLNVRTCSourceWrapper + from .libgen import NVRTCLibraryGenerator else: # Provide a dummy class that raises error on instantiation class NVRTCKernelAdapter: @@ -45,3 +50,19 @@ else: def __init__(self, *args, **kwargs): raise ImportError(NVRTC_UNAVAILABLE_MESSAGE) + + @classmethod + def from_database(cls, *args, **kwargs): + raise ImportError(NVRTC_UNAVAILABLE_MESSAGE) + + class TLNVRTCSourceWrapper: + """Dummy TLNVRTCSourceWrapper that raises ImportError on instantiation.""" + + def __init__(self, *args, **kwargs): + raise ImportError(NVRTC_UNAVAILABLE_MESSAGE) + + class NVRTCLibraryGenerator: + """Dummy NVRTCLibraryGenerator that raises ImportError on instantiation.""" + + def __init__(self, *args, **kwargs): + raise ImportError(NVRTC_UNAVAILABLE_MESSAGE) diff --git a/tilelang/jit/adapter/nvrtc/adapter.py b/tilelang/jit/adapter/nvrtc/adapter.py index d6723a03..5f8a2827 100644 --- a/tilelang/jit/adapter/nvrtc/adapter.py +++ b/tilelang/jit/adapter/nvrtc/adapter.py @@ -9,12 +9,13 @@ from tvm.target import Target from tilelang import tvm as tvm from tilelang.engine.param import KernelParam from tilelang.jit.adapter.wrapper import TLPyWrapper -from tilelang.jit.adapter.libgen import PyLibraryGenerator from tilelang.utils.language import retrieve_func_from_module from tilelang.utils.target import determine_target from tilelang.jit.adapter.base import BaseKernelAdapter from tilelang.jit.adapter.nvrtc import is_nvrtc_available, check_nvrtc_available +from .libgen import NVRTCLibraryGenerator + logger = logging.getLogger(__name__) # Import cuda bindings if available @@ -75,7 +76,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter): self.wrapper.assign_device_module(device_mod) self.host_func, self.function_names = self.wrapper.wrap(kernel_global_source) - self.lib_generator = PyLibraryGenerator(self.target, self.verbose) + self.lib_generator = NVRTCLibraryGenerator(self.target, self.verbose) self.lib_generator.update_lib_code(self.kernel_global_source) self.lib_generator.update_host_func(self.host_func) self.lib_generator.assign_compile_flags(compile_flags) @@ -130,7 +131,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter): adapter.target = Target.canon_target(determine_target(target)) adapter.verbose = verbose - adapter.lib_generator = PyLibraryGenerator(adapter.target, adapter.verbose) + adapter.lib_generator = NVRTCLibraryGenerator(adapter.target, adapter.verbose) adapter.lib_generator.assign_compile_flags(compile_flags) adapter.lib_generator.load_lib(lib_path=kernel_lib_path) adapter.pymodule = adapter.lib_generator.pymodule diff --git a/tilelang/jit/adapter/nvrtc/libgen.py b/tilelang/jit/adapter/nvrtc/libgen.py new file mode 100644 index 00000000..50a587a5 --- /dev/null +++ b/tilelang/jit/adapter/nvrtc/libgen.py @@ -0,0 +1,235 @@ +"""NVRTC Library Generator for TileLang. + +Compiles CUDA kernels at runtime using NVRTC and manages resulting binaries. + +Why NVRTC instead of nvcc: +- No offline compilation step, enables true JIT workflows +- Works without CUDA toolkit installed (only requires driver) +- Allows kernel specialization based on runtime parameters + +Key responsibilities: +- Compile CUDA source to cubin using NVRTC API +- Generate accompanying Python launcher code +- Load compiled cubin and extract kernel handles +- Manage library lifecycle (load/unload) +""" +from __future__ import annotations +import importlib +import logging +import os.path as osp +import platform +import tempfile +from types import ModuleType + +from tvm.target import Target + +from tilelang import tvm as tvm +from tilelang.jit.adapter.libgen import LibraryGenerator +from tilelang.jit.adapter.utils import is_cuda_target +from tilelang.jit.adapter.nvrtc import is_nvrtc_available, NVRTC_UNAVAILABLE_MESSAGE + +logger = logging.getLogger(__name__) + +if is_nvrtc_available: + import cuda.bindings.driver as cuda + from tilelang.contrib.nvrtc import compile_cuda +else: + raise ImportError(NVRTC_UNAVAILABLE_MESSAGE) + + +class NVRTCLibraryGenerator(LibraryGenerator): + """Runtime compiler and loader for NVRTC-compiled CUDA kernels. + + Lifecycle: + 1. compile_lib(): CUDA source → cubin + Python launcher + 2. load_lib(): cubin → loaded library + kernel handles + 3. pymodule.call(): Execute kernels via Python launcher + 4. __del__: Cleanup (unload library) + + Why three files (cu, cubin, py): + - .cu: Source for debugging, kept in temp directory + - .cubin: Compiled binary, loaded by CUDA driver + - .py: Launch code, imported as Python module + + Attributes: + host_func: Generated Python launch code (from wrapper) + culib: CUDA library handle (CUlibrary) + pymodule: Imported Python module containing call() function + """ + host_func: str | None = None + culib: cuda.CUlibrary | None = None + pymodule: ModuleType | None = None + pypath: str | None = None + + def __init__(self, target: Target, verbose: bool = False): + """Initialize NVRTC library generator. + + Args: + target: Compilation target (must be CUDA) + verbose: Enable verbose compilation output + """ + super().__init__(target, verbose) + + @staticmethod + def import_from_file(module_name, file_path): + """Dynamically import Python module from file path. + + Standard importlib pattern for loading modules outside sys.path. + Used to import generated .py launcher code from temp directory. + + Args: + module_name: Name to assign to imported module + file_path: Absolute path to .py file + + Returns: + Imported module object + """ + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None or spec.loader is None: + raise ImportError(f"Failed to import module from file: {file_path}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + def update_host_func(self, host_func: str): + """Store generated Python launch code for later file write. + + Called by adapter after wrapper generates the launch code. + This is the bridge between code generation and file output. + + Args: + host_func: Python source code containing call() function + """ + self.host_func = host_func + + def load_lib(self, lib_path: str | None = None): + """Load compiled cubin and Python launcher into memory. + + Why two loads: + 1. Import Python module for launch logic + 2. Load cubin via CUDA Driver API for kernel handles + + Context synchronization: CUDA context must be current before loading. + If not, use torch.cuda.synchronize() to establish context. + + Args: + lib_path: Path to .cubin file (optional, uses self.libpath if None) + + Side effects: + - Sets self.pymodule to imported Python module + - Sets self.culib to CUDA library handle + """ + if lib_path is None: + lib_path = self.libpath + else: + self.libpath = lib_path + + self.pypath = lib_path.replace(".cubin", ".py") + self.pymodule = self.import_from_file("kernel", self.pypath) + + # Ensure the context is valid + ctx = cuda.cuCtxGetCurrent()[1] + if cuda.cuCtxGetApiVersion(ctx)[0] != cuda.CUresult.CUDA_SUCCESS: + import torch + torch.cuda.synchronize() + + result, self.culib = cuda.cuLibraryLoadFromFile( + bytes(lib_path, "utf-8"), [], [], 0, [], [], 0) + if result != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"Failed to load library: {lib_path}, error: {result}") + + def compile_lib(self, timeout: float | None = None): + """Compile CUDA source to cubin using NVRTC and write output files. + + Output artifacts (all in temp directory): + - .cu: Source code (for debugging) + - .cubin: Compiled binary (for execution) + - .py: Python launcher (for calling kernels) + + Include paths setup: + - TileLang templates: kernel primitives and utilities + - CUTLASS: optimized GEMM/tensor ops + - CUDA headers: driver/runtime APIs + + Why architecture detection: + ARM64 servers (SBSA) have different header paths than x86_64. + + Args: + timeout: Compilation timeout in seconds (currently unsupported by NVRTC compiler) + + Side effects: + - Writes .cu, .cubin, .py files to temp directory + - Sets self.srcpath, self.libpath, self.pypath + """ + target = self.target + verbose = self.verbose + if is_cuda_target(target): + from tilelang.env import (CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH) + src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) + libpath = src.name.replace(".cu", ".cubin") + + project_root = osp.join(osp.dirname(__file__), "..", "..") + if CUTLASS_INCLUDE_DIR is None: + cutlass_path = osp.abspath(osp.join(project_root, "3rdparty/cutlass/include")) + else: + cutlass_path = CUTLASS_INCLUDE_DIR + + if TILELANG_TEMPLATE_PATH is None: + tl_template_path = osp.abspath(osp.join(project_root, "src")) + else: + tl_template_path = TILELANG_TEMPLATE_PATH + + cuda_home = CUDA_HOME if CUDA_HOME else "/usr/local/cuda" + __CUDACC_VER_MAJOR__ = cuda.CUDA_VERSION // 1000 + + # Determine target architecture + machine = platform.machine() + target_arch = "sbsa-linux" if machine in ("aarch64", "arm64") else "x86_64-linux" + + options = [ + f"-I{tl_template_path}", + f"-I{cutlass_path}", + f"-I{cuda_home}/include", + f"-I{cuda_home}/targets/{target_arch}/include", + f"-I{cuda_home}/targets/{target_arch}/include/cccl", + f"-D__CUDACC_VER_MAJOR__={__CUDACC_VER_MAJOR__}", + ] + if self.compile_flags: + options += [ + item for flag in self.compile_flags for item in flag.split() + if item not in options + ] + + cubin_bytes = compile_cuda( + self.lib_code, target_format="cubin", options=options, verbose=verbose) + with open(libpath, "wb") as f: + f.write(cubin_bytes) + + src.write(self.lib_code) + src.flush() + + self.srcpath = src.name + self.libpath = libpath + self.pypath = src.name.replace(".cu", ".py") + if self.host_func is None: + raise RuntimeError( + "Host function is not set, please call update_host_func() first.") + with open(self.pypath, "w") as f: + f.write(self.host_func) + else: + raise ValueError(f"Unsupported target: {target}") + + def __del__(self): + """Cleanup: unload CUDA library when object is destroyed. + + Critical for resource management - CUDA libraries consume GPU memory. + Failure to unload is logged but not raised (destructor can't fail). + + Why explicit unload: + Python GC doesn't know about GPU resources, must release manually. + """ + if self.culib: + result = cuda.cuLibraryUnload(self.culib)[0] + if result != cuda.CUresult.CUDA_SUCCESS: + logger.warning(f"Failed to unload library: {self.libpath}") + self.culib = None diff --git a/tilelang/jit/adapter/nvrtc/wrapper.py b/tilelang/jit/adapter/nvrtc/wrapper.py new file mode 100644 index 00000000..1a29adef --- /dev/null +++ b/tilelang/jit/adapter/nvrtc/wrapper.py @@ -0,0 +1,563 @@ +"""NVRTC Source Wrapper for TileLang. + +Generates Python runtime code for launching CUDA kernels compiled via NVRTC. + +Why this exists: +- NVRTC compiles kernels at runtime, needs Python launch code (not C++) +- TMA descriptors must be initialized once per unique buffer, not per kernel +- L2 cache policies require explicit CUDA Driver API setup/teardown + +Key design: +- Two-pass generation: collect all descriptors first, then generate launches +- Dict-based deduplication ensures TMA descriptors created only once +- Generates pure Python using cuda.bindings.driver for zero C++ dependency +""" +from __future__ import annotations +from typing import Any, ClassVar + +from tvm import IRModule +from tvm.target import Target +from tvm.tir.stmt_functor import post_order_visit + +from tilelang import tvm as tvm +from tilelang.jit.adapter.wrapper import TLCUDASourceWrapper +from tilelang.jit.adapter.utils import (match_declare_kernel, pythonic_expr, + parse_function_call_args, parse_tma_descriptor_args) + +PREDEF_HOST_FUNC_PY = """ +from cuda.bindings.driver import ( + CUtensorMapDataType, + CUtensorMapInterleave, + CUtensorMapSwizzle, + CUtensorMapL2promotion, + CUtensorMapFloatOOBfill, + cuTensorMapEncodeTiled, + cuTensorMapEncodeIm2col, + CUresult, + cuKernelSetAttribute, + CUfunction_attribute, + CUdevice, + CUlaunchConfig, + cuLaunchKernelEx, + cuuint64_t, + cuuint32_t, + CUkernel, +) +import ctypes + +_function_names = {} + +def call({}): + {} +""" + +TMA_DESC_INIT_FUNC_PY = """ + {0}_type = CUtensorMapDataType({1}) + {0}_tensorRank = {2} + {0}_globalAddress = {3}.data_ptr() + {0}_globalDim = [{4}] + {0}_globalStride = [{5}][1:] + {0}_boxDim = [{6}] + {0}_elementStrides = [{7}] + {0}_interleave = CUtensorMapInterleave({8}) + {0}_swizzle = CUtensorMapSwizzle({9}) + {0}_l2Promotion = CUtensorMapL2promotion({10}) + {0}_oobFill = CUtensorMapFloatOOBfill({11}) + + res, {0} = cuTensorMapEncodeTiled( + {0}_type, + {0}_tensorRank, + {0}_globalAddress, + {0}_globalDim, + {0}_globalStride, + {0}_boxDim, + {0}_elementStrides, + {0}_interleave, + {0}_swizzle, + {0}_l2Promotion, + {0}_oobFill, + ) + + if res != CUresult.CUDA_SUCCESS: + raise RuntimeError(f"Failed to initialize the TMA descriptor {0}: {{res}}") +""" + +TMA_IM2COL_DESC_INIT_FUNC_PY = """ + {0}_type = CUtensorMapDataType({1}) + {0}_tensorRank = {2} + {0}_globalAddress = {3}.data_ptr() + {0}_globalDim = [{4}] + {0}_globalStride = [{5}][1:] + {0}_elementStrides = [{6}] + {0}_lowerCorner = [{7}] + {0}_upperCorner = [{8}] + {0}_channelsPerPixel = {9} + {0}_pixelsPerColumn = {10} + {0}_interleave = CUtensorMapInterleave({11}) + {0}_swizzle = CUtensorMapSwizzle({12}) + {0}_l2Promotion = CUtensorMapL2promotion({13}) + {0}_oobFill = CUtensorMapFloatOOBfill({14}) + + res, {0} = cuTensorMapEncodeIm2col( + {0}_type, + {0}_tensorRank, + {0}_globalAddress, + {0}_globalDim, + {0}_globalStride, + {0}_lowerCorner, + {0}_upperCorner, + {0}_channelsPerPixel, + {0}_pixelsPerColumn, + {0}_elementStrides, + {0}_interleave, + {0}_swizzle, + {0}_l2Promotion, + {0}_oobFill, + ) + + if res != CUresult.CUDA_SUCCESS: + raise RuntimeError(f"Failed to initialize the TMA descriptor {0}: {{res}}") +""" + +L2_PERSISTENT_MAP_CREATE_HANDLE_PY = """ + from cuda.bindings.driver import ( + CUstreamAttrValue, + CUstreamAttrID, + CUlimit, + CUaccessProperty, + cuCtxGetLimit, + cuCtxSetLimit, + cuStreamSetAttribute, + cuCtxResetPersistingL2Cache, + ) + + stream_attribute = CUstreamAttrValue() + res, init_persisting_l2_cache_size = cuCtxGetLimit(CUlimit.CU_LIMIT_PERSISTING_L2_CACHE_SIZE) + if res != CUresult.CUDA_SUCCESS: + raise RuntimeError(f"Failed to get L2 cache size limit: {{res}}") +""" + +L2_PERSISTENT_MAP_INIT_FUNC_PY = """ + stream_attribute.accessPolicyWindow.hitRatio = {1} + stream_attribute.accessPolicyWindow.hitProp = CUaccessProperty.CU_ACCESS_PROPERTY_PERSISTING + stream_attribute.accessPolicyWindow.missProp = CUaccessProperty.CU_ACCESS_PROPERTY_STREAMING + + res = cuCtxSetLimit(CUlimit.CU_LIMIT_PERSISTING_L2_CACHE_SIZE, {2})[0] + if res != CUresult.CUDA_SUCCESS: + raise RuntimeError(f"Failed to set L2 cache size limit: {{res}}") + + stream_attribute.accessPolicyWindow.base_ptr = {0}.data_ptr() + stream_attribute.accessPolicyWindow.num_bytes = {2} + + res = cuStreamSetAttribute(stream, CUstreamAttrID.CU_LAUNCH_ATTRIBUTE_ACCESS_POLICY_WINDOW, stream_attribute)[0] + if res != CUresult.CUDA_SUCCESS: + raise RuntimeError(f"Failed to set stream L2 access policy: {{res}}") +""" + +L2_PERSISTENT_MAP_RESET_HANDLE_PY = """ + stream_attribute.accessPolicyWindow.num_bytes = 0 + res = cuStreamSetAttribute(stream, CUstreamAttrID.CU_LAUNCH_ATTRIBUTE_ACCESS_POLICY_WINDOW, stream_attribute)[0] + if res != CUresult.CUDA_SUCCESS: + raise RuntimeError(f"Failed to reset stream L2 access policy: {{res}}") + + res = cuCtxResetPersistingL2Cache()[0] + if res != CUresult.CUDA_SUCCESS: + raise RuntimeError(f"Failed to reset L2 cache: {{res}}") + + res = cuCtxSetLimit(CUlimit.CU_LIMIT_PERSISTING_L2_CACHE_SIZE, init_persisting_l2_cache_size)[0] + if res != CUresult.CUDA_SUCCESS: + raise RuntimeError(f"Failed to restore L2 cache size limit: {{res}}") +""" + +KERNEL_LAUNCH_FUNC_PY = """ + res = cuKernelSetAttribute( + CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + {7}, + kernels["{0}"], + CUdevice({10}) + )[0] + if res != CUresult.CUDA_SUCCESS: + raise RuntimeError(f"Failed to set max dynamic shared memory size to {7} for kernel {0}: {{res}}") + + config = CUlaunchConfig() + config.gridDimX = {1} + config.gridDimY = {2} + config.gridDimZ = {3} + config.blockDimX = {4} + config.blockDimY = {5} + config.blockDimZ = {6} + config.sharedMemBytes = {7} + config.hStream = stream + + arg_values = {8} + arg_types = {9} + + res = cuLaunchKernelEx(config, kernels["{0}"], (arg_values, arg_types), 0)[0] + if res != CUresult.CUDA_SUCCESS: + raise RuntimeError(f"Failed to launch kernel {0}: {{res}}") +""" + + +class TLNVRTCSourceWrapper(TLCUDASourceWrapper): + """NVRTC backend wrapper: generates Python kernel launch code. + + Core responsibility: transform TVM IRModule into executable Python function + that initializes resources (TMA descriptors, L2 cache) and launches kernels + via CUDA Driver API. + + Data flow: + IRModule → collect kernel metadata → deduplicate resources → + generate Python code → executable function + + Why Python generation instead of C++: + NVRTC workflow requires runtime compilation, Python is the natural host. + Using cuda.bindings.driver eliminates C++ wrapper complexity. + """ + + _TYPE_MAP: ClassVar[dict[str, str]] = { + "float32": "ctypes.c_float", + "float16": "ctypes.c_uint16", + "bfloat16": "ctypes.c_uint16", + "float8_e4m3": "ctypes.c_uint8", + "float8_e4m3fn": "ctypes.c_uint8", + "float8_e5m2": "ctypes.c_uint8", + "float64": "ctypes.c_double", + "int64": "ctypes.c_int64", + "int32": "ctypes.c_int32", + "uint32": "ctypes.c_uint32", + "bool": "ctypes.c_bool", + "int8": "ctypes.c_int8", + "uint8": "ctypes.c_uint8", + "int16": "ctypes.c_int16", + "uint16": "ctypes.c_uint16", + "uchar": "ctypes.c_uint8", + } + + _generated_host_func: str | None = None + + def __init__(self, + scheduled_ir_module: IRModule, + source: str, + target: Target, + device_mod: IRModule | None = None, + host_mod: IRModule | None = None, + pass_configs: dict[str, Any] | None = None): + """Initialize NVRTC wrapper with compiled IR modules. + + Args: + scheduled_ir_module: TVM IR after scheduling passes + source: Generated CUDA C++ source code + target: Compilation target (should be NVRTC-compatible) + device_mod: Device-side IR module (kernel functions) + host_mod: Host-side IR module (launch logic) + pass_configs: Optional compiler pass configurations + """ + super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs) + + @property + def host_func(self): + """Override parent's host_func to return generated Python code.""" + if self._generated_host_func is not None: + return self._generated_host_func + return super().host_func + + @host_func.setter + def host_func(self, value): + """Allow setting generated host function code.""" + self._generated_host_func = value + + def _pythonic_expr(self, expr: tvm.tir.PrimExpr) -> str: + """Convert TVM expression to Python string, ignoring casts. + + Casts are noise in generated Python code - Python is dynamically typed. + """ + return pythonic_expr(expr, self._TYPE_MAP, ignore_cast=True) + + def create_dispatch_func(self, code, function_informations): + """Generate Python dispatch function that launches multiple CUDA kernels. + + Why two-pass design: + Pass 1: Collect TMA descriptors from all kernels into shared dicts + Pass 2: Generate code - descriptors first (deduplicated), then launches + + Single-pass would create duplicate descriptors for each kernel. + Dict naturally deduplicates by descriptor name. + + Args: + code: CUDA C++ source containing kernel declarations + function_informations: Dict mapping kernel names to metadata + (grid/block dims, params, shared memory size) + + Returns: + Python source code defining a call() function that: + 1. Initializes L2 cache policies (if needed) + 2. Creates TMA descriptors once per unique buffer + 3. Launches each kernel with cuLaunchKernelEx + 4. Resets L2 cache policies (if needed) + """ + # Extract the set of dynamic symbolic names used in the primary function + dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func) + + function_args = [{"name": "kernels", "type": "dict[str, CUkernel]"}] + # Collect function arguments based on primary function's parameters and buffer mappings + for param in self.prim_func.params: + if param in self.prim_func.buffer_map: + buffer = self.prim_func.buffer_map[param] + function_args.append({ + "name": buffer.data.name, + "type": "ctypes.c_void_p", + }) + elif isinstance(param, tvm.tir.Var): + function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)}) + else: + raise ValueError( + f"Parameter {param} is not in the buffer map of the primary function.") + # Add dynamic symbols as integer arguments + for dyn_sym in dynamic_symbolic_set: + if dyn_sym not in [arg["name"] for arg in function_args]: + function_args.append({"name": dyn_sym, "type": "ctypes.c_int"}) + + function_args.append(self.get_stream_type()) + + # Format the function arguments for declaration + def_args = ", ".join([f"{arg['name']}" for arg in function_args]) + + # Check if any function needs L2 Persistent Map + has_l2_persistent_map = False + for function_name, _ in function_informations.items(): + if function_name in self.l2_persistent_map: + has_l2_persistent_map = True + break + + desc_name_map: dict[str, str] = {} + desc_name_var_map: dict[str, tvm.tir.Var] = {} + device_index = 0 + kernel_launch_code = """""" + if has_l2_persistent_map: + kernel_launch_code += L2_PERSISTENT_MAP_CREATE_HANDLE_PY + + # First pass: collect all TMA descriptors from all kernels to avoid duplication + kernel_info_list = [] + for function_name, function_info in function_informations.items(): + block_info = function_info["block_info"] + grid_info = function_info["grid_info"] + dynamic_smem_buf = function_info["dynamic_smem_buf"] + function_params = function_info["function_params"] + + # Find the location of the global kernel function in the code + index = match_declare_kernel(code, function_name + "(") + + # Analyze the function declaration to prepare for argument extraction + declaration = code[index:].split(";")[0] + + # Identify the start of the function body to insert arguments + index = code.index("{", index) + + # Transform function for NVRTC: returns (arg_value, arg_type) tuples + def transform_nvrtc_arg(name: str, arg_type: str): + if arg_type == "ctypes.c_void_p": + return (f"{name}.data_ptr()", arg_type) + return (name, arg_type) + + call_args = parse_function_call_args(declaration, function_args, function_params, + desc_name_map, desc_name_var_map, + transform_nvrtc_arg) + + for arg_name, arg_type in call_args: + if arg_type == "ctypes.c_void_p": + device_index = f"{arg_name.replace('.data_ptr()', '')}.device.index" + break + + # Store kernel info for second pass + kernel_info_list.append({ + 'function_name': function_name, + 'block_info': block_info, + 'grid_info': grid_info, + 'dynamic_smem_buf': dynamic_smem_buf, + 'call_args': call_args, + 'device_index': device_index, + }) + + # Generate TMA descriptor initialization code once for all kernels + kernel_launch_code += self.generate_tma_descriptor_args(desc_name_map, desc_name_var_map) + + # Second pass: generate kernel launch code for each kernel + for kernel_info in kernel_info_list: + function_name = kernel_info['function_name'] + block_info = kernel_info['block_info'] + grid_info = kernel_info['grid_info'] + dynamic_smem_buf = kernel_info['dynamic_smem_buf'] + call_args = kernel_info['call_args'] + device_index = kernel_info['device_index'] + + arg_names = ", ".join([arg[0] for arg in call_args]) + arg_types = ", ".join([arg[1] for arg in call_args]) + smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf + + # Generate L2 persistent map initialization for this function + init_l2_persistent_map = self.generate_l2_persistent_map(function_name) + kernel_launch_code += init_l2_persistent_map + + # Generate kernel launch code + kernel_launch_code += KERNEL_LAUNCH_FUNC_PY.format(function_name, + self._pythonic_expr(grid_info[0]), + self._pythonic_expr(grid_info[1]), + self._pythonic_expr(grid_info[2]), + self._pythonic_expr(block_info[0]), + self._pythonic_expr(block_info[1]), + self._pythonic_expr(block_info[2]), + smem_str, arg_names, arg_types, + device_index) + + # Reset L2 persistent map after all kernel execution + if has_l2_persistent_map: + kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE_PY + + # Wrap the kernel dispatch logic in an external C function + host_func = PREDEF_HOST_FUNC_PY.format( + repr(list(function_informations.keys())), def_args, kernel_launch_code) + return host_func + + def generate_l2_persistent_map(self, function_name: str) -> str: + """Generate Python code to configure L2 cache persistence for a kernel. + + L2 persistence pins frequently-accessed data in L2 cache to reduce + memory bandwidth. Requires explicit setup via CUDA stream attributes. + + Args: + function_name: Kernel name to check for L2 persistence config + + Returns: + Python code that sets stream access policy window, or empty + string if no L2 persistence configured for this kernel. + """ + if function_name not in self.l2_persistent_map: + return "" + init_l2_persistent_map = "" + for buffer_name, (hit_ratio, + size_in_bytes) in self.l2_persistent_map[function_name].items(): + # Get persisting_l2_cache_max_size + from tilelang.carver.arch.driver import get_persisting_l2_cache_max_size + persisting_l2_cache_max_size = get_persisting_l2_cache_max_size() + try: + num_bytes = min(size_in_bytes, persisting_l2_cache_max_size) + except TypeError: + # as size_in_bytes may be a symbolic expression + num_bytes = persisting_l2_cache_max_size + init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC_PY.format( + buffer_name, float(hit_ratio), self._pythonic_expr(num_bytes)) + + return init_l2_persistent_map + + def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], + desc_name_var_map: dict[str, tvm.tir.Var]) -> str: + """Generate Python code to initialize TMA descriptors. + + TMA (Tensor Memory Accelerator) descriptors are opaque CUDA objects + that describe memory layout for async copies. Must be created on host + before kernel launch. + + Args: + desc_name_map: Maps descriptor variable names to buffer names + desc_name_var_map: Maps descriptor names to TVM variables + + Returns: + Python code that calls cuTensorMapEncodeTiled/Im2col for each + unique descriptor. Empty string if no TMA descriptors needed. + """ + tma_descriptor_init = "" + if self.tma_descriptor_args is None: + return tma_descriptor_init + + # Parse TMA descriptor arguments using the common utility + parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map, + desc_name_var_map, self._pythonic_expr) + + # Generate Python code from parsed parameters + for params in parsed_params: + if not params.is_img2col: + tma_descriptor_init += TMA_DESC_INIT_FUNC_PY.format( + params.handle_name, params.dtype, params.tensor_rank, params.global_address, + ", ".join(map(lambda x: f"cuuint64_t({x})", params.global_dim)), + ", ".join(map(lambda x: f"cuuint64_t({x})", params.global_stride)), + ", ".join(map(lambda x: f"cuuint32_t({x})", params.box_dim)), + ", ".join(map(lambda x: f"cuuint32_t({x})", params.element_strides)), + params.interleave, params.swizzle, params.l2_promotion, params.oob_fill) + else: + tma_descriptor_init += TMA_IM2COL_DESC_INIT_FUNC_PY.format( + params.handle_name, params.dtype, params.tensor_rank, params.global_address, + ", ".join(map(lambda x: f"cuuint64_t({x})", params.global_dim)), + ", ".join(map(lambda x: f"cuuint64_t({x})", params.global_stride)), + ", ".join(map(lambda x: f"cuuint32_t({x})", + params.element_strides)), ", ".join(params.lower_corner), + ", ".join(params.upper_corner), params.smem_box_channel, params.smem_box_pixel, + params.interleave, params.swizzle, params.l2_promotion, params.oob_fill) + + return tma_descriptor_init + + def update_lib_code(self, code: str): + """Update library code and generate host dispatch function. + + Entry point for code generation. Walks the host IR to extract kernel + call sites, matches them with device kernels, then generates Python + dispatch code via create_dispatch_func(). + + Args: + code: CUDA C++ source code containing compiled kernels + + Returns: + The same code string (stored in self.lib_code). Side effect: + sets self.host_func to generated Python dispatcher. + """ + # Update the library code with the given code string + self.lib_code = code + + # Organize function information for code generation + function_informations = {} + for function_name in self.function_names: + # Do not update function with dispatch host function + if (function_name not in self.block_info) or (function_name not in self.grid_info): + continue + + assert function_name in self.device_mod, f"Function {function_name} not found in device module" + device_func = self.device_mod[function_name] + kernel_params_cnt = len(device_func.params) + function_params: list[str] | None = None + + def visitor(node, fn=function_name, param_cnt=kernel_params_cnt): + nonlocal function_params + if isinstance(node, tvm.tir.Call): + if not (hasattr(node, "op") and + node.op == tvm.ir.Op.get("tir.tvm_call_packed")): + return + args = node.args + if not args or args[0] != fn: + return + if len(args) < 1 + param_cnt: + raise AssertionError( + "tvm_call_packed should have at least 1 argument and match device function parameters" + ) + function_params = args[1:1 + param_cnt] + + post_order_visit(self.host_func.body, visitor) + assert function_params is not None, "function_params should not be None" + + function_informations[function_name] = { + "function_name": function_name, + "block_info": self.block_info[function_name], + "grid_info": self.grid_info[function_name], + "dynamic_smem_buf": self.dynamic_smem_buf[function_name], + "function_params": function_params, + } + + # Create the host function wrapper for the CUDA kernel + self.host_func = self.create_dispatch_func(code, function_informations) + return self.lib_code + + def get_stream_type(self) -> dict[str, str]: + """Return stream parameter spec for Python signature. + + NVRTC backend uses raw int for stream handle (not cudaStream_t pointer). + Default to 0 (NULL stream) for convenience. + """ + return {"name": "stream=0", "type": "int"} diff --git a/tilelang/jit/adapter/utils.py b/tilelang/jit/adapter/utils.py index efc965e1..94e590d3 100644 --- a/tilelang/jit/adapter/utils.py +++ b/tilelang/jit/adapter/utils.py @@ -1,7 +1,7 @@ from __future__ import annotations import re -from typing import Literal +from typing import Literal, Callable, Any from tilelang import tvm as tvm from tvm import IRModule, tir from tvm.target import Target @@ -107,13 +107,16 @@ def get_annotated_mod( return dispatch[model_type](mod) -def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = None) -> str: +def pythonic_expr(expr: tvm.tir.PrimExpr, + dtype_map: dict[str, str] | None = None, + ignore_cast: bool = False) -> str: """ Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence. Args: expr: The TVM PrimExpr to convert. - + dtype_map: A dictionary mapping data types to their string representations. + ignore_cast: Whether to ignore the cast operator and return the string representation of the value without the cast. Returns: A string representation of the expression. """ @@ -158,10 +161,11 @@ def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = Non elif isinstance(node, tvm.tir.Cast): # C-style cast has high precedence value_str, _ = node_to_result_map[node.value] - if dtype_map is None: - s = f"({node.dtype}){value_str}" + if ignore_cast: + s = value_str else: - s = f"({dtype_map[node.dtype]}){value_str}" + type_str = node.dtype if dtype_map is None else dtype_map[node.dtype] + s = f"({type_str}){value_str}" p = PRECEDENCE.get(type(node), ATOMIC_PRECEDENCE) elif isinstance( node, @@ -216,3 +220,238 @@ def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = Non tvm.tir.stmt_functor.post_order_visit(expr, _visitor) return next(iter(node_to_result_map[expr]), "") + + +def maybe_desc_name(name: str, + matches: list[str], + i: int, + desc_name_map: dict[str, str] | None = None) -> bool: + """ + Check if a parameter name corresponds to a TMA descriptor. + + Args: + name: The parameter name to check. + matches: List of all matched parameter names. + i: Index of the current match. + desc_name_map: Optional mapping to store descriptor name relationships. + + Returns: + True if the parameter is a TMA descriptor. + """ + match = matches[i] + if not (match == name + "_desc" or match.startswith(name + "_desc_")): + return False + desc_decls = [] + if desc_name_map is not None: + desc_name_map[match] = name + if i > 0: + desc_decls.append(matches[i - 1]) + if i < len(matches) - 1: + desc_decls.append(matches[i + 1]) + return any([decl == "CUtensorMap" for decl in desc_decls]) + + +def parse_function_call_args( + declaration: str, + function_args: list[dict[str, str]], + function_params: list[Any], + desc_name_map: dict[str, str] | None = None, + desc_name_var_map: dict[str, tvm.tir.Var] | None = None, + transform_arg: Callable[[str, str], Any] | None = None, +) -> list[Any]: + """ + Parse function call arguments from a kernel declaration. + + Args: + declaration: The kernel function declaration string. + function_args: List of function argument specifications. + function_params: List of function parameters from TVM IR. + desc_name_map: Optional mapping for descriptor names. + desc_name_var_map: Optional mapping from descriptor names to TVM variables. + transform_arg: Optional function to transform each argument (name, type) -> result. + + Returns: + List of parsed call arguments. + """ + pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)" + matches = re.findall(pattern, declaration) + call_args = [] + + for i, match in enumerate(matches): + for arg in function_args: + if arg["name"] == match: + if transform_arg is not None: + call_args.append(transform_arg(match, arg["type"])) + else: + call_args.append(match) + elif maybe_desc_name(arg["name"], matches, i, desc_name_map): + if transform_arg is not None: + call_args.append(transform_arg(match, "None")) + else: + call_args.append(match) + if desc_name_var_map is not None and function_params is not None: + assert len(call_args) <= len(function_params), \ + f"Too many arguments: {len(call_args)} > {len(function_params)}" + desc_name_var_map[match] = function_params[len(call_args) - 1] + + return call_args + + +class TMADescriptorParams: + """Parsed TMA descriptor parameters.""" + + def __init__(self, + handle_name: str, + dtype: str, + tensor_rank: int, + global_address: Any, + is_img2col: bool = False): + self.handle_name = handle_name + self.dtype = dtype + self.tensor_rank = tensor_rank + self.global_address = global_address + self.is_img2col = is_img2col + + # Common fields + self.global_dim: list[str] = [] + self.global_stride: list[str] = [] + self.element_strides: list[str] = [] + self.interleave: str = "" + self.swizzle: str = "" + self.l2_promotion: str = "" + self.oob_fill: str = "" + + # Tiled-specific fields + self.box_dim: list[str] = [] + + # Im2col-specific fields + self.lower_corner: list[str] = [] + self.upper_corner: list[str] = [] + self.smem_box_channel: str = "" + self.smem_box_pixel: str = "" + + +def parse_tma_descriptor_args( + tma_descriptor_args: dict[tvm.tir.Var, list[Any]], + desc_name_map: dict[str, str], + desc_name_var_map: dict[str, tvm.tir.Var], + pythonic_expr_func: Callable[[Any], str], +) -> list[TMADescriptorParams]: + """ + Parse TMA descriptor arguments into structured parameters. + + Args: + tma_descriptor_args: Dictionary mapping TMA descriptor variables to their arguments. + desc_name_map: Mapping from descriptor handles to parameter names. + desc_name_var_map: Mapping from descriptor handles to TVM variables. + pythonic_expr_func: Function to convert TVM expressions to strings. + + Returns: + List of parsed TMA descriptor parameters. + """ + if not tma_descriptor_args: + return [] + + results = [] + + for handle_name, _ in desc_name_map.items(): + assert handle_name in desc_name_var_map, \ + f"Handle name {handle_name} not found in desc_name_var_map" + desc_var = desc_name_var_map[handle_name] + + assert desc_var in tma_descriptor_args, \ + f"TMA descriptor {desc_var} not found in {tma_descriptor_args}" + args = tma_descriptor_args[desc_var] + + # Skip __tvm_tensormap_create_tiled and second element (like CUDA version) + if len(args) < 3: + raise ValueError( + f"TMA descriptor args too short: {len(args)} elements, expected at least 3") + + tma_create_str, _, dtype, tensor_rank, global_address, *remaining_args = args + + is_img2col = (tma_create_str.value == "__tvm_tensormap_create_im2col") + + # Convert basic fields + dtype = pythonic_expr_func(dtype) + tensor_rank = int(pythonic_expr_func(tensor_rank)) + + # Validate tensor_rank + if not isinstance(tensor_rank, int) or tensor_rank <= 0: + raise ValueError(f"Invalid tensor_rank: {tensor_rank}. Must be a positive integer") + + params = TMADescriptorParams(handle_name, dtype, tensor_rank, global_address, is_img2col) + + if not is_img2col: + # Tiled mode + expected_args_len = 4 * tensor_rank + 4 + if len(remaining_args) < expected_args_len: + raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, " + f"expected {expected_args_len} for tensor_rank {tensor_rank}") + + # Extract dimensions and strides + params.global_dim = [pythonic_expr_func(i) for i in remaining_args[:tensor_rank]] + params.global_stride = [ + pythonic_expr_func(i) for i in remaining_args[tensor_rank:2 * tensor_rank] + ] + params.box_dim = [ + pythonic_expr_func(i) for i in remaining_args[2 * tensor_rank:3 * tensor_rank] + ] + params.element_strides = [ + pythonic_expr_func(i) for i in remaining_args[3 * tensor_rank:4 * tensor_rank] + ] + + # Extract remaining parameters + try: + interleave, swizzle, l2_promotion, oob_fill = remaining_args[4 * tensor_rank:4 * + tensor_rank + 4] + params.interleave = pythonic_expr_func(interleave) + params.swizzle = pythonic_expr_func(swizzle) + params.l2_promotion = pythonic_expr_func(l2_promotion) + params.oob_fill = pythonic_expr_func(oob_fill) + except ValueError as e: + raise ValueError( + "Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)" + ) from e + else: + # Im2col mode + expected_args_len = 5 * tensor_rank + 2 + if len(remaining_args) < expected_args_len: + raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, " + f"expected {expected_args_len} for tensor_rank {tensor_rank}") + + # Extract dimensions and strides + params.global_dim = [pythonic_expr_func(i) for i in remaining_args[:tensor_rank]] + params.global_stride = [ + pythonic_expr_func(i) for i in remaining_args[tensor_rank:2 * tensor_rank] + ] + params.element_strides = [ + pythonic_expr_func(i) for i in remaining_args[2 * tensor_rank:3 * tensor_rank] + ] + params.lower_corner = [ + pythonic_expr_func(i) for i in remaining_args[3 * tensor_rank:4 * tensor_rank - 2] + ] + params.upper_corner = [ + pythonic_expr_func(i) + for i in remaining_args[4 * tensor_rank - 2:5 * tensor_rank - 4] + ] + + # Extract remaining parameters + try: + smem_box_pixel, smem_box_channel, interleave, swizzle, l2_promotion, oob_fill = \ + remaining_args[5 * tensor_rank - 4:5 * tensor_rank + 2] + params.smem_box_pixel = pythonic_expr_func(smem_box_pixel) + params.smem_box_channel = pythonic_expr_func(smem_box_channel) + params.interleave = pythonic_expr_func(interleave) + params.swizzle = pythonic_expr_func(swizzle) + params.l2_promotion = pythonic_expr_func(l2_promotion) + params.oob_fill = pythonic_expr_func(oob_fill) + except ValueError as e: + raise ValueError( + "Failed to unpack the final 6 TMA parameters " + "(smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill)" + ) from e + + results.append(params) + + return results diff --git a/tilelang/jit/adapter/wrapper.py b/tilelang/jit/adapter/wrapper.py index cdd0d5c7..7819890d 100644 --- a/tilelang/jit/adapter/wrapper.py +++ b/tilelang/jit/adapter/wrapper.py @@ -5,7 +5,8 @@ from typing import Any from tvm import IRModule from tvm.target import Target from .utils import (is_metal_target, match_declare_kernel, match_declare_kernel_cpu, is_cuda_target, - is_hip_target, is_cpu_target, get_annotated_mod, pythonic_expr) + is_hip_target, is_cpu_target, get_annotated_mod, pythonic_expr, + parse_function_call_args, parse_tma_descriptor_args) import re import logging import textwrap @@ -49,16 +50,6 @@ extern "C" int call({}) {{ }} """ -PREDEF_HOST_FUNC_PY = """ -import cuda.bindings.driver -import ctypes - -_function_names = {} - -def call({}): - {} -""" - L2_PERSISTENT_MAP_CREATE_HANDLE = """ \tcudaStreamAttrValue stream_attribute; \tsize_t init_persisting_l2_cache_size; @@ -136,65 +127,6 @@ TMA_IM2COL_DESC_INIT_FUNC = """ \t}} """ -TMA_DESC_INIT_FUNC_PY = """ -\t{0}_type = cuda.bindings.driver.CUtensorMapDataType({1}) -\t{0}_tensorRank = {2} -\t{0}_globalAddress = {3}.data_ptr() -\t{0}_globalDim = [{4}] -\t{0}_globalStride = [{5}][1:] -\t{0}_boxDim = [{6}] -\t{0}_elementStrides = [{7}] -\t{0}_interleave = cuda.bindings.driver.CUtensorMapInterleave({8}) -\t{0}_swizzle = cuda.bindings.driver.CUtensorMapSwizzle({9}) -\t{0}_l2Promotion = cuda.bindings.driver.CUtensorMapL2promotion({10}) -\t{0}_oobFill = cuda.bindings.driver.CUtensorMapFloatOOBfill({11}) - -\tres, {0} = cuda.bindings.driver.cuTensorMapEncodeTiled( -\t\t{0}_type, -\t\t{0}_tensorRank, -\t\t{0}_globalAddress, -\t\t{0}_globalDim, -\t\t{0}_globalStride, -\t\t{0}_boxDim, -\t\t{0}_elementStrides, -\t\t{0}_interleave, -\t\t{0}_swizzle, -\t\t{0}_l2Promotion, -\t\t{0}_oobFill, -\t) - -\tif res != cuda.bindings.driver.CUresult.CUDA_SUCCESS: -\t\traise RuntimeError(f"Failed to initialize the TMA descriptor {0}: {{res}}") -""" - -KERNEL_LAUNCH_FUNC_PY = """ -\tres = cuda.bindings.driver.cuKernelSetAttribute( -\t\tcuda.bindings.driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, -\t\t{7}, -\t\tkernels["{0}"], -\t\tcuda.bindings.driver.CUdevice({10}) -\t)[0] -\tif res != cuda.bindings.driver.CUresult.CUDA_SUCCESS: -\t\traise RuntimeError(f"Failed to set max dynamic shared memory size to {7} for kernel {0}: {{res}}") - -\tconfig = cuda.bindings.driver.CUlaunchConfig() -\tconfig.gridDimX = {1} -\tconfig.gridDimY = {2} -\tconfig.gridDimZ = {3} -\tconfig.blockDimX = {4} -\tconfig.blockDimY = {5} -\tconfig.blockDimZ = {6} -\tconfig.sharedMemBytes = {7} -\tconfig.hStream = stream - -\targ_values = {8} -\targ_types = {9} - -\tres = cuda.bindings.driver.cuLaunchKernelEx(config, kernels["{0}"], (arg_values, arg_types), 0)[0] -\tif res != cuda.bindings.driver.CUresult.CUDA_SUCCESS: -\t\traise RuntimeError(f"Failed to launch kernel {0}: {{res}}") -""" - class BaseWrapper(ABC): @@ -297,41 +229,6 @@ class TLCUDASourceWrapper: # Format the function arguments for declaration def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) - def func_call_args(s, - function_args, - function_params, - desc_name_map: dict[str, str] | None = None, - desc_name_var_map: dict[str, tvm.tir.Var] | None = None): - # Extract the function call arguments matching the function definition - def maybe_desc(name: str, matches: list[str], i: int): - match = matches[i] - if not (match == name + "_desc" or match.startswith(name + "_desc_")): - return False - desc_decls = [] - if desc_name_map is not None: - desc_name_map[match] = name - if i > 0: - desc_decls.append(matches[i - 1]) - if i < len(matches) - 1: - desc_decls.append(matches[i + 1]) - return any([decl == "CUtensorMap" for decl in desc_decls]) - - pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)" - matches = re.findall(pattern, s) - call_args = [] - for i, match in enumerate(matches): - for arg in function_args: - if arg["name"] == match: - call_args.append(match) - elif maybe_desc(arg["name"], matches, i): - call_args.append(match) - assert len(call_args) <= len( - function_params - ), f"Function {function_name} has {len(function_params)} parameters, but {len(call_args)} arguments" - desc_name_var_map[match] = function_params[len(call_args) - 1] - - return call_args - has_l2_persistent_map = False for function_name, _ in function_informations.items(): if function_name in self.l2_persistent_map: @@ -365,8 +262,8 @@ class TLCUDASourceWrapper: kernel_launch_code += init_l2_persistent_map if self.use_cooperative_groups[function_name]: - args_list = func_call_args(declaration, function_args, function_params, - desc_name_map, desc_name_var_map) + args_list = parse_function_call_args(declaration, function_args, function_params, + desc_name_map, desc_name_var_map) assert len(function_params) == len( args_list ), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments" @@ -377,8 +274,8 @@ class TLCUDASourceWrapper: kernel_launch_code += "\tTILELANG_CHECK(cudaLaunchCooperativeKernel((void*){}, {}, {}, {}, {}, stream));\n".format( function_name, grid_str, block_str, function_name + "_args", smem_str) else: - args_list = func_call_args(declaration, function_args, function_params, - desc_name_map, desc_name_var_map) + args_list = parse_function_call_args(declaration, function_args, function_params, + desc_name_map, desc_name_var_map) assert len(function_params) == len( args_list ), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments" @@ -420,101 +317,26 @@ class TLCUDASourceWrapper: tma_descripter_init = "" if self.tma_descriptor_args is None: return tma_descripter_init - for handle_name, _ in desc_name_map.items(): - assert handle_name in desc_name_var_map, f"Handle name {handle_name} not found in desc_name_var_map" - desc_var = desc_name_var_map[handle_name] - - assert desc_var in self.tma_descriptor_args, f"TMA descriptor {desc_var} not found in {self.tma_descriptor_args}" - args = self.tma_descriptor_args[desc_var] - # Skip __tvm_tensormap_create_tiled - if len(args) < 3: - raise ValueError( - f"TMA descriptor args too short: {len(args)} elements, expected at least 3") - - tma_create_str, _, dtype, tensor_rank, globalAddress, *remaining_args = args - - is_img2col = (tma_create_str.value == "__tvm_tensormap_create_im2col") - dtype = self._pythonic_expr(dtype) - tensor_rank = int(self._pythonic_expr(tensor_rank)) - - # Validate tensor_rank - if not isinstance(tensor_rank, int) or tensor_rank <= 0: - raise ValueError(f"Invalid tensor_rank: {tensor_rank}. Must be a positive integer") - - if not is_img2col: - # Calculate required length for remaining_args - expected_args_len = 4 * tensor_rank + 4 # 4 groups of tensor_rank size + 4 parameters - if len(remaining_args) < expected_args_len: - raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, " - f"expected {expected_args_len} for tensor_rank {tensor_rank}") - - # Extract dimensions and strides using list slicing - global_dim = remaining_args[:tensor_rank] - global_stride = remaining_args[tensor_rank:2 * tensor_rank] - box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank] - element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank] - - global_dim = [self._pythonic_expr(i) for i in global_dim] - global_stride = [self._pythonic_expr(i) for i in global_stride] - box_dim = [self._pythonic_expr(i) for i in box_dim] - element_strides = [self._pythonic_expr(i) for i in element_strides] - - # Extract remaining parameters - try: - interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 * - tensor_rank + 4] - interleave = self._pythonic_expr(interleave) - swizzle = self._pythonic_expr(swizzle) - l2Promotion = self._pythonic_expr(l2Promotion) - oobFill = self._pythonic_expr(oobFill) - except ValueError as e: - raise ValueError( - "Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)" - ) from e + # Parse TMA descriptor arguments using the common utility + parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map, + desc_name_var_map, self._pythonic_expr) + + # Generate C++ code from parsed parameters + for params in parsed_params: + if not params.is_img2col: tma_descripter_init += TMA_DESC_INIT_FUNC.format( - handle_name, dtype, tensor_rank, globalAddress, ",".join(global_dim), - ",".join(global_stride), ",".join(box_dim), ",".join(element_strides), - interleave, swizzle, l2Promotion, oobFill) + params.handle_name, params.dtype, params.tensor_rank, params.global_address, + ",".join(params.global_dim), ",".join(params.global_stride), + ",".join(params.box_dim), ",".join(params.element_strides), params.interleave, + params.swizzle, params.l2_promotion, params.oob_fill) else: - # Calculate required length for remaining_args - expected_args_len = 5 * tensor_rank + 2 - if len(remaining_args) < expected_args_len: - raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, " - f"expected {expected_args_len} for tensor_rank {tensor_rank}") - - # Extract dimensions and strides using list slicing - global_dim = remaining_args[:tensor_rank] - global_stride = remaining_args[tensor_rank:2 * tensor_rank] - element_strides = remaining_args[2 * tensor_rank:3 * tensor_rank] - lower_corner = remaining_args[3 * tensor_rank:4 * tensor_rank - 2] - upper_corner = remaining_args[4 * tensor_rank - 2:5 * tensor_rank - 4] - global_dim = [self._pythonic_expr(i) for i in global_dim] - global_stride = [self._pythonic_expr(i) for i in global_stride] - element_strides = [self._pythonic_expr(i) for i in element_strides] - lower_corner = [self._pythonic_expr(i) for i in lower_corner] - upper_corner = [self._pythonic_expr(i) for i in upper_corner] - - # Extract remaining parameters - try: - smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill = remaining_args[ - 5 * tensor_rank - 4:5 * tensor_rank + 2] - smem_box_pixel = self._pythonic_expr(smem_box_pixel) - smem_box_channel = self._pythonic_expr(smem_box_channel) - interleave = self._pythonic_expr(interleave) - swizzle = self._pythonic_expr(swizzle) - l2Promotion = self._pythonic_expr(l2Promotion) - oobFill = self._pythonic_expr(oobFill) - except ValueError as e: - raise ValueError( - "Failed to unpack the final 6 TMA parameters (smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill)" - ) from e - tma_descripter_init += TMA_IM2COL_DESC_INIT_FUNC.format( - handle_name, dtype, tensor_rank, globalAddress, ",".join(global_dim), - ",".join(global_stride), ",".join(element_strides), ",".join(lower_corner), - ",".join(upper_corner), smem_box_channel, smem_box_pixel, interleave, swizzle, - l2Promotion, oobFill) + params.handle_name, params.dtype, params.tensor_rank, params.global_address, + ",".join(params.global_dim), ",".join(params.global_stride), + ",".join(params.element_strides), ",".join(params.lower_corner), + ",".join(params.upper_corner), params.smem_box_channel, params.smem_box_pixel, + params.interleave, params.swizzle, params.l2_promotion, params.oob_fill) return tma_descripter_init @@ -713,213 +535,6 @@ class TLCUDASourceWrapper: raise ValueError("Cannot find primary function in the module.") -class TLNVRTCSourceWrapper(TLCUDASourceWrapper): - """ - A wrapper class for the TileLang NVRTC backend. - """ - - _TYPE_MAP = { - "float32": "ctypes.c_float", - "float16": "ctypes.c_uint16", - "bfloat16": "ctypes.c_uint16", - "float8_e4m3": "ctypes.c_uint8", - "float8_e4m3fn": "ctypes.c_uint8", - "float8_e5m2": "ctypes.c_uint8", - "float64": "ctypes.c_double", - "int64": "ctypes.c_int64", - "int32": "ctypes.c_int32", - "uint32": "ctypes.c_uint32", - "bool": "ctypes.c_bool", - "int8": "ctypes.c_int8", - "uint8": "ctypes.c_uint8", - "int16": "ctypes.c_int16", - "uint16": "ctypes.c_uint16", - "uchar": "ctypes.c_uint8", - } - - def __init__(self, - scheduled_ir_module: IRModule, - source: str, - target: Target, - device_mod: IRModule | None = None, - host_mod: IRModule | None = None, - pass_configs: dict[str, Any] | None = None): - super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs) - - def create_dispatch_func(self, code, function_informations): - # Extract the set of dynamic symbolic names used in the primary function - dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func) - - function_args = [{"name": "kernels", "type": "Dict[str, cuda.bindings.driver.CUkernel]"}] - # Collect function arguments based on primary function's parameters and buffer mappings - for param in self.prim_func.params: - if param in self.prim_func.buffer_map: - buffer = self.prim_func.buffer_map[param] - function_args.append({ - "name": buffer.data.name, - "type": "ctypes.c_void_p", - }) - elif isinstance(param, tvm.tir.Var): - function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)}) - else: - raise ValueError( - f"Parameter {param} is not in the buffer map of the primary function.") - # Add dynamic symbols as integer arguments - for dyn_sym in dynamic_symbolic_set: - if dyn_sym not in [arg["name"] for arg in function_args]: - function_args.append({"name": dyn_sym, "type": "ctypes.c_int"}) - - function_args.append(self.get_stream_type()) - # Format the function arguments for declaration - def_args = ", ".join([f"{arg['name']}" for arg in function_args]) - - def func_call_args(s, function_args, desc_name_map: dict[str, str] | None = None): - # Extract the function call arguments matching the function definition - def maybe_desc(name: str, matches: list[str], i: int): - match = matches[i] - if not (match == name + "_desc" or match.startswith(name + "_desc_")): - return False - desc_decls = [] - if desc_name_map is not None: - desc_name_map[match] = name - if i > 0: - desc_decls.append(matches[i - 1]) - if i < len(matches) - 1: - desc_decls.append(matches[i + 1]) - return any([decl == "CUtensorMap" for decl in desc_decls]) - - pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)" - matches = re.findall(pattern, s) - call_args = [] - for i, match in enumerate(matches): - for arg in function_args: - if arg["name"] == match: - call_args.append( - (f"{match}.data_ptr()" if arg["type"] == "ctypes.c_void_p" else match, - arg["type"])) - elif maybe_desc(arg["name"], matches, i): - call_args.append((match, "None")) - return call_args - - desc_name_map: dict[str, str] = {} - device_index = 0 - kernel_launch_code = """""" - for function_name, function_info in function_informations.items(): - block_info = function_info["block_info"] - grid_info = function_info["grid_info"] - dynamic_smem_buf = function_info["dynamic_smem_buf"] - - # Find the location of the global kernel function in the code - index = match_declare_kernel(code, function_name + "(") - - # Analyze the function declaration to prepare for argument extraction - declaration = code[index:].split(";")[0] - - # Identify the start of the function body to insert arguments - index = code.index("{", index) - call_args = func_call_args(declaration, function_args, desc_name_map) - for arg_name, arg_type in call_args: - if arg_type == "ctypes.c_void_p": - device_index = f"{arg_name.replace('.data_ptr()', '')}.device.index" - break - arg_names = ", ".join([arg[0] for arg in call_args]) - arg_types = ", ".join([arg[1] for arg in call_args]) - smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf - kernel_launch_code += self.generate_tma_descriptor_args( - desc_name_map) + KERNEL_LAUNCH_FUNC_PY.format( - function_name, self._pythonic_expr(grid_info[0]), - self._pythonic_expr(grid_info[1]), self._pythonic_expr(grid_info[2]), - self._pythonic_expr(block_info[0]), self._pythonic_expr(block_info[1]), - self._pythonic_expr( - block_info[2]), smem_str, arg_names, arg_types, device_index) - - # Wrap the kernel dispatch logic in an external C function - host_func = PREDEF_HOST_FUNC_PY.format( - repr(list(function_informations.keys())), def_args, kernel_launch_code) - return host_func - - def generate_tma_descriptor_args(self, desc_name_map: dict[str, str]) -> str: - tma_descripter_init = "" - if self.tma_descriptor_args is None: - return tma_descripter_init - - for handle_name, name in desc_name_map.items(): - desc_name = name + "_desc" - assert desc_name in self.tma_descriptor_args, f"TMA descriptor {desc_name} not found in {self.tma_descriptor_args}" - args = self.tma_descriptor_args[desc_name] - # Skip __tvm_tensormap_create_tiled - if len(args) < 3: - raise ValueError( - f"TMA descriptor args too short: {len(args)} elements, expected at least 3") - _, dtype, tensor_rank, globalAddress, *remaining_args = args[1:] - - tensor_rank = int(tensor_rank) - # Validate tensor_rank - if not isinstance(tensor_rank, int) or tensor_rank <= 0: - raise ValueError(f"Invalid tensor_rank: {tensor_rank}. Must be a positive integer") - - # Calculate required length for remaining_args - # 4 groups of tensor_rank size + 4 parameters - expected_args_len = 4 * tensor_rank + 4 - if len(remaining_args) < expected_args_len: - raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, " - f"expected {expected_args_len} for tensor_rank {tensor_rank}") - - # Extract dimensions and strides using list slicing - global_dim = remaining_args[:tensor_rank] - global_stride = remaining_args[tensor_rank:2 * tensor_rank] - box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank] - element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank] - - global_dim = [str(i) for i in global_dim] - global_stride = [str(i) for i in global_stride] - box_dim = [str(i) for i in box_dim] - element_strides = [str(i) for i in element_strides] - - # Extract remaining parameters - try: - interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 * - tensor_rank + 4] - except ValueError as e: - raise ValueError( - "Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)" - ) from e - - tma_descripter_init += TMA_DESC_INIT_FUNC_PY.format( - handle_name, dtype, tensor_rank, globalAddress, - ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint64_t({x})", global_dim)), - ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint64_t({x})", global_stride)), - ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint32_t({x})", box_dim)), - ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint32_t({x})", - element_strides)), interleave, swizzle, l2Promotion, oobFill) - return tma_descripter_init - - def update_lib_code(self, code: str): - # Update the library code with the given code string - self.lib_code = code - - # Organize function information for code generation - function_informations = {} - for function_name in self.function_names: - # Do not update function with dispatch host function - if (function_name not in self.block_info) or (function_name not in self.grid_info): - continue - - function_informations[function_name] = { - "function_name": function_name, - "block_info": self.block_info[function_name], - "grid_info": self.grid_info[function_name], - "dynamic_smem_buf": self.dynamic_smem_buf[function_name], - } - - # Create the host function wrapper for the CUDA kernel - self.host_func = self.create_dispatch_func(code, function_informations) - return self.lib_code - - def get_stream_type(self) -> dict[str, str]: - return {"name": "stream=0", "type": "int"} - - class TLHIPSourceWrapper(TLCUDASourceWrapper): """ A wrapper class for the TileLang HIP backend. @@ -1230,9 +845,10 @@ class TLPyWrapper(TLWrapper): def wrap(self, c_source: str): # assert self.scheduled_ir_module is not None, "Please assign optimized module first." if is_cuda_target(self.target): + from tilelang.jit.adapter.nvrtc import TLNVRTCSourceWrapper wrapper_class = TLNVRTCSourceWrapper else: - raise ValueError(f"Unsupported platform: {self.arch.platform}") + raise ValueError(f"Unsupported target for NVRTC backend: {self.target}") wrapper = wrapper_class( scheduled_ir_module=self.scheduled_ir_module, source=c_source, diff --git a/tilelang/jit/kernel.py b/tilelang/jit/kernel.py index bb47716c..6f5eb0b5 100644 --- a/tilelang/jit/kernel.py +++ b/tilelang/jit/kernel.py @@ -15,7 +15,7 @@ from tilelang import tvm from tilelang import env from tilelang.engine.param import CompiledArtifact, KernelParam from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter, - NVRTCKernelAdapter, TorchDLPackKernelAdapter, MetalKernelAdapter) + TorchDLPackKernelAdapter, MetalKernelAdapter) from tilelang.profiler import Profiler, TensorSupplyType from tilelang.utils.target import determine_target from tilelang.contrib import nvcc as tl_nvcc @@ -270,6 +270,7 @@ class JITKernel(Generic[_P, _T]): compile_flags=compile_flags, ) elif execution_backend == "nvrtc": + from tilelang.jit.adapter import NVRTCKernelAdapter adapter = NVRTCKernelAdapter( params=artifact.params, result_idx=out_idx, @@ -339,6 +340,7 @@ class JITKernel(Generic[_P, _T]): pass_configs=pass_configs, ) elif execution_backend == "nvrtc": + from tilelang.jit.adapter import NVRTCKernelAdapter adapter = NVRTCKernelAdapter.from_database( params=params, result_idx=result_idx, diff --git a/tilelang/language/annotations.py b/tilelang/language/annotations.py index 12d3af4d..3c469e78 100644 --- a/tilelang/language/annotations.py +++ b/tilelang/language/annotations.py @@ -5,6 +5,7 @@ from typing import Callable from tilelang.layout import Layout from tvm.script.parser.tir import attr, block_attr +from tvm.tir import FloatImm __all__ = [ "use_swizzle", @@ -49,5 +50,5 @@ def annotate_l2_hit_ratio(l2_hit_ratio_map: dict): _l2_hit_ratio_map = {} for buffer, hit_ratio in l2_hit_ratio_map.items(): assert buffer.scope() == "global", "persistent L2 can only be applied to global buffers" - _l2_hit_ratio_map[buffer.data] = float(hit_ratio) + _l2_hit_ratio_map[buffer.data] = FloatImm("float32", float(hit_ratio)) return block_attr({"l2_hit_ratio_map": _l2_hit_ratio_map}) -- GitLab From 729e66ca6de418085d896f6f662184f931da9bb2 Mon Sep 17 00:00:00 2001 From: Jiaxing Ding <61589029+Paran0idy@users.noreply.github.com> Date: Sat, 15 Nov 2025 22:12:20 +0800 Subject: [PATCH 004/139] [AMD] Update CK for ROCm7 (#1262) --- 3rdparty/composable_kernel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel index 1c45ca35..b38bb492 160000 --- a/3rdparty/composable_kernel +++ b/3rdparty/composable_kernel @@ -1 +1 @@ -Subproject commit 1c45ca35dd5c215e0c1db1f40f01556f467f52a8 +Subproject commit b38bb492a1a55b5abb0c345962143c0f9c482cfb -- GitLab From 2de566e798e2b6786255df395ce652d52f10af9e Mon Sep 17 00:00:00 2001 From: Kevinzz Date: Sun, 16 Nov 2025 15:56:11 +0800 Subject: [PATCH 005/139] [BugFix] Remove memory_order in atomic constexpr and fix NSA bwd (#1260) * fix nsa bwd and atomic * [Lint] * [BugFix] - New implementation for atomicMax and atomicMin using atomicCAS - PTX version atomicAdd for single 16-byte data - Modify the test cases * [Lint] --------- Co-authored-by: tzj-fxz --- .../deepseek_nsa/example_tilelang_nsa_bwd.py | 24 +- src/tl_templates/cuda/atomic.h | 213 +++++++++++++++--- .../test_tilelang_language_atomic_add.py | 60 ++--- 3 files changed, 229 insertions(+), 68 deletions(-) diff --git a/examples/deepseek_nsa/example_tilelang_nsa_bwd.py b/examples/deepseek_nsa/example_tilelang_nsa_bwd.py index 8387d227..1d1b5ea3 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_bwd.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_bwd.py @@ -106,8 +106,8 @@ def tilelang_kernel_fwd( T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared) if is_causal: - for i, j in T.Parallel(G, BS): - acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, + for k, j in T.Parallel(G, BS): + acc_s[k, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) else: T.clear(acc_s) @@ -124,18 +124,18 @@ def tilelang_kernel_fwd( T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=True) - for i in T.Parallel(G): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(G, BS): - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + for k in T.Parallel(G): + scores_scale[k] = T.exp2(scores_max_prev[k] * scale - scores_max[k] * scale) + for k, j in T.Parallel(G, BS): + acc_s[k, j] = T.exp2(acc_s[k, j] * scale - scores_max[k] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(G): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for k in T.Parallel(G): + logsum[k] = logsum[k] * scores_scale[k] + scores_sum[k] T.copy(acc_s, acc_s_cast) # Rescale - for i, j in T.Parallel(G, BV): - acc_o[i, j] *= scores_scale[i] + for k, j in T.Parallel(G, BV): + acc_o[k, j] *= scores_scale[k] # V * softmax(Q * K) T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared) @@ -465,8 +465,8 @@ def tilelang_kernel_bwd_dqkv( T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) # [G] T.copy(Delta_slc[i_b, i, i_h * G:(i_h + 1) * G], delta) - for i, j in T.Parallel(BS, G): - dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + for _i, _j in T.Parallel(BS, G): + dsT_cast[_i, _j] = qkT[_i, _j] * (dsT[_i, _j] - delta[_j]) * sm_scale # [BS, G] @ [G, BK] -> [BS, BK] T.gemm(dsT_cast, Q_shared, dk, policy=T.GemmWarpPolicy.FullRow) diff --git a/src/tl_templates/cuda/atomic.h b/src/tl_templates/cuda/atomic.h index 82eeccfd..a573886b 100644 --- a/src/tl_templates/cuda/atomic.h +++ b/src/tl_templates/cuda/atomic.h @@ -46,10 +46,22 @@ TL_DEVICE void AtomicMax(T1 &ref, T2 val, int memory_order = int(cuda::memory_order_relaxed)) { using NT1 = typename normalize_atomic_type::type; T1 *address = &ref; - if constexpr ((std::is_same_v || - std::is_same_v) && - memory_order == int(cuda::memory_order_relaxed)) { - atomicMax(reinterpret_cast(address), static_cast(val)); + if constexpr (std::is_same_v || + std::is_same_v) { + // There is no implementation of atomicMax for half and bf16 in cuda. + // We simulate this process by atomicCAS loop. + unsigned short *address_as_ushort = + reinterpret_cast(address); + unsigned short val_as_ushort = *reinterpret_cast(&val); + unsigned short old_val_ushort = *address_as_ushort; + while (val > *reinterpret_cast(&old_val_ushort)) { + unsigned short assumed_val_ushort = old_val_ushort; + old_val_ushort = + atomicCAS(address_as_ushort, assumed_val_ushort, val_as_ushort); + if (assumed_val_ushort == old_val_ushort) { + break; + } + } } else { cuda::atomic_ref aref(*address); aref.fetch_max(cuda_cast(val), cuda::memory_order(memory_order)); @@ -61,11 +73,21 @@ TL_DEVICE T1 AtomicMaxRet(T1 &ref, T2 val, int memory_order = int(cuda::memory_order_relaxed)) { using NT1 = typename normalize_atomic_type::type; T1 *address = &ref; - if constexpr ((std::is_same_v || - std::is_same_v) && - memory_order == int(cuda::memory_order_relaxed)) { - return static_cast( - atomicMax(reinterpret_cast(address), static_cast(val))); + if constexpr (std::is_same_v || + std::is_same_v) { + unsigned short *address_as_ushort = + reinterpret_cast(address); + unsigned short val_as_ushort = *reinterpret_cast(&val); + unsigned short old_val_ushort = *address_as_ushort; + while (val > *reinterpret_cast(&old_val_ushort)) { + unsigned short assumed_val_ushort = old_val_ushort; + old_val_ushort = + atomicCAS(address_as_ushort, assumed_val_ushort, val_as_ushort); + if (assumed_val_ushort == old_val_ushort) { + break; + } + } + return static_cast(*reinterpret_cast(&old_val_ushort)); } else { cuda::atomic_ref aref(*address); return static_cast( @@ -78,10 +100,22 @@ TL_DEVICE void AtomicMin(T1 &ref, T2 val, int memory_order = int(cuda::memory_order_relaxed)) { using NT1 = typename normalize_atomic_type::type; T1 *address = &ref; - if constexpr ((std::is_same_v || - std::is_same_v) && - memory_order == int(cuda::memory_order_relaxed)) { - atomicMin(reinterpret_cast(address), static_cast(val)); + if constexpr (std::is_same_v || + std::is_same_v) { + // There is no implementation of atomicMin for half and bf16 in cuda. + // We simulate this process by atomicCAS loop. + unsigned short *address_as_ushort = + reinterpret_cast(address); + unsigned short val_as_ushort = *reinterpret_cast(&val); + unsigned short old_val_ushort = *address_as_ushort; + while (val < *reinterpret_cast(&old_val_ushort)) { + unsigned short assumed_val_ushort = old_val_ushort; + old_val_ushort = + atomicCAS(address_as_ushort, assumed_val_ushort, val_as_ushort); + if (assumed_val_ushort == old_val_ushort) { + break; + } + } } else { cuda::atomic_ref aref(*address); aref.fetch_min(cuda_cast(val), cuda::memory_order(memory_order)); @@ -93,11 +127,21 @@ TL_DEVICE T1 AtomicMinRet(T1 &ref, T2 val, int memory_order = int(cuda::memory_order_relaxed)) { using NT1 = typename normalize_atomic_type::type; T1 *address = &ref; - if constexpr ((std::is_same_v || - std::is_same_v) && - memory_order == int(cuda::memory_order_relaxed)) { - return static_cast( - atomicMin(reinterpret_cast(address), static_cast(val))); + if constexpr (std::is_same_v || + std::is_same_v) { + unsigned short *address_as_ushort = + reinterpret_cast(address); + unsigned short val_as_ushort = *reinterpret_cast(&val); + unsigned short old_val_ushort = *address_as_ushort; + while (val < *reinterpret_cast(&old_val_ushort)) { + unsigned short assumed_val_ushort = old_val_ushort; + old_val_ushort = + atomicCAS(address_as_ushort, assumed_val_ushort, val_as_ushort); + if (assumed_val_ushort == old_val_ushort) { + break; + } + } + return static_cast(*reinterpret_cast(&old_val_ushort)); } else { cuda::atomic_ref aref(*address); return static_cast( @@ -110,10 +154,67 @@ TL_DEVICE void AtomicAdd(T1 &ref, T2 val, int memory_order = int(cuda::memory_order_relaxed)) { using NT1 = typename normalize_atomic_type::type; T1 *address = &ref; - if constexpr ((std::is_same_v || - std::is_same_v) && - memory_order == int(cuda::memory_order_relaxed)) { - atomicAdd(reinterpret_cast(address), static_cast(val)); + if constexpr (std::is_same_v || + std::is_same_v) { + if (memory_order == int(cuda::memory_order_relaxed)) { + atomicAdd(reinterpret_cast(address), static_cast(val)); + } else { + // Since atomic ref do not support memory order, we need to inline ptx + // code here for each situation + if constexpr (std::is_same_v) { + // fp16 + __half ret_val; + unsigned short ret_val_cast = + *reinterpret_cast(&ret_val); + unsigned long long ref_address = + reinterpret_cast(address); + unsigned short val_cast = *reinterpret_cast(&val); + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.noftz.f16 %0, [%1], %2;" + : "=h"(ret_val_cast) + : "l"(ref_address), "h"(val_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.acquire.gpu.global.add.noftz.f16 %0, [%1], %2;" + : "=h"(ret_val_cast) + : "l"(ref_address), "h"(val_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.noftz.f16 %0, [%1], %2;" + : "=h"(ret_val_cast) + : "l"(ref_address), "h"(val_cast) + : "memory"); + } + } else if constexpr (std::is_same_v) { + // bf16 + __nv_bfloat16 ret_val; + unsigned short ret_val_cast = + *reinterpret_cast(&ret_val); + unsigned long long ref_address = + reinterpret_cast(address); + unsigned short val_cast = *reinterpret_cast(&val); + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.noftz.bf16 %0, [%1], %2;" + : "=h"(ret_val_cast) + : "l"(ref_address), "h"(val_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.acquire.gpu.global.add.noftz.bf16 %0, [%1], %2;" + : "=h"(ret_val_cast) + : "l"(ref_address), "h"(val_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.noftz.bf16 %0, [%1], %2;" + : "=h"(ret_val_cast) + : "l"(ref_address), "h"(val_cast) + : "memory"); + } + } + } } else { cuda::atomic_ref aref(*address); aref.fetch_add(cuda_cast(val), cuda::memory_order(memory_order)); @@ -125,11 +226,69 @@ TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val, int memory_order = int(cuda::memory_order_relaxed)) { using NT1 = typename normalize_atomic_type::type; T1 *address = &ref; - if constexpr ((std::is_same_v || - std::is_same_v) && - memory_order == int(cuda::memory_order_relaxed)) { - return static_cast( - atomicAdd(reinterpret_cast(address), static_cast(val))); + if constexpr (std::is_same_v || + std::is_same_v) { + if (memory_order == int(cuda::memory_order_relaxed)) { + return static_cast( + atomicAdd(reinterpret_cast(address), static_cast(val))); + } else { + if constexpr (std::is_same_v) { + // fp16 + __half ret_val; + unsigned short ret_val_cast = + *reinterpret_cast(&ret_val); + unsigned long long ref_address = + reinterpret_cast(address); + unsigned short val_cast = *reinterpret_cast(&val); + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.noftz.f16 %0, [%1], %2;" + : "=h"(ret_val_cast) + : "l"(ref_address), "h"(val_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.acquire.gpu.global.add.noftz.f16 %0, [%1], %2;" + : "=h"(ret_val_cast) + : "l"(ref_address), "h"(val_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.noftz.f16 %0, [%1], %2;" + : "=h"(ret_val_cast) + : "l"(ref_address), "h"(val_cast) + : "memory"); + } + return static_cast(*reinterpret_cast<__half *>(&ret_val_cast)); + } else if constexpr (std::is_same_v) { + // bf16 + __nv_bfloat16 ret_val; + unsigned short ret_val_cast = + *reinterpret_cast(&ret_val); + unsigned long long ref_address = + reinterpret_cast(address); + unsigned short val_cast = *reinterpret_cast(&val); + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.noftz.bf16 %0, [%1], %2;" + : "=h"(ret_val_cast) + : "l"(ref_address), "h"(val_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.acquire.gpu.global.add.noftz.bf16 %0, [%1], %2;" + : "=h"(ret_val_cast) + : "l"(ref_address), "h"(val_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.noftz.bf16 %0, [%1], %2;" + : "=h"(ret_val_cast) + : "l"(ref_address), "h"(val_cast) + : "memory"); + } + return static_cast( + *reinterpret_cast<__nv_bfloat16 *>(&ret_val_cast)); + } + } } else { cuda::atomic_ref aref(*address); return static_cast( diff --git a/testing/python/language/test_tilelang_language_atomic_add.py b/testing/python/language/test_tilelang_language_atomic_add.py index 42c33e54..132e002a 100644 --- a/testing/python/language/test_tilelang_language_atomic_add.py +++ b/testing/python/language/test_tilelang_language_atomic_add.py @@ -236,7 +236,31 @@ def run_atomic_addx2(M, N, block_M, block_N): torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3) -@tilelang.jit +def test_atomic_add(): + run_atomic_add(8, 128, 128, 32, 32) + + +def test_atomic_max(): + run_atomic_max(4, 64, 64, 16, 16) + + +def test_atomic_min(): + run_atomic_min(4, 64, 64, 16, 16) + + +def test_atomic_load_store(): + run_atomic_load_store(64, 64, 16, 16) + + +def test_atomic_memory_order(): + run_atomic_memory_order(4, 64, 64, 16, 16) + + +def test_atomic_addx2(): + run_atomic_addx2(32, 64, 8, 16) + + +@tilelang.jit(debug_root_path="./testing/python/language") def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype="float"): @T.prim_func @@ -248,9 +272,9 @@ def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype="float" idx_j = by * block_N + j if idx_i < M and idx_j < N: val = A[idx_i, idx_j] - T.atomic_add(B[idx_i, idx_j], val, memory_order="relaxed") - T.atomic_max(C[idx_i, idx_j], val, memory_order="acquire") - T.atomic_min(D[idx_i, idx_j], val, memory_order="release") + T.atomic_add(B[idx_i, idx_j], val, memory_order="release") + T.atomic_max(C[idx_i, idx_j], val, memory_order="relaxed") + T.atomic_min(D[idx_i, idx_j], val, memory_order="relaxed") return atomic_different_orders @@ -271,30 +295,6 @@ def run_atomic_different_memory_orders(M, N, block_M, block_N, dtype="float32"): torch.testing.assert_close(D, torch.minimum(torch.full_like(A, float('inf')), A)) -def test_atomic_add(): - run_atomic_add(8, 128, 128, 32, 32) - - -def test_atomic_max(): - run_atomic_max(4, 64, 64, 16, 16) - - -def test_atomic_min(): - run_atomic_min(4, 64, 64, 16, 16) - - -def test_atomic_load_store(): - run_atomic_load_store(64, 64, 16, 16) - - -def test_atomic_memory_order(): - run_atomic_memory_order(4, 64, 64, 16, 16) - - -def test_atomic_addx2(): - run_atomic_addx2(32, 64, 8, 16) - - @tilelang.jit def atomic_addx4_program(M, N, block_M, block_N): @@ -361,7 +361,9 @@ def run_atomic_return_prev(M, N, block_M, block_N, dtype="float32"): def test_atomic_different_memory_orders(): - run_atomic_different_memory_orders(32, 32, 8, 8) + run_atomic_different_memory_orders(32, 32, 8, 8, dtype="float") + run_atomic_different_memory_orders(32, 32, 8, 8, dtype="float16") + run_atomic_different_memory_orders(32, 32, 8, 8, dtype="bfloat16") def test_atomic_addx4(): -- GitLab From 716dbef52f550dd4d0864c340eb2362904b0ea33 Mon Sep 17 00:00:00 2001 From: Zhengju Tang <97930865+tzj-fxz@users.noreply.github.com> Date: Mon, 17 Nov 2025 01:22:02 +0800 Subject: [PATCH 006/139] [Example] Add GQA decoding kernel with varlen page table (#1265) * [Example] Add page table for gqa decode * [Example] Page table for varlen decoding * [Lint] * [Refactor] Remove redundant code * [Lint] * [Lint] * [Lint] --- .../example_gqa_decode_varlen_logits_paged.py | 711 ++++++++++++++++++ 1 file changed, 711 insertions(+) create mode 100644 examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py diff --git a/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py b/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py new file mode 100644 index 00000000..e565cbeb --- /dev/null +++ b/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py @@ -0,0 +1,711 @@ +import torch +import math +import argparse +import tilelang +import tilelang.language as T +from example_gqa_decode_varlen_logits import flash_attn_with_attn_pool_decode, repeat_kv, do_bench + +torch.manual_seed(0) + + +def get_configs(): + import itertools + block_N = [64, 128] + block_H = [64] + num_split = [1] + num_stages = [1, 2, 3] + threads = [128] + _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) + + configs = [{ + 'block_N': c[0], + 'block_H': c[1], + 'num_split': c[2], + 'num_stages': c[3], + 'threads': c[4] + } for c in _configs] + return configs + + +# @autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit(out_idx=[-2, -1], debug_root_path="./examples/flash_decoding") +def flashattn(batch, + heads, + k_heads, + max_seqlen_kv, + total_seqlen_k, + dim, + has_sink, + page_block_size, + block_N=128, + block_H=64, + num_split=1, + num_stages=1, + threads=128): + scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + shape_q = [batch, heads, dim] + shape_k = [total_seqlen_k, k_heads, dim] + shape_v = [total_seqlen_k, k_heads, dim] + shape_o = [batch, heads, dim] + shape_s = [batch, heads, math.ceil(max_seqlen_kv / block_N)] + dtype = "float16" + accum_dtype = "float" + kv_group_num = heads // k_heads + assert page_block_size >= block_N and page_block_size % block_N == 0, "page_block_size must be larger than block_N and a multiple of block_N" + + valid_block_H = min(block_H, kv_group_num) + # TODO: check if max_seqlen_kv is correct for varlen case + + @T.macro + def flash_attn( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + cu_seqlens_k: T.Tensor([batch + 1], "int32"), + s_aux: T.Tensor([heads], "float32"), + BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / block_N)], "int32"), + Output: T.Tensor([batch, heads, dim], dtype), + S: T.Tensor(shape_s, dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([valid_block_H, dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + S_shared = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], dtype) + s_aux_shared = T.alloc_shared([block_H], "float32") + + bid = bx + hid = by + cur_kv_head = hid // (kv_group_num // valid_block_H) + + cur_start_k = cu_seqlens_k[bid] + cur_end_k = cu_seqlens_k[bid + 1] + cur_seqlen_k = cur_end_k - cur_start_k + + T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + # loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N) + for k in T.Pipelined(loop_range, num_stages=num_stages): + k_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + ( + k * block_N) % page_block_size + T.copy(K[cur_start_k + k_start:cur_start_k + k_start + block_N, cur_kv_head, :], + K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], + -T.infinity(accum_dtype)) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # scores_max_prev is m_i + # scores_max is row_max->m_ij in triton + T.copy(scores_max, S_shared[:, k]) + # scores_scale is alpha in triton + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + # scores_sum is l_ij in triton + # logsum is l_i in triton + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + v_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + ( + k * block_N) % page_block_size + T.copy(V[cur_start_k + v_start:cur_start_k + v_start + block_N, cur_kv_head, :], + V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + if has_sink: + T.copy(s_aux[hid * valid_block_H:hid * valid_block_H + block_H], s_aux_shared) + for i in T.Parallel(block_H): + logsum[i] += s_aux_shared[i] + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + for h, k in T.Parallel(block_H, math.ceil(max_seqlen_kv / block_N)): + S_shared[h, k] = T.exp2((S_shared[h, k] - scores_max[h]) * scale) / logsum[h] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(acc_o[:valid_block_H, :], O_shared) + T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :]) + T.copy(S_shared[:valid_block_H, :], S[bid, + hid * valid_block_H:(hid + 1) * valid_block_H, :]) + + @T.prim_func + def flashattn_gqa_decode_no_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + cu_seqlens_k: T.Tensor([batch + 1], "int32"), + s_aux: T.Tensor([heads], "float32"), + BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / page_block_size)], "int32"), + Output: T.Tensor(shape_o, dtype), + S: T.Tensor(shape_s, dtype), + ): + flash_attn(Q, K, V, cu_seqlens_k, s_aux, BLOCK_TABLE, Output, S) + + # TODO: split version + return flashattn_gqa_decode_no_split + + +def flash_attn_with_attn_pool_decode_tilelang( + Q: torch.Tensor, ## [tq = b, q_h, q_dim] + K: torch.Tensor, ## [tk, k_h, k_dim] + V: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_k: int, + real_max_k_seqlen: int, + num_split: int, + softmax_scale: float, + s_aux: torch.Tensor = None, + block_size: int = 64, + use_per_kv_head_sparse_index: bool = False, + tl_kernel=None, + block_table: torch.Tensor = None, +): + num_tokens, q_h, head_size = Q.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = K.size(1) + + assert Q.dim() == K.dim() == 3 + assert Q.size(2) == K.size(2) + assert cu_seqlens_k.dim() == 1 + assert head_size in {64, 128, 256} + assert Q.is_contiguous() + assert K.is_contiguous() + assert V.is_contiguous() + + gqa_group_size = q_h // k_h + + O_tl = torch.zeros_like(Q) + S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), + dtype=Q.dtype, + device=Q.device) + O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux, block_table) + + if use_per_kv_head_sparse_index: + S_tl = torch.max_pool2d(S_tl, kernel_size=(gqa_group_size, 1), stride=(gqa_group_size, 1)) + else: + S_tl = torch.max_pool2d(S_tl, kernel_size=(q_h, 1), stride=(q_h, 1)) + + return O_tl, S_tl + + +def test_equal_seqlen_decode_main(args): + """Test decode kernel with equal sequence lengths""" + print("Testing decode kernel with equal sequence lengths") + + batch_size = args.batch_size + q_heads = args.q_heads + kv_heads = args.kv_heads + k_seqlen = args.k_seqlen + real_max_k_seqlen = args.k_seqlen + head_size = args.head_size + block_size = args.block_size + page_block_size = args.page_block_size + dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + + # For decode, query is just 1 token per batch + q = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) + k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device='cuda', dtype=dtype) + v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device='cuda', dtype=dtype) + softmax_scale = 1.0 / math.sqrt(head_size) + + # Generate sink values if needed + sink = None + if args.test_sink: + sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values + print(f"Using sink attention with sink values: {sink}") + + # Convert to varlen format for K, V + k_varlen = k.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size).contiguous() + v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size).contiguous() + + # Generate cumulative sequence lengths + cu_seqlens_k = torch.arange( + 0, (batch_size + 1) * k_seqlen, k_seqlen, device='cuda', dtype=torch.int32) + max_seqlen_k = k_seqlen + + print(f"q shape: {q.shape}") + print(f"k_varlen shape: {k_varlen.shape}") + print(f"v_varlen shape: {v_varlen.shape}") + + num_tokens, q_h, head_size = q.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = k_varlen.size(1) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, + args.test_sink, page_block_size) + + block_table = torch.zeros( + batch, math.ceil(real_max_k_seqlen / page_block_size), device='cuda', dtype=torch.int32) + block_cnt = 0 + for i in range(batch): + cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item() + for j in range(math.ceil(cur_seqlen / page_block_size)): + block_table[i, j] = block_cnt + block_cnt += 1 + block_cnt = 0 + + # Test our decode kernel + O_triton, S_triton = flash_attn_with_attn_pool_decode( + q, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + real_max_k_seqlen, + args.num_split, + softmax_scale, + s_aux=sink, + block_size=block_size) + O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( + q, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + real_max_k_seqlen, + args.num_split, + softmax_scale, + s_aux=sink, + block_size=block_size, + tl_kernel=tl_kernel, + block_table=block_table, + ) + for i in range(batch_size): + S_tilelang[i, :, + math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / + block_size):] = 0 + + # Compute torch reference + q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size] + k_repeat = repeat_kv(k, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size] + v_repeat = repeat_kv(v, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size] + + if sink is None: + # Standard scaled dot-product attention + logits = torch.matmul(q_expanded, k_repeat.transpose( + -2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] + attn_weights = torch.softmax(logits, dim=-1) + O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size] + else: + # s_aux attention + logits = torch.matmul(q_expanded, k_repeat.transpose( + -2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] + + sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(logits_max, sink_expanded) + sinks = torch.exp(sink_expanded - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + attn_weights = unnormalized_scores / normalizer + O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), + v_repeat).squeeze(2) # [batch, q_heads, head_size] + + # Compute attention score pooling + attn_score_pooled = torch.max_pool2d( + attn_weights.squeeze(2), # [b, q_heads, k_seqlen] + kernel_size=(q_heads, block_size), + stride=(q_heads, block_size), + ceil_mode=True).to(torch.float16) + + print("S_tilelang", S_tilelang) + print("attn_score_pooled", attn_score_pooled) + + max_diff_o = torch.max(torch.abs(O_triton - O_torch)) + max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) + max_diff_o_tilelang = torch.max(torch.abs(O_tilelang - O_torch)) + max_diff_s_tilelang = torch.max(torch.abs(S_tilelang - attn_score_pooled)) + + print(f"Max difference in O: {max_diff_o.item()}") + print(f"Max difference in S: {max_diff_s.item()}") + print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}") + print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}") + assert torch.allclose( + O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" + assert torch.allclose( + S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" + assert torch.allclose( + O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}" + assert torch.allclose( + S_tilelang, attn_score_pooled, atol=1e-2, + rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}" + print("✅ All tests passed!") + + +def test_varlen_decode_main(args): + """Test decode kernel with variable sequence lengths""" + batch_size = args.batch_size + q_heads = args.q_heads + kv_heads = args.kv_heads + max_k_seqlen = args.k_seqlen # Use as max sequence length + real_max_k_seqlen = args.k_seqlen + head_size = args.head_size + block_size = args.block_size + page_block_size = args.page_block_size + dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + + print(f"Testing decode kernel with variable sequence lengths (max_k_seqlen={max_k_seqlen})") + + # Generate sink values if needed + sink = None + if args.test_sink: + sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values + print(f"Using sink attention with sink values: {sink}") + + # Generate variable length k sequences + k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,)) + print(f"k_seqlens: {k_seqlens}") + + # Generate cumulative sequence lengths for k + cu_seqlens_k = torch.zeros(batch_size + 1, device='cuda', dtype=torch.int32) + total_k_tokens = 0 + for i in range(batch_size): + cu_seqlens_k[i] = total_k_tokens + total_k_tokens += k_seqlens[i] + cu_seqlens_k[batch_size] = total_k_tokens + + print(f"cu_seqlens_k: {cu_seqlens_k}") + + # Generate tensors - Q is [batch_size, q_heads, head_size] for decode + q_decode = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) + k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) + v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) + + softmax_scale = 1.0 / math.sqrt(head_size) + max_seqlen_k = int(k_seqlens.max()) + + print(f"Actual max_seqlen_k: {max_seqlen_k}") + print(f"q_decode shape: {q_decode.shape}") + print(f"k_varlen shape: {k_varlen.shape}") + print(f"v_varlen shape: {v_varlen.shape}") + + num_tokens, q_h, head_size = q_decode.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = k_varlen.size(1) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, + args.test_sink, page_block_size) + + block_table = torch.zeros( + batch, math.ceil(real_max_k_seqlen / page_block_size), device='cuda', dtype=torch.int32) + block_cnt = 0 + for i in range(batch): + cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item() + for j in range(math.ceil(cur_seqlen / page_block_size)): + block_table[i, j] = block_cnt + block_cnt += 1 + block_cnt = 0 + + # Test our decode kernel + O_triton, S_triton = flash_attn_with_attn_pool_decode( + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + real_max_k_seqlen, + args.num_split, + softmax_scale, + s_aux=sink, + block_size=block_size) + O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + real_max_k_seqlen, + args.num_split, + softmax_scale, + s_aux=sink, + block_size=block_size, + tl_kernel=tl_kernel, + block_table=block_table, + ) + for i in range(batch_size): + S_tilelang[i, :, + math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / + block_size):] = 0 + + # Create torch reference - pad tensors for comparison + k_padded_list = [] + v_padded_list = [] + + for i in range(batch_size): + actual_k_len = k_seqlens[i] + + # Extract and pad k, v for this batch + k_start = cu_seqlens_k[i] + k_end = cu_seqlens_k[i + 1] + + # Pad to max_seqlen_k + k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device='cuda', dtype=dtype) + v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device='cuda', dtype=dtype) + + k_padded[:actual_k_len] = k_varlen[k_start:k_end] + v_padded[:actual_k_len] = v_varlen[k_start:k_end] + + k_padded_list.append(k_padded) + v_padded_list.append(v_padded) + + # Stack to create batched tensors [b, max_seqlen, kv_heads, head_size] + k_padded_batched = torch.stack( + k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] + v_padded_batched = torch.stack( + v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] + + # Expand q to match kv heads: [b, q_heads, 1, head_size] + q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size] + + print(f"q_expanded shape: {q_expanded.shape}") + print(f"k_padded_batched shape: {k_padded_batched.shape}") + print(f"v_padded_batched shape: {v_padded_batched.shape}") + + # Compute torch reference + k_repeat = repeat_kv(k_padded_batched, + q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] + v_repeat = repeat_kv(v_padded_batched, + q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] + + if sink is None: + # Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen] + attn_score = torch.matmul(q_expanded, k_repeat.transpose( + -2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] + + # Apply sequence length masking + for i in range(batch_size): + actual_k_len = k_seqlens[i] + attn_score[i, :, :, actual_k_len:] = float('-inf') + + attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen] + + # Mask out invalid positions + for i in range(batch_size): + actual_k_len = k_seqlens[i] + attn_weights[i, :, :, actual_k_len:] = 0.0 + + # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] + O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size] + else: + # s_aux attention + logits = torch.matmul(q_expanded, k_repeat.transpose( + -2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] + + # Apply sequence length masking + for i in range(batch_size): + actual_k_len = k_seqlens[i] + logits[i, :, :, actual_k_len:] = float('-inf') + + sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(logits_max, sink_expanded) + sinks = torch.exp(sink_expanded - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + attn_weights = unnormalized_scores / normalizer + + # Mask out invalid positions + for i in range(batch_size): + actual_k_len = k_seqlens[i] + attn_weights[i, :, :, actual_k_len:] = 0.0 + + # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] + O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), + v_repeat) # [b, q_heads, 1, head_size] + + O_torch = O_torch.squeeze(2) # [b, q_heads, head_size] + + # Compute attention score pooling for S + attn_score_pooled = torch.max_pool2d( + attn_weights.squeeze(2), # [b, q_heads, max_seqlen] + kernel_size=(q_heads, block_size), + stride=(q_heads, block_size), + ceil_mode=True).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)] + + print(f"O_triton shape: {O_triton.shape}") + print(f"O_tilelang shape: {O_tilelang.shape}") + print(f"O_torch shape: {O_torch.shape}") + print(f"S_triton shape: {S_triton.shape}") + print(f"S_tilelang shape: {S_tilelang.shape}") + print(f"attn_score_pooled shape: {attn_score_pooled.shape}") + + # Compare results + max_diff_o = torch.max(torch.abs(O_triton - O_torch)) + max_diff_o_tl = torch.max(torch.abs(O_tilelang - O_torch)) + print(f"Max difference in O: {max_diff_o.item()}") + print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}") + + max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) + max_diff_s_tl = torch.max( + torch.abs(S_tilelang[:, :, :math.ceil(max_seqlen_k / block_size)] - attn_score_pooled)) + print(f"Max difference in S: {max_diff_s.item()}") + print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}") + + assert torch.allclose( + O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" + assert torch.allclose( + S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" + assert torch.allclose( + O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}" + assert torch.allclose( + S_tilelang[:, :, :math.ceil(max_seqlen_k / block_size)], + attn_score_pooled, + atol=1e-2, + rtol=1e-2), f"Score mismatch: {max_diff_s_tl.item()}" + + print("✅ All tests passed!") + + +def speed_benchmark_decode_comparison(args): + """Speed benchmark for decode kernel""" + batch_size = args.batch_size + q_heads = args.q_heads + kv_heads = args.kv_heads + max_k_seqlen = args.k_seqlen + real_max_k_seqlen = args.k_seqlen + head_size = args.head_size + block_size = args.block_size + page_block_size = args.page_block_size + dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + + print("\n=== Decode Speed Benchmark Comparison ===") + print("Configuration:") + print(f" Batch size: {batch_size}") + print(f" Q heads: {q_heads}, KV heads: {kv_heads}") + print(f" Max K sequence length: {max_k_seqlen}") + print(f" Head size: {head_size}") + print(f" Block size: {block_size}") + print(f" Data type: {dtype}") + print(f" Variable lengths: {args.test_varlen}") + print(f" s_aux attention: {args.test_sink}") + print() + + # Generate input data + if args.test_varlen: + k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,)) + else: + k_seqlens = torch.full((batch_size,), max_k_seqlen, dtype=int) + + # Generate cumulative sequence lengths for k + cu_seqlens_k = torch.zeros(batch_size + 1, device='cuda', dtype=torch.int32) + total_k_tokens = 0 + for i in range(batch_size): + cu_seqlens_k[i] = total_k_tokens + total_k_tokens += k_seqlens[i] + cu_seqlens_k[batch_size] = total_k_tokens + + # Generate tensors + q_decode = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) + k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) + v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) + + softmax_scale = 1.0 / math.sqrt(head_size) + max_seqlen_k = int(k_seqlens.max()) + + # Generate sink values if needed + sink = None + if args.test_sink: + sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values + print(" Using sink attention with sink values") + + print("Setup complete:") + print(f" Total K tokens: {total_k_tokens}") + print(f" Actual max K seq len: {max_seqlen_k}") + if args.test_varlen: + print(f" K sequence lengths: {k_seqlens.tolist()}") + + # Warmup + num_tokens, q_h, head_size = q_decode.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = k_varlen.size(1) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, + args.test_sink, page_block_size) + + block_table = torch.zeros( + batch, math.ceil(real_max_k_seqlen / page_block_size), device='cuda', dtype=torch.int32) + block_cnt = 0 + for i in range(batch): + cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item() + for j in range(math.ceil(cur_seqlen / page_block_size)): + block_table[i, j] = block_cnt + block_cnt += 1 + block_cnt = 0 + + # Benchmark + print("⚡ Benchmarking Tilelang kernel (100 iterations)...") + tilelang_time = do_bench( + flash_attn_with_attn_pool_decode_tilelang, + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + args.k_seqlen, + 1, + softmax_scale, + sink, + block_size, + False, + tl_kernel, + block_table, + ) + print(f"Average decode kernel time Tilelang: {tilelang_time:.3f} ms") + + # Benchmark + print("⚡ Benchmarking Triton kernel (100 iterations)...") + triton_time = do_bench(flash_attn_with_attn_pool_decode, q_decode, k_varlen, v_varlen, + cu_seqlens_k, max_seqlen_k, args.k_seqlen, 1, softmax_scale, sink, + block_size) + print(f"Average decode kernel time Triton: {triton_time:.3f} ms") + print(f"Speedup: {(triton_time / tilelang_time):.3f}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Flash Attention Decode with Attention Pooling') + parser.add_argument('--batch_size', type=int, default=1, help='Batch size') + parser.add_argument('--q_heads', type=int, default=32, help='Number of query heads') + parser.add_argument('--kv_heads', type=int, default=8, help='Number of key-value heads') + parser.add_argument('--k_seqlen', type=int, default=8192, help='Key sequence length') + parser.add_argument( + '--head_size', type=int, default=128, choices=[64, 128, 256], help='Head dimension') + parser.add_argument('--block_size', type=int, default=128, help='Block size for computation') + parser.add_argument( + '--dtype', type=str, default='bfloat16', choices=['float16', 'bfloat16'], help='Data type') + parser.add_argument( + '--test_varlen', action='store_true', help='Test with truly variable sequence lengths') + parser.add_argument( + '--test_sink', action='store_true', help='Test with sink attention mechanism') + parser.add_argument('--benchmark', action='store_true', help='Run speed benchmark') + parser.add_argument( + '--num_split', type=int, default=1, choices=[1, 16], help='Number of splits') + parser.add_argument('--page_block_size', type=int, default=128, help='Page block size') + args = parser.parse_args() + args.test_sink = True + args.test_varlen = True + args.dtype = 'float16' + args.num_split = 1 + + if args.benchmark: + speed_benchmark_decode_comparison(args) + elif args.test_varlen: + test_varlen_decode_main(args) + else: + test_equal_seqlen_decode_main(args) -- GitLab From 041d4a06b53ebeb4540636063cad2aa66fc5e1b9 Mon Sep 17 00:00:00 2001 From: Kuris <227995639+kurisu6912@users.noreply.github.com> Date: Mon, 17 Nov 2025 13:06:23 +0800 Subject: [PATCH 007/139] [Refactor] add support for numpy dtype conversion (#1255) * add typing stub for tir.ir * remove idents * minor update * [Refactor] add numpy conversion for dtype * fix lint error * remove unused np.float_ in dtype conversion * fix type in np.int_ * fix typo * minor fix * remove debug files --- .../test_tilelang_language_frontend_v2.py | 113 ++++++------- tilelang/language/v2/dtypes.py | 155 +++++++++--------- 2 files changed, 134 insertions(+), 134 deletions(-) diff --git a/testing/python/language/test_tilelang_language_frontend_v2.py b/testing/python/language/test_tilelang_language_frontend_v2.py index fb3f1e15..1d9a20fe 100644 --- a/testing/python/language/test_tilelang_language_frontend_v2.py +++ b/testing/python/language/test_tilelang_language_frontend_v2.py @@ -145,62 +145,63 @@ def test_dtype_str_repr(): buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope='shared') # noqa F841 -def test_torch_eq(): - dtypes = [ - T.bool, - T.short, - T.int, - T.long, - T.half, - T.float, - T.long, - T.int8, - T.int16, - T.int32, - T.int64, - T.uint8, - T.uint16, - T.uint32, - T.uint64, - T.float8_e4m3fn, - T.float8_e4m3fnuz, - T.float8_e5m2, - T.float8_e5m2fnuz, - T.float8_e8m0fnu, - T.float16, - T.bfloat16, - T.float32, - T.float64, - ] - torch_dtypes = [ - torch.bool, - torch.short, - torch.int, - torch.long, - torch.half, - torch.float, - torch.long, - torch.int8, - torch.int16, - torch.int32, - torch.int64, - torch.uint8, - torch.uint16, - torch.uint32, - torch.uint64, - torch.float8_e4m3fn, - torch.float8_e4m3fnuz, - torch.float8_e5m2, - torch.float8_e5m2fnuz, - torch.float8_e8m0fnu, - torch.float16, - torch.bfloat16, - torch.float32, - torch.float64, - ] - for a, b in zip(dtypes, torch_dtypes): - assert a == b, f"{a} and {b} are not equal" - assert T.dtype(b) == a, "dtype conversion error" +# not supported now +# def test_torch_eq(): +# dtypes = [ +# T.bool, +# T.short, +# T.int, +# T.long, +# T.half, +# T.float, +# T.long, +# T.int8, +# T.int16, +# T.int32, +# T.int64, +# T.uint8, +# T.uint16, +# T.uint32, +# T.uint64, +# T.float8_e4m3fn, +# T.float8_e4m3fnuz, +# T.float8_e5m2, +# T.float8_e5m2fnuz, +# T.float8_e8m0fnu, +# T.float16, +# T.bfloat16, +# T.float32, +# T.float64, +# ] +# torch_dtypes = [ +# torch.bool, +# torch.short, +# torch.int, +# torch.long, +# torch.half, +# torch.float, +# torch.long, +# torch.int8, +# torch.int16, +# torch.int32, +# torch.int64, +# torch.uint8, +# torch.uint16, +# torch.uint32, +# torch.uint64, +# torch.float8_e4m3fn, +# torch.float8_e4m3fnuz, +# torch.float8_e5m2, +# torch.float8_e5m2fnuz, +# torch.float8_e8m0fnu, +# torch.float16, +# torch.bfloat16, +# torch.float32, +# torch.float64, +# ] +# for a, b in zip(dtypes, torch_dtypes): +# assert a == b, f"{a} and {b} are not equal" +# assert T.dtype(b) == a, "dtype conversion error" def test_var_assign(): diff --git a/tilelang/language/v2/dtypes.py b/tilelang/language/v2/dtypes.py index 2161e377..0702635a 100644 --- a/tilelang/language/v2/dtypes.py +++ b/tilelang/language/v2/dtypes.py @@ -1,95 +1,98 @@ from tilelang import tvm from tvm import ir import torch -import ctypes from typing import TYPE_CHECKING, Union from tvm import tir import tvm.script.ir_builder.tir._ffi_api as tb_ffi +import numpy as np dtype = tvm.DataType # Python 3.9 compatibility: avoid PEP 604 unions at runtime AnyDType = Union[ir.Type, str, type, torch.dtype, dtype] -# Base dtype conversion list -_dtype_cvt_base = [ - (None, 'handle', ctypes.c_long, 'long', None), # use long to repr void* - (bool, 'bool', ctypes.c_bool, 'bool', 'Boolean'), - (int, 'int32', ctypes.c_int32, 'int', 'Int32'), - (float, 'float32', ctypes.c_float, 'float', 'Float32'), - (torch.short, 'int16', ctypes.c_int16, 'short', 'Int16'), - (torch.int, 'int32', ctypes.c_int32, 'int', 'Int32'), - (torch.long, 'int64', ctypes.c_int64, 'long long', 'Int64'), - (torch.half, 'float16', None, None, 'Float16'), - (torch.float, 'float32', ctypes.c_float, 'float', 'Float32'), - (torch.double, 'float64', ctypes.c_double, 'double', 'Float64'), - - # (pytype, 'tvm dtype str', 'ctypes dtype', 'cffi dtype') - (torch.bool, 'bool', ctypes.c_bool, 'bool', 'Boolean'), - (torch.int8, 'int8', ctypes.c_int8, 'char', 'Int8'), - (torch.int16, 'int16', ctypes.c_int16, 'short', 'Int16'), - (torch.int32, 'int32', ctypes.c_int32, 'int', 'Int32'), - (torch.int64, 'int64', ctypes.c_int64, 'long long', 'Int64'), - (torch.uint8, 'uint8', ctypes.c_uint8, 'unsigned char', 'UInt8'), - (torch.uint16, 'uint16', ctypes.c_uint16, 'unsigned short', 'UInt16'), - (torch.uint32, 'uint32', ctypes.c_uint32, 'unsigned int', 'UInt32'), - (torch.uint64, 'uint64', ctypes.c_uint64, 'unsigned long long', 'UInt64'), - (torch.float16, 'float16', None, None, 'Float16'), - (torch.float32, 'float32', ctypes.c_float, 'float', 'Float32'), - (torch.float64, 'float64', ctypes.c_double, 'double', 'Float64'), - (None, 'float8_e4m3', None, None, 'Float8E4M3'), - (torch.bfloat16, 'bfloat16', None, None, 'BFloat16'), -] - -# Dynamically add fp8-related types if they exist in torch -_fp8_dtype_mappings = [ - ('float8_e4m3fn', 'Float8E4M3FN'), - ('float8_e4m3fnuz', 'Float8E4M3FNUZ'), - ('float8_e5m2', 'Float8E5M2'), - ('float8_e5m2fnuz', 'Float8E5M2FNUZ'), - ('float8_e8m0fnu', 'Float8E8M0FNU'), -] - -_dtype_cvt = list(_dtype_cvt_base) -for torch_attr_name, tvm_name in _fp8_dtype_mappings: - if hasattr(torch, torch_attr_name): - torch_dtype = getattr(torch, torch_attr_name) - _dtype_cvt.append((torch_dtype, torch_attr_name, None, None, tvm_name)) - +_PYTHON_DTYPE_TO_STR = { + bool: 'bool', + int: 'int32', + float: 'float32', +} -def _create_type_mapper(sidx, didx, smapper=lambda x: x, dmapper=lambda x: x): - return { - smapper(item[sidx]): dmapper(item[didx]) - for item in _dtype_cvt - if item[didx] is not None and item[sidx] is not None - } +_NUMPY_DTYPE_TO_STR = { + np.bool_: 'bool', + np.short: 'int16', + np.int_: 'int64', + np.longlong: 'int64', + np.half: 'float16', + np.double: 'float64', + np.int8: 'int8', + np.int16: 'int16', + np.int32: 'int32', + np.int64: 'int64', + np.uint8: 'uint8', + np.uint16: 'uint16', + np.uint32: 'uint32', + np.uint64: 'uint64', + np.float16: 'float16', + np.float32: 'float32', + np.float64: 'float64', +} +_NUMPY_DTYPE_TO_STR.update({np.dtype(k): v for k, v in _NUMPY_DTYPE_TO_STR.items()}) -_dtype_py2tvmstr = _create_type_mapper(0, 1) -_dtype_tvmstr2fficall = _create_type_mapper(1, 4, dmapper=lambda x: getattr(tb_ffi, x)) -_dtype_tvm2py = _create_type_mapper(1, 0, lambda x: dtype(x)) -_dtype_tvm2ctype = _create_type_mapper(1, 2, lambda x: dtype(x)) -_dtype_tvm2cffi = _create_type_mapper(1, 3, lambda x: dtype(x)) +_TORCH_DTYPE_TO_STR = { + torch.bool: 'bool', + torch.short: 'int16', + torch.int: 'int32', + torch.long: 'int64', + torch.half: 'float16', + torch.float: 'float32', + torch.double: 'float64', + torch.int8: 'int8', + torch.int16: 'int16', + torch.int32: 'int32', + torch.int64: 'int64', + torch.uint8: 'uint8', + torch.uint16: 'uint16', + torch.uint32: 'uint32', + torch.uint64: 'uint64', + torch.float16: 'float16', + torch.float32: 'float32', + torch.float64: 'float64', + torch.bfloat16: 'bfloat16', +} +# _STR_TO_TORCH_DTYPE = {v: k for k, v in _TORCH_DTYPE_TO_STR.items()} -def __dtype_eq__(self: dtype, other: AnyDType): - if isinstance(other, str): - return str.__eq__(self, other) - if other in _dtype_py2tvmstr: - return str.__eq__(self, _dtype_py2tvmstr[other]) - return NotImplemented +# _STR_TO_NUMPY_DTYPE = {v: k for k, v in _NUMPY_DTYPE_TO_STR.items()} +_DTYPE_TO_STR = {**_PYTHON_DTYPE_TO_STR, **_NUMPY_DTYPE_TO_STR, **_TORCH_DTYPE_TO_STR} -def __dtype_ne__(self: dtype, other: AnyDType): - if isinstance(other, str): - return str.__ne__(self, other) - if other in _dtype_py2tvmstr: - return str.__ne__(self, _dtype_py2tvmstr[other]) - return NotImplemented +_STR_TO_TVM_DTYPE_CALL = { + 'bool': 'Boolean', + 'int8': 'Int8', + 'int32': 'Int32', + 'int64': 'Int64', + 'uint8': 'UInt8', + 'uint16': 'UInt16', + 'uint32': 'UInt32', + 'uint64': 'UInt64', + 'float16': 'Float16', + 'float32': 'Float32', + 'float64': 'Float64', + 'bfloat16': 'BFloat16', + 'float8_e4m3': 'Float8E4M3', + 'float8_e4m3fn': 'Float8E4M3FN', + 'float8_e4m3fnuz': 'Float8E4M3FNUZ', + 'float8_e5m2': 'Float8E5M2', + 'float8_e5m2fnuz': 'Float8E5M2FNUZ', + 'float8_e8m0fnu': 'Float8E8M0FNU' +} def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var: - if self in _dtype_tvmstr2fficall: - return _dtype_tvmstr2fficall[self](expr, is_size_var) + if self in _STR_TO_TVM_DTYPE_CALL: + attr = _STR_TO_TVM_DTYPE_CALL[self] + call = getattr(tb_ffi, attr, None) + return call(expr, is_size_var) # try to construct the ffi call if self.startswith('uint'): val = 'UInt' + self[4:] @@ -117,17 +120,13 @@ __orig_dtype_new = dtype.__new__ def __dtype_new__(cls, value: AnyDType) -> dtype: if isinstance(value, str): return __orig_dtype_new(cls, value) - elif value in _dtype_py2tvmstr: - return __orig_dtype_new(cls, _dtype_py2tvmstr[value]) + elif value in _DTYPE_TO_STR: + return __orig_dtype_new(cls, _DTYPE_TO_STR[value]) else: - expected = set(list(_dtype_py2tvmstr.keys()) + list(_dtype_tvmstr2fficall.values())) + expected = set(list(_DTYPE_TO_STR.keys()) + list(_DTYPE_TO_STR.values())) raise TypeError(f"Invalid DataType {value}({type(value)}), expect one of {expected}") -dtype.__eq__ = __dtype_eq__ -dtype.__req__ = __dtype_eq__ -dtype.__ne__ = __dtype_ne__ -dtype.__rne__ = __dtype_ne__ dtype.__call__ = __dtype_call__ dtype.__new__ = __dtype_new__ -- GitLab From a2a278149f56bc6ffb8f99a10fde737d2d2ae677 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Mon, 17 Nov 2025 06:07:30 +0000 Subject: [PATCH 008/139] [EXAMPLE] In the flash attention example keep the max of all blocks seen in scores_max numerical stability (#1148) * Keep the max of all blocks seen in scores_max for stability * ruff formatting --- examples/flash_attention/example_mha_fwd_bhsd.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/flash_attention/example_mha_fwd_bhsd.py b/examples/flash_attention/example_mha_fwd_bhsd.py index e936cee3..e0e0bca2 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd.py +++ b/examples/flash_attention/example_mha_fwd_bhsd.py @@ -86,6 +86,10 @@ def flashattn(batch, T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # in the first ceil_div(kBlockM, kBlockN) steps. -- GitLab From b3d6f03cea2710497a8704c083148813ee0826f3 Mon Sep 17 00:00:00 2001 From: Chaofan Lin Date: Mon, 17 Nov 2025 19:42:32 +0800 Subject: [PATCH 009/139] [Docs] Improve Installation Guide (#1270) * [Docs] Improve installation guide * address comments --- docs/get_started/Installation.md | 134 ++++++++++--------------------- 1 file changed, 42 insertions(+), 92 deletions(-) diff --git a/docs/get_started/Installation.md b/docs/get_started/Installation.md index 3d5c6db9..be0d794e 100644 --- a/docs/get_started/Installation.md +++ b/docs/get_started/Installation.md @@ -8,25 +8,25 @@ - **Python Version**: >= 3.8 - **CUDA Version**: 12.0 <= CUDA < 13 -The easiest way to install **tile-lang** is directly from PyPI using pip. To install the latest version, run the following command in your terminal: +The easiest way to install tilelang is directly from PyPI using pip. To install the latest version, run the following command in your terminal: ```bash pip install tilelang ``` -Alternatively, you may choose to install **tile-lang** using prebuilt packages available on the Release Page: +Alternatively, you may choose to install tilelang using prebuilt packages available on the Release Page: ```bash pip install tilelang-0.0.0.dev0+ubuntu.20.4.cu120-py3-none-any.whl ``` -To install the latest version of **tile-lang** from the GitHub repository, you can run the following command: +To install the latest version of tilelang from the GitHub repository, you can run the following command: ```bash pip install git+https://github.com/tile-ai/tilelang.git ``` -After installing **tile-lang**, you can verify the installation by running: +After installing tilelang, you can verify the installation by running: ```bash python -c "import tilelang; print(tilelang.__version__)" @@ -40,18 +40,18 @@ python -c "import tilelang; print(tilelang.__version__)" - **Python Version**: >= 3.8 - **CUDA Version**: >= 10.0 -```bash -docker run -it --rm --ipc=host nvcr.io/nvidia/pytorch:23.01-py3 -``` +If you prefer Docker, please skip to the [Install Using Docker](#install-using-docker) section. This section focuses on building from source on a native Linux environment. -To build and install **tile-lang** directly from source, follow these steps. This process requires certain pre-requisites from Apache TVM, which can be installed on Ubuntu/Debian-based systems using the following commands: +First, install the OS-level prerequisites on Ubuntu/Debian-based systems using the following commands: ```bash apt-get update apt-get install -y python3 python3-dev python3-setuptools gcc zlib1g-dev build-essential cmake libedit-dev ``` -After installing the prerequisites, you can clone the **tile-lang** repository and install it using pip: +Then, clone the tilelang repository and install it using pip. The `-v` flag enables verbose output during the build process. + +> **Note**: Use the `--recursive` flag to include necessary submodules. Tilelang currently depends on a customized version of TVM, which is included as a submodule. If you prefer [Building with Existing TVM Installation](#using-existing-tvm), you can skip cloning the TVM submodule (but still need other dependencies). ```bash git clone --recursive https://github.com/tile-ai/tilelang.git @@ -59,12 +59,18 @@ cd tilelang pip install . -v ``` -If you want to install **tile-lang** in development mode, you can run the following command: +If you want to install tilelang in development mode, you can use the `-e` flag so that any changes to the Python files will be reflected immediately without reinstallation. ```bash pip install -e . -v ``` +> **Note**: changes to C++ files require rebuilding the tilelang C++ library. See [Faster Rebuild for Developers](#faster-rebuild-for-developers) below. A default `build` directory will be created if you use `pip install`, so you can also directly run `make` in the `build` directory to rebuild it as [Working from Source via PYTHONPATH](#working-from-source-via-pythonpath) suggested below. + +(working-from-source-via-pythonpath)= + +### Working from Source via `PYTHONPATH` + If you prefer to work directly from the source tree via `PYTHONPATH`, make sure the native extension is built first: ```bash @@ -85,17 +91,21 @@ Some useful CMake options you can toggle while configuring: - `-DUSE_ROCM=ON` selects ROCm support when building on AMD GPUs. - `-DNO_VERSION_LABEL=ON` disables the backend/git suffix in `tilelang.__version__`. -We currently provide four methods to install **tile-lang**: +(using-existing-tvm)= -1. [Install Using Docker](#install-method-1) (Recommended) -2. [Install from Source (using the bundled TVM submodule)](#install-method-2) -3. [Install from Source (using your own TVM installation)](#install-method-3) +### Building with Existing TVM Installation -(install-method-1)= +If you already have a compatible TVM installation, use the `TVM_ROOT` environment variable to specify the location of your existing TVM repository when building tilelang: -### Method 1: Install Using Docker (Recommended) +```bash +TVM_ROOT= pip install . -v +``` + +(install-using-docker)= -For users who prefer a containerized environment with all dependencies pre-configured, **tile-lang** provides Docker images for different CUDA versions. This method is particularly useful for ensuring consistent environments across different systems and is the **recommended approach** for most users. +## Install Using Docker + +For users who prefer a containerized environment with all dependencies pre-configured, tilelang provides Docker images for different CUDA versions. This method is particularly useful for ensuring consistent environments across different systems. **Prerequisites:** - Docker installed on your system @@ -142,82 +152,17 @@ docker run -itd \ - `--name tilelang_b200`: Assigns a name to the container for easy management - `/bin/zsh`: Uses zsh as the default shell -4. **Access the Container**: +4. **Access the Container and Verify Installation**: ```bash docker exec -it tilelang_b200 /bin/zsh -``` - -5. **Verify Installation**: - -Once inside the container, verify that **tile-lang** is working correctly: - -```bash +# Inside the container: python -c "import tilelang; print(tilelang.__version__)" ``` -You can now run TileLang examples and develop your applications within the containerized environment. The Docker image comes with all necessary dependencies pre-installed, including CUDA toolkit, TVM, and TileLang itself. - -**Example Usage:** - -After accessing the container, you can run TileLang examples: - -```bash -cd /home/tilelang/examples -python elementwise/test_example_elementwise.py -``` - -This Docker-based installation method provides a complete, isolated environment that works seamlessly on systems with compatible NVIDIA GPUs like the B200, ensuring optimal performance for TileLang applications. - -(install-method-2)= - -### Method 2: Install from Source (Using the Bundled TVM Submodule) - -If you already have a compatible TVM installation, follow these steps: - -1. **Clone the Repository**: - -```bash -git clone --recursive https://github.com/tile-ai/tilelang -cd tilelang -``` - -**Note**: Use the `--recursive` flag to include necessary submodules. - -2. **Configure Build Options**: - -Create a build directory and specify your existing TVM path: - -```bash -pip install . -v -``` - -(install-method-3)= - -### Method 3: Install from Source (Using Your Own TVM Installation) - -If you prefer to use the built-in TVM version, follow these instructions: - -1. **Clone the Repository**: - -```bash -git clone --recursive https://github.com/tile-ai/tilelang -cd tilelang -``` - -**Note**: Ensure the `--recursive` flag is included to fetch submodules. - -2. **Configure Build Options**: - -Copy the configuration file and enable the desired backends (e.g., LLVM and CUDA): - -```bash -TVM_ROOT= pip install . -v -``` - ## Install with Nightly Version -For users who want access to the latest features and improvements before official releases, we provide nightly builds of **tile-lang**. +For users who want access to the latest features and improvements before official releases, we provide nightly builds of tilelang. ```bash pip install tilelang -f https://tile-ai.github.io/whl/nightly/cu121/ @@ -253,23 +198,28 @@ Set `NO_TOOLCHAIN_VERSION=ON` to disable this. ### Run-time environment variables +TODO + +## Other Tips -## IDE Configs +### IDE Configs -Building tilelang locally will automatically `compile_commands.json` file in `build` dir. +Building tilelang locally will automatically generate a `compile_commands.json` file in `build` dir. VSCode with clangd and [clangd extension](https://marketplace.visualstudio.com/items?itemName=llvm-vs-code-extensions.vscode-clangd) should be able to index that without extra configuration. -## Compile cache +### Compile Cache -`ccache` will be automatically used if found. +The default path of the compile cache is `~/.tilelang/cache`. `ccache` will be automatically used if found. -## Repairing wheels +### Repairing Wheels If you plan to use your wheel in other environment, -it's recommend to use auditwheel (on Linux) or delocate (on Darwin) +it's recommended to use auditwheel (on Linux) or delocate (on Darwin) to repair them. -## Faster rebuild for developers +(faster-rebuild-for-developers)= + +### Faster Rebuild for Developers `pip install` introduces extra [un]packaging and takes ~30 sec to complete, even if no source change. -- GitLab From 3ab93cd76b77978f416359bc9998e225ac276dcd Mon Sep 17 00:00:00 2001 From: Tong WU <109033598+Rachmanino@users.noreply.github.com> Date: Mon, 17 Nov 2025 21:53:19 +0800 Subject: [PATCH 010/139] [Enhancement] Keep max score attention across blocks in FlashAttention for better numerical stablity (#1269) * Implement max score retention across blocks in FlashAttention for improved stability * fix manual pipeline parameters * Update examples/flash_attention/example_gqa_fwd_varlen.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * fix typo * more * fix a previous typo --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- .../benchmark_tilelang_block_sparse_fmha.py | 2 ++ examples/amd/example_amd_flash_attn_bwd.py | 2 ++ examples/amd/example_amd_flash_attn_fwd.py | 2 ++ examples/attention_sink/example_gqa_sink_bwd_bhsd.py | 2 ++ .../example_gqa_sink_fwd_bhsd_wgmma_pipelined.py | 4 +++- examples/attention_sink/example_mha_sink_bwd_bhsd.py | 2 ++ examples/attention_sink/example_mha_sink_fwd_bhsd.py | 2 ++ .../example_mha_sink_fwd_bhsd_wgmma_pipelined.py | 4 +++- .../example_tilelang_sparse_gqa_decode_paged.py | 3 +-- ...example_tilelang_sparse_gqa_decode_varlen_indice.py | 3 +-- .../example_tilelang_sparse_gqa_decode_varlen_mask.py | 1 + .../amd/benchmark_mla_decode_amd_tilelang.py | 4 ++++ examples/deepseek_mla/example_mla_decode.py | 4 ++++ examples/deepseek_mla/example_mla_decode_paged.py | 4 ++++ examples/deepseek_mla/example_mla_decode_persistent.py | 2 ++ examples/deepseek_mla/example_mla_decode_ws.py | 10 +++++++++- .../experimental/example_mla_decode_kv_fp8.py | 2 ++ examples/deepseek_v32/sparse_mla_fwd.py | 2 ++ examples/deepseek_v32/sparse_mla_fwd_pipelined.py | 4 ++++ examples/flash_attention/README.md | 4 +++- examples/flash_attention/example_gqa_bwd.py | 2 ++ examples/flash_attention/example_gqa_bwd_tma_reduce.py | 2 ++ .../example_gqa_bwd_tma_reduce_varlen.py | 2 ++ .../flash_attention/example_gqa_bwd_wgmma_pipelined.py | 2 ++ examples/flash_attention/example_gqa_fwd_bshd.py | 2 ++ .../example_gqa_fwd_bshd_wgmma_pipelined.py | 4 +++- examples/flash_attention/example_gqa_fwd_varlen.py | 1 - examples/flash_attention/example_mha_bwd_bhsd.py | 2 ++ examples/flash_attention/example_mha_bwd_bshd.py | 4 +++- .../example_mha_bwd_bshd_wgmma_pipelined.py | 2 ++ .../example_mha_fwd_bhsd_wgmma_pipelined.py | 4 +++- examples/flash_attention/example_mha_fwd_bshd.py | 2 ++ .../example_mha_fwd_bshd_wgmma_pipelined.py | 4 +++- examples/flash_attention/example_mha_fwd_varlen.py | 2 ++ examples/flash_decoding/example_gqa_decode.py | 4 ++++ examples/flash_decoding/example_mha_inference.py | 2 ++ .../minference/example_vertical_slash_sparse_attn.py | 4 ++++ examples/seer_attention/block_sparse_attn_tilelang.py | 2 ++ .../test_tilelang_transform_config_index_bitwidth.py | 2 ++ 39 files changed, 99 insertions(+), 13 deletions(-) diff --git a/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py index aefe4d42..7c9edb59 100644 --- a/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py @@ -95,6 +95,8 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # in the first ceil_div(kBlockM, kBlockN) steps. diff --git a/examples/amd/example_amd_flash_attn_bwd.py b/examples/amd/example_amd_flash_attn_bwd.py index d47866e1..d5c52f9c 100644 --- a/examples/amd/example_amd_flash_attn_bwd.py +++ b/examples/amd/example_amd_flash_attn_bwd.py @@ -178,6 +178,8 @@ def fast_flashattn( T.copy(m_i, m_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for i in T.Parallel(block_M): + m_i[i] = T.max(m_i[i], m_prev[i]) for i in T.Parallel(block_M): if m_prev[i] == -T.infinity(accum_dtype): diff --git a/examples/amd/example_amd_flash_attn_fwd.py b/examples/amd/example_amd_flash_attn_fwd.py index 6ec5db1e..3c422c28 100644 --- a/examples/amd/example_amd_flash_attn_fwd.py +++ b/examples/amd/example_amd_flash_attn_fwd.py @@ -171,6 +171,8 @@ def fast_flashattn( T.copy(m_i, m_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for i in T.Parallel(block_M): + m_i[i] = T.max(m_i[i], m_prev[i]) for i in T.Parallel(block_M): sf = T.exp(m_prev[i] * scale - m_i[i] * scale) diff --git a/examples/attention_sink/example_gqa_sink_bwd_bhsd.py b/examples/attention_sink/example_gqa_sink_bwd_bhsd.py index eec43db9..b442505f 100644 --- a/examples/attention_sink/example_gqa_sink_bwd_bhsd.py +++ b/examples/attention_sink/example_gqa_sink_bwd_bhsd.py @@ -99,6 +99,8 @@ def flashattn_fwd( T.copy(V[bz, by // groups, k * block_N:(k + 1) * block_N, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # NOTE(wt): check_inf is necessary for sliding window attention. diff --git a/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py b/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py index 7765603a..8d181726 100644 --- a/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py +++ b/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py @@ -105,6 +105,8 @@ def flashattn( T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # NOTE(wt): check_inf is necessary for sliding window attention. @@ -181,7 +183,7 @@ def flashattn( num_stages=num_stages, order=[-1, 0, 3, 1, -1, 2], stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) diff --git a/examples/attention_sink/example_mha_sink_bwd_bhsd.py b/examples/attention_sink/example_mha_sink_bwd_bhsd.py index 866668e4..b9fa0fd9 100644 --- a/examples/attention_sink/example_mha_sink_bwd_bhsd.py +++ b/examples/attention_sink/example_mha_sink_bwd_bhsd.py @@ -96,6 +96,8 @@ def flashattn_fwd( T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # NOTE(wt): check_inf is necessary for sliding window attention. diff --git a/examples/attention_sink/example_mha_sink_fwd_bhsd.py b/examples/attention_sink/example_mha_sink_fwd_bhsd.py index 2449b090..0ccb6958 100644 --- a/examples/attention_sink/example_mha_sink_fwd_bhsd.py +++ b/examples/attention_sink/example_mha_sink_fwd_bhsd.py @@ -95,6 +95,8 @@ def flashattn( T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # NOTE(wt): check_inf is necessary for sliding window attention. diff --git a/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py b/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py index 35284407..64d6ec69 100644 --- a/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py +++ b/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py @@ -98,6 +98,8 @@ def flashattn( T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # NOTE(wt): check_inf is necessary for sliding window attention. @@ -174,7 +176,7 @@ def flashattn( num_stages=num_stages, order=[-1, 0, 3, 1, -1, 2], stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py index e2998216..1c4b847d 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py @@ -105,8 +105,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): - scores_max[i] = T.if_then_else(scores_max[i] > scores_max_prev[i], - scores_max[i], scores_max_prev[i]) + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py index ae300426..b3087522 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py @@ -95,8 +95,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): - scores_max[i] = T.if_then_else(scores_max[i] > scores_max_prev[i], - scores_max[i], scores_max_prev[i]) + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py index ad62817d..3417bd7f 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py @@ -92,6 +92,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): diff --git a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py index db460437..61c3b63c 100644 --- a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py +++ b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py @@ -91,6 +91,8 @@ def flashmla_decode(batch, T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): @@ -157,6 +159,8 @@ def flashmla_decode(batch, T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): diff --git a/examples/deepseek_mla/example_mla_decode.py b/examples/deepseek_mla/example_mla_decode.py index 417e319f..3932d112 100644 --- a/examples/deepseek_mla/example_mla_decode.py +++ b/examples/deepseek_mla/example_mla_decode.py @@ -74,6 +74,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): @@ -148,6 +150,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): diff --git a/examples/deepseek_mla/example_mla_decode_paged.py b/examples/deepseek_mla/example_mla_decode_paged.py index fe50d4d4..d23ff00c 100644 --- a/examples/deepseek_mla/example_mla_decode_paged.py +++ b/examples/deepseek_mla/example_mla_decode_paged.py @@ -93,6 +93,8 @@ def mla_decode_tilelang(batch, acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j]) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): @@ -176,6 +178,8 @@ def mla_decode_tilelang(batch, acc_s[i, j] = T.if_then_else(start + k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j]) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): diff --git a/examples/deepseek_mla/example_mla_decode_persistent.py b/examples/deepseek_mla/example_mla_decode_persistent.py index 3f57ea05..2f896f26 100644 --- a/examples/deepseek_mla/example_mla_decode_persistent.py +++ b/examples/deepseek_mla/example_mla_decode_persistent.py @@ -98,6 +98,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) diff --git a/examples/deepseek_mla/example_mla_decode_ws.py b/examples/deepseek_mla/example_mla_decode_ws.py index 6554d57d..fcd427ef 100644 --- a/examples/deepseek_mla/example_mla_decode_ws.py +++ b/examples/deepseek_mla/example_mla_decode_ws.py @@ -104,7 +104,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1) T.copy(m_i, m_i_prev) - T.reduce_max(acc_s, m_i, dim=1, clear=False) + T.reduce_max(acc_s, out=m_i, dim=1, clear=False) + for h_i in T.Parallel(block_H): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(block_H): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(block_H, block_N): @@ -137,6 +139,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.copy(m_i, m_i_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(block_H): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(block_H): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(block_H, block_N): @@ -324,6 +328,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.copy(m_i, m_i_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(block_H): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(block_H): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(block_H, block_N): @@ -356,6 +362,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.copy(m_i, m_i_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(block_H): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(block_H): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(block_H, block_N): diff --git a/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py b/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py index 1b1447e8..b141822f 100644 --- a/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py +++ b/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py @@ -74,6 +74,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): diff --git a/examples/deepseek_v32/sparse_mla_fwd.py b/examples/deepseek_v32/sparse_mla_fwd.py index a39c72c4..e65b8901 100644 --- a/examples/deepseek_v32/sparse_mla_fwd.py +++ b/examples/deepseek_v32/sparse_mla_fwd.py @@ -147,6 +147,8 @@ def sparse_mla_fwd( ) T.copy(m_i, m_i_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(H_per_block): alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(H_per_block, BI): diff --git a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py index 96dda7df..1621d85b 100644 --- a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py +++ b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py @@ -164,6 +164,8 @@ def sparse_mla_fwd( T.copy(m_i, m_i_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(H_per_block): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(H_per_block, BI): @@ -198,6 +200,8 @@ def sparse_mla_fwd( T.copy(m_i, m_i_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(H_per_block): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(H_per_block, BI): diff --git a/examples/flash_attention/README.md b/examples/flash_attention/README.md index be11a8dc..633727ec 100644 --- a/examples/flash_attention/README.md +++ b/examples/flash_attention/README.md @@ -77,7 +77,9 @@ def flash_attention( # Compute the maximum value per row on dimension 1 (block_N) T.reduce_max(acc_s, scores_max, dim=1, clear=False) - + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # Compute the factor by which we need to rescale previous partial sums for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i]) diff --git a/examples/flash_attention/example_gqa_bwd.py b/examples/flash_attention/example_gqa_bwd.py index dd9c8f7c..968d1de3 100644 --- a/examples/flash_attention/example_gqa_bwd.py +++ b/examples/flash_attention/example_gqa_bwd.py @@ -61,6 +61,8 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim_v): diff --git a/examples/flash_attention/example_gqa_bwd_tma_reduce.py b/examples/flash_attention/example_gqa_bwd_tma_reduce.py index 2af06e4b..c427908a 100644 --- a/examples/flash_attention/example_gqa_bwd_tma_reduce.py +++ b/examples/flash_attention/example_gqa_bwd_tma_reduce.py @@ -66,6 +66,8 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim_v): diff --git a/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py index 88f2d81e..a9604f4d 100644 --- a/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py +++ b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py @@ -119,6 +119,8 @@ def flashattn_fwd(batch, V_shared[i, d] = 0.0 T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim_v): diff --git a/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py index 02421249..e916812f 100644 --- a/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py +++ b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py @@ -61,6 +61,8 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim_v): diff --git a/examples/flash_attention/example_gqa_fwd_bshd.py b/examples/flash_attention/example_gqa_fwd_bshd.py index 3d4bfe45..a6d3b5f2 100644 --- a/examples/flash_attention/example_gqa_fwd_bshd.py +++ b/examples/flash_attention/example_gqa_fwd_bshd.py @@ -127,6 +127,8 @@ def flashattn(batch, T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # in the first ceil_div(kBlockM, kBlockN) steps. diff --git a/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py index 21f5e9a9..03ad15e9 100644 --- a/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py @@ -94,6 +94,8 @@ def flashattn( T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # in the first ceil_div(kBlockM, kBlockN) steps. @@ -154,7 +156,7 @@ def flashattn( num_stages=num_stages, order=[-1, 0, 3, 1, -1, 2], stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) diff --git a/examples/flash_attention/example_gqa_fwd_varlen.py b/examples/flash_attention/example_gqa_fwd_varlen.py index db16e158..ccc50e41 100644 --- a/examples/flash_attention/example_gqa_fwd_varlen.py +++ b/examples/flash_attention/example_gqa_fwd_varlen.py @@ -155,7 +155,6 @@ def flashattn(batch_size, T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) - for i in T.Parallel(block_M): scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) diff --git a/examples/flash_attention/example_mha_bwd_bhsd.py b/examples/flash_attention/example_mha_bwd_bhsd.py index 8247b265..d91d1770 100644 --- a/examples/flash_attention/example_mha_bwd_bhsd.py +++ b/examples/flash_attention/example_mha_bwd_bhsd.py @@ -63,6 +63,8 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim): diff --git a/examples/flash_attention/example_mha_bwd_bshd.py b/examples/flash_attention/example_mha_bwd_bshd.py index 414061ff..7c85f982 100644 --- a/examples/flash_attention/example_mha_bwd_bshd.py +++ b/examples/flash_attention/example_mha_bwd_bshd.py @@ -59,6 +59,8 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim): @@ -344,7 +346,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=8, help='Batch size') parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1048, help='Context size') + parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') parser.add_argument('--d_head', type=int, default=64, help='Head dimension') parser.add_argument('--causal', type=bool, default=False, help='Causal flag') args = parser.parse_args() diff --git a/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py index e10ef581..e8ee5d97 100644 --- a/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py @@ -60,6 +60,8 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim): diff --git a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py index e1d0130a..b797bbcc 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py @@ -86,6 +86,8 @@ def flashattn(batch, T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # in the first ceil_div(kBlockM, kBlockN) steps. @@ -149,7 +151,7 @@ def flashattn(batch, num_stages=num_stages, order=[-1, 0, 3, 1, -1, 2], stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) diff --git a/examples/flash_attention/example_mha_fwd_bshd.py b/examples/flash_attention/example_mha_fwd_bshd.py index a9268019..b5b72828 100644 --- a/examples/flash_attention/example_mha_fwd_bshd.py +++ b/examples/flash_attention/example_mha_fwd_bshd.py @@ -81,6 +81,8 @@ def flashattn(batch, T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # in the first ceil_div(kBlockM, kBlockN) steps. diff --git a/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py index d7023a20..02d8baef 100644 --- a/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py @@ -81,6 +81,8 @@ def flashattn(batch, T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # in the first ceil_div(kBlockM, kBlockN) steps. @@ -141,7 +143,7 @@ def flashattn(batch, num_stages=num_stages, order=[-1, 0, 3, 1, -1, 2], stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) diff --git a/examples/flash_attention/example_mha_fwd_varlen.py b/examples/flash_attention/example_mha_fwd_varlen.py index f381e900..bbb4546c 100644 --- a/examples/flash_attention/example_mha_fwd_varlen.py +++ b/examples/flash_attention/example_mha_fwd_varlen.py @@ -167,6 +167,8 @@ def flashattn(batch_size, T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # in the first ceil_div(kBlockM, kBlockN) steps. diff --git a/examples/flash_decoding/example_gqa_decode.py b/examples/flash_decoding/example_gqa_decode.py index 9ec3a026..46d9beea 100644 --- a/examples/flash_decoding/example_gqa_decode.py +++ b/examples/flash_decoding/example_gqa_decode.py @@ -115,6 +115,8 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): @@ -188,6 +190,8 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): diff --git a/examples/flash_decoding/example_mha_inference.py b/examples/flash_decoding/example_mha_inference.py index 3eabc9a7..0360b3e2 100644 --- a/examples/flash_decoding/example_mha_inference.py +++ b/examples/flash_decoding/example_mha_inference.py @@ -70,6 +70,8 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # in the first ceil_div(kBlockM, kBlockN) steps. diff --git a/examples/minference/example_vertical_slash_sparse_attn.py b/examples/minference/example_vertical_slash_sparse_attn.py index ebf8513a..48df3e09 100644 --- a/examples/minference/example_vertical_slash_sparse_attn.py +++ b/examples/minference/example_vertical_slash_sparse_attn.py @@ -87,6 +87,8 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) @@ -194,6 +196,8 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - diff --git a/examples/seer_attention/block_sparse_attn_tilelang.py b/examples/seer_attention/block_sparse_attn_tilelang.py index dcd581c6..219d3ee3 100644 --- a/examples/seer_attention/block_sparse_attn_tilelang.py +++ b/examples/seer_attention/block_sparse_attn_tilelang.py @@ -62,6 +62,8 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # in the first ceil_div(kBlockM, kBlockN) steps. diff --git a/testing/python/transform/test_tilelang_transform_config_index_bitwidth.py b/testing/python/transform/test_tilelang_transform_config_index_bitwidth.py index f051f028..1ef1589a 100644 --- a/testing/python/transform/test_tilelang_transform_config_index_bitwidth.py +++ b/testing/python/transform/test_tilelang_transform_config_index_bitwidth.py @@ -71,6 +71,8 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # in the first ceil_div(kBlockM, kBlockN) steps. -- GitLab From 220c32362ef5e152621082f310fb89202b92323c Mon Sep 17 00:00:00 2001 From: Yu Cheng <54519279+chengyupku@users.noreply.github.com> Date: Tue, 18 Nov 2025 01:26:51 +0800 Subject: [PATCH 011/139] [Bugfix] Fix multiple cg defination when using T.sync_grid (#1272) --- src/target/codegen_cuda.cc | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 6b5f5063..dda96925 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1645,10 +1645,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { } else if (op->op.same_as(tl::sync_grid())) { this->need_cooperative_groups_ = true; this->PrintIndent(); - this->stream << "cooperative_groups::grid_group grid = " - "cooperative_groups::this_grid();\n"; - this->PrintIndent(); - this->stream << "grid.sync();\n"; + this->stream << "cooperative_groups::this_grid().sync();\n"; } else if (op->op.same_as(tl::loop_break())) { this->PrintIndent(); this->stream << "break;\n"; -- GitLab From b1922518ce3238a3982c61e909e8fc74ab4e37cc Mon Sep 17 00:00:00 2001 From: Yichen Yan Date: Tue, 18 Nov 2025 11:36:32 +0800 Subject: [PATCH 012/139] [Minor] Remove from __future__ import annotations for python 3.8 (#1273) --- tilelang/carver/arch/arch_base.py | 3 --- tilelang/carver/common_schedules.py | 1 - tilelang/carver/roller/hint.py | 3 +-- tilelang/carver/roller/policy/common.py | 1 - tilelang/carver/roller/rasterization.py | 1 - tilelang/carver/roller/shape_inference/common.py | 1 - tilelang/carver/roller/shape_inference/tir.py | 1 - tilelang/carver/template/base.py | 7 +++---- tilelang/carver/template/conv.py | 1 - tilelang/carver/template/elementwise.py | 1 - tilelang/carver/template/flashattention.py | 1 - tilelang/carver/template/gemv.py | 1 - tilelang/carver/template/matmul.py | 1 - tilelang/contrib/cc.py | 1 - tilelang/contrib/nvcc.py | 1 - tilelang/intrinsics/mma_sm70_layout.py | 3 --- tilelang/jit/adapter/ctypes/adapter.py | 1 - tilelang/jit/adapter/cython/adapter.py | 1 - tilelang/jit/adapter/dlpack.py | 2 -- tilelang/language/allocate.py | 2 +- tilelang/language/annotations.py | 2 -- tilelang/language/copy.py | 1 - tilelang/language/customize.py | 1 - tilelang/language/experimental/gemm_sp.py | 1 - tilelang/language/fill.py | 1 - tilelang/language/frame.py | 1 - tilelang/language/gemm.py | 1 - tilelang/language/kernel.py | 1 - tilelang/language/loop.py | 1 - tilelang/language/overrides/parser.py | 2 -- tilelang/language/parser/operation.py | 2 -- tilelang/language/proxy.py | 2 +- tilelang/language/reduce.py | 1 - tilelang/language/tir/ir.py | 1 - tilelang/language/utils.py | 1 - tilelang/language/v2/builder.py | 1 - tilelang/language/warpgroup.py | 2 -- tilelang/layout/fragment.py | 10 ++++------ tilelang/layout/gemm_sp.py | 1 - tilelang/layout/layout.py | 6 ++---- tilelang/layout/swizzle.py | 2 +- tilelang/primitives/gemm/__init__.py | 1 - tilelang/profiler/__init__.py | 1 - tilelang/quantize/lop3.py | 1 - tilelang/quantize/mxfp.py | 1 - tilelang/transform/add_bufstore_wrapper.py | 1 - tilelang/utils/tensor.py | 1 - 47 files changed, 13 insertions(+), 68 deletions(-) diff --git a/tilelang/carver/arch/arch_base.py b/tilelang/carver/arch/arch_base.py index a10fa434..4c8825e8 100644 --- a/tilelang/carver/arch/arch_base.py +++ b/tilelang/carver/arch/arch_base.py @@ -1,6 +1,3 @@ -from __future__ import annotations - - class TileDevice: """ Represents the architecture of a computing device, capturing various hardware specifications. diff --git a/tilelang/carver/common_schedules.py b/tilelang/carver/common_schedules.py index 2766a15e..199f0158 100644 --- a/tilelang/carver/common_schedules.py +++ b/tilelang/carver/common_schedules.py @@ -19,7 +19,6 @@ # Modifications Copyright (c) Microsoft. # The code below is mostly copied from apache/tvm common_schedules.py in dlight. """Common schedule strategies for TIR.""" -from __future__ import annotations from typing import Callable from tvm import tir diff --git a/tilelang/carver/roller/hint.py b/tilelang/carver/roller/hint.py index 20d62f68..17c69dae 100644 --- a/tilelang/carver/roller/hint.py +++ b/tilelang/carver/roller/hint.py @@ -1,5 +1,4 @@ """Hint definition for schedule""" -from __future__ import annotations from tvm import DataType from . import PrimFuncNode import numpy as np @@ -218,7 +217,7 @@ class Hint: return dic @classmethod - def from_dict(cls, dic: dict) -> Hint: + def from_dict(cls, dic: dict) -> 'Hint': hint = cls() for k, v in dic.items(): setattr(hint, k, v) diff --git a/tilelang/carver/roller/policy/common.py b/tilelang/carver/roller/policy/common.py index 747dddbb..fb33eefd 100644 --- a/tilelang/carver/roller/policy/common.py +++ b/tilelang/carver/roller/policy/common.py @@ -1,4 +1,3 @@ -from __future__ import annotations import numpy as np diff --git a/tilelang/carver/roller/rasterization.py b/tilelang/carver/roller/rasterization.py index 39c603b6..ebd1319a 100644 --- a/tilelang/carver/roller/rasterization.py +++ b/tilelang/carver/roller/rasterization.py @@ -1,5 +1,4 @@ """Rasteration Plan For L2 Cache Locality""" -from __future__ import annotations class Rasterization: diff --git a/tilelang/carver/roller/shape_inference/common.py b/tilelang/carver/roller/shape_inference/common.py index aaf59aed..c52a170e 100644 --- a/tilelang/carver/roller/shape_inference/common.py +++ b/tilelang/carver/roller/shape_inference/common.py @@ -1,4 +1,3 @@ -from __future__ import annotations from collections import OrderedDict from tvm import arith diff --git a/tilelang/carver/roller/shape_inference/tir.py b/tilelang/carver/roller/shape_inference/tir.py index 675298c6..618cf9b3 100644 --- a/tilelang/carver/roller/shape_inference/tir.py +++ b/tilelang/carver/roller/shape_inference/tir.py @@ -1,4 +1,3 @@ -from __future__ import annotations from collections.abc import Mapping from tvm.tir.schedule.schedule import BlockRV from tvm.ir import structural_equal diff --git a/tilelang/carver/template/base.py b/tilelang/carver/template/base.py index 5aa5074c..a119c16a 100644 --- a/tilelang/carver/template/base.py +++ b/tilelang/carver/template/base.py @@ -1,5 +1,4 @@ # Import necessary modules and classes -from __future__ import annotations from abc import ABC, abstractmethod # For defining abstract base classes from dataclasses import dataclass, field # For defining data classes from ..arch import ( # Import architecture-related utilities and classes @@ -42,7 +41,7 @@ class BaseTemplate(ABC): """ pass - def with_arch(self, arch: TileDevice) -> BaseTemplate: + def with_arch(self, arch: TileDevice) -> 'BaseTemplate': """ Sets the architecture for this template and returns itself. @@ -110,7 +109,7 @@ class BaseTemplate(ABC): """ raise NotImplementedError("initialize_function is not implemented") - def set_function(self, func: PrimFunc) -> BaseTemplate: + def set_function(self, func: PrimFunc) -> 'BaseTemplate': """ Sets the function for this template and returns itself. @@ -123,7 +122,7 @@ class BaseTemplate(ABC): self._func = func return self - def set_output_nodes(self, output_nodes: list[OutputNode]) -> BaseTemplate: + def set_output_nodes(self, output_nodes: list[OutputNode]) -> 'BaseTemplate': """ Sets the output nodes for this template and returns itself. diff --git a/tilelang/carver/template/conv.py b/tilelang/carver/template/conv.py index f180084d..9ea89202 100644 --- a/tilelang/carver/template/conv.py +++ b/tilelang/carver/template/conv.py @@ -1,4 +1,3 @@ -from __future__ import annotations from dataclasses import dataclass from .base import BaseTemplate from tvm import te, tir diff --git a/tilelang/carver/template/elementwise.py b/tilelang/carver/template/elementwise.py index 26d53152..8cd30619 100644 --- a/tilelang/carver/template/elementwise.py +++ b/tilelang/carver/template/elementwise.py @@ -1,5 +1,4 @@ # Import necessary modules -from __future__ import annotations from dataclasses import dataclass # Used for defining data classes from .base import BaseTemplate # Importing the base class for templates from tvm import te # Importing TVM's tensor expression module diff --git a/tilelang/carver/template/flashattention.py b/tilelang/carver/template/flashattention.py index 760b1981..ae1a2540 100644 --- a/tilelang/carver/template/flashattention.py +++ b/tilelang/carver/template/flashattention.py @@ -1,4 +1,3 @@ -from __future__ import annotations from dataclasses import dataclass from .base import BaseTemplate from tvm import te diff --git a/tilelang/carver/template/gemv.py b/tilelang/carver/template/gemv.py index 7195a0b8..cdcc78d0 100644 --- a/tilelang/carver/template/gemv.py +++ b/tilelang/carver/template/gemv.py @@ -1,4 +1,3 @@ -from __future__ import annotations from dataclasses import dataclass from .base import BaseTemplate from tvm import te diff --git a/tilelang/carver/template/matmul.py b/tilelang/carver/template/matmul.py index 4847cdb2..653ddab3 100644 --- a/tilelang/carver/template/matmul.py +++ b/tilelang/carver/template/matmul.py @@ -1,4 +1,3 @@ -from __future__ import annotations from dataclasses import dataclass from .base import BaseTemplate from tvm import te diff --git a/tilelang/contrib/cc.py b/tilelang/contrib/cc.py index 0807c255..87d943ab 100644 --- a/tilelang/contrib/cc.py +++ b/tilelang/contrib/cc.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """Util to invoke C/C++ compilers in the system.""" -from __future__ import annotations import functools import os import shutil diff --git a/tilelang/contrib/nvcc.py b/tilelang/contrib/nvcc.py index 2903b15d..202e0f3b 100644 --- a/tilelang/contrib/nvcc.py +++ b/tilelang/contrib/nvcc.py @@ -1,7 +1,6 @@ # pylint: disable=invalid-name # modified from apache tvm python/tvm/contrib/nvcc.py """Utility to invoke nvcc compiler in the system""" -from __future__ import absolute_import as _abs from __future__ import annotations import os diff --git a/tilelang/intrinsics/mma_sm70_layout.py b/tilelang/intrinsics/mma_sm70_layout.py index d6491c2b..e7a57da7 100644 --- a/tilelang/intrinsics/mma_sm70_layout.py +++ b/tilelang/intrinsics/mma_sm70_layout.py @@ -1,6 +1,3 @@ -from __future__ import annotations - - def shared_16x4_to_mma_a_32x4_layout(row, col, rep): tid = (row % 4) + 16 * ((row // 4) % 2) + 4 * (row // 8) + 8 * rep local_id = col diff --git a/tilelang/jit/adapter/ctypes/adapter.py b/tilelang/jit/adapter/ctypes/adapter.py index 648c66c1..bf0aef51 100644 --- a/tilelang/jit/adapter/ctypes/adapter.py +++ b/tilelang/jit/adapter/ctypes/adapter.py @@ -1,6 +1,5 @@ """The profiler and convert to torch utils""" from __future__ import annotations - import torch from ..base import BaseKernelAdapter import ctypes diff --git a/tilelang/jit/adapter/cython/adapter.py b/tilelang/jit/adapter/cython/adapter.py index 7857872c..bc43533b 100644 --- a/tilelang/jit/adapter/cython/adapter.py +++ b/tilelang/jit/adapter/cython/adapter.py @@ -1,6 +1,5 @@ """The profiler and convert to torch utils""" from __future__ import annotations - import ctypes import logging import torch diff --git a/tilelang/jit/adapter/dlpack.py b/tilelang/jit/adapter/dlpack.py index 9fa767f0..402dfb2f 100644 --- a/tilelang/jit/adapter/dlpack.py +++ b/tilelang/jit/adapter/dlpack.py @@ -1,6 +1,4 @@ """The profiler and convert to torch utils""" -from __future__ import annotations - import torch from tilelang.contrib.dlpack import to_pytorch_func from .base import BaseKernelAdapter diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index d70355ad..f0784e86 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -13,8 +13,8 @@ Available allocation functions: Each function takes shape and dtype parameters and returns a TVM buffer object with the appropriate memory scope. """ - from __future__ import annotations + from typing import overload, Literal from tilelang import tvm as tvm from tvm.script import tir as T diff --git a/tilelang/language/annotations.py b/tilelang/language/annotations.py index 3c469e78..2ce71cb9 100644 --- a/tilelang/language/annotations.py +++ b/tilelang/language/annotations.py @@ -1,6 +1,4 @@ """Annotation helpers exposed on the TileLang language surface.""" -from __future__ import annotations - from typing import Callable from tilelang.layout import Layout diff --git a/tilelang/language/copy.py b/tilelang/language/copy.py index 4ad857b5..62de13d0 100644 --- a/tilelang/language/copy.py +++ b/tilelang/language/copy.py @@ -1,6 +1,5 @@ """The language interface for tl programs.""" from __future__ import annotations - from typing import Literal from tilelang import language as T from tilelang.utils.language import ( diff --git a/tilelang/language/customize.py b/tilelang/language/customize.py index 0830c22d..9175bdb8 100644 --- a/tilelang/language/customize.py +++ b/tilelang/language/customize.py @@ -1,6 +1,5 @@ """The language interface for tl programs.""" from __future__ import annotations - import tilelang.language as T from tvm.tir import PrimExpr, Buffer, op from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store # noqa: F401 diff --git a/tilelang/language/experimental/gemm_sp.py b/tilelang/language/experimental/gemm_sp.py index fc511c00..e966e7d6 100644 --- a/tilelang/language/experimental/gemm_sp.py +++ b/tilelang/language/experimental/gemm_sp.py @@ -1,6 +1,5 @@ """The language interface for tl programs.""" from __future__ import annotations - from tilelang.primitives.gemm.base import GemmWarpPolicy import tilelang.language as T from tvm import tir diff --git a/tilelang/language/fill.py b/tilelang/language/fill.py index 74aeb264..ad74720f 100644 --- a/tilelang/language/fill.py +++ b/tilelang/language/fill.py @@ -1,6 +1,5 @@ """The language interface for tl programs.""" from __future__ import annotations - from tvm import tir from tilelang.language import has_let_value, get_let_value from tilelang.utils.language import get_buffer_region_from_load diff --git a/tilelang/language/frame.py b/tilelang/language/frame.py index 8e6d5926..db649952 100644 --- a/tilelang/language/frame.py +++ b/tilelang/language/frame.py @@ -1,6 +1,5 @@ """Override the LetFrame to print a message when entering the frame.""" from __future__ import annotations - from tvm.ffi import register_object as _register_object from tvm.tir import Var, PrimExpr, BufferLoad, BufferRegion from tvm.ir import Range diff --git a/tilelang/language/gemm.py b/tilelang/language/gemm.py index 0f01582f..0f2e82d7 100644 --- a/tilelang/language/gemm.py +++ b/tilelang/language/gemm.py @@ -1,6 +1,5 @@ """The language interface for tl programs.""" from __future__ import annotations - from tilelang.primitives.gemm.base import GemmWarpPolicy import tilelang.language as T from tvm import tir diff --git a/tilelang/language/kernel.py b/tilelang/language/kernel.py index 54b78d3d..5e819da7 100644 --- a/tilelang/language/kernel.py +++ b/tilelang/language/kernel.py @@ -1,6 +1,5 @@ """The language interface for tl programs.""" from __future__ import annotations - from collections import deque from tvm import tir from tvm.tir import Var diff --git a/tilelang/language/loop.py b/tilelang/language/loop.py index 85f2acd8..4f8d5c30 100644 --- a/tilelang/language/loop.py +++ b/tilelang/language/loop.py @@ -1,6 +1,5 @@ """The language interface for tl programs.""" from __future__ import annotations - from typing import Any from tvm import tir from tvm.tir import IntImm diff --git a/tilelang/language/overrides/parser.py b/tilelang/language/overrides/parser.py index 01d59b60..af42098a 100644 --- a/tilelang/language/overrides/parser.py +++ b/tilelang/language/overrides/parser.py @@ -1,6 +1,4 @@ """TVMScript parser overrides tailored for TileLang.""" -from __future__ import annotations - from functools import partial from tvm.script.ir_builder import tir as T diff --git a/tilelang/language/parser/operation.py b/tilelang/language/parser/operation.py index 43774947..b2138acf 100644 --- a/tilelang/language/parser/operation.py +++ b/tilelang/language/parser/operation.py @@ -17,8 +17,6 @@ # This file is modified from the original version, # which is part of the TVM project (https://tvm.apache.org/). """The tir expression operation registration""" -from __future__ import annotations - from tvm import tir from tvm.ffi.runtime_ctypes import DataType, DataTypeCode from tvm.tir import IntImm diff --git a/tilelang/language/proxy.py b/tilelang/language/proxy.py index 2c5a372f..e2f65e83 100644 --- a/tilelang/language/proxy.py +++ b/tilelang/language/proxy.py @@ -1,6 +1,6 @@ """The language interface for tl programs.""" - from __future__ import annotations + from typing import Any, SupportsIndex, TYPE_CHECKING from collections.abc import Sequence from typing_extensions import Self diff --git a/tilelang/language/reduce.py b/tilelang/language/reduce.py index 5b895c41..09289559 100644 --- a/tilelang/language/reduce.py +++ b/tilelang/language/reduce.py @@ -1,6 +1,5 @@ """The language interface for tl programs.""" from __future__ import annotations - from tvm import tir from tilelang.language import copy, macro, alloc_shared, alloc_fragment from tilelang.language.utils import buffer_to_tile_region diff --git a/tilelang/language/tir/ir.py b/tilelang/language/tir/ir.py index fc5491ce..74cb32f7 100644 --- a/tilelang/language/tir/ir.py +++ b/tilelang/language/tir/ir.py @@ -1,4 +1,3 @@ -from __future__ import annotations import tvm.script.ir_builder.tir.ir as _ir from tvm.script.ir_builder.tir import frame from tvm.tir import PrimExpr diff --git a/tilelang/language/utils.py b/tilelang/language/utils.py index 8a918c3f..ad8b83dd 100644 --- a/tilelang/language/utils.py +++ b/tilelang/language/utils.py @@ -1,4 +1,3 @@ -from __future__ import annotations from tilelang import tvm as tvm from tvm import tir from tvm.tir import PrimExpr, Buffer, BufferLoad, op diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index 90c8a8e9..684880b7 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -1,5 +1,4 @@ from __future__ import annotations - from contextlib import contextmanager, AbstractContextManager from dataclasses import dataclass import inspect diff --git a/tilelang/language/warpgroup.py b/tilelang/language/warpgroup.py index 872d3001..bec76809 100644 --- a/tilelang/language/warpgroup.py +++ b/tilelang/language/warpgroup.py @@ -1,6 +1,4 @@ """The language interface for tl programs.""" -from __future__ import annotations - from tvm.script.ir_builder.tir.frame import TIRFrame from tvm.ffi import register_object from tilelang import _ffi_api diff --git a/tilelang/layout/fragment.py b/tilelang/layout/fragment.py index 06fc7a98..b9a56d8e 100644 --- a/tilelang/layout/fragment.py +++ b/tilelang/layout/fragment.py @@ -1,7 +1,5 @@ """Wrapping Layouts.""" # pylint: disable=invalid-name, unsupported-binary-operation -from __future__ import annotations - import tvm import tvm_ffi from tvm.ir import Range @@ -124,7 +122,7 @@ class Fragment(Layout): def repeat(self, repeats, repeat_on_thread: bool = False, - lower_dim_first: bool = True) -> Fragment: + lower_dim_first: bool = True) -> 'Fragment': """ Returns a new Fragment that repeats the iteration space a given number of times. @@ -144,7 +142,7 @@ class Fragment(Layout): """ return _ffi_api.Fragment_repeat(self, repeats, repeat_on_thread, lower_dim_first) - def replicate(self, replicate: int) -> Fragment: + def replicate(self, replicate: int) -> 'Fragment': """ Replicate the Fragment across a new thread dimension. @@ -160,7 +158,7 @@ class Fragment(Layout): """ return _ffi_api.Fragment_replicate(self, replicate) - def condense_rep_var(self) -> Fragment: + def condense_rep_var(self) -> 'Fragment': """ Condense or fold the replicate variable into the existing iteration space. This operation may be used to reduce dimensionality if the replicate variable @@ -207,7 +205,7 @@ class Fragment(Layout): """ return f"Fragment<{self.get_input_shape()}->{self.get_output_shape()}, thread={self.thread}, index={self.index}>" - def is_equal(self, other: Fragment) -> bool: + def is_equal(self, other: 'Fragment') -> bool: """ Check if the current fragment is equal to another fragment. """ diff --git a/tilelang/layout/gemm_sp.py b/tilelang/layout/gemm_sp.py index 2fd58cd2..eaaa178f 100644 --- a/tilelang/layout/gemm_sp.py +++ b/tilelang/layout/gemm_sp.py @@ -1,7 +1,6 @@ """Wrapping Layouts.""" # pylint: disable=invalid-name, unsupported-binary-operation from __future__ import annotations - import tvm import tilelang.language as T import warnings diff --git a/tilelang/layout/layout.py b/tilelang/layout/layout.py index 14db1222..10e0357e 100644 --- a/tilelang/layout/layout.py +++ b/tilelang/layout/layout.py @@ -1,7 +1,5 @@ """Wrapping Layouts.""" # pylint: disable=invalid-name, unsupported-binary-operation -from __future__ import annotations - import tvm_ffi from tvm.ir import Node, Range from tvm.tir import IterVar, Var, PrimExpr, IndexMap @@ -122,7 +120,7 @@ class Layout(Node): # Map the provided indices using the constructed index mapping return index_map.map_indices(indices) - def inverse(self) -> Layout: + def inverse(self) -> 'Layout': """ Compute the inverse of the current layout transformation. @@ -133,7 +131,7 @@ class Layout(Node): """ return _ffi_api.Layout_inverse(self) - def is_equal(self, other: Layout) -> bool: + def is_equal(self, other: 'Layout') -> bool: """ Check if the current layout is equal to another layout. diff --git a/tilelang/layout/swizzle.py b/tilelang/layout/swizzle.py index f63c954a..3a219c67 100644 --- a/tilelang/layout/swizzle.py +++ b/tilelang/layout/swizzle.py @@ -1,7 +1,7 @@ """Wrapping Layouts.""" # pylint: disable=invalid-name, unsupported-binary-operation - from __future__ import annotations + import tvm from tvm.tir import Buffer, BufferLoad, BufferRegion from tilelang import _ffi_api diff --git a/tilelang/primitives/gemm/__init__.py b/tilelang/primitives/gemm/__init__.py index ee9436d1..24843740 100644 --- a/tilelang/primitives/gemm/__init__.py +++ b/tilelang/primitives/gemm/__init__.py @@ -1,5 +1,4 @@ from __future__ import annotations - from tvm import tir from tilelang.utils import is_local, is_fragment, is_shared from tilelang.primitives.gemm.base import GemmWarpPolicy diff --git a/tilelang/profiler/__init__.py b/tilelang/profiler/__init__.py index c681ee97..3ff2baab 100644 --- a/tilelang/profiler/__init__.py +++ b/tilelang/profiler/__init__.py @@ -1,6 +1,5 @@ """The profiler and convert to torch utils""" from __future__ import annotations - from typing import Callable, Any, Literal from functools import partial import torch diff --git a/tilelang/quantize/lop3.py b/tilelang/quantize/lop3.py index 47d91f05..e4e7f7ee 100644 --- a/tilelang/quantize/lop3.py +++ b/tilelang/quantize/lop3.py @@ -1,6 +1,5 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from __future__ import annotations from typing import Literal decode_i4_to_f16 = """ diff --git a/tilelang/quantize/mxfp.py b/tilelang/quantize/mxfp.py index 0425c549..80f3e061 100644 --- a/tilelang/quantize/mxfp.py +++ b/tilelang/quantize/mxfp.py @@ -1,4 +1,3 @@ -from __future__ import annotations from typing import Literal # Implementation asm for fp4 to bf16, using twiddling diff --git a/tilelang/transform/add_bufstore_wrapper.py b/tilelang/transform/add_bufstore_wrapper.py index 7ccab470..d8457f99 100644 --- a/tilelang/transform/add_bufstore_wrapper.py +++ b/tilelang/transform/add_bufstore_wrapper.py @@ -1,4 +1,3 @@ -from __future__ import annotations from tvm.tir import (BufferStore, For, AttrStmt, ForKind, Var, PrimFunc, BufferLoad, Buffer, IntImm) from tvm.tir.stmt_functor import ir_transform, post_order_visit from tvm.tir.transform import prim_func_pass diff --git a/tilelang/utils/tensor.py b/tilelang/utils/tensor.py index 51f63db4..79947750 100644 --- a/tilelang/utils/tensor.py +++ b/tilelang/utils/tensor.py @@ -1,4 +1,3 @@ -from __future__ import annotations """The profiler and convert to torch utils""" from enum import Enum import torch -- GitLab From e805f8e5a96a0c63342bdf0420941737dcbdc469 Mon Sep 17 00:00:00 2001 From: Chaofan Lin Date: Tue, 18 Nov 2025 14:06:31 +0800 Subject: [PATCH 013/139] [BugFix] Adding extra parameters into autotune hashkey (#1274) * [BugFix] Adding extra parameters into autotune hashkey * lint * None check * check serializable --- tilelang/autotuner/tuner.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/tilelang/autotuner/tuner.py b/tilelang/autotuner/tuner.py index 4027c619..7138f4c1 100644 --- a/tilelang/autotuner/tuner.py +++ b/tilelang/autotuner/tuner.py @@ -235,7 +235,8 @@ class AutoTuner: self._kernel_parameters = k_parameters self._function_parameters = f_parameters - def generate_cache_key(self, parameters: dict[str, Any]) -> AutotuneResult | None: + def generate_cache_key(self, parameters: dict[str, Any], + extra_parameters: dict[str, Any]) -> AutotuneResult | None: """Generate a cache key for the auto-tuning process. """ @@ -261,6 +262,7 @@ class AutoTuner: key_data = { "version": __version__, "op_parameters": tuple(op_parameters), + "extra_parameters": extra_parameters, "func_source": func_source, "configs": self.configs, "compile_args": hash(self.compile_args), @@ -293,10 +295,28 @@ class AutoTuner: sig = inspect.signature(self.fn) parameters = sig.parameters + # NOTE(chaofan): We need to extract some parameters from the closure. + # Consider the case: + # def gemm(M, N, K): + # def kernel(...) + # If we only extract source, M/N/K will be symbolic and there will be cache problem. + extra_parameters: dict[str, Any] = {} + cells = self.fn.__closure__ + var_names = self.fn.__code__.co_freevars + if cells is not None: + assert len(var_names) == len(cells), "Number of free variables does not match" + for var_name, cell in zip(var_names, cells): + if var_name in parameters: + continue + # Cell content must be serializable + assert isinstance(cell.cell_contents, (int, float, str, bool, type(None))), \ + f"Cell contents {cell.cell_contents} is not serializable: {type(cell.cell_contents)}" + extra_parameters[var_name] = cell.cell_contents + if isinstance(self.configs, Callable): self.configs = self.configs(*self._kernel_parameters) - key = self.generate_cache_key(parameters) + key = self.generate_cache_key(parameters, extra_parameters) with self._lock: if env.is_cache_enabled(): -- GitLab From 49c857154efdf9edf509c8ab1fb0c967724470b8 Mon Sep 17 00:00:00 2001 From: Elevator14B Date: Tue, 18 Nov 2025 15:28:23 +0800 Subject: [PATCH 014/139] Fix various issues under `int64_t` static and dynamic shape. (#1218) * Fix various issues under int64_t static and dynamic shape. * Resolve reviewed issues. * Add unit test. * fix --------- Co-authored-by: LeiWang1999 --- src/transform/inject_assumes.cc | 4 +- .../language/test_tilelang_language_int64.py | 66 +++++++++++++++++++ .../jit/adapter/cython/cython_wrapper.pyx | 4 +- tilelang/jit/adapter/nvrtc/wrapper.py | 4 +- tilelang/jit/adapter/wrapper.py | 28 ++++---- 5 files changed, 88 insertions(+), 18 deletions(-) create mode 100644 testing/python/language/test_tilelang_language_int64.py diff --git a/src/transform/inject_assumes.cc b/src/transform/inject_assumes.cc index 485e270c..3c3bf923 100644 --- a/src/transform/inject_assumes.cc +++ b/src/transform/inject_assumes.cc @@ -6,6 +6,7 @@ #include "tvm/node/structural_hash.h" #include "tvm/tir/builtin.h" #include "tvm/tir/expr.h" +#include "tvm/tir/op.h" #include "tvm/tir/stmt.h" #include "tvm/tir/stmt_functor.h" #include "tvm/tir/transform.h" @@ -62,7 +63,8 @@ private: Stmt build(Stmt body) { auto analyzer = arith::Analyzer{}; for (const auto &e : items) { - auto simplified = analyzer.Simplify(GT(e.expr, 0)); + auto simplified = + analyzer.Simplify(GT(e.expr, make_zero(e.expr->dtype))); std::stringstream ss; ss << "Buffer shape should be greater than 0: shape `" << e.expr << "` from buffer "; diff --git a/testing/python/language/test_tilelang_language_int64.py b/testing/python/language/test_tilelang_language_int64.py new file mode 100644 index 00000000..28fa2211 --- /dev/null +++ b/testing/python/language/test_tilelang_language_int64.py @@ -0,0 +1,66 @@ +import tilelang +import tilelang.language as T + + +@tilelang.jit +def fill_symbolic(value: float, dtype="bfloat16"): + n = T.symbolic("n", "int64") + block_n = 512 + + @T.prim_func + def main(x: T.Tensor[n, dtype]): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(n, block_n), threads=128) as bx: + # Doesn't yet work with int64-shaped global tensor + # T.fill(x[bx * block_n : (bx + 1) * block_n], value) + for i in T.Parallel(block_n): + x[bx * block_n + i] = value + + return main + + +def run_fill_symbolic(n: int): + import torch + + x = torch.zeros(n, dtype=torch.bfloat16, device="cuda") + fill_symbolic(1.0)(x) + assert x.min() == 1.0 and x.max() == 1.0 + + +def test_fill_symbolic(): + # Requires 8GB VRAM + run_fill_symbolic(2**32) + + +@tilelang.jit +def fill_static(n: int, value: float, dtype="bfloat16"): + block_n = 512 + + @T.prim_func + def main(x: T.Tensor[n, dtype]): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(n, block_n), threads=128) as bx: + # Doesn't yet work with int64-shaped global tensor + # T.fill(x[bx * block_n : (bx + 1) * block_n], value) + for i in T.Parallel(block_n): + x[bx * block_n + i] = value + + return main + + +def run_fill_static(n: int): + import torch + + x = torch.zeros(n, dtype=torch.bfloat16, device="cuda") + fill_static(n, 1.0)(x) + assert x.min() == 1.0 and x.max() == 1.0 + + +def test_fill_static(): + # Requires 8GB VRAM + run_fill_static(2**32) + + +if __name__ == "__main__": + test_fill_symbolic() + test_fill_static() diff --git a/tilelang/jit/adapter/cython/cython_wrapper.pyx b/tilelang/jit/adapter/cython/cython_wrapper.pyx index f17bfffc..873e5507 100644 --- a/tilelang/jit/adapter/cython/cython_wrapper.pyx +++ b/tilelang/jit/adapter/cython/cython_wrapper.pyx @@ -267,9 +267,9 @@ cdef class CythonKernelWrapper: # Add dynamic dimension values to kernel arguments for _, (ref_id, buffer_idx, shape_idx) in self.dynamic_symbolic_map.items(): if ref_id == 0: - call_args.append(tensor_list[buffer_idx].shape[shape_idx]) + call_args.append(ctypes.c_int64(tensor_list[buffer_idx].shape[shape_idx])) else: - call_args.append(tensor_list[buffer_idx].stride(shape_idx)) + call_args.append(ctypes.c_int64(tensor_list[buffer_idx].stride(shape_idx))) # Add CUDA stream to kernel arguments call_args.append(ctypes.c_void_p(stream)) diff --git a/tilelang/jit/adapter/nvrtc/wrapper.py b/tilelang/jit/adapter/nvrtc/wrapper.py index 1a29adef..7e00050c 100644 --- a/tilelang/jit/adapter/nvrtc/wrapper.py +++ b/tilelang/jit/adapter/nvrtc/wrapper.py @@ -313,9 +313,9 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): raise ValueError( f"Parameter {param} is not in the buffer map of the primary function.") # Add dynamic symbols as integer arguments - for dyn_sym in dynamic_symbolic_set: + for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set: if dyn_sym not in [arg["name"] for arg in function_args]: - function_args.append({"name": dyn_sym, "type": "ctypes.c_int"}) + function_args.append({"name": dyn_sym, "type": self._lookup_type(dyn_sym_dtype)}) function_args.append(self.get_stream_type()) diff --git a/tilelang/jit/adapter/wrapper.py b/tilelang/jit/adapter/wrapper.py index 7819890d..48b8e908 100644 --- a/tilelang/jit/adapter/wrapper.py +++ b/tilelang/jit/adapter/wrapper.py @@ -220,9 +220,9 @@ class TLCUDASourceWrapper: raise ValueError( f"Parameter {param} is not in the buffer map of the primary function.") # Add dynamic symbols as integer arguments - for dyn_sym in dynamic_symbolic_set: + for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set: if dyn_sym not in [arg["name"] for arg in function_args]: - function_args.append({"name": dyn_sym, "type": "int"}) + function_args.append({"name": dyn_sym, "type": self._lookup_type(dyn_sym_dtype)}) function_args.append(self.get_stream_type()) @@ -405,18 +405,20 @@ class TLCUDASourceWrapper: def get_dynamic_symbolic_set(self, prim_func): # Determine the set of dynamic symbols used in the function - dynamic_symbolic_set: list[str] = [] + dynamic_symbolic_set: dict[str, str] = {} - def unique_push_back(name: str): + def unique_push_back(name: str, dtype: str): if name not in dynamic_symbolic_set: - dynamic_symbolic_set.append(name) + dynamic_symbolic_set[name] = dtype + else: + assert dtype == dynamic_symbolic_set[name] for param in prim_func.params: if param in prim_func.buffer_map: buffer = prim_func.buffer_map[param] for dim in buffer.shape: if isinstance(dim, tvm.tir.Var): - unique_push_back(dim.name) + unique_push_back(dim.name, str(dim.dtype)) # Note: In buffer definitions, any dynamic symbols appearing in strides are listed after those in the shape. for param in prim_func.params: @@ -424,9 +426,9 @@ class TLCUDASourceWrapper: buffer = prim_func.buffer_map[param] for stride in buffer.strides: if isinstance(stride, tvm.tir.Var): - unique_push_back(stride.name) + unique_push_back(stride.name, str(stride.dtype)) - return dynamic_symbolic_set + return list(dynamic_symbolic_set.items()) def get_init_func(self): # Initialize an empty string for the CUDA function call @@ -665,8 +667,8 @@ class TLCPUSourceWrapper: raise ValueError( f"Parameter {param} is not in the buffer map of the primary function.") # Add dynamic symbols as integer arguments - for dyn_sym in dynamic_symbolic_set: - function_args.append({"name": dyn_sym, "type": "int"}) + for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set: + function_args.append({"name": dyn_sym, "type": self._lookup_type(dyn_sym_dtype)}) # Format the function arguments for declaration def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) @@ -715,14 +717,14 @@ class TLCPUSourceWrapper: def get_dynamic_symbolic_set(self, prim_func): # Determine the set of dynamic symbols used in the function - dynamic_symbolic_set: list[str] = [] + dynamic_symbolic_set: dict[str, str] = {} for param in prim_func.params: if param in prim_func.buffer_map: buffer = prim_func.buffer_map[param] for dim in buffer.shape: if isinstance(dim, tvm.tir.Var) and (dim.name not in dynamic_symbolic_set): - dynamic_symbolic_set.append(dim.name) - return dynamic_symbolic_set + dynamic_symbolic_set[dim.name] = str(dim.dtype) + return list(dynamic_symbolic_set.items()) def get_cpu_init_func(self): # Provide init() and get_last_error() for CPU backend -- GitLab From 0f980f15c575bf35db73a70fc04a8a53c005b2c8 Mon Sep 17 00:00:00 2001 From: Jay Zhuang <80731350+learning-chip@users.noreply.github.com> Date: Tue, 18 Nov 2025 14:35:18 +0100 Subject: [PATCH 015/139] Bug fix for Gated Delta Net benchmark script (#1267) * fix argument order for fla chunk_gated_delta_rule_fwd_h * explicit import assert_similar from utils * rename utils module to avoid name clash * set store_final_state and save_new_value to True * fix --------- Co-authored-by: LeiWang1999 --- examples/gdn/example_chunk_delta_bwd.py | 2 +- examples/gdn/example_chunk_delta_h.py | 30 +++++++++++++++++------ examples/gdn/example_chunk_o_bwd.py | 2 +- examples/gdn/example_wy_fast_bwd_split.py | 2 +- examples/gdn/{utils.py => test_utils.py} | 0 5 files changed, 25 insertions(+), 11 deletions(-) rename examples/gdn/{utils.py => test_utils.py} (100%) diff --git a/examples/gdn/example_chunk_delta_bwd.py b/examples/gdn/example_chunk_delta_bwd.py index 518b0ee2..d9ccc256 100644 --- a/examples/gdn/example_chunk_delta_bwd.py +++ b/examples/gdn/example_chunk_delta_bwd.py @@ -24,7 +24,7 @@ import torch.nn.functional as F torch.random.manual_seed(0) # torch.set_printoptions(profile="full") -from utils import * +from test_utils import assert_similar def prepare_input( diff --git a/examples/gdn/example_chunk_delta_h.py b/examples/gdn/example_chunk_delta_h.py index 61c2abd3..cc384ade 100644 --- a/examples/gdn/example_chunk_delta_h.py +++ b/examples/gdn/example_chunk_delta_h.py @@ -20,7 +20,7 @@ import torch import torch.nn.functional as F from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401 -from utils import * +from test_utils import assert_similar # (zhengju) We can slightly modify the generated cuda code from tilelang lowering # in the debug folder to make the performance better. To enable this callback, @@ -292,9 +292,15 @@ def run_test( getattr(torch, state_dtype)) # fla ref - h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h(K, W, U, G, initial_state, - store_final_state, chunk_size, - save_new_value) + h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h( + k=K, + w=W, + u=U, + g=G, + initial_state=initial_state, + output_final_state=store_final_state, + chunk_size=chunk_size, + save_new_value=save_new_value) # tilelang kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype, @@ -305,8 +311,16 @@ def run_test( # (zhengju) If you want to print the generated cuda code, you can uncomment the following line # print("CUDA Code:\n", kernel.get_kernel_source()) - fla_time = do_bench(chunk_gated_delta_rule_fwd_h, K, W, U, G, initial_state, store_final_state, - chunk_size, save_new_value) + fla_time = do_bench( + chunk_gated_delta_rule_fwd_h, + k=K, + w=W, + u=U, + g=G, + initial_state=initial_state, + output_final_state=store_final_state, + chunk_size=chunk_size, + save_new_value=save_new_value) tilelang_time = do_bench(kernel, K, W, U, G, initial_state) # check correctness @@ -371,8 +385,8 @@ def main(): chunk_size=64, use_g=True, use_initial_state=False, - store_final_state=False, - save_new_value=False, + store_final_state=True, + save_new_value=True, block_DK=32, block_DV=32, threads=128, diff --git a/examples/gdn/example_chunk_o_bwd.py b/examples/gdn/example_chunk_o_bwd.py index 7e87a2c4..ff4d3f7a 100644 --- a/examples/gdn/example_chunk_o_bwd.py +++ b/examples/gdn/example_chunk_o_bwd.py @@ -19,7 +19,7 @@ except ImportError: fla = None import torch -from utils import * +from test_utils import assert_similar torch.random.manual_seed(0) # torch.set_printoptions(profile="full") diff --git a/examples/gdn/example_wy_fast_bwd_split.py b/examples/gdn/example_wy_fast_bwd_split.py index 618a82b4..42a0040d 100644 --- a/examples/gdn/example_wy_fast_bwd_split.py +++ b/examples/gdn/example_wy_fast_bwd_split.py @@ -501,7 +501,7 @@ def run_test( dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum( dim=-1) - from utils import assert_similar + from test_utils import assert_similar assert_similar(dk_ref, dk_tilelang, eps=1e-5, name="dk", raise_assert=False) assert_similar(dv_ref, dv_tilelang, eps=1e-5, name="dv", raise_assert=False) assert_similar(dbeta_ref, dbeta_tilelang, eps=1e-5, name="dbeta", raise_assert=False) diff --git a/examples/gdn/utils.py b/examples/gdn/test_utils.py similarity index 100% rename from examples/gdn/utils.py rename to examples/gdn/test_utils.py -- GitLab From 1b0efb650fd0dfd05d0b643bf5eaa8e9781239ee Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Tue, 18 Nov 2025 21:37:01 +0800 Subject: [PATCH 016/139] [Bugfix] Minor fix for some cases (#1278) --- .../gemm_v2/correctness_evaluation_tcgen05.py | 25 ++++++++----------- .../intrinsics/tcgen05_macro_generator.py | 5 ++-- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/maint/gemm_v2/correctness_evaluation_tcgen05.py b/maint/gemm_v2/correctness_evaluation_tcgen05.py index f5d76589..1831ac8a 100644 --- a/maint/gemm_v2/correctness_evaluation_tcgen05.py +++ b/maint/gemm_v2/correctness_evaluation_tcgen05.py @@ -191,7 +191,7 @@ def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): if __name__ == "__main__": - # tilelang.testing.main() + tilelang.testing.main() # # Test Pass # for m in [32, 64, 128, 256]: @@ -203,6 +203,16 @@ if __name__ == "__main__": # run_gemm(m, n, k * 3, False, True, "float16", "float", "float", m, n, k, 2, 128) # print(f"Test {m} {n} {k} Pass") + # # Test Pass + # for m in [32, 64, 128, 256]: + # for n in [32, 64, 128]: + # for k in [16, 32, 64, 128]: + # if m in [32, 64] and (n not in [64, 128, 256]): + # continue + # print(f"======================= Test {m} {n} {k} False True =============================") + # run_gemm(m, n, k * 3, False, True, "float16", "float", "float", m, n, k, 2, 256) + # print(f"Test {m} {n} {k} Pass") + # # Test Pass # for m in [32, 64, 128, 256]: # for n in [16, 32, 64, 128]: @@ -211,16 +221,3 @@ if __name__ == "__main__": # continue # print(f"======================= Test {m} {n} {k} False True =============================") # run_gemm(m, n, k * 3, False, True, "float8_e5m2", "float", "float", m, n, k, 2, 128) - # print(f"Test {m} {n} {k} Pass") - - tilelang.disable_cache() - run_gemm(32, 512, 16, False, True, "float16", "float32", "float32", 32, 512, 16, 0, 128) - run_gemm(32, 512, 32, False, True, "float16", "float32", "float32", 32, 512, 32, 0, 128) - run_gemm(32, 512, 64, False, True, "float16", "float32", "float32", 32, 512, 64, 0, 128) - run_gemm(64, 512, 16, False, True, "float16", "float32", "float32", 64, 512, 16, 0, 128) - run_gemm(64, 512, 16, False, True, "float16", "float32", "float32", 32, 512, 16, 0, 128) - run_gemm(128, 512, 16, False, True, "float16", "float32", "float32", 128, 512, 16, 0, 128) - - # run_gemm(64, 512, 32, False, True, "float16", "float32", "float32", 64, 512, 32, 0, 128) - # run_gemm(64, 512, 64, False, True, "float16", "float32", "float32", 64, 512, 64, 0, 128) - # run_gemm(128, 512, 16, False, True, "float16", "float32", "float32", 128, 512, 16, 0, 128) diff --git a/tilelang/intrinsics/tcgen05_macro_generator.py b/tilelang/intrinsics/tcgen05_macro_generator.py index 814d28b6..e53ff7cb 100644 --- a/tilelang/intrinsics/tcgen05_macro_generator.py +++ b/tilelang/intrinsics/tcgen05_macro_generator.py @@ -247,8 +247,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): mask_zero = T.Cast("int32", 0) mask0 = mask1 = mask2 = mask3 = mask_zero - num_inst_m = 4 * self.warp_row_tiles // atom_m - num_inst_n = self.warp_col_tiles // atom_n + # TCGEN05 only has one warp group + num_inst_m = self.block_row_warps * self.warp_row_tiles // atom_m + num_inst_n = self.block_col_warps * self.warp_col_tiles // atom_n # Helper to allow BufferRegion/BufferLoad as inputs def access_ptr_from(buffer_or_load_or_region, access_type: str = "r"): -- GitLab From 921b96a31bb10e7aff84dece6e7501cf1fb96c63 Mon Sep 17 00:00:00 2001 From: Chaofan Lin Date: Tue, 18 Nov 2025 23:17:49 +0800 Subject: [PATCH 017/139] [Language] Add shape check in `T.view/reshape` (#1277) * [Language] Add shape check in T.view/reshape * address comments --- .../test_tilelang_language_reshape.py | 21 +++++++++++++ .../language/test_tilelang_language_view.py | 31 +++++++++++++++++++ tilelang/language/customize.py | 12 ++++--- tilelang/utils/language.py | 13 +++++++- 4 files changed, 72 insertions(+), 5 deletions(-) diff --git a/testing/python/language/test_tilelang_language_reshape.py b/testing/python/language/test_tilelang_language_reshape.py index c510bdd3..60588b4a 100644 --- a/testing/python/language/test_tilelang_language_reshape.py +++ b/testing/python/language/test_tilelang_language_reshape.py @@ -2,6 +2,7 @@ from tilelang import tvm as tvm import tilelang.testing import tilelang as tl import torch +import pytest def reshape_test(N, M, dtype): @@ -262,5 +263,25 @@ def test_reduce_after_reshape(): run_reduce_after_reshape(2048, 64, "float16") +def reshape_shape_mismatch_test(N, M, dtype): + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor((N,), dtype), + B: T.Tensor((N // M, M), dtype), + ): + with T.Kernel(1) as _: + A_reshaped = T.reshape(A, [N // M, M + 1]) + T.copy(A_reshaped, B) + + return main + + +def test_reshape_shape_mismatch(): + with pytest.raises(AssertionError): + reshape_shape_mismatch_test(1024, 32, "float32") + + if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_language_view.py b/testing/python/language/test_tilelang_language_view.py index c16c5185..a79d428b 100644 --- a/testing/python/language/test_tilelang_language_view.py +++ b/testing/python/language/test_tilelang_language_view.py @@ -1,6 +1,7 @@ from tilelang import tvm as tvm import tilelang.testing import tilelang as tl +import pytest def view_test(N, M, dtype, new_dtype=None): @@ -54,5 +55,35 @@ def test_reshape_view(): run_view(2048, 64, "float16", "float32") +def view_shape_mismatch_test(N, M, dtype, new_dtype=None): + import tilelang.language as T + + new_shape = [N // M, M + 1] + if new_dtype: + from tvm import DataType + dtype_src = DataType(dtype) + dtype_dst = DataType(new_dtype) + src_bits = dtype_src.bits + dst_bits = dtype_dst.bits + scale = src_bits / dst_bits + new_shape[-1] = int(M * scale) + + @T.prim_func + def main( + A: T.Tensor((N,), dtype), + B: T.Tensor(new_shape, new_dtype if new_dtype else dtype), + ): + with T.Kernel(1) as _: + A_viewed = T.view(A, new_shape, dtype=new_dtype) + T.copy(A_viewed, B) + + return main + + +def test_view_shape_mismatch(): + with pytest.raises(AssertionError): + view_shape_mismatch_test(1024, 32, "float32") + + if __name__ == "__main__": tilelang.testing.main() diff --git a/tilelang/language/customize.py b/tilelang/language/customize.py index 9175bdb8..3d40ce47 100644 --- a/tilelang/language/customize.py +++ b/tilelang/language/customize.py @@ -2,6 +2,7 @@ from __future__ import annotations import tilelang.language as T from tvm.tir import PrimExpr, Buffer, op +from tilelang.utils.language import (bits_product, prim_expr_equal) from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store # noqa: F401 @@ -45,19 +46,22 @@ def reshape(src: Buffer, shape: list[PrimExpr]) -> Buffer: Returns: Buffer: A new buffer view with the specified shape """ + assert prim_expr_equal(bits_product(shape, src.dtype), + bits_product(src.shape, src.dtype)), "T.reshape/view shape check failed." return T.Tensor(shape, src.dtype, src.data) def view(src: Buffer, shape: list[PrimExpr] | None = None, dtype: str | None = None) -> Buffer: - """ - Return a Tensor view of the input buffer with an optional new shape and dtype. + """Return a Tensor view of the input buffer with an optional new shape and dtype. - If `shape` is None the source buffer's shape is used; if `dtype` is None the source buffer's dtype is used. The returned buffer shares the same underlying data as `src` (no copy). - """ + If `shape` is None the source buffer's shape is used; if `dtype` is None the source buffer's dtype is used. The returned buffer shares the same underlying data as `src` (no copy). + """ if shape is None: shape = src.shape if dtype is None: dtype = src.dtype + assert prim_expr_equal(bits_product(shape, dtype), + bits_product(src.shape, src.dtype)), "T.reshape/view shape check failed." return T.Tensor(shape, dtype, src.data) diff --git a/tilelang/utils/language.py b/tilelang/utils/language.py index de180745..e9fe13da 100644 --- a/tilelang/utils/language.py +++ b/tilelang/utils/language.py @@ -1,7 +1,7 @@ from __future__ import annotations from tvm.tir import Buffer, BufferLoad, BufferRegion, PrimExpr from functools import reduce -from tvm import IRModule +from tvm import IRModule, DataType from tvm.tir import PrimFunc from tvm import ir, tir @@ -349,6 +349,17 @@ def retrieve_offset(obj: Buffer | BufferRegion | BufferLoad) -> list: raise ValueError(f"Unsupported retrieve_offset argument type: {type(obj)} for object {obj}") +def bits_product(shape: list[PrimExpr], dtype: str) -> PrimExpr: + """ + Compute the number of bits in a Buffer (shape with dtype).""" + if len(shape) == 0: + return tir.IntImm("int32", 1) + result = shape[0] + for i in range(1, len(shape)): + result = result * shape[i] + return result * DataType(dtype).bits + + def prim_expr_equal(lhs, rhs) -> bool: """ Robust equality for PrimExpr shapes/extents. -- GitLab From 74da369695068da9ddef76dc807792abcea0f6fa Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Tue, 18 Nov 2025 23:50:57 +0800 Subject: [PATCH 018/139] [FFI] Use tvm ffi as the default execution backend (#1259) * [Refactor] Update FFI type handling and simplify argument management * Refactored FFI type definitions in runtime and code generation files to use `TVMFFIAny` instead of `TVMValue`, enhancing type clarity. * Updated function registration in `runtime.cc` to utilize canonical names for better consistency. * Simplified argument handling in the `simplify` transformation, ensuring unused buffer parameters are removed only when simplification is enabled. * Adjusted autotuner and profiler parameters to standardize the execution backend to `tvm_ffi`, improving clarity in backend selection. * Removed obsolete `adapt_torch2tvm` function from tensor utilities to streamline the codebase and reduce complexity. * [Update] Sync TVM submodule and enhance kernel source handling * Updated the TVM submodule to commit cdc2aced, ensuring compatibility with recent changes. * Added functionality to print kernel source in `example_blocksparse_gemm.py` for better debugging. * Commented out the main execution call in test files to prevent unintended execution during testing. * Introduced `tilelang.disable_cache()` in various test files to streamline testing and avoid cache-related issues. * Refactored kernel source retrieval methods to improve clarity and consistency across different execution backends. * [Refactor] Clean up imports and improve code formatting * Removed unused import of `tilelang.testing` in `test_example_blocksparse_gemm.py` to streamline the code. * Reformatted several lines in `arg_binder.cc`, `make_packed_api.cc`, `tvm_ffi.py`, and `adapter.py` for improved readability and consistency. * Updated comments and spacing in `tvm_ffi.py` to enhance clarity without altering functionality. * Update execution backend options and improve resolution logic - Changed default execution backend from "cython" to "auto" in multiple locations to allow automatic selection based on the target. - Expanded the list of supported execution backends to include "torch" and "nvrtc" across various classes and functions. - Enhanced backend resolution logic in `KernelCache` and `AutoTuner` to ensure appropriate backend selection based on the target. - Updated documentation to reflect changes in execution backend options and their defaults. * lint fix * fix * Enhance argument handling in CUDA and HIP runtime modules - Updated `ExtractFuncInfo` in `rt_mod_cuda.cc` and `rt_mod_hip.cc` to map boolean argument types to int32, ensuring compatibility with device runtime. - Refactored `BindDLTensor` in `arg_binder.cc` to improve null handling and validation checks for DLTensor parameters, utilizing expression-level guards to prevent dereferencing null pointers. - Enhanced error checking for buffer shape, strides, and data fields, ensuring robust handling of optional inputs and maintaining consistency across various checks. * lint fix * lint fix * lint fix * lint fix * minor fix * fix * recover check * Refactor argument binding and validation in `arg_binder.cc` - Improved null handling and validation checks in `BindDLTensor`, ensuring safe dereferencing of pointers. - Enhanced consistency checks for buffer shape, strides, and data fields, utilizing expression-level guards. - Updated `MakePackedAPI` to maintain code clarity and consistency in argument handling. - Minor adjustments in test files to streamline kernel execution and improve readability. * lint fix * stride fix * minor fix * fix * lint fix * lint fix * Add CUDA stream access policy window helpers and integrate with L2 persistent cache management - Introduced functions to set and reset the CUDA stream access policy window, allowing for better control over L2 cache usage. - Updated runtime files to include new FFI packed functions for managing stream attributes. - Modified lower_hopper_intrin to incorporate prologue and epilogue statements for L2 cache setup and teardown. - Enhanced tests to verify the inclusion of new FFI calls in the generated kernel source. * check with symbolic * support null ptr * Update CMakeLists and lower.py for code generation and subproject status - Added `codegen_c_host.cc` to the list of source files in CMakeLists.txt for improved code generation support. - Updated the function call in `lower.py` to use `target.build.tilelang_c` for C target host code generation, enhancing compatibility. - Marked the TVM subproject as dirty to indicate local modifications. * lint fix * Update comments for clarity in quickstart.py --- 3rdparty/tvm | 2 +- CMakeLists.txt | 1 + .../example_blocksparse_gemm.py | 1 - examples/gdn/example_chunk_o_bwd.py | 1 - examples/gdn/test_example_gdn_compilation.py | 1 + examples/quickstart.py | 5 +- pyproject.toml | 1 + src/runtime/runtime.cc | 172 ++++- src/runtime/runtime.h | 8 +- src/target/codegen_c_host.cc | 556 +++++++++++++++++ src/target/codegen_c_host.h | 124 ++++ src/target/codegen_cpp.cc | 8 +- src/target/rt_mod_cuda.cc | 6 +- src/target/rt_mod_hip.cc | 6 +- src/transform/arg_binder.cc | 384 +++++++++--- src/transform/arg_binder.h | 4 + src/transform/lower_hopper_intrin.cc | 64 +- src/transform/make_packed_api.cc | 293 ++++----- src/transform/simplify.cc | 57 +- .../python/debug/test_tilelang_debug_print.py | 2 +- .../dynamic/test_tilelang_dynamic_symbolic.py | 3 +- .../jit/test_tilelang_jit_gemm_ctypes.py | 411 ------------ .../python/jit/test_tilelang_jit_nullptr.py | 13 +- .../python/jit/test_tilelang_jit_tvm_ffi.py | 589 ++++++++++++++++++ .../language/test_tilelang_language_alloc.py | 4 +- tilelang/autotuner/param.py | 6 +- tilelang/autotuner/tuner.py | 21 +- tilelang/cache/__init__.py | 3 +- tilelang/cache/kernel_cache.py | 145 +++-- tilelang/contrib/dlpack.py | 20 - tilelang/engine/lower.py | 2 +- tilelang/jit/__init__.py | 45 +- tilelang/jit/adapter/__init__.py | 2 +- tilelang/jit/adapter/base.py | 48 +- tilelang/jit/adapter/ctypes/adapter.py | 25 +- tilelang/jit/adapter/cython/adapter.py | 26 +- tilelang/jit/adapter/dlpack.py | 40 -- tilelang/jit/adapter/nvrtc/adapter.py | 21 +- tilelang/jit/adapter/tvm_ffi.py | 321 ++++++++++ tilelang/jit/execution_backend.py | 100 +++ tilelang/jit/kernel.py | 85 ++- tilelang/profiler/__init__.py | 4 +- tilelang/utils/tensor.py | 19 - 43 files changed, 2721 insertions(+), 928 deletions(-) create mode 100644 src/target/codegen_c_host.cc create mode 100644 src/target/codegen_c_host.h delete mode 100644 testing/python/jit/test_tilelang_jit_gemm_ctypes.py create mode 100644 testing/python/jit/test_tilelang_jit_tvm_ffi.py delete mode 100644 tilelang/jit/adapter/dlpack.py create mode 100644 tilelang/jit/adapter/tvm_ffi.py create mode 100644 tilelang/jit/execution_backend.py diff --git a/3rdparty/tvm b/3rdparty/tvm index 093b2cdb..f4105f89 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 093b2cdb2187140b197336496d65d61ace89e8ff +Subproject commit f4105f89a646622acc9818584d1d91e2ca3f533d diff --git a/CMakeLists.txt b/CMakeLists.txt index 72e1d979..f784f11f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -138,6 +138,7 @@ file(GLOB TILE_LANG_SRCS src/transform/*.cc src/op/*.cc src/target/utils.cc + src/target/codegen_c_host.cc src/target/codegen_cpp.cc src/target/rt_mod_cpp.cc # intrin_rule doesn't have system dependency diff --git a/examples/blocksparse_gemm/example_blocksparse_gemm.py b/examples/blocksparse_gemm/example_blocksparse_gemm.py index 7b9cff7c..8cd3a821 100644 --- a/examples/blocksparse_gemm/example_blocksparse_gemm.py +++ b/examples/blocksparse_gemm/example_blocksparse_gemm.py @@ -166,7 +166,6 @@ def main(): enable_rasteration=DEFAULT_ENABLE_RASTERIZATION) block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})") - # Create block mask with desired sparsity mask_shape = (M // block_M, N // block_N, K // block_K) block_mask = torch.rand(mask_shape).cuda() > sparsity diff --git a/examples/gdn/example_chunk_o_bwd.py b/examples/gdn/example_chunk_o_bwd.py index ff4d3f7a..20aa8414 100644 --- a/examples/gdn/example_chunk_o_bwd.py +++ b/examples/gdn/example_chunk_o_bwd.py @@ -468,7 +468,6 @@ def run_test( kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, gate_dtype, state_dtype, chunk_size, scale, use_g, use_dw, block_DK, block_DV, threads, num_stages) - print(kernel.get_kernel_source()) dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W) if use_g: diff --git a/examples/gdn/test_example_gdn_compilation.py b/examples/gdn/test_example_gdn_compilation.py index e184dbca..75a62171 100644 --- a/examples/gdn/test_example_gdn_compilation.py +++ b/examples/gdn/test_example_gdn_compilation.py @@ -117,6 +117,7 @@ def test_example_chunk_o_bwd_compilation(): kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, gate_dtype, state_dtype, chunk_size, 1.0, use_g, True, block_DK, block_DV, threads, num_stages) + dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W) # noqa: F841 if use_g: diff --git a/examples/quickstart.py b/examples/quickstart.py index 42514ee3..46a39e0d 100644 --- a/examples/quickstart.py +++ b/examples/quickstart.py @@ -55,10 +55,9 @@ block_M = 128 block_N = 128 block_K = 32 -# 1. Define the kernel (matmul) and compile/lower it into an executable module +# Define the kernel (matmul) and compile/lower it into an executable module matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K) - -# 3. Test the kernel in Python with PyTorch data +# Test the kernel in Python with PyTorch data import torch # Create random input tensors on the GPU diff --git a/pyproject.toml b/pyproject.toml index 8c417d56..706cd529 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,6 +104,7 @@ tilelang = "tilelang" # TVM "tilelang/3rdparty/tvm/src" = "3rdparty/tvm/src" "tilelang/3rdparty/tvm/python" = "3rdparty/tvm/python" +"tilelang/3rdparty/tvm/include" = "3rdparty/tvm/include" "tilelang/3rdparty/tvm/version.py" = "3rdparty/tvm/version.py" # CUTLASS "tilelang/3rdparty/cutlass/include" = "3rdparty/cutlass/include" diff --git a/src/runtime/runtime.cc b/src/runtime/runtime.cc index a00786e2..b2a7127d 100644 --- a/src/runtime/runtime.cc +++ b/src/runtime/runtime.cc @@ -13,6 +13,12 @@ namespace tvm { namespace tl { +#if 1 +// Thread-local storage for restoring the L2 persisting cache limit +static thread_local size_t __tl_prev_persisting_l2_cache_size = 0; +static thread_local bool __tl_prev_persisting_l2_cache_saved = false; +#endif + #if (CUDA_MAJOR_VERSION >= 12) template static std::string ArrayToStr(const T *ptr, size_t n) { std::stringstream ss; @@ -91,19 +97,21 @@ struct TensorMapArgs { // set device api TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_packed("tvm_tensormap_create_tiled", [](PackedArgs args, - Any *ret) { - TensorMapArgs T = TensorMapArgs::Extract(args); - CUresult result = cuTensorMapEncodeTiled( - T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim, - T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave, T.swizzle, - T.l2Promotion, T.oobFill); - if (result != CUDA_SUCCESS) { - LOG_FATAL << "Failed to initialize the TMA descriptor " << result << '\n' - << T.ToDebugString(); - } - *ret = static_cast(result); - }); + // Register using the canonical names defined in runtime.h + refl::GlobalDef().def_packed( + tl::tvm_tensormap_create_tiled, [](PackedArgs args, Any *ret) { + TensorMapArgs T = TensorMapArgs::Extract(args); + CUresult result = cuTensorMapEncodeTiled( + T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim, + T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave, + T.swizzle, T.l2Promotion, T.oobFill); + if (result != CUDA_SUCCESS) { + LOG_FATAL << "Failed to initialize the TMA descriptor " << result + << '\n' + << T.ToDebugString(); + } + *ret = static_cast(result); + }); } struct TensorMapIm2ColArgs { @@ -183,7 +191,7 @@ struct TensorMapIm2ColArgs { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed( - "tvm_tensormap_create_im2col", [](PackedArgs args, Any *ret) { + tl::tvm_tensormap_create_im2col, [](PackedArgs args, Any *ret) { TensorMapIm2ColArgs T = TensorMapIm2ColArgs::Extract(args); CUresult result = cuTensorMapEncodeIm2col( T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim, @@ -201,5 +209,141 @@ TVM_FFI_STATIC_INIT_BLOCK() { #endif // (CUDA_MAJOR_VERSION >= 12) +// +// CUDA L2 Persisting Cache Access Policy Window helpers. +// Exposed as TVM FFI packed functions similar to TMA initialization. +// +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + // Set stream access policy window and adjust persisting L2 cache size + // Args: + // [0]: void* base_ptr (required) + // [1]: int64 num_bytes (required) + // [2]: float hit_ratio (optional, default 0.8) + // [3]: void* stream (optional, default 0 => default stream) + // [4]: int64 l2_limit_bytes (optional, default = num_bytes) + refl::GlobalDef().def_packed( + tl::tvm_cuda_stream_set_access_policy_window, + [](PackedArgs args, Any *ret) { + ICHECK(args.size() >= 2) << "Expected at least base_ptr and num_bytes"; + + void *base_ptr = args[0].cast(); + size_t num_bytes = static_cast(args[1].cast()); + float hit_ratio = 0.8f; + if (args.size() >= 3) { + // Accept double/float + hit_ratio = static_cast(args[2].cast()); + } + CUstream stream = nullptr; + if (args.size() >= 4) { + stream = reinterpret_cast(args[3].cast()); + } + size_t l2_limit_bytes = num_bytes; + if (args.size() >= 5) { + l2_limit_bytes = static_cast(args[4].cast()); + } + + // Clamp requested limit to device capability + CUdevice device; + CUresult result = cuCtxGetDevice(&device); + if (result != CUDA_SUCCESS) { + LOG_FATAL << "Failed to get current CUDA device: " << result; + } + int max_persisting = 0; + result = cuDeviceGetAttribute( + &max_persisting, CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE, + device); + if (result != CUDA_SUCCESS) { + LOG_FATAL << "Failed to query MAX_PERSISTING_L2_CACHE_SIZE: " + << result; + } + if (max_persisting > 0 && + l2_limit_bytes > static_cast(max_persisting)) { + l2_limit_bytes = static_cast(max_persisting); + } + + // Save current limit to restore later + size_t init_persisting_l2_cache_size = 0; + result = cuCtxGetLimit(&init_persisting_l2_cache_size, + CU_LIMIT_PERSISTING_L2_CACHE_SIZE); + if (result != CUDA_SUCCESS) { + LOG_FATAL << "Failed to get current persisting L2 cache size limit: " + << result; + } + __tl_prev_persisting_l2_cache_size = init_persisting_l2_cache_size; + __tl_prev_persisting_l2_cache_saved = true; + + // Set new limit + result = + cuCtxSetLimit(CU_LIMIT_PERSISTING_L2_CACHE_SIZE, l2_limit_bytes); + if (result != CUDA_SUCCESS) { + LOG_FATAL << "Failed to set persisting L2 cache size limit: " + << result; + } + + // Apply access policy window to stream + CUstreamAttrValue stream_attribute; + memset(&stream_attribute, 0, sizeof(stream_attribute)); + stream_attribute.accessPolicyWindow.base_ptr = base_ptr; + stream_attribute.accessPolicyWindow.num_bytes = l2_limit_bytes; + stream_attribute.accessPolicyWindow.hitRatio = hit_ratio; + stream_attribute.accessPolicyWindow.hitProp = + CU_ACCESS_PROPERTY_PERSISTING; + stream_attribute.accessPolicyWindow.missProp = + CU_ACCESS_PROPERTY_STREAMING; + + result = cuStreamSetAttribute(stream, + CU_STREAM_ATTRIBUTE_ACCESS_POLICY_WINDOW, + &stream_attribute); + if (result != CUDA_SUCCESS) { + LOG_FATAL << "Failed to set stream access policy window: " << result; + } + + *ret = static_cast(result); + }); + + // Reset stream access policy window and restore the previous L2 cache size + // Args: + // [0]: void* stream (optional, default 0) + refl::GlobalDef().def_packed( + tl::tvm_cuda_stream_reset_access_policy_window, + [](PackedArgs args, Any *ret) { + CUstream stream = nullptr; + if (args.size() >= 1) { + stream = reinterpret_cast(args[0].cast()); + } + + CUstreamAttrValue stream_attribute; + memset(&stream_attribute, 0, sizeof(stream_attribute)); + // num_bytes = 0 disables the access policy window on the stream + stream_attribute.accessPolicyWindow.num_bytes = 0; + + CUresult result = cuStreamSetAttribute( + stream, CU_STREAM_ATTRIBUTE_ACCESS_POLICY_WINDOW, + &stream_attribute); + if (result != CUDA_SUCCESS) { + LOG_FATAL << "Failed to reset stream access policy window: " + << result; + } + + result = cuCtxResetPersistingL2Cache(); + if (result != CUDA_SUCCESS) { + LOG_FATAL << "Failed to reset persisting L2 cache lines: " << result; + } + + if (__tl_prev_persisting_l2_cache_saved) { + result = cuCtxSetLimit(CU_LIMIT_PERSISTING_L2_CACHE_SIZE, + __tl_prev_persisting_l2_cache_size); + if (result != CUDA_SUCCESS) { + LOG_FATAL << "Failed to restore persisting L2 cache size limit: " + << result; + } + __tl_prev_persisting_l2_cache_saved = false; + } + + *ret = static_cast(result); + }); +} + } // namespace tl } // namespace tvm diff --git a/src/runtime/runtime.h b/src/runtime/runtime.h index fb9dfcfd..4b389fc0 100644 --- a/src/runtime/runtime.h +++ b/src/runtime/runtime.h @@ -16,7 +16,13 @@ constexpr const char *tvm_tensormap_create_tiled = constexpr const char *tvm_tensormap_create_im2col = "__tvm_tensormap_create_im2col"; #endif // (CUDA_MAJOR_VERSION >= 12) + +// CUDA stream access policy window helpers +constexpr const char *tvm_cuda_stream_set_access_policy_window = + "__tvm_cuda_stream_set_access_policy_window"; +constexpr const char *tvm_cuda_stream_reset_access_policy_window = + "__tvm_cuda_stream_reset_access_policy_window"; } // namespace tl } // namespace tvm -#endif // TVM_TL_RUNTIME_RUNTIME_H_ \ No newline at end of file +#endif // TVM_TL_RUNTIME_RUNTIME_H_ diff --git a/src/target/codegen_c_host.cc b/src/target/codegen_c_host.cc new file mode 100644 index 00000000..b5e74b0a --- /dev/null +++ b/src/target/codegen_c_host.cc @@ -0,0 +1,556 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file codegen_c_host.cc + */ +#include "codegen_c_host.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +// For escaping strings embedded into generated C sources +#include "support/str_escape.h" + +namespace tvm { +namespace tl { + +CodeGenCHost::CodeGenCHost() { + module_name_ = name_supply_->FreshName(tvm::ffi::symbol::tvm_ffi_library_ctx); +} + +void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, + bool emit_fwd_func_decl, std::string target_str, + const std::unordered_set &devices) { + emit_asserts_ = emit_asserts; + emit_fwd_func_decl_ = emit_fwd_func_decl; + declared_globals_.clear(); + decl_stream << "// tilelang target: " << target_str << "\n"; + decl_stream << "#define TVM_EXPORTS\n"; + decl_stream << "#include \"tvm/runtime/base.h\"\n"; + decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n"; + decl_stream << "#include \"tvm/ffi/c_api.h\"\n"; + decl_stream << "#include \n"; + // snprintf for richer assert messages with actual values + decl_stream << "#include \n"; + decl_stream << "#include \n"; + CodeGenCHost::InitGlobalContext(); + tvm::codegen::CodeGenC::Init(output_ssa); +} + +void CodeGenCHost::InitGlobalContext() { + decl_stream << "void* " << tvm::ffi::symbol::tvm_ffi_library_ctx + << " = NULL;\n"; +} + +void CodeGenCHost::DefineModuleName() { + decl_stream << "void* " << module_name_ << " = NULL;\n"; +} + +void CodeGenCHost::AddFunction(const tvm::GlobalVar &gvar, + const tvm::tir::PrimFunc &func) { + return AddFunction(gvar, func, /*emit_fwd_func_decl=*/false); +} + +void CodeGenCHost::AddFunction(const tvm::GlobalVar &gvar, + const tvm::tir::PrimFunc &func, + bool emit_fwd_func_decl) { + auto global_symbol = + func->GetAttr(tvm::attr::kGlobalSymbol); + if (global_symbol) { + function_names_.push_back(global_symbol.value()); + } + + emit_fwd_func_decl_ = emit_fwd_func_decl; + tvm::codegen::CodeGenC::AddFunction(gvar, func); + if (func->HasNonzeroAttr(tvm::tir::attr::kIsEntryFunc) && !has_main_func_) { + ICHECK(global_symbol.has_value()) + << "CodeGenCHost: The entry func must have the global_symbol " + "attribute, " + << "but function " << gvar << " only has attributes " << func->attrs; + function_names_.push_back(tvm::ffi::symbol::tvm_ffi_main); + stream << "// CodegenC: NOTE: Auto-generated entry function\n"; + PrintFuncPrefix(stream); + PrintType(func->ret_type, stream); + stream << " " << tvm::ffi::symbol::tvm_ffi_main + << "(void* self, void* args,int num_args, void* result) {\n"; + stream << " return " << static_cast(global_symbol.value()) + << "(self, args, num_args, result);\n"; + stream << "}\n"; + has_main_func_ = true; + } +} + +void CodeGenCHost::GenerateForwardFunctionDeclarations( + tvm::ffi::String global_symbol, const tvm::ffi::Array &arg_types, + const tvm::Type &ret_type) { + if (!emit_fwd_func_decl_) { + return; + } + for (auto &func_already_defined : GetFunctionNames()) { + if (global_symbol == func_already_defined) { + return; + } + } + this->PrintFuncPrefix(fwd_decl_stream); + this->PrintType(ret_type, fwd_decl_stream); + fwd_decl_stream << " " << global_symbol << "("; + for (size_t i = 0; i < arg_types.size(); ++i) { + if (i > 0) { + fwd_decl_stream << ", "; + } + tvm::codegen::CodeGenSourceBase::PrintType(arg_types[i], fwd_decl_stream); + } + fwd_decl_stream << ");\n"; +} + +void CodeGenCHost::PrintFuncPrefix(std::ostream &os) { // NOLINT(*) + os << "#ifdef __cplusplus\n" + << "extern \"C\"\n" + << "#endif\n"; +} + +void CodeGenCHost::PrintType(tvm::DataType t, std::ostream &os) { // NOLINT(*) + int lanes = t.lanes(); + if (t.is_handle()) { + ICHECK_EQ(lanes, 1) << "does not support vector types"; + os << "void*"; + return; + } + if (t.is_void()) { + os << "void"; + return; + } + if (t == tvm::DataType::Bool()) { + os << "bool"; + return; + } + bool fail = false; + if (t.is_float()) { + switch (t.bits()) { + case 16: + os << "half"; + break; + case 32: + os << "float"; + break; + case 64: + os << "double"; + break; + default: + fail = true; + break; + } + if (!fail && lanes == 1) + return; + if (!fail && (lanes >= 2 && lanes <= 16)) { + os << lanes; + return; + } + } + if (t.is_bfloat16()) { + os << "__bf16"; + return; + } + if (t.is_int() || t.is_uint()) { + if (t.is_uint()) { + os << 'u'; + } + switch (t.bits()) { + case 8: + os << "int8_t"; + break; + case 16: + os << "int16_t"; + break; + case 32: + os << "int32_t"; + break; + case 64: + os << "int64_t"; + break; + case 1: + os << "int32_t"; + break; + default: + fail = true; + break; + } + if (!fail && lanes == 1) + return; + if (!fail && (lanes >= 2 && lanes <= 16)) { + os << lanes; + return; + } + } + LOG(FATAL) << "Cannot convert type " << t << " to C type"; +} + +void CodeGenCHost::VisitExpr_(const tvm::tir::BroadcastNode *op, + std::ostream &os) { // NOLINT(*) + std::string v = PrintExpr(op->value); + int lanes = op->dtype.lanes(); + os << "(("; + PrintType(op->dtype, os); + os << ")("; + for (int i = 0; i < lanes; ++i) { + if (i != 0) + os << ", "; + os << v; + } + os << "))"; +} + +void CodeGenCHost::PrintGetFuncFromBackend( + const std::string &func_name, const std::string &packed_func_name) { + this->PrintIndent(); + this->stream << "if (" << packed_func_name << " == NULL) {\n"; + int packed_func_if_scope = this->BeginScope(); + this->PrintIndent(); + this->stream << "if (TVMBackendGetFuncFromEnv(" << module_name_ << ", \"" + << func_name << "\"" + << ", &" << packed_func_name << ") != 0) {\n"; + int get_func_env_scope = this->BeginScope(); + this->PrintIndent(); + this->stream << "return -1;\n"; + this->EndScope(get_func_env_scope); + this->PrintIndent(); + this->stream << "}\n"; + this->EndScope(packed_func_if_scope); + this->PrintIndent(); + this->stream << "}\n"; +} + +void CodeGenCHost::PrintCallPacked(const tvm::tir::CallNode *op) { + using namespace tvm::tir; + const StringImmNode *func_name = op->args[0].as(); + ICHECK(func_name != nullptr) + << "tvm_call_[c]packed_lowered expects first argument as function name"; + int64_t begin = op->args[2].as()->value; + int64_t end = op->args[3].as()->value; + int64_t num_args = end - begin; + ICHECK_GE(num_args, 0); + + std::string packed_func_name; + if (op->op.same_as(builtin::tvm_call_packed_lowered())) { + packed_func_name = GetPackedName(op); + this->PrintGetFuncFromBackend(func_name->value, packed_func_name); + } else { + // directly use the original symbol + ICHECK(op->op.same_as(builtin::tvm_call_cpacked_lowered())); + packed_func_name = + tvm::ffi::symbol::tvm_ffi_symbol_prefix + func_name->value; + } + + std::string args_stack = PrintExpr(op->args[1]); + this->PrintIndent(); + std::string result = name_supply_->FreshName("result"); + this->stream << "TVMFFIAny " << result << ";\n"; + this->PrintIndent(); + // must make sure type_index is set to none + this->stream << result << ".type_index = kTVMFFINone;\n"; + this->PrintIndent(); + this->stream << result << ".zero_padding = 0;\n"; + this->PrintIndent(); + this->stream << result << ".v_int64 = 0;\n"; + this->PrintIndent(); + if (op->op.same_as(builtin::tvm_call_packed_lowered())) { + this->stream << "if (TVMFFIFunctionCall(" << packed_func_name << ", "; + } else { + this->stream << "if (" << packed_func_name << "(NULL, "; + } + this->stream << "(TVMFFIAny*) " << args_stack << ", " << num_args << ", " + << "&" << result << ") != 0) {\n"; + int func_call_scope = this->BeginScope(); + this->PrintIndent(); + this->stream << "return -1;\n"; + this->EndScope(func_call_scope); + this->PrintIndent(); + this->stream << "}\n"; +} + +std::string CodeGenCHost::GetPackedName(const tvm::tir::CallNode *op) { + using namespace tvm::tir; + const StringImmNode *s = op->args[0].as(); + ICHECK(s != nullptr) + << "tvm_call_packed_lowered expects first argument as function name"; + std::string func_name = s->value; + std::string packed_func_name = func_name + "_packed"; + std::string unique_name; + auto it = declared_globals_.find(packed_func_name); + if (it != declared_globals_.end()) { + unique_name = it->second; + } else { + unique_name = name_supply_->FreshName(packed_func_name); + declared_globals_[packed_func_name] = unique_name; + decl_stream << "static void* " << unique_name << " = NULL;\n"; + } + return unique_name; +} + +void CodeGenCHost::VisitExpr_(const tvm::tir::CallNode *op, + std::ostream &os) { // NOLINT(*) + using namespace tvm::tir; + if (op->op.same_as(builtin::tvm_stack_alloca())) { + std::string stack_name = name_supply_->FreshName("stack"); + const std::string &type = op->args[0].as()->value; + const IntImmNode *num = op->args[1].as(); + ICHECK(num != nullptr); + static_assert(alignof(TVMFFIAny) % alignof(DLTensor) == 0, "invariant"); + size_t unit = sizeof(TVMFFIAny); + size_t size = 0; + if (type == "shape") { + size = (num->value * sizeof(ffi::Shape::index_type) + unit - 1) / unit; + } else if (type == "tvm_ffi_any") { + size = (num->value * sizeof(TVMFFIAny) + unit - 1) / unit; + } else if (type == "array") { + size = (num->value * sizeof(DLTensor) + unit - 1) / unit; + } else { + LOG(FATAL) << "Unknown stack alloca type " << type; + } + this->PrintIndent(); + this->stream << "TVMFFIAny " << stack_name << "[" << size << "];\n"; + os << stack_name; + } else if (op->op.same_as(builtin::tvm_call_packed_lowered())) { + this->PrintCallPacked(op); + } else if (op->op.same_as(builtin::tvm_call_cpacked_lowered())) { + this->PrintCallPacked(op); + } else if (op->op.same_as(builtin::tvm_throw_last_error())) { + this->PrintIndent(); + this->stream << "return -1;\n"; + } else { + tvm::codegen::CodeGenC::VisitExpr_(op, os); + } +} + +void CodeGenCHost::VisitStmt_(const tvm::tir::AssertStmtNode *op) { // NOLINT(*) + using namespace tvm::tir; + if (emit_asserts_) { + std::string cond = PrintExpr(op->condition); + PrintIndent(); + stream << "if (!(" << cond << ")) {\n"; + int assert_if_scope = this->BeginScope(); + { + // Prepare the base error message + const auto *msg_node = op->message.as(); + ICHECK(msg_node != nullptr) << "Assert message expected to be StringImm"; + const std::string &raw_msg = msg_node->value; + const std::string esc_msg = tvm::support::StrEscape( + raw_msg.c_str(), raw_msg.length(), /*use_octal_escape=*/true, + /*escape_whitespace_special_chars=*/true); + + // If the assertion condition contains any equality checks anywhere + // in a composite boolean expression, append the actual LHS/RHS values + // Collect all EQ nodes within the condition (including inside And/Or/Not) + std::vector eq_nodes; + { + std::vector stk; + stk.push_back(op->condition); + while (!stk.empty()) { + PrimExpr cur = stk.back(); + stk.pop_back(); + if (const auto *eq = cur.as()) { + eq_nodes.push_back(eq); + continue; + } + if (const auto *an = cur.as()) { + stk.push_back(an->a); + stk.push_back(an->b); + continue; + } + if (const auto *on = cur.as()) { + stk.push_back(on->a); + stk.push_back(on->b); + continue; + } + if (const auto *nn = cur.as()) { + stk.push_back(nn->a); + continue; + } + } + } + + if (!eq_nodes.empty()) { + // Build a single detailed message that includes all LHS/RHS pairs + PrintIndent(); + stream << "char __tvm_assert_msg_buf[1024];\n"; + PrintIndent(); + stream << "int __tvm_assert_msg_len = snprintf(__tvm_assert_msg_buf, " + "sizeof(__tvm_assert_msg_buf), \"%s\", \"" + << esc_msg << "\");\n"; + + auto escape_for_printf_literal = [&](const std::string &s) { + std::string out; + out.reserve(s.size()); + for (char c : s) { + if (c == '%') { + out += "%%"; + } else if (c == '"') { + out += "\\\""; + } else if (c == '\\') { + out += "\\\\"; + } else { + out.push_back(c); + } + } + return out; + }; + + for (const auto *eq : eq_nodes) { + std::string lhs = PrintExpr(eq->a); + std::string rhs = PrintExpr(eq->b); + std::string lhs_disp = escape_for_printf_literal(lhs); + std::string rhs_disp = escape_for_printf_literal(rhs); + PrintIndent(); + stream << "__tvm_assert_msg_len += snprintf(__tvm_assert_msg_buf + " + "__tvm_assert_msg_len, " + "sizeof(__tvm_assert_msg_buf) - __tvm_assert_msg_len, \"; (" + << lhs_disp << " == " << rhs_disp + << ") got: %lld, expected: %lld\", (long long)(" << lhs + << "), (long long)(" << rhs << "));\n"; + } + PrintIndent(); + stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", " + "__tvm_assert_msg_buf);\n"; + } else { + // Fallback: just emit the base message + PrintIndent(); + stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", \"" << esc_msg + << "\");\n"; + } + } + PrintIndent(); + stream << "return -1;\n"; + this->EndScope(assert_if_scope); + PrintIndent(); + stream << "}\n"; + } + this->PrintStmt(op->body); +} + +void CodeGenCHost::VisitExpr_(const tvm::tir::MinNode *op, + std::ostream &os) { // NOLINT(*) + PrintTernaryCondExpr(op, "<", os); +} + +void CodeGenCHost::VisitExpr_(const tvm::tir::MaxNode *op, + std::ostream &os) { // NOLINT(*) + PrintTernaryCondExpr(op, ">", os); +} + +template +inline void CodeGenCHost::PrintTernaryCondExpr(const T *op, const char *compare, + std::ostream &os) { // NOLINT(*) + std::ostringstream temp_a; + VisitExpr(op->a, temp_a); + std::string a_id = SSAGetID(temp_a.str(), op->a.dtype()); + std::ostringstream temp_b; + VisitExpr(op->b, temp_b); + std::string b_id = SSAGetID(temp_b.str(), op->b.dtype()); + + os << "((" << a_id << ") " << compare << " (" << b_id << ") " + << "? (" << a_id << ") : (" << b_id << "))"; +} + +} // namespace tl +} // namespace tvm + +namespace tvm { +namespace tl { + +using tvm::codegen::CodeGenSourceBase; +using tvm::codegen::CSourceModuleCreate; +using tvm::ffi::Array; +using tvm::ffi::Map; +using tvm::ffi::Module; +using tvm::ffi::String; + +// Build function that mirrors TVM's host C codegen, registered under a +// TileLang-specific name. +::tvm::ffi::Module BuildTileLangCHost(::tvm::IRModule mod, + ::tvm::Target target) { + bool output_ssa = false; + bool emit_asserts = true; + bool emit_fwd_func_decl = true; + + std::unordered_set devices; + if (mod->GetAttr<::tvm::ffi::Map<::tvm::GlobalVar, ::tvm::ffi::String>>( + "device_contexts") != nullptr) { + ::tvm::ffi::Map<::tvm::GlobalVar, ::tvm::ffi::String> device_contexts = + mod->GetAttr<::tvm::ffi::Map<::tvm::GlobalVar, ::tvm::ffi::String>>( + "device_contexts") + .value(); + for (auto const &context : device_contexts) { + devices.insert(context.second.data()); + } + } + + CodeGenCHost cg; + cg.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices); + cg.SetConstantsByteAlignment( + target->GetAttr<::tvm::Integer>("constants-byte-alignment").value_or(16)); + + auto is_aot_executor_fn = [](::tvm::tir::PrimFunc const &func) -> bool { + return func->GetAttr<::tvm::Bool>("runner_function", ::tvm::Bool(false)) + .value(); + }; + + std::vector> funcs; + for (auto [gvar, base_func] : mod->functions) { + ICHECK(base_func->IsInstance<::tvm::tir::PrimFuncNode>()) + << "CodegenCHost: Can only take PrimFunc"; + auto prim_func = ::tvm::Downcast<::tvm::tir::PrimFunc>(base_func); + funcs.push_back({gvar, prim_func}); + } + + auto sort_key = [&is_aot_executor_fn](const auto &kv) { + return std::tuple{is_aot_executor_fn(kv.second), kv.first->name_hint}; + }; + std::sort(funcs.begin(), funcs.end(), + [&sort_key](const auto &kv_a, const auto &kv_b) { + return sort_key(kv_a) < sort_key(kv_b); + }); + + for (const auto &[gvar, prim_func] : funcs) { + cg.DeclareFunction(gvar, prim_func); + } + + for (const auto &[gvar, prim_func] : funcs) { + cg.AddFunction(gvar, prim_func, emit_fwd_func_decl); + } + + std::string code = cg.Finish(); + return ::tvm::codegen::CSourceModuleCreate(code, "c", cg.GetFunctionNames()); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("target.build.tilelang_c", BuildTileLangCHost); +} + +} // namespace tl +} // namespace tvm diff --git a/src/target/codegen_c_host.h b/src/target/codegen_c_host.h new file mode 100644 index 00000000..8d54cb4a --- /dev/null +++ b/src/target/codegen_c_host.h @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file codegen_c_host.h + * \brief Generate C host code (TileLang copy). + */ +#ifndef TL_TARGET_SOURCE_CODEGEN_C_HOST_H_ +#define TL_TARGET_SOURCE_CODEGEN_C_HOST_H_ + +#include +#include +#include +#include +#include + +#include "target/source/codegen_c.h" +#include "tvm/target/codegen.h" +#include "tvm/tir/expr.h" + +namespace tvm { +namespace tl { + +// TileLang copy of TVM's CodeGenCHost, under the tl namespace. +// Inherits from tvm::codegen::CodeGenC. +class CodeGenCHost : public tvm::codegen::CodeGenC { +public: + CodeGenCHost(); + void Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl, + std::string target_str, + const std::unordered_set &devices); + + void InitGlobalContext(); + + void AddFunction(const tvm::GlobalVar &gvar, + const tvm::tir::PrimFunc &f) override; + void AddFunction(const tvm::GlobalVar &gvar, const tvm::tir::PrimFunc &f, + bool emit_fwd_func_decl); + /*! + * \brief Add functions from the (unordered) range to the current module in a + * deterministic order. This helps with debugging. + * + * \param functions A vector of unordered range of current module. + */ + void AddFunctionsOrdered( + std::vector> functions); + void DefineModuleName(); + + using tvm::codegen::CodeGenC::PrintType; + void PrintType(tvm::DataType t, std::ostream &os) final; // NOLINT(*) + void PrintFuncPrefix(std::ostream &os) final; // NOLINT(*) + + // overload visitor functions + void VisitExpr_(const tvm::tir::BroadcastNode *op, + std::ostream &os) final; // NOLINT(*) + void VisitExpr_(const tvm::tir::CallNode *op, + std::ostream &os) override; // NOLINT(*) + // overload min and max to use the ternary operator, so we don't rely on the + // standard library implementations + void VisitExpr_(const tvm::tir::MinNode *op, + std::ostream &os) final; // NOLINT(*) + void VisitExpr_(const tvm::tir::MaxNode *op, + std::ostream &os) final; // NOLINT(*) + + void VisitStmt_(const tvm::tir::AssertStmtNode *op) final; // NOLINT(*) + + void GenerateForwardFunctionDeclarations( + tvm::ffi::String global_symbol, + const tvm::ffi::Array &arg_types, + const tvm::Type &ret_type) override; + tvm::ffi::Array GetFunctionNames() { + return function_names_; + } + +private: + std::string module_name_; + /* \brief mapping global packed func to the unique name */ + std::unordered_map declared_globals_; + /* \brief names of the functions declared in this module */ + tvm::ffi::Array function_names_; + /*! \brief whether to emit asserts in the resulting C code */ + bool emit_asserts_; + /*! \brief whether to emit forwared function declarations in the resulting C + * code */ + bool emit_fwd_func_decl_; + /*! \brief whether to generate the entry function if encountered */ + bool has_main_func_ = false; + + std::string GetPackedName(const tvm::tir::CallNode *op); + void PrintGetFuncFromBackend(const std::string &func_name, + const std::string &packed_func_name); + void PrintCallPacked(const tvm::tir::CallNode *op); + /*! + * \brief Print ternary conditional operator implementing binary `op` + * Forces the operands to be in SSA form. + * \param op binary operator being expressed + * \param compare string representation of comparison operator + * \param os stream reference to print into + */ + template + inline void PrintTernaryCondExpr(const T *op, const char *compare, + std::ostream &os); // NOLINT(*) +}; + +} // namespace tl +} // namespace tvm + +#endif // TL_TARGET_SOURCE_CODEGEN_C_HOST_H_ diff --git a/src/target/codegen_cpp.cc b/src/target/codegen_cpp.cc index 9accf530..975f9a48 100644 --- a/src/target/codegen_cpp.cc +++ b/src/target/codegen_cpp.cc @@ -203,12 +203,12 @@ void CodeGenTileLangCPP::PrintFuncCall(const std::string &packed_func_name, this->PrintIndent(); std::string ret_val = name_supply_->FreshName("ret_val"); std::string ret_type_code = name_supply_->FreshName("ret_type_code"); - this->stream << "TVMValue " << ret_val << ";\n"; + this->stream << "TVMFFIAny " << ret_val << ";\n"; this->PrintIndent(); this->stream << "int " << ret_type_code << ";\n"; this->PrintIndent(); this->stream << "if (TVMFuncCall(" << packed_func_name << ", " - << "(TVMValue*) stack_value" + << "(TVMFFIAny*) stack_value" << ", " << "(int*) stack_tcode" << ", " << num_args << ", " @@ -228,13 +228,13 @@ void CodeGenTileLangCPP::PrintFuncCallC( this->PrintIndent(); std::string ret_val = name_supply_->FreshName("ret_val"); std::string ret_type_code = name_supply_->FreshName("ret_type_code"); - this->stream << "TVMValue " << ret_val << ";\n"; + this->stream << "TVMFFIAny " << ret_val << ";\n"; this->PrintIndent(); this->stream << "int " << ret_type_code << ";\n"; this->PrintIndent(); this->stream << "if (" << packed_func_name << "( " - << "(TVMValue*) stack_value " + << "(TVMFFIAny*) stack_value " << ", " << "(int*) stack_tcode" << ", " << num_args << ", " diff --git a/src/target/rt_mod_cuda.cc b/src/target/rt_mod_cuda.cc index bb69170f..cbef0e64 100644 --- a/src/target/rt_mod_cuda.cc +++ b/src/target/rt_mod_cuda.cc @@ -24,7 +24,11 @@ ExtractFuncInfo(const IRModule &mod) { continue; } } - info.arg_types.push_back(f->params[i].dtype()); + DataType dtype = f->params[i].dtype(); + // Device runtime cannot directly take bool arguments, map to int32. + if (dtype.is_bool()) + dtype = DataType::Int(32); + info.arg_types.push_back(dtype); } if (auto opt = f->GetAttr>( tir::attr::kKernelLaunchParams)) { diff --git a/src/target/rt_mod_hip.cc b/src/target/rt_mod_hip.cc index 50991d63..1e5c689c 100644 --- a/src/target/rt_mod_hip.cc +++ b/src/target/rt_mod_hip.cc @@ -35,7 +35,11 @@ ExtractFuncInfo(const IRModule &mod) { continue; } } - info.arg_types.push_back(f->params[i].dtype()); + DataType dtype = f->params[i].dtype(); + // Device runtime cannot directly take bool arguments, map to int32. + if (dtype.is_bool()) + dtype = DataType::Int(32); + info.arg_types.push_back(dtype); } if (auto opt = f->GetAttr>( tir::attr::kKernelLaunchParams)) { diff --git a/src/transform/arg_binder.cc b/src/transform/arg_binder.cc index 7df6d0cc..6a0909b8 100644 --- a/src/transform/arg_binder.cc +++ b/src/transform/arg_binder.cc @@ -51,6 +51,43 @@ void BinderAddAssert(arith::Analyzer *ana, PrimExpr cond, } } +bool ArgBinder::BindNullable(const PrimExpr &arg, const PrimExpr &value, + const std::string &arg_name, bool with_lets, + const PrimExpr &nullable_guard) { + // Currently only used in BindDLTensor, nullable_guard is already a defined + // bool, so use it directly. + auto MakeGuarded = [&](PrimExpr basic) -> PrimExpr { + // is_null || basic + return Or(nullable_guard, basic); + }; + + ICHECK_EQ(arg.dtype(), value.dtype()) << "arg " << arg << " value " << value; + if (const VarNode *v = arg.as()) { + auto it = def_map_->find(v); + if (it == def_map_->end()) { + // First time binding: identical behavior as Bind_ + Var v_arg = Downcast(arg); + defs_.emplace_back(v_arg); + if (with_lets) { + (*def_map_)[v] = arg; + init_nest_.emplace_back(LetStmt(v_arg, value, Evaluate(0))); + } else { + (*def_map_)[v] = value; + } + return true; + } else { + // Second or later binding: add is_null short-circuit + PrimExpr cond = MakeGuarded(it->second == value); + BinderAddAssert(&analyzer_, cond, arg_name, &asserts_); + } + } else { + // For non-Var expressions, also add is_null short-circuit + PrimExpr cond = MakeGuarded(arg == value); + BinderAddAssert(&analyzer_, cond, arg_name, &asserts_); + } + return false; +} + bool ArgBinder::Bind_(const PrimExpr &arg, const PrimExpr &value, const std::string &arg_name, bool with_lets) { ICHECK_EQ(arg.dtype(), value.dtype()) << "arg " << arg << " value " << value; @@ -96,8 +133,30 @@ void ArgBinder::BindBuffer(const Buffer &arg, const Buffer &value, const std::string &arg_name, bool fuzzy_match) { ICHECK_EQ(arg.scope(), value.scope()) << "Argument " << arg_name << " Buffer bind scope mismatch"; - ICHECK_EQ(arg->dtype, value->dtype) - << "Argument " << arg_name << " Buffer bind data type mismatch"; + // Relax dtype check to allow FP8 E4M3 variants to bind together. + auto dtype_compatible = [](DataType expected, DataType provided) -> bool { + if (expected == provided) + return true; + // If expected is float8_e4m3, allow float8_e4m3fn/float8_e4m3fnuz as well. + if (expected.is_float8_e4m3()) { + return provided.is_float8_e4m3() || provided.is_float8_e4m3fn() || + provided.is_float8_e4m3fnuz(); + } + // If expected is float8_e5m2, allow float8_e5m2fnuz as well. + if (expected.is_float8_e5m2()) { + return provided.is_float8_e5m2() || provided.is_float8_e5m2fnuz(); + } + // If expected is bool, allow binding from int8/uint8 with same lanes. + if (expected.is_bool()) { + bool is_i8 = provided.is_int() && provided.bits() == 8; + bool is_u8 = provided.is_uint() && provided.bits() == 8; + return (is_i8 || is_u8) && expected.lanes() == provided.lanes(); + } + return false; + }; + ICHECK(dtype_compatible(arg->dtype, value->dtype)) + << "Argument " << arg_name << " Buffer bind data type mismatch: expected " + << arg->dtype << ", got " << value->dtype; if (value->data_alignment % arg->data_alignment != 0) { LOG(WARNING) << "Trying to bind buffer to another one with lower alignment " "requirement " @@ -167,10 +226,15 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, const DataType tvm_ndim_type = DataType::Int(32); const Stmt nop = Evaluate(0); - init_nest_.emplace_back(AssertStmt( - !Call(DataType::Bool(), builtin::isnullptr(), {handle}), - StringImm(arg_name + " is expected to have non-NULL DLTensor* pointer"), - nop)); + // Allow NULL DLTensor* for optional inputs. When the handle is NULL, + // avoid dereferencing it by using expression-level conditionals and + // short-circuiting guards in asserts. Cache the null check in a Let-bound + // boolean so codegen does not repeat `(handle == NULL)` everywhere. + Var is_null_var(arg_name + "_is_null", DataType::Bool()); + init_nest_.emplace_back( + LetStmt(is_null_var, + Call(DataType::Bool(), builtin::isnullptr(), {handle}), nop)); + const PrimExpr &is_null = is_null_var; // dimension checks PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim); @@ -193,25 +257,91 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, PrimExpr a_ndim = make_const(tvm_ndim_type, static_cast(buffer->shape.size())); std::ostringstream ndim_err_msg; + // Note: We cannot embed runtime values into the message string. + // Keep message human-friendly without printing TIR exprs. ndim_err_msg << arg_name << ".ndim is expected to equal " - << buffer->shape.size(); + << buffer->shape.size() << ", but got mismatched ndim"; auto msg = StringImm(ndim_err_msg.str()); - init_nest_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop)); + // Only check ndim when handle is non-NULL (using short-circuit OR) + v_ndim = tvm::if_then_else(Not(is_null), v_ndim, make_zero(tvm_ndim_type)); + init_nest_.emplace_back(AssertStmt(Or(is_null, a_ndim == v_ndim), msg, nop)); // type checks std::ostringstream type_err_msg; - type_err_msg << arg_name << ".dtype is expected to be " << buffer->dtype; - PrimExpr cond = - (TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeCode) == - IntImm(DataType::UInt(8), buffer->dtype.code()) && - TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeBits) == - IntImm(DataType::UInt(8), buffer->dtype.bits()) && - TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes) == - IntImm(DataType::UInt(16), buffer->dtype.lanes())); + // Avoid dumping TIR expressions in error text; just state mismatch. + // Include expected dtype triplet for clarity. + type_err_msg << arg_name << ".dtype is expected to be " << buffer->dtype + << ", but got incompatible dtype"; + // Guard all dtype field loads by `is_null` using if_then_else + PrimExpr v_type_code = tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeCode), + IntImm(DataType::UInt(8), buffer->dtype.code())); + PrimExpr v_type_bits = tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeBits), + IntImm(DataType::UInt(8), buffer->dtype.bits())); + PrimExpr v_type_lanes = tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes), + IntImm(DataType::UInt(16), buffer->dtype.lanes())); + PrimExpr expect_code = IntImm(DataType::UInt(8), buffer->dtype.code()); + PrimExpr expect_bits = IntImm(DataType::UInt(8), buffer->dtype.bits()); + PrimExpr expect_lanes = IntImm(DataType::UInt(16), buffer->dtype.lanes()); + + PrimExpr cond = (v_type_code == expect_code && v_type_bits == expect_bits && + v_type_lanes == expect_lanes); + + // Allow float8_e4m3 to match float8_e4m3fn/float8_e4m3fnuz at runtime. + if (buffer->dtype.is_float8_e4m3()) { + PrimExpr code_e4m3 = IntImm(DataType::UInt(8), DataType::kFloat8_e4m3); + PrimExpr code_e4m3fn = IntImm(DataType::UInt(8), DataType::kFloat8_e4m3fn); + PrimExpr code_e4m3fnuz = + IntImm(DataType::UInt(8), DataType::kFloat8_e4m3fnuz); + PrimExpr code_match = + (v_type_code == code_e4m3 || v_type_code == code_e4m3fn || + v_type_code == code_e4m3fnuz); + cond = cond || (code_match && v_type_bits == expect_bits && + v_type_lanes == expect_lanes); + } + // Allow float8_e5m2 to match float8_e5m2fnuz at runtime. + if (buffer->dtype.is_float8_e5m2()) { + PrimExpr code_e5m2 = IntImm(DataType::UInt(8), DataType::kFloat8_e5m2); + PrimExpr code_e5m2fnuz = + IntImm(DataType::UInt(8), DataType::kFloat8_e5m2fnuz); + PrimExpr code_match = + (v_type_code == code_e5m2 || v_type_code == code_e5m2fnuz); + cond = cond || (code_match && v_type_bits == expect_bits && + v_type_lanes == expect_lanes); + } + // Allow bool to match int8/uint8 at runtime, and also kDLBool(code=6). + if (buffer->dtype.is_bool()) { + PrimExpr code_int = IntImm(DataType::UInt(8), DataType::kInt); + PrimExpr code_uint = IntImm(DataType::UInt(8), DataType::kUInt); + PrimExpr code_kdlbool = IntImm(DataType::UInt(8), 6); + PrimExpr bits8 = IntImm(DataType::UInt(8), 8); + PrimExpr bits1 = IntImm(DataType::UInt(8), 1); + PrimExpr lanes_ok = (v_type_lanes == expect_lanes); + PrimExpr int8_ok = + (v_type_code == code_int && v_type_bits == bits8 && lanes_ok); + PrimExpr uint8_ok = + (v_type_code == code_uint && v_type_bits == bits8 && lanes_ok); + // Some frontends may tag bool tensors as kDLBool(code=6), commonly with + // bits=8 or bits=1. + PrimExpr kdlbool8_ok = + (v_type_code == code_kdlbool && v_type_bits == bits8 && lanes_ok); + PrimExpr kdlbool1_ok = + (v_type_code == code_kdlbool && v_type_bits == bits1 && lanes_ok); + // Also accept any dtype whose bitwidth=1, regardless of code, to be + // defensive. + PrimExpr bit1_ok = (v_type_bits == bits1 && lanes_ok); + cond = cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok || bit1_ok; + } if (!(buffer->dtype == DataType::Int(1) || buffer->dtype == DataType::Int(4) || buffer->dtype == DataType::UInt(4))) { auto type_msg = StringImm(type_err_msg.str()); - asserts_.emplace_back(AssertStmt(cond, type_msg, nop)); + // Only check dtype when handle is non-NULL (short-circuit) + asserts_.emplace_back(AssertStmt(Or(is_null, cond), type_msg, nop)); } // shape field @@ -220,32 +350,70 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, tvm_shape_type, shape_handle_name()); Var v_shape(shape_handle_name(), DataType::Handle()); def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0)); - init_nest_.emplace_back(LetStmt( - buf_shape->data, - TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape), nop)); + // Use if_then_else for NULL guard on the shape pointer itself, avoiding + // dereferencing TVMStructGet(handle, kArrShape) when handle is NULL. + init_nest_.emplace_back( + LetStmt(buf_shape->data, + tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape), + make_zero(DataType::Handle())), + nop)); init_nest_.emplace_back(DeclBuffer(buf_shape, nop)); + for (size_t k = 0; k < buffer->shape.size(); ++k) { + // These packed-bit dtype shapes were not bound in the original + // implementation, so we just use them as is. if (buffer->dtype == DataType::Int(4) || buffer->dtype == DataType::UInt(4) || buffer->dtype == DataType::Int(1)) { break; } - Bind_(buffer->shape[k], - cast(buffer->shape[k].dtype(), - BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})), - shape_element_name(k), true); + + // The "real" runtime shape value read from DLTensor + PrimExpr raw_shape_val = + cast(buffer->shape[k].dtype(), + BufferLoad(buf_shape, + {IntImm(DataType::Int(32), static_cast(k))})); + + // Bind to the value of the symbolic dimension (e.g., m) in TIR, with an + // is_null guard: + // handle is NULL → use 0, placeholder but no dereference + // handle non-NULL → actually read from DLTensor's shape array + PrimExpr bound_shape_val = tvm::if_then_else( + is_null, make_zero(buffer->shape[k].dtype()), raw_shape_val); + + // When first encountering a Var (e.g., m), this will generate: + // Let(m, bound_shape_val, ...) + // Constant dimensions will only generate consistency assertions. + BindNullable(buffer->shape[k], bound_shape_val, shape_element_name(k), true, + is_null); + + // Keep an explicit "consistency check": when non-NULL, the symbolic + // dimension must equal the DLTensor's shape. + Stmt shape_check = AssertStmt( + Or(is_null, buffer->shape[k] == raw_shape_val), + StringImm(shape_element_name(k) + " mismatch with DLTensor shape"), + Evaluate(0)); + asserts_.emplace_back(shape_check); } + // strides field Buffer buf_strides = decl_buffer({IntImm(DataType::Int(32), buffer->strides.size())}, tvm_shape_type, arg_name + ".strides"); def_handle_dtype_.Set(buf_strides->data, tir::TypeAnnotation(tvm_shape_type)); - init_nest_.emplace_back(LetStmt( - buf_strides->data, - TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop)); + init_nest_.emplace_back( + LetStmt(buf_strides->data, + tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), + make_zero(DataType::Handle())), + nop)); init_nest_.emplace_back(DeclBuffer(buf_strides, nop)); PrimExpr v_strides_is_null = Call(DataType::Bool(1), builtin::isnullptr(), {buf_strides->data}); + if (buffer->strides.empty()) { // Assert the buffer is compact DataType stype = buffer->DefaultIndexType(); @@ -253,13 +421,16 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, ffi::Array conds; for (size_t i = buffer->shape.size(); i != 0; --i) { size_t k = i - 1; - PrimExpr svalue = - cast(stype, BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); + PrimExpr svalue = cast( + stype, BufferLoad(buf_strides, + {IntImm(DataType::Int(32), static_cast(k))})); conds.push_back(buffer->shape[k] == 1 || expect_stride == svalue); expect_stride = expect_stride * buffer->shape[k]; } std::ostringstream stride_err_msg; - stride_err_msg << stride_handle_name() << ": expected to be compact array"; + stride_err_msg + << stride_handle_name() + << ": expected to be compact array, but got non-compact strides"; if (!conds.empty()) { auto stride_msg = StringImm(stride_err_msg.str()); Stmt check = @@ -267,6 +438,7 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, Span span) { return logical_and(a, b, span); }, const_true(1), conds), stride_msg, Evaluate(0)); + // Only check when strides array is actually present at runtime check = IfThenElse(Not(v_strides_is_null), check); asserts_.emplace_back(SeqStmt({check, Evaluate(0)})); } @@ -277,13 +449,27 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, DataType stride_dtype = buffer->strides[k].dtype(); PrimExpr explicit_stride = cast(stride_dtype, - BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); + BufferLoad(buf_strides, + {IntImm(DataType::Int(32), static_cast(k))})); PrimExpr stride_from_shape_cast = cast(stride_dtype, stride_from_shape); - PrimExpr value = tvm::if_then_else( + + PrimExpr core_value = tvm::if_then_else( v_strides_is_null, stride_from_shape_cast, explicit_stride); - value = tvm::if_then_else(buffer->shape[k] == 1, make_zero(stride_dtype), - value); - Bind_(buffer->strides[k], value, stride_element_name(k), true); + core_value = tvm::if_then_else(buffer->shape[k] == 1, + make_zero(stride_dtype), core_value); + + // Bind like shape: define var when needed, and only assert when non-NULL + PrimExpr bound_stride_val = + tvm::if_then_else(is_null, make_zero(stride_dtype), core_value); + BindNullable(buffer->strides[k], bound_stride_val, stride_element_name(k), + true, is_null); + + Stmt stride_check = AssertStmt( + Or(is_null, buffer->strides[k] == core_value), + StringImm(stride_element_name(k) + " mismatch with DLTensor strides"), + Evaluate(0)); + asserts_.emplace_back(stride_check); + PrimExpr shape_extent = cast(stride_dtype, buffer->shape[k]); stride_from_shape = analyzer_.Simplify(stride_from_shape_cast * shape_extent); @@ -291,7 +477,7 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, } else { PrimExpr stride_from_shape = make_const(buffer->DefaultIndexType(), 1); - for (int k = buffer->strides.size() - 1; k >= 0; k--) { + for (int k = static_cast(buffer->strides.size()) - 1; k >= 0; --k) { DataType stride_dtype = buffer->strides[k].dtype(); PrimExpr explicit_stride = cast(stride_dtype, @@ -300,75 +486,127 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, stride_dtype, BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})); PrimExpr stride_from_shape_cast = cast(stride_dtype, stride_from_shape); - Bind_(buffer->strides[k], - tvm::if_then_else(v_strides_is_null, stride_from_shape_cast, - explicit_stride), - stride_element_name(k), true); + PrimExpr core_value = tvm::if_then_else( + v_strides_is_null, stride_from_shape_cast, explicit_stride); + + PrimExpr bound_stride_val = + tvm::if_then_else(is_null, make_zero(stride_dtype), core_value); + BindNullable(buffer->strides[k], bound_stride_val, stride_element_name(k), + true, is_null); + + Stmt stride_check = AssertStmt( + Or(is_null, buffer->strides[k] == core_value), + StringImm(stride_element_name(k) + " mismatch with DLTensor strides"), + Evaluate(0)); + asserts_.emplace_back(stride_check); stride_from_shape = analyzer_.Simplify(stride_from_shape_cast * shape_stride); } } + // Byte_offset field. int data_bytes = GetVectorBytes(buffer->dtype); if (const auto *const_offset = buffer->elem_offset.as()) { - Bind_(make_const(DataType::UInt(64), const_offset->value * data_bytes), - TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset), - arg_name + ".byte_offset", true); + // Constant elem_offset: only need consistency check, no need for additional + // Var binding. + PrimExpr actual_byte_offset = tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset), + make_const(DataType::UInt(64), 0)); + PrimExpr expect_byte_offset = + make_const(DataType::UInt(64), const_offset->value * data_bytes); + Stmt byte_off_check = + AssertStmt(Or(is_null, expect_byte_offset == actual_byte_offset), + StringImm(arg_name + ".byte_offset mismatch"), nop); + asserts_.emplace_back(byte_off_check); } else { - if (Bind_(buffer->elem_offset, - cast(buffer->elem_offset.dtype(), - (TVMArrayGet(DataType::UInt(64), handle, - builtin::kArrByteOffset) / - make_const(DataType::UInt(64), data_bytes))), - arg_name + ".elem_offset", true)) { - if (buffer->offset_factor > 1) { - PrimExpr offset = buffer->elem_offset; - PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor); - PrimExpr zero = make_zero(offset.dtype()); - BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero, - arg_name + ".elem_offset", &asserts_); - } + PrimExpr actual_byte_offset = tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset), + make_const(DataType::UInt(64), 0)); + PrimExpr expect_elem_off = + cast(buffer->elem_offset.dtype(), + (actual_byte_offset / make_const(DataType::UInt(64), data_bytes))); + + // Like shape/stride, do NULL-safe binding for elem_offset: + // handle is NULL → 0 + // handle non-NULL → actual_byte_offset / data_bytes + PrimExpr bound_elem_off = tvm::if_then_else( + is_null, make_zero(buffer->elem_offset.dtype()), expect_elem_off); + BindNullable(buffer->elem_offset, bound_elem_off, arg_name + ".elem_offset", + true, is_null); + + // Strict consistency check for non-NULL case + Stmt elem_off_check = + AssertStmt(Or(is_null, buffer->elem_offset == expect_elem_off), + StringImm(arg_name + ".elem_offset mismatch"), nop); + asserts_.emplace_back(elem_off_check); + + if (buffer->offset_factor > 1) { + PrimExpr offset = buffer->elem_offset; + PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor); + PrimExpr zero = make_zero(offset.dtype()); + Stmt off_factor_check = + AssertStmt(Or(is_null, truncmod(offset, factor) == zero), + StringImm(arg_name + ".elem_offset factor mismatch"), nop); + asserts_.emplace_back(off_factor_check); } } + // device info. - Bind_(device_type, - TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceType), - arg_name + ".device_type", true); - Bind_(device_id, - TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceId), - arg_name + ".device_id", true); + // Define device_id from handle when available (so later passes can use it) + PrimExpr actual_dev_type = tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceType), + make_zero(DataType::Int(32))); + PrimExpr actual_dev_id = tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceId), + make_zero(DataType::Int(32))); + // Bind device_id to a safe expression (0 when NULL handle) + BindNullable(device_id, actual_dev_id, arg_name + ".device_id", true, + is_null); + // Check device_type consistency (device_id equality is implicitly ensured by + // binding above) + init_nest_.emplace_back( + AssertStmt(Or(is_null, device_type == actual_dev_type), + StringImm(arg_name + ".device_type mismatch"), nop)); // Data field. Because the validation of the data field may depend // on a dynamic size defined by the other DLTensor* parameters, this // field must be generated last. - if (Bind_(buffer->data, - TVMArrayGet(DataType::Handle(), handle, builtin::kArrData), - arg_name + ".data", true)) { + // Bind data pointer using expression-level guard to avoid deref on NULL. + { Var vptr(buffer->data); + PrimExpr data_ptr = tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::Handle(), handle, builtin::kArrData), + make_zero(DataType::Handle())); + BindNullable(buffer->data, data_ptr, arg_name + ".data", true, is_null); // Check if the data pointer is NULL. This check is skipped for - // size-0 arrays, since CUDA provides a NULL pointer for size-zero - // allocations. + // size-0 arrays and also skipped when handle itself is NULL. auto alloc_size = [&]() -> PrimExpr { PrimExpr product = IntImm(buffer->DefaultIndexType(), 1); - for (const auto &dim : buffer->shape) { + for (const auto &dim : buffer->shape) product *= dim; - } return product; }(); asserts_.emplace_back(AssertStmt( - alloc_size == 0 || - !Call(DataType::Bool(), builtin::isnullptr(), {vptr}), - StringImm(arg_name + " is expected to have non-NULL data pointer"), + Or(is_null, (alloc_size == 0) || + !Call(DataType::Bool(), builtin::isnullptr(), {vptr})), + StringImm(arg_name + + " is expected to have non-NULL data pointer, but got NULL"), nop)); - def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype)); // mark alignment of external bufs init_nest_.emplace_back( AttrStmt(vptr, tir::attr::storage_alignment, IntImm(DataType::Int(32), buffer->data_alignment), nop)); + + def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype)); } } diff --git a/src/transform/arg_binder.h b/src/transform/arg_binder.h index d04e7e9b..cf9f8466 100644 --- a/src/transform/arg_binder.h +++ b/src/transform/arg_binder.h @@ -154,6 +154,10 @@ public: return def_handle_dtype_; } + bool BindNullable(const PrimExpr &arg, const PrimExpr &value, + const std::string &arg_name, bool with_lets, + const PrimExpr &nullable_guard); + private: // Internal bind function bool Bind_(const PrimExpr &arg, const PrimExpr &value, diff --git a/src/transform/lower_hopper_intrin.cc b/src/transform/lower_hopper_intrin.cc index b082a574..e9c848ac 100644 --- a/src/transform/lower_hopper_intrin.cc +++ b/src/transform/lower_hopper_intrin.cc @@ -26,10 +26,13 @@ public: LowerHopperIntrin substituter(disable_shuffle_elect); fptr->body = substituter.VisitStmt(f->body); Map> init_desc_arg_map; + // Collect prologue/epilogue statements for host-side setup/teardown + Array prologue_stmts; + Array epilogue_stmts; for (const auto &[call, var] : substituter.desc_map_) { // Should allocate 128 bytes for TensorMap on stack Call alloc_desc = Call(DataType::Handle(), builtin::tvm_stack_alloca(), - {StringImm("arg_value"), 16}); + {StringImm("tvm_ffi_any"), 16}); Array init_desc_args; if (call->op.same_as(create_tma_descriptor())) { init_desc_args.push_back(StringImm(tvm_tensormap_create_tiled)); @@ -44,11 +47,66 @@ public: // add to function attribute Call init_desc = Call(DataType::Handle(), builtin::tvm_call_packed(), init_desc_args); - fptr->body = - LetStmt(var, alloc_desc, SeqStmt({Evaluate(init_desc), fptr->body})); + // Accumulate TMA descriptor init into prologue + prologue_stmts.push_back(LetStmt(var, alloc_desc, Evaluate(init_desc))); init_desc_arg_map.Set(var, init_desc_args); } f = WithAttr(std::move(f), "tma_descriptor_args", init_desc_arg_map); + + // Additionally, if L2 persistent cache annotations were lowered earlier, + // materialize TVM FFI calls to set the stream access policy window. + if (f->attrs.defined() && f->attrs->dict.count("l2_persistent_map")) { + auto l2_map = + f->GetAttr>>("l2_persistent_map"); + if (l2_map.defined()) { + // Build a lookup from buffer name to Buffer object + std::unordered_map name2buf; + for (const auto &kv : f->buffer_map) { + name2buf.emplace(kv.second->name, kv.second); + } + for (const auto &kv : l2_map.value()) { + const std::string buf_name = kv.first; + const Array &args = kv.second; + if (name2buf.count(buf_name) == 0) { + continue; + } + const Buffer &buf = name2buf.at(buf_name); + // Build base pointer expression (read access) + PrimExpr base_ptr = buf.access_ptr(1); + // Args packed: func_name, base_ptr, num_bytes, hit_ratio + Array packed_args; + packed_args.push_back( + StringImm(tvm_cuda_stream_set_access_policy_window)); + packed_args.push_back(base_ptr); + // size_in_bytes (args[1]) then hit_ratio (args[0]) + ICHECK_GE(args.size(), 2); + packed_args.push_back(args[1]); + packed_args.push_back(args[0]); + prologue_stmts.push_back(Evaluate(Call( + DataType::Int(32), builtin::tvm_call_packed(), packed_args))); + } + // Add a single epilogue call to reset the access policy window and + // restore L2 limit + Array reset_args; + reset_args.push_back( + StringImm(tvm_cuda_stream_reset_access_policy_window)); + epilogue_stmts.push_back(Evaluate( + Call(DataType::Int(32), builtin::tvm_call_packed(), reset_args))); + } + } + + // Stitch prologue statements before the original body + if (!prologue_stmts.empty()) { + // Chain the Let/Evaluate statements sequentially + Stmt seq = prologue_stmts.size() == 1 ? prologue_stmts[0] + : SeqStmt(prologue_stmts); + fptr->body = SeqStmt({seq, fptr->body}); + } + if (!epilogue_stmts.empty()) { + Stmt seq_end = epilogue_stmts.size() == 1 ? epilogue_stmts[0] + : SeqStmt(epilogue_stmts); + fptr->body = SeqStmt({fptr->body, seq_end}); + } return f; } diff --git a/src/transform/make_packed_api.cc b/src/transform/make_packed_api.cc index 545d2403..187a75dc 100644 --- a/src/transform/make_packed_api.cc +++ b/src/transform/make_packed_api.cc @@ -20,6 +20,7 @@ /*! * \file make_packed_api.cc Lower PrimFunc to use the packed function API. */ +#include #include #include #include @@ -32,6 +33,7 @@ #include #include +#include #include #include @@ -43,13 +45,11 @@ namespace tvm { namespace tl { using namespace tir; using namespace ffi; -static constexpr const char *kDeviceContextVar = "device_api_context"; namespace { class ReturnRewriter : public StmtMutator { public: - explicit ReturnRewriter(Var ret_var, Var ret_tcode) - : ret_var_(std::move(ret_var)), ret_tcode_(std::move(ret_tcode)) {} + explicit ReturnRewriter(Var ret_var) : ret_var_(ret_var) {} Stmt VisitStmt_(const ForNode *node) override { if (node->kind == ForKind::kParallel) @@ -79,8 +79,6 @@ private: struct ConvertedInfo { int type_index{-1}; PrimExpr expr; - Buffer dummy_val_buffer; - Buffer dummy_tcode_buffer; }; ConvertedInfo ConvertForFFI(const PrimExpr &val) { @@ -88,7 +86,11 @@ private: // convert val's data type to FFI data type, return type code DataType dtype = val.dtype(); - if (dtype.is_int() || dtype.is_uint()) { + if (dtype.is_bool()) { + info.type_index = ffi::TypeIndex::kTVMFFIBool; + info.expr = Cast(DataType::Int(64), val); + + } else if (dtype.is_int() || dtype.is_uint()) { info.type_index = ffi::TypeIndex::kTVMFFIInt; info.expr = Cast(DataType::Int(64), val); } else if (dtype.is_float()) { @@ -101,56 +103,39 @@ private: LOG(FATAL) << "data type " << dtype << " not supported yet"; } - // If multiple return locations have the same data type, use the - // same dummy buffer declaration. - auto it = dummy_val_buffer_map_.find(info.type_index); - if (it != dummy_val_buffer_map_.end()) { - info.dummy_val_buffer = it->second; - } else { - info.dummy_val_buffer = - Buffer(ret_var_, info.expr.dtype(), {1}, {1}, ConstInt32(0), - ret_var_->name_hint, 0, 0, kDefault); - dummy_val_buffer_map_[info.type_index] = info.dummy_val_buffer; - } - - // The type_index is always a 32-bit int, so we don't need to have a - // separate map. - if (!dummy_tcode_buffer_.defined()) { - dummy_tcode_buffer_ = - Buffer(ret_tcode_, DataType::Int(32), {1}, {1}, ConstInt32(0), - ret_tcode_->name_hint, 0, 0, kDefault); - } - info.dummy_tcode_buffer = dummy_tcode_buffer_; - return info; } - Stmt WriteToOut(const PrimExpr &val) { + Stmt WriteToOut(PrimExpr val) { auto info = ConvertForFFI(val); - Stmt store_val = BufferStore(info.dummy_val_buffer, info.expr, {0}); - Stmt store_tcode = - BufferStore(info.dummy_tcode_buffer, info.type_index, {0}); + Stmt store_tindex = tir::Evaluate( + tir::Call(DataType::Int(32), tir::builtin::tvm_struct_set(), + {ret_var_, IntImm(DataType::Int(32), 0), + IntImm(DataType::Int(32), tir::builtin::kTVMFFIAnyTypeIndex), + IntImm(DataType::Int(32), info.type_index)})); + Stmt store_zero_padding = tir::Evaluate(tir::Call( + DataType::Int(32), tir::builtin::tvm_struct_set(), + {ret_var_, IntImm(DataType::Int(32), 0), + IntImm(DataType::Int(32), tir::builtin::kTVMFFIAnyZeroPadding), + IntImm(DataType::Int(32), 0)})); + Stmt store_val = tir::Evaluate(tir::Call( + DataType::Int(32), tir::builtin::tvm_struct_set(), + {ret_var_, IntImm(DataType::Int(32), 0), + IntImm(DataType::Int(32), tir::builtin::kTVMFFIAnyUnionValue), + info.expr})); Stmt ret_zero = Evaluate(tvm::ret(0)); - return SeqStmt({store_val, store_tcode, ret_zero}); + return SeqStmt({store_tindex, store_zero_padding, store_val, ret_zero}); } Var ret_var_; - Var ret_tcode_; int in_parallel_{0}; - - std::unordered_map dummy_val_buffer_map_; - Buffer dummy_tcode_buffer_; }; -Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) { - ReturnRewriter rewriter(std::move(ret_var), std::move(ret_tcode)); - return rewriter(std::move(body)); -} - class SubroutineCallRewriter : public StmtExprMutator { public: - static Optional Apply(const Map &packed_func_methods, - Stmt stmt) { + static ffi::Optional + Apply(const ffi::Map &packed_func_methods, + Stmt stmt) { SubroutineCallRewriter rewriter(packed_func_methods); stmt = rewriter.VisitStmt(stmt); if (rewriter.made_change_) { @@ -162,16 +147,16 @@ public: private: explicit SubroutineCallRewriter( - const Map &packed_func_methods) + const ffi::Map &packed_func_methods) : packed_func_methods(packed_func_methods) {} PrimExpr VisitExpr_(const CallNode *op) override { auto node = Downcast(StmtExprMutator::VisitExpr_(op)); if (auto *gvar_ptr = node->op.as()) { - auto gvar = tvm::ffi::GetRef(gvar_ptr); + auto gvar = ffi::GetRef(gvar_ptr); if (auto symbol = packed_func_methods.Get(gvar)) { - Array cpacked_args; + ffi::Array cpacked_args; cpacked_args.push_back(tir::StringImm(symbol.value())); for (auto arg : node->args) { cpacked_args.push_back(arg); @@ -187,19 +172,18 @@ private: return node; } - const Map &packed_func_methods; + const ffi::Map &packed_func_methods; bool made_change_{false}; }; } // namespace -inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, const std::string &msg) { - return AssertStmt(std::move(lhs) == std::move(rhs), tvm::tir::StringImm(msg), - Evaluate(0)); +inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { + return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0)); } -inline Stmt MakeAssertNotNull(PrimExpr ptr, const std::string &msg) { - Call isnull(DataType::Bool(), builtin::isnullptr(), {std::move(ptr)}); +inline Stmt MakeAssertNotNull(PrimExpr ptr, std::string msg) { + Call isnull(DataType::Bool(), builtin::isnullptr(), {ptr}); return AssertStmt(!isnull, tvm::tir::StringImm(msg), Evaluate(0)); } @@ -254,21 +238,16 @@ PrimFunc MakePackedAPI(PrimFunc func) { } auto *func_ptr = func.CopyOnWrite(); + // set the global symbol to the packed function name const Stmt nop = Evaluate(0); int num_args = static_cast(func_ptr->params.size()); // Data field definitions // The packed fields + Var v_self_handle("self_handle", DataType::Handle()); Var v_packed_args("args", DataType::Handle()); - Buffer buf_packed_arg_type_ids = - decl_buffer({IntImm(DataType::Int(32), func_ptr->params.size())}, - DataType::Int(32), "arg_type_ids"); Var v_num_packed_args("num_args", DataType::Int(32)); - Var v_out_ret_value("out_ret_value", PointerType(PrimType(DataType::Void()))); - Var v_out_ret_tcode("out_ret_tcode", - PointerType(PrimType(DataType::Int(32)))); - Var v_resource_handle("resource_handle", DataType::Handle()); - // The arguments of the function. + Var v_result("result", PointerType(PrimType(DataType::Void()))); // The device context Var device_id("dev_id"); @@ -278,37 +257,24 @@ PrimFunc MakePackedAPI(PrimFunc func) { std::vector seq_init, seq_check, arg_buffer_declarations; std::unordered_map vmap; ArgBinder binder(&vmap); - std::vector shape_checks; - tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current(); - bool disable_dynamic_tail_split = - ctxt->GetConfig(kDisableDynamicTailSplit, Bool(true)).value(); // --------------------------- // local function definitions // load i-th argument as type t - auto f_arg_value = [&](DataType t, int i) { - Array call_args{ + auto f_load_arg_value = [&](DataType arg_type, int i) { + ffi::Array call_args{ v_packed_args, IntImm(DataType::Int(32), i), - IntImm(DataType::Int(32), builtin::kTVMValueContent)}; + IntImm(DataType::Int(32), builtin::kTVMFFIAnyUnionValue)}; // load 64 bit version - DataType api_type = APIType(t); + DataType api_type = APIType(arg_type); PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args); // cast to the target version. - if (api_type != t) { - res = Cast(t, res); + if (api_type != arg_type) { + res = Cast(arg_type, res); } return res; }; - // Find the device API context argument based on name - for (const auto ¶m : func_ptr->params) { - if (param->name_hint == kDeviceContextVar) { - num_args--; - v_resource_handle = param; - break; - } - } - // Assert correct type codes for each argument. This must be done // *before* any initialization steps produced by // `binder.BindDLTensor()`. The validity of those initialization @@ -321,12 +287,10 @@ PrimFunc MakePackedAPI(PrimFunc func) { return error_message.str(); }())); - seq_init.push_back(MakeAssertNotNull( - v_packed_args, name_hint + ": TVMValue* arg pointer was NULL")); - seq_init.push_back(MakeAssertNotNull( - buf_packed_arg_type_ids->data, name_hint + ": int* type_codes was NULL")); - - seq_init.emplace_back(DeclBuffer(buf_packed_arg_type_ids, nop)); + if (num_args > 0) { + seq_init.push_back( + MakeAssertNotNull(v_packed_args, name_hint + ": args pointer is NULL")); + } // Need to delay binding of the buffers, in case some arguments also // appear in the buffer. @@ -335,26 +299,17 @@ PrimFunc MakePackedAPI(PrimFunc func) { for (int i = 0; i < static_cast(func_ptr->params.size()); ++i) { Var param = func_ptr->params[i]; - - // Ignore the device context argument, as it will still be passed - // as a native argument. - if (param->name_hint == kDeviceContextVar) { - continue; - } - - var_def.emplace_back(f_arg_value(param.dtype(), i), param); - if (func_ptr->buffer_map.count(param)) { - buffer_def.emplace_back(param, func_ptr->buffer_map[param]); - } - - // type code checks - Var type_index(param->name_hint + ".code", DataType::Int(32)); - seq_init.emplace_back(LetStmt( + PrimExpr arg_value; + // type index checks + Var type_index(param->name_hint + ".type_index", DataType::Int(32)); + seq_init.push_back(LetStmt( type_index, - BufferLoad(buf_packed_arg_type_ids, {IntImm(DataType::Int(32), i)}), + tir::Call(DataType::Int(32), builtin::tvm_struct_get(), + {v_packed_args, IntImm(DataType::Int(32), i), + IntImm(DataType::Int(32), builtin::kTVMFFIAnyTypeIndex)}), nop)); - DataType t = param.dtype(); - if (t.is_handle()) { + DataType dtype = param.dtype(); + if (dtype.is_handle()) { std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be pointer"; seq_init.emplace_back( @@ -363,23 +318,63 @@ PrimFunc MakePackedAPI(PrimFunc func) { type_index == ffi::TypeIndex::kTVMFFIDLTensorPtr || type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin, tvm::tir::StringImm(msg.str()), nop)); - } else if (t.is_int() || t.is_uint()) { + // if type_index is Tensor, we need to add the offset of the DLTensor + // header which always equals 16 bytes, this ensures that T.handle always + // shows up as a DLTensor* + const int64_t object_cell_offset = sizeof(TVMFFIObject); + static_assert(object_cell_offset == 24); + arg_value = f_load_arg_value(param.dtype(), i); + PrimExpr handle_from_tensor = + Call(DataType::Handle(), tir::builtin::handle_add_byte_offset(), + {arg_value, IntImm(DataType::Int(32), object_cell_offset)}); + arg_value = Select(type_index == ffi::TypeIndex::kTVMFFITensor, + handle_from_tensor, arg_value); + } else if (dtype.is_bool()) { + std::ostringstream msg; + msg << name_hint << ": Expect arg[" << i << "] to be boolean"; + seq_init.emplace_back( + AssertStmt(type_index == ffi::TypeIndex::kTVMFFIBool || + type_index == ffi::TypeIndex::kTVMFFIInt, + tvm::tir::StringImm(msg.str()), nop)); + arg_value = + Cast(DataType::Bool(), f_load_arg_value(DataType::Int(64), i)); + + } else if (dtype.is_int() || dtype.is_uint()) { std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be int"; - seq_init.emplace_back(AssertStmt(type_index == kDLInt, - tvm::tir::StringImm(msg.str()), nop)); + seq_init.emplace_back( + AssertStmt(type_index == ffi::TypeIndex::kTVMFFIInt || + type_index == ffi::TypeIndex::kTVMFFIBool, + tvm::tir::StringImm(msg.str()), nop)); + arg_value = f_load_arg_value(param.dtype(), i); } else { - ICHECK(t.is_float()); + ICHECK(dtype.is_float()); std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be float"; - seq_init.emplace_back(AssertStmt(type_index == kDLFloat, - tvm::tir::StringImm(msg.str()), nop)); + seq_init.emplace_back( + AssertStmt(type_index == ffi::TypeIndex::kTVMFFIFloat || + type_index == ffi::TypeIndex::kTVMFFIInt || + type_index == ffi::TypeIndex::kTVMFFIBool, + tvm::tir::StringImm(msg.str()), nop)); + // use select so we can also handle int conversion to bool + arg_value = tir::Select( + type_index == ffi::TypeIndex::kTVMFFIFloat, + /* true_value = */ f_load_arg_value(param.dtype(), i), + /* false_value = */ + Cast(param.dtype(), f_load_arg_value(DataType::Int(64), i))); + } + var_def.emplace_back(arg_value, param); + if (func_ptr->buffer_map.count(param)) { + // buffer binding now depends on type index + // if the index is Tensor handle, we need to offset to get the DLTensor* + buffer_def.emplace_back(param, func_ptr->buffer_map[param]); } } - Array args{v_packed_args, buf_packed_arg_type_ids->data, - v_num_packed_args, v_out_ret_value, - v_out_ret_tcode, v_resource_handle}; + // signature: (void* handle, TVMFFIAny* packed_args, int num_args, TVMFFIAny* + // v_result) + ffi::Array args{v_self_handle, v_packed_args, v_num_packed_args, + v_result}; // Arg definitions are defined before buffer binding to avoid the use before // def errors. @@ -392,83 +387,57 @@ PrimFunc MakePackedAPI(PrimFunc func) { binder.Bind(param, expr, name_hint + "." + param->name_hint, true); } - for (const auto &kv : buffer_def) { - binder.BindDLTensor(kv.second, device_type, device_id, kv.first, - name_hint + "." + kv.first->name_hint); - arg_buffer_declarations.push_back(DeclBuffer(kv.second, nop)); + for (const auto &[var, buffer] : buffer_def) { + binder.BindDLTensor(buffer, device_type, device_id, var, + name_hint + "." + var->name_hint); + arg_buffer_declarations.push_back(DeclBuffer(buffer, nop)); } - - func = - WithAttrs(std::move(func), - {{tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)}, - {tvm::attr::kTarget, target_host}}); - Stmt body = RewriteReturn(func_ptr->body, v_out_ret_value, v_out_ret_tcode); + // reset global symbol to attach prefix + func = WithAttrs( + std::move(func), + {{tvm::attr::kCallingConv, static_cast(CallingConv::kCPackedFunc)}, + {tvm::attr::kTarget, target_host}, + {tvm::attr::kGlobalSymbol, + ffi::symbol::tvm_ffi_symbol_prefix + global_symbol.value()}}); + + Stmt body = ReturnRewriter(v_result)(func_ptr->body); body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::compute_scope, StringImm(name_hint + "_compute_"), body); // Set device context if (vmap.count(device_id.get())) { - auto node = String("default"); + ffi::Any node = ffi::String("default"); seq_check.push_back(AttrStmt(node, tir::attr::device_id, device_id, nop)); seq_check.push_back( AttrStmt(node, tir::attr::device_type, device_type, nop)); if (runtime::DeviceAPI::NeedSetDevice(target_device_type)) { Stmt set_device = - Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), + Evaluate(Call(DataType::Int(32), tir::builtin::tvm_call_packed(), {StringImm(runtime::symbol::tvm_set_device), device_type, device_id})); body = SeqStmt({set_device, body}); } } - // (zhengju) For dynamic constraint, we need to check the buffer shape and - // dtype to make sure the buffer can be vectorized. - for (const auto &kv : buffer_def) { - if (disable_dynamic_tail_split) { - Optional opt_dynamic_alignment = - ctxt->GetConfig(kDynamicAlignment, Optional()); - int dynamic_alignment = opt_dynamic_alignment.value_or(Integer(8))->value; - // The vectorize dimension will be the last dimension of the buffer - auto vectorize_dim = kv.second->shape[kv.second->shape.size() - 1]; - auto shape_vectorize_expr = [&]() -> PrimExpr { - PrimExpr result = IntImm(kv.second->DefaultIndexType(), 1); - result = result * vectorize_dim; - result = FloorMod(result, IntImm(result->dtype, dynamic_alignment)); - return result; - }(); - shape_checks.emplace_back(AssertStmt( - shape_vectorize_expr == 0, - tvm::tir::StringImm( - kv.second->name + - ": Vectorize dimension in buffer must be divisible by " + - std::to_string(dynamic_alignment)), - nop)); - } - } - // Return error code of zero on success body = SeqStmt({body, Evaluate(ret(Integer(0)))}); - if (!disable_dynamic_tail_split) { - body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts(), - arg_buffer_declarations}, - body); - } else { - body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts(), - arg_buffer_declarations, shape_checks}, - body); - } - + body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts(), + arg_buffer_declarations}, + body); func_ptr->body = body; func_ptr->params = args; - Array undefined = UndefinedVars(func_ptr->body, func_ptr->params); + ffi::Array undefined = UndefinedVars(body, func_ptr->params); + ICHECK_EQ(undefined.size(), 0) << "In PrimFunc " << name_hint << " variables " << undefined << " are used, but are not passed in as API arguments"; - func_ptr->buffer_map = Map(); - func_ptr->ret_type = PrimType(DataType::Int(32)); // return the function. + func_ptr->buffer_map = ffi::Map(); + func_ptr->ret_type = PrimType(DataType::Int(32)); + + // return the function. return func; } diff --git a/src/transform/simplify.cc b/src/transform/simplify.cc index d64c7016..5a83f0df 100644 --- a/src/transform/simplify.cc +++ b/src/transform/simplify.cc @@ -240,37 +240,42 @@ public: simplifier.MarkBufferMapShapes(func); func.CopyOnWrite()->body = simplifier(func->body); - // Begin to remove useless var and buffer - // First get used buffers - simplifier.used_buffers_ = CollectUsedBuffers(func); - - bool param_updated = false; - Array new_params; - Map new_buffer_map; - // Check whether each buffer is used - for (const auto &var : func->params) { - if (func->buffer_map.find(var) != func->buffer_map.end()) { - if (simplifier.used_buffers_.find(func->buffer_map[var].get()) != - simplifier.used_buffers_.end()) { - new_params.push_back(var); - new_buffer_map.Set(var, func->buffer_map[var]); - } else if (simplifier.used_in_buffer_def_.find( - func->buffer_map[var]->data.get()) != - simplifier.used_in_buffer_def_.end()) { - new_params.push_back(var); - new_buffer_map.Set(var, func->buffer_map[var]); + // Optionally remove unused buffer parameters + if (simplify_arguments) { + // First get used buffers + simplifier.used_buffers_ = CollectUsedBuffers(func); + + bool param_updated = false; + Array new_params; + Map new_buffer_map; + // Check whether each buffer is used + for (const auto &var : func->params) { + if (func->buffer_map.find(var) != func->buffer_map.end()) { + if (simplifier.used_buffers_.find(func->buffer_map[var].get()) != + simplifier.used_buffers_.end()) { + new_params.push_back(var); + new_buffer_map.Set(var, func->buffer_map[var]); + } else if (simplifier.used_in_buffer_def_.find( + func->buffer_map[var]->data.get()) != + simplifier.used_in_buffer_def_.end()) { + new_params.push_back(var); + new_buffer_map.Set(var, func->buffer_map[var]); + } else { + param_updated = true; + } } else { - param_updated = true; + // Non-buffer parameters (e.g., scalars) are always retained + new_params.push_back(var); } } - } - if (param_updated) { - return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type, - new_buffer_map, func->attrs, func->span); - } else { - return func; + if (param_updated) { + return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type, + new_buffer_map, func->attrs, func->span); + } } + // Either no change to params or argument simplification disabled + return func; } private: diff --git a/testing/python/debug/test_tilelang_debug_print.py b/testing/python/debug/test_tilelang_debug_print.py index 1bc76161..fcfae4ed 100644 --- a/testing/python/debug/test_tilelang_debug_print.py +++ b/testing/python/debug/test_tilelang_debug_print.py @@ -13,7 +13,7 @@ def debug_print_buffer(M=16, N=16, dtype="float16"): shared_buf = T.alloc_shared([M, N], dtype) T.print(shared_buf) - jit_kernel = tilelang.compile(program, target="cuda") + jit_kernel = tilelang.compile(program, target="cuda", execution_backend="tvm_ffi") profiler = jit_kernel.get_profiler() profiler.run_once() diff --git a/testing/python/dynamic/test_tilelang_dynamic_symbolic.py b/testing/python/dynamic/test_tilelang_dynamic_symbolic.py index 07f4d784..4b9dff71 100644 --- a/testing/python/dynamic/test_tilelang_dynamic_symbolic.py +++ b/testing/python/dynamic/test_tilelang_dynamic_symbolic.py @@ -514,5 +514,4 @@ def test_assert_tl_matmul_block_all_dynamic_with_pass_config(): if __name__ == "__main__": - # tilelang.testing.main() - assert_tl_matmul_macro_correctness(128, 128, 128, "float16", "float16", "float16") + tilelang.testing.main() diff --git a/testing/python/jit/test_tilelang_jit_gemm_ctypes.py b/testing/python/jit/test_tilelang_jit_gemm_ctypes.py deleted file mode 100644 index fd5243f0..00000000 --- a/testing/python/jit/test_tilelang_jit_gemm_ctypes.py +++ /dev/null @@ -1,411 +0,0 @@ -from tilelang import tvm as tvm -import tilelang.language as T -import tilelang.testing -import tilelang -import torch -from tilelang.utils.tensor import map_torch_type - - -def matmul( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - accum_dtype, - num_stages, - threads, -): - A_shape = (K, M) if trans_A else (M, K) - B_shape = (N, K) if trans_B else (K, N) - A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) - B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, in_dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - if trans_A: - T.copy(A[k * block_K, by * block_M], A_shared) - else: - T.copy(A[by * block_M, k * block_K], A_shared) - if trans_B: - T.copy(B[bx * block_N, k * block_K], B_shared) - else: - T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main - - -def run_gemm( - M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128, -): - program = matmul( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - num_stages, - num_threads, - ) - - stramp = "&*(XS)" - - @tvm.register_global_func("tilelang_callback_cuda_postproc", override=True) - def tilelang_callback_cuda_postproc(code, _): - code = f"// {stramp}\n" + code - return code - - matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="ctypes") - - kernel_source = matmul_kernel.get_kernel_source() - - assert stramp in kernel_source, f"Expected {stramp} in the kernel source" - - -def test_gemm_f16f16f16_nn(): - run_gemm( - 512, - 1024, - 768, - False, - False, - "float16", - "float16", - "float16", - 128, - 256, - 32, - 2, - ) - - -def matmu_jit_kernel( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - accum_dtype, - num_stages, - threads, -): - A_shape = (K, M) if trans_A else (M, K) - B_shape = (N, K) if trans_B else (K, N) - A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) - B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - - import tilelang.language as T - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, in_dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - if trans_A: - T.copy(A[k * block_K, by * block_M], A_shared) - else: - T.copy(A[by * block_M, k * block_K], A_shared) - if trans_B: - T.copy(B[bx * block_N, k * block_K], B_shared) - else: - T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main - - -def run_gemm_jit_kernel( - M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128, -): - program = matmu_jit_kernel( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - num_stages, - num_threads, - ) - - matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="ctypes") - - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) - - A = torch.randn(M, K, dtype=in_dtype).cuda() - B = torch.randn(K, N, dtype=in_dtype).cuda() - - if trans_A: - A = A.T - if trans_B: - B = B.T - - def ref_program(A, B): - import torch - C = torch.matmul(A.to(torch.float), B.to(torch.float)) - C = C.to(out_dtype) - return C - - ref_C = ref_program(A, B) - C = matmul_kernel(A, B) - - tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) - - -def test_gemm_jit_kernel(): - run_gemm_jit_kernel( - 512, - 1024, - 768, - False, - False, - "float16", - "float16", - "float16", - 128, - 256, - 32, - 2, - ) - - -def run_ctypes_kernel_do_bench(M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128): - program = matmul( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - num_stages, - num_threads, - ) - - matmul_kernel = tilelang.compile(program, execution_backend="ctypes") - - profiler = matmul_kernel.get_profiler() - - ctypes_latency = profiler.do_bench(func=matmul_kernel) - print(f"Ctypes Latency: {ctypes_latency} ms") - - assert ctypes_latency is not None - - tvm_latency = profiler.do_bench() - print(f"TVM Latency: {tvm_latency} ms") - - assert tvm_latency is not None - - -def test_ctypes_kernel_do_bench(): - run_ctypes_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, - 256, 32, 2) - - -def run_ctypes_kernel_multi_stream(M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128): - program = matmul( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - num_stages, - num_threads, - ) - - matmul_kernel = tilelang.compile(program, execution_backend="ctypes") - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) - tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() - tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() - - if trans_A: - tensor_a = tensor_a.T - if trans_B: - tensor_b = tensor_b.T - tensor_c = torch.randn(M, N, dtype=out_dtype).cuda() - - num_streams = 4 - for _ in range(num_streams): - stream = torch.cuda.Stream() - with torch.cuda.stream(stream): - matmul_kernel(tensor_a, tensor_b, tensor_c) - - -def test_ctypes_kernel_multi_stream(): - run_ctypes_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", - 128, 256, 32, 2) - - -def run_ctypes_dynamic_shape(M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128): - program = matmul( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - num_stages, - num_threads, - ) - - matmul_kernel = tilelang.compile(program, execution_backend="ctypes") - if isinstance(M, T.Var): - M = 1024 - if isinstance(N, T.Var): - N = 1024 - if isinstance(K, T.Var): - K = 768 - - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) - - tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() - tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() - - if trans_A: - tensor_a = tensor_a.T - if trans_B: - tensor_b = tensor_b.T - tensor_c = torch.randn(M, N, dtype=out_dtype).cuda() - - matmul_kernel(tensor_a, tensor_b, tensor_c) - - tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype) - tilelang.testing.torch_assert_close( - tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) - - -def test_ctypes_dynamic_shape(): - run_ctypes_dynamic_shape( - T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) - - run_ctypes_dynamic_shape( - T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, - 256, 32, 2) - - run_ctypes_dynamic_shape( - T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", - "float16", 128, 256, 32, 2) - - -if __name__ == "__main__": - # tilelang.testing.main() - test_gemm_f16f16f16_nn() diff --git a/testing/python/jit/test_tilelang_jit_nullptr.py b/testing/python/jit/test_tilelang_jit_nullptr.py index 6241ea90..07d4e04c 100644 --- a/testing/python/jit/test_tilelang_jit_nullptr.py +++ b/testing/python/jit/test_tilelang_jit_nullptr.py @@ -83,28 +83,27 @@ def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_ def run_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - func = ptr_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype) + kernel = ptr_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype) a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype)) b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype)) c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype)) d = torch.randn(N, device="cuda", dtype=map_torch_type(accum_dtype)) - - func(a, b, c, None, M, N, K, False) + kernel(a, b, c, None, M, N, K, False) ref_no_bias = (a @ b.T).to(map_torch_type(accum_dtype)) ref_with_bias = ref_no_bias + d torch.testing.assert_close(c, ref_no_bias, atol=1e-2, rtol=1e-2) - func(a, b, c, d, M, N, K, True) + kernel(a, b, c, d, M, N, K, True) torch.testing.assert_close(c, ref_with_bias, atol=1e-2, rtol=1e-2) - func = tensor_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype) - func(a, b, c, None, False) + kernel = tensor_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype) + kernel(a, b, c, None, False) torch.testing.assert_close(c, ref_no_bias, atol=1e-2, rtol=1e-2) - func(a, b, c, d, True) + kernel(a, b, c, d, True) torch.testing.assert_close(c, ref_with_bias, atol=1e-2, rtol=1e-2) diff --git a/testing/python/jit/test_tilelang_jit_tvm_ffi.py b/testing/python/jit/test_tilelang_jit_tvm_ffi.py new file mode 100644 index 00000000..cd5d9c75 --- /dev/null +++ b/testing/python/jit/test_tilelang_jit_tvm_ffi.py @@ -0,0 +1,589 @@ +from tilelang import tvm as tvm +import tilelang.language as T +import tilelang.testing +import tilelang +import torch +from tilelang.utils.tensor import map_torch_type + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + stramp = "&*(XS)" + + @tvm.register_global_func("tilelang_callback_cuda_postproc", override=True) + def tilelang_callback_cuda_postproc(code, _): + code = f"// {stramp}\n" + code + return code + + matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="tvm_ffi") + + kernel_source = matmul_kernel.get_kernel_source() + + assert stramp in kernel_source, f"Expected {stramp} in the kernel source" + + +def test_gemm_f16f16f16_nn(): + run_gemm( + 512, + 1024, + 768, + False, + False, + "float16", + "float16", + "float16", + 128, + 256, + 32, + 2, + ) + + +def matmu_jit_kernel( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_jit_kernel( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmu_jit_kernel( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="tvm_ffi") + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + + A = torch.randn(M, K, dtype=in_dtype).cuda() + B = torch.randn(K, N, dtype=in_dtype).cuda() + + if trans_A: + A = A.T + if trans_B: + B = B.T + + def ref_program(A, B): + import torch + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(out_dtype) + return C + + ref_C = ref_program(A, B) + C = matmul_kernel(A, B) + + tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + + +def test_gemm_jit_kernel(): + run_gemm_jit_kernel( + 512, + 1024, + 768, + False, + False, + "float16", + "float16", + "float16", + 128, + 256, + 32, + 2, + ) + + +def run_tvm_ffi_kernel_do_bench(M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, execution_backend="tvm_ffi") + + profiler = matmul_kernel.get_profiler() + + tvm_ffi_latency = profiler.do_bench(func=matmul_kernel) + print(f"tvm_ffi Latency: {tvm_ffi_latency} ms") + + assert tvm_ffi_latency is not None + + tvm_latency = profiler.do_bench() + print(f"TVM Latency: {tvm_latency} ms") + + assert tvm_latency is not None + + +def test_tvm_ffi_kernel_do_bench(): + run_tvm_ffi_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, + 256, 32, 2) + + +def run_tvm_ffi_kernel_multi_stream(M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, execution_backend="tvm_ffi") + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() + tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() + + if trans_A: + tensor_a = tensor_a.T + if trans_B: + tensor_b = tensor_b.T + tensor_c = torch.randn(M, N, dtype=out_dtype).cuda() + + num_streams = 4 + for _ in range(num_streams): + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + matmul_kernel(tensor_a, tensor_b, tensor_c) + + +def test_tvm_ffi_kernel_multi_stream(): + run_tvm_ffi_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", + 128, 256, 32, 2) + + +def run_tvm_ffi_dynamic_shape(M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, execution_backend="tvm_ffi") + if isinstance(M, T.Var): + M = 1024 + if isinstance(N, T.Var): + N = 1024 + if isinstance(K, T.Var): + K = 768 + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + + tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() + tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() + + if trans_A: + tensor_a = tensor_a.T + if trans_B: + tensor_b = tensor_b.T + tensor_c = torch.randn(M, N, dtype=out_dtype).cuda() + + matmul_kernel(tensor_a, tensor_b, tensor_c) + + tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype) + tilelang.testing.torch_assert_close( + tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + + +def test_tvm_ffi_dynamic_shape(): + run_tvm_ffi_dynamic_shape( + T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + run_tvm_ffi_dynamic_shape( + T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, + 256, 32, 2) + + run_tvm_ffi_dynamic_shape( + T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", + "float16", 128, 256, 32, 2) + + +def check_hopper(): + if not torch.cuda.is_available(): + return False + props = torch.cuda.get_device_properties(0) + compute_capability = props.major, props.minor + return compute_capability == (9, 0) + + +def convolution_im2col(N, + C, + H, + W, + F, + K, + S, + D, + P, + block_M, + block_N, + block_K, + num_stages, + threads, + dtype="float16", + accum_dtype="float"): + KH, KW = K, K + OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 + OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 + + @T.prim_func + def main( + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), + ): + with T.Kernel( + T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), + threads=threads) as (bx, by): + data_shared = T.alloc_shared((block_M, block_K), dtype) + kernel_shared = T.alloc_shared((block_K, block_N), dtype) + out_local = T.alloc_fragment((block_M, block_N), accum_dtype) + out_shared = T.alloc_shared((block_M, block_N), dtype) + + kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) + out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) + + T.annotate_layout({ + out_shared: tilelang.layout.make_swizzled_layout(out_shared), + data_shared: tilelang.layout.make_swizzled_layout(data_shared), + kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), + }) + + T.clear(out_local) + for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): + T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P) + T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) + T.gemm(data_shared, kernel_shared, out_local) + + T.copy(out_local, out_shared) + T.copy(out_shared, out_flat[by * block_M, bx * block_N]) + + return main + + +def run_tvm_ffi_im2col_tma_desc(N, + C, + H, + W, + F, + K, + S, + D, + P, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=256): + """Test im2col TMA descriptor functionality in tvm_ffi backend.""" + program = convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, + num_threads) + + conv_kernel = tilelang.compile(program, out_idx=-1, execution_backend="tvm_ffi") + + a = torch.randn(N, H, W, C).cuda().half() + b = torch.randn(K, K, C, F).cuda().half() + + out_c = conv_kernel(a, b) + + # Reference implementation using torch.conv2d + def ref_program(A, B): + A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W + B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W + C = torch.conv2d(A, B, stride=S, padding=P, dilation=D) + C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C + return C + + ref_c = ref_program(a, b) + tilelang.testing.torch_assert_close( + out_c, ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + + +def test_tvm_ffi_im2col_tma_desc(): + """Test im2col TMA descriptor with tvm_ffi backend.""" + if not check_hopper(): + import pytest + pytest.skip("Test requires Hopper GPU (compute capability 9.0)") + + # Small test case for im2col TMA descriptor + run_tvm_ffi_im2col_tma_desc( + N=4, + C=64, + H=32, + W=32, + F=64, + K=3, + S=1, + D=1, + P=1, + block_M=64, + block_N=128, + block_K=32, + num_stages=3, + num_threads=256) + + +def test_tvm_ffi_l2_persistent_map(): + """Test L2 persistent cache annotation with elementwise add.""" + from tilelang.language import annotate_l2_hit_ratio + + M = 1024 + N = 1024 + + @tilelang.jit(out_idx=[-1], execution_backend="tvm_ffi") + def elementwise_add_with_l2_cache( + M, + N, + block_size=256, + dtype="float32", + ): + + @T.prim_func + def kernel( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(M * N // block_size, threads=block_size) as bx: + # Annotate L2 persistent cache for buffer B + # B will be accessed multiple times and benefit from L2 caching + annotate_l2_hit_ratio({B: 0.8}) + + for i in T.serial(block_size): + idx = bx * block_size + i + if idx < M * N: + row = idx // N + col = idx % N + C[row, col] = A[row, col] + B[row, col] + + return kernel + + # Compile the kernel + kernel = elementwise_add_with_l2_cache(M, N) + + source = kernel.get_host_source() + assert "__tvm_cuda_stream_set_access_policy_window_packed" in source, "Expected __tvm_cuda_stream_set_access_policy_window_packed in the kernel source" + assert "__tvm_cuda_stream_reset_access_policy_window_packed" in source, "Expected __tvm_cuda_stream_reset_access_policy_window_packed in the kernel source" + + # Create test tensors + a = torch.randn(M, N, dtype=torch.float32).cuda() + b = torch.randn(M, N, dtype=torch.float32).cuda() + + # Run kernel with out_idx=[-1], C is returned not passed in + c = kernel(a, b) + + # Verify correctness + ref_c = a + b + tilelang.testing.torch_assert_close(c, ref_c, atol=1e-5, rtol=1e-5) + + print("L2 persistent map test passed!") + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_language_alloc.py b/testing/python/language/test_tilelang_language_alloc.py index 202d6bfa..149a1c28 100644 --- a/testing/python/language/test_tilelang_language_alloc.py +++ b/testing/python/language/test_tilelang_language_alloc.py @@ -113,7 +113,6 @@ def run_alloc_var_with_initializer( kernel = tilelang.compile(program, out_idx=[1]) code = kernel.get_kernel_source() - print(code) assert f"= {init_value};" in code @@ -151,8 +150,7 @@ def run_alloc_multi_vars_with_initializer( program = alloc_multi_vars_with_initializer(N, block_N, dtype) kernel = tilelang.compile(program, out_idx=[1]) - code = kernel.get_kernel_source() - print(code) + code = kernel.get_kernel_source(kernel_only=True) assert code.count("= 1;") == 1 assert code.count("= 2;") == 1 diff --git a/tilelang/autotuner/param.py b/tilelang/autotuner/param.py index b93c4448..3e401cc5 100644 --- a/tilelang/autotuner/param.py +++ b/tilelang/autotuner/param.py @@ -33,7 +33,7 @@ class CompileArgs: """Compile arguments for the auto-tuner. Detailed description can be found in `tilelang.jit.compile`. Attributes: out_idx: List of output tensor indices. - execution_backend: Execution backend to use for kernel execution (default: "cython"). + execution_backend: Execution backend to use for kernel execution (default: "auto"). target: Compilation target, either as a string or a TVM Target object (default: "auto"). target_host: Target host for cross-compilation (default: None). verbose: Whether to enable verbose output (default: False). @@ -42,7 +42,7 @@ class CompileArgs: """ out_idx: list[int] | int | None = None - execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython" + execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto" target: Literal['auto', 'cuda', 'hip'] = 'auto' target_host: str | Target = None verbose: bool = False @@ -208,7 +208,7 @@ class AutotuneResult: target: str | Target = "auto", target_host: str | Target = None, out_idx: list[int] | int | None = None, - execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", + execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi", pass_configs: dict = None, func: Callable = None, verbose: bool = False, diff --git a/tilelang/autotuner/tuner.py b/tilelang/autotuner/tuner.py index 7138f4c1..47ac888c 100644 --- a/tilelang/autotuner/tuner.py +++ b/tilelang/autotuner/tuner.py @@ -139,8 +139,9 @@ class AutoTuner: def set_compile_args(self, out_idx: list[int] | int | None = None, - target: Literal['auto', 'cuda', 'hip'] = 'auto', - execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", + target: Literal['auto', 'cuda', 'hip', 'metal'] = 'auto', + execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", + "torch"] = "auto", target_host: str | Target = None, verbose: bool = False, pass_configs: dict[str, Any] | None = None): @@ -157,10 +158,15 @@ class AutoTuner: Returns: AutoTuner: Self for method chaining. """ + # Normalize target to a concrete TVM Target and resolve execution backend + t = Target(determine_target(target)) + from tilelang.jit.execution_backend import resolve_execution_backend + resolved_backend = resolve_execution_backend(execution_backend, t) + self.compile_args = CompileArgs( out_idx=out_idx, - target=Target(determine_target(target)), - execution_backend=execution_backend, + target=t, + execution_backend=resolved_backend, target_host=target_host, verbose=verbose, pass_configs=pass_configs) @@ -591,7 +597,7 @@ class AutoTuner: func=best_kernel.prim_func, kernel=best_kernel) - if self.compile_args.execution_backend in ("dlpack", "torch"): + if self.compile_args.execution_backend in ("torch"): logger.warning("DLPack backend does not support cache saving to disk.") else: with self._lock: @@ -728,8 +734,9 @@ def autotune( # This is the new public interface Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto". target_host : Union[str, Target], optional Target host for cross-compilation. Defaults to None. - execution_backend : Literal["dlpack", "ctypes", "cython"], optional - Backend for kernel execution and argument passing. Defaults to "cython". + execution_backend : Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional + Backend for kernel execution and argument passing. Use "auto" to pick a sensible + default per target (cuda->tvm_ffi, metal->torch, others->cython). verbose : bool, optional Enables verbose logging during compilation. Defaults to False. pass_configs : Optional[Dict[str, Any]], optional diff --git a/tilelang/cache/__init__.py b/tilelang/cache/__init__.py index c338ce61..144c2729 100644 --- a/tilelang/cache/__init__.py +++ b/tilelang/cache/__init__.py @@ -18,7 +18,8 @@ def cached( *args, target: str | Target = "auto", target_host: str | Target = None, - execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] | None = "cython", + execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] + | None = "auto", verbose: bool | None = False, pass_configs: dict | None = None, compile_flags: list[str] | str | None = None, diff --git a/tilelang/cache/kernel_cache.py b/tilelang/cache/kernel_cache.py index d0a801fb..74ecb278 100644 --- a/tilelang/cache/kernel_cache.py +++ b/tilelang/cache/kernel_cache.py @@ -13,14 +13,15 @@ from typing import Callable, Literal import cloudpickle from tvm.target import Target from tvm.tir import PrimFunc - +from tvm.runtime import Executable from tilelang.engine.param import KernelParam from tilelang import env from tilelang.jit import JITKernel from tilelang import __version__ -KERNEL_PATH = "kernel.cu" -WRAPPED_KERNEL_PATH = "wrapped_kernel.cu" +DEVICE_KERNEL_PATH = "device_kernel.cu" +HOST_KERNEL_PATH = "host_kernel.cu" +EXECUTABLE_PATH = "executable.so" KERNEL_LIB_PATH = "kernel_lib.so" KERNEL_CUBIN_PATH = "kernel.cubin" KERNEL_PY_PATH = "kernel.py" @@ -40,7 +41,7 @@ class KernelCache: _instance = None # For implementing singleton pattern _lock = threading.Lock() # For thread safety _memory_cache = {} # In-memory cache dictionary - execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython" + execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi" def __new__(cls): """ @@ -69,7 +70,7 @@ class KernelCache: self, func: Callable, out_idx: list[int], - execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", + execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi", args=None, target: str | Target = "auto", target_host: str | Target = None, @@ -117,7 +118,8 @@ class KernelCache: *args, target: str | Target = "auto", target_host: str | Target = None, - execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", + execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", + "torch"] = "auto", verbose: bool = False, pass_configs: dict = None, compile_flags: list[str] | str | None = None, @@ -135,12 +137,30 @@ class KernelCache: Returns: JITKernel: The compiled kernel, either freshly compiled or from cache """ + # Normalize target and resolve execution backend before proceeding + from tilelang.utils.target import determine_target as _determine_target + from tilelang.jit.execution_backend import resolve_execution_backend, allowed_backends_for_target + norm_target = Target(_determine_target(target)) if isinstance(target, str) else target + requested_backend = execution_backend + execution_backend = resolve_execution_backend(requested_backend, norm_target) + if verbose: + allowed_now = allowed_backends_for_target(norm_target, include_unavailable=False) + # Avoid duplicate logs when caller already resolved explicitly + if requested_backend in (None, "auto") or requested_backend != execution_backend: + self.logger.info( + "Execution backend resolved -> '%s' (requested='%s', target='%s', allowed: %s)", + execution_backend, + requested_backend, + norm_target.kind.name, + ", ".join(sorted(allowed_now)), + ) + if not env.is_cache_enabled(): return JITKernel( func, out_idx=out_idx, execution_backend=execution_backend, - target=target, + target=norm_target, target_host=target_host, verbose=verbose, pass_configs=pass_configs, @@ -152,7 +172,7 @@ class KernelCache: out_idx=out_idx, execution_backend=execution_backend, args=args, - target=target, + target=norm_target, target_host=target_host, pass_configs=pass_configs, compile_flags=compile_flags, @@ -168,7 +188,7 @@ class KernelCache: self.logger.debug(f"Checking disk cache for kernel {func.attrs['global_symbol']}") # Then check disk cache - kernel = self._load_kernel_from_disk(key, target, target_host, out_idx, + kernel = self._load_kernel_from_disk(key, norm_target, target_host, out_idx, execution_backend, pass_configs, compile_flags, func, verbose) if kernel is not None: @@ -186,18 +206,15 @@ class KernelCache: func, out_idx=out_idx, execution_backend=execution_backend, - target=target, + target=norm_target, target_host=target_host, verbose=verbose, pass_configs=pass_configs, compile_flags=compile_flags, ) - if execution_backend in ("dlpack", "torch"): - self.logger.warning("DLPack or torch backend does not support cache saving to disk.") - else: - with self._lock: - if env.is_cache_enabled(): - self._save_kernel_to_disk(key, kernel, func, verbose) + with self._lock: + if env.is_cache_enabled(): + self._save_kernel_to_disk(key, kernel, func, verbose) # Store in memory cache after compilation self._memory_cache[key] = kernel @@ -239,6 +256,12 @@ class KernelCache: # Use atomic POSIX replace, so other processes cannot see a partial write os.replace(temp_path, path) + @staticmethod + def _safe_write_executable(executable: Executable, path: str): + temp_path = os.path.join(env.TILELANG_TMP_DIR, f"{os.getpid()}_{uuid.uuid4()}.so") + executable.export_library(temp_path) + os.replace(temp_path, path) + def _save_kernel_to_disk(self, key: str, kernel: JITKernel, @@ -265,41 +288,46 @@ class KernelCache: # Save kernel source code try: - kernel_path = os.path.join(cache_path, KERNEL_PATH) + device_kernel_path = os.path.join(cache_path, DEVICE_KERNEL_PATH) if verbose: - self.logger.debug(f"Saving kernel source code to file: {kernel_path}") + self.logger.debug(f"Saving kernel source code to file: {device_kernel_path}") if kernel.kernel_source is not None: - KernelCache._safe_write_file(kernel_path, "w", + KernelCache._safe_write_file(device_kernel_path, "w", lambda file: file.write(kernel.kernel_source)) except Exception as e: self.logger.error(f"Error saving kernel source code to disk: {e}") # Save wrapped kernel source code try: - wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH) + host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH) if verbose: - self.logger.debug( - f"Saving wrapped kernel source code to file: {wrapped_kernel_path}") - KernelCache._safe_write_file( - wrapped_kernel_path, "w", - lambda file: file.write(kernel.adapter.get_kernel_source())) + self.logger.debug(f"Saving wrapped kernel source code to file: {host_kernel_path}") + if self.execution_backend == "tvm_ffi": + KernelCache._safe_write_file( + host_kernel_path, "w", + lambda file: file.write(kernel.adapter.get_host_source())) + else: + KernelCache._safe_write_file( + host_kernel_path, "w", + lambda file: file.write(kernel.adapter.get_kernel_source())) except Exception as e: - self.logger.error(f"Error saving wrapped kernel source code to disk: {e}") + self.logger.error(f"Error saving host kernel source code to disk: {e}") # Save the kernel library try: # Save CUBIN or SO file - kernel_lib_path = KERNEL_CUBIN_PATH if self.execution_backend == "nvrtc" else KERNEL_LIB_PATH + if self.execution_backend == "nvrtc": + kernel_lib_path = KERNEL_CUBIN_PATH + elif self.execution_backend == "tvm_ffi": + kernel_lib_path = EXECUTABLE_PATH + else: + kernel_lib_path = KERNEL_LIB_PATH + kernel_lib_path = os.path.join(cache_path, kernel_lib_path) - src_lib_path = kernel.adapter.libpath - if verbose: - self.logger.debug(f"Saving kernel library to file: {kernel_lib_path}") - KernelCache._safe_write_file( - kernel_lib_path, "wb", - lambda file: file.write(KernelCache._load_binary(src_lib_path))) # Save an extra Python file for NVRTC if self.execution_backend == "nvrtc": + src_lib_path = kernel.adapter.libpath kernel_py_path = os.path.join(cache_path, KERNEL_PY_PATH) src_lib_path = src_lib_path.replace(".cubin", ".py") if verbose: @@ -307,6 +335,19 @@ class KernelCache: KernelCache._safe_write_file( kernel_py_path, "wb", lambda file: file.write(KernelCache._load_binary(src_lib_path))) + elif self.execution_backend == "tvm_ffi": + executable = kernel.adapter.executable + if verbose: + self.logger.debug(f"Saving kernel executable to file: {executable}") + KernelCache._safe_write_executable(executable, kernel_lib_path) + else: + src_lib_path = kernel.adapter.libpath + if verbose: + self.logger.debug(f"Saving kernel library to file: {kernel_lib_path}") + KernelCache._safe_write_file( + kernel_lib_path, "wb", + lambda file: file.write(KernelCache._load_binary(src_lib_path))) + except Exception as e: self.logger.error(f"Error saving kernel library to disk: {e}") @@ -326,7 +367,7 @@ class KernelCache: target: str | Target = "auto", target_host: str | Target = None, out_idx: list[int] = None, - execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", + execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi", pass_configs: dict = None, compile_flags: list[str] | str | None = None, func: Callable = None, @@ -349,25 +390,39 @@ class KernelCache: JITKernel: The loaded kernel if found, None otherwise. """ cache_path = self._get_cache_path(key) - wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH) - kernel_lib_path = os.path.join( - cache_path, KERNEL_CUBIN_PATH if self.execution_backend == "nvrtc" else KERNEL_LIB_PATH) + device_kernel_path = os.path.join(cache_path, DEVICE_KERNEL_PATH) + host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH) + if self.execution_backend == "nvrtc": + kernel_lib_path = KERNEL_CUBIN_PATH + elif self.execution_backend == "tvm_ffi": + kernel_lib_path = EXECUTABLE_PATH + else: + kernel_lib_path = KERNEL_LIB_PATH + kernel_lib_path = os.path.join(cache_path, kernel_lib_path) params_path = os.path.join(cache_path, PARAMS_PATH) if not all([os.path.exists(file) for file in (kernel_lib_path, params_path)]): return None - kernel_global_source: str | None = None + device_kernel_source: str | None = None + host_kernel_source: str | None = None kernel_params: list[KernelParam] | None = None # Load the kernel source file (optional) + try: + if verbose: + self.logger.debug(f"Loading kernel source code from file: {device_kernel_path}") + with open(device_kernel_path) as f: + device_kernel_source = f.read() + except Exception as e: + self.logger.error(f"Error loading kernel source code from disk: {e}") try: if verbose: self.logger.debug( - f"Loading wrapped kernel source code from file: {wrapped_kernel_path}") - with open(wrapped_kernel_path) as f: - kernel_global_source = f.read() + f"Loading wrapped kernel source code from file: {host_kernel_path}") + with open(host_kernel_path) as f: + host_kernel_source = f.read() except Exception as e: - self.logger.error(f"Error loading wrapped kernel source code from disk: {e}") + self.logger.error(f"Error loading host kernel source code from disk: {e}") # Load kernel parameters try: @@ -378,10 +433,11 @@ class KernelCache: except Exception as e: self.logger.error(f"Error loading kernel parameters from disk: {e}") - if kernel_global_source and kernel_params: + if host_kernel_source and device_kernel_source and kernel_params: return JITKernel.from_database( func=func, - kernel_global_source=kernel_global_source, + host_kernel_source=host_kernel_source, + device_kernel_source=device_kernel_source, kernel_lib_path=kernel_lib_path, params=kernel_params, target=target, @@ -392,6 +448,7 @@ class KernelCache: compile_flags=compile_flags, ) else: + # TODO(lei): report what the reason is. return None def _clear_disk_cache(self): diff --git a/tilelang/contrib/dlpack.py b/tilelang/contrib/dlpack.py index e61d80ce..6772fe11 100644 --- a/tilelang/contrib/dlpack.py +++ b/tilelang/contrib/dlpack.py @@ -59,23 +59,3 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func): return tvm_func(*args) return _wrapper - - -def to_pytorch_func(tvm_func): - """Convert a tvm function into one that accepts PyTorch tensors - - Parameters - ---------- - tvm_func: Function - Built tvm function operating on arrays - - Returns - ------- - wrapped_func: Function - Wrapped tvm function that operates on PyTorch tensors - """ - # pylint: disable=import-outside-toplevel - import torch - import torch.utils.dlpack - - return convert_func(tvm_func, torch.Tensor, torch.utils.dlpack.to_dlpack) diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py index d0c27b4c..c2a14552 100644 --- a/tilelang/engine/lower.py +++ b/tilelang/engine/lower.py @@ -146,7 +146,7 @@ def host_codegen(host_mod: tvm.IRModule, target_host: Target) -> tvm.IRModule: if target_host.kind.name == "llvm": host_mod = tvm.ffi.get_global_func("target.build.llvm")(host_mod, target_host) elif target_host.kind.name == "c": - host_mod = tvm.ffi.get_global_func("target.build.c")(host_mod, target_host) + host_mod = tvm.ffi.get_global_func("target.build.tilelang_c")(host_mod, target_host) else: raise ValueError(f"Target host {target_host.kind.name} is not supported") return host_mod diff --git a/tilelang/jit/__init__.py b/tilelang/jit/__init__.py index 24378ac8..9f0e25f4 100644 --- a/tilelang/jit/__init__.py +++ b/tilelang/jit/__init__.py @@ -23,7 +23,6 @@ except ImportError: # Python < 3.10 from typing_extensions import ParamSpec from tilelang import tvm as tvm from tilelang.language.v2 import PrimFunc -from tilelang.jit.adapter.utils import is_metal_target from tvm.target import Target from tilelang.jit.kernel import JITKernel @@ -46,7 +45,8 @@ _T = TypeVar('_T') def compile( func: PrimFunc[_KP, _T] = None, out_idx: list[int] | int | None = None, - execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", + execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", + "torch"] = "auto", target: str | Target = "auto", target_host: str | Target | None = None, verbose: bool = False, @@ -61,8 +61,9 @@ def compile( The TileLang TIR function to compile and wrap. out_idx : Union[List[int], int], optional Index(es) of the output tensors to return (default: None). - execution_backend : Literal["dlpack", "ctypes", "cython", "nvrtc"], optional - Execution backend to use for kernel execution (default: "cython"). + execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional + Execution backend to use for kernel execution. Use "auto" to pick a sensible + default per target (cuda->tvm_ffi, metal->torch, others->cython). target : Union[str, Target], optional Compilation target, either as a string or a TVM Target object (default: "auto"). target_host : Union[str, Target], optional @@ -80,8 +81,19 @@ def compile( # This path is not a performance critical path, so we can afford to convert the target. target = Target(determine_target(target)) - if is_metal_target(target): - assert execution_backend == 'torch', 'Currently metal target only support `tl.jit(execution_backend="torch")`' + # Resolve execution backend (handles aliases, auto, validation per target) + requested_backend = execution_backend + from tilelang.jit.execution_backend import resolve_execution_backend, allowed_backends_for_target + execution_backend = resolve_execution_backend(requested_backend, target) + if verbose: + allowed_now = allowed_backends_for_target(target, include_unavailable=False) + logger.info( + "Execution backend resolved -> '%s' (requested='%s', target='%s', allowed: %s)", + execution_backend, + requested_backend, + target.kind.name, + ", ".join(sorted(allowed_now)), + ) return cached( func=func, @@ -97,7 +109,8 @@ def compile( def par_compile(funcs: Iterable[PrimFunc[_KP, _T]], out_idx: list[int] | int | None = None, - execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", + execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", + "torch"] = "auto", target: str | Target = "auto", target_host: str | Target | None = None, verbose: bool = False, @@ -113,8 +126,9 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]], The TileLang TIR functions to compile and wrap. out_idx : Union[List[int], int], optional Index(es) of the output tensors to return (default: None). - execution_backend : Literal["dlpack", "ctypes", "cython", "nvrtc"], optional - Execution backend to use for kernel execution (default: "cython"). + execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional + Execution backend to use for kernel execution. Use "auto" to pick a sensible + default per target (cuda->tvm_ffi, metal->torch, others->cython). target : Union[str, Target], optional Compilation target, either as a string or a TVM Target object (default: "auto"). target_host : Union[str, Target], optional @@ -165,7 +179,7 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]], class JITImpl(Generic[_P, _KP, _T]): func: Callable[_P, _T] | PrimFunc[_KP, _T] out_idx: list[int] | int | None - execution_backend: Literal["dlpack", "ctypes", "cython"] + execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] target: str | Target target_host: str | Target verbose: bool @@ -286,7 +300,8 @@ def jit( out_idx: Any = None, target: str | Target = "auto", target_host: str | Target = None, - execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", + execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", + "torch"] = "auto", verbose: bool = False, pass_configs: dict[str, Any] | None = None, debug_root_path: str | None = None, @@ -301,7 +316,8 @@ def jit( # This is the new public interface out_idx: Any = None, target: str | Target = "auto", target_host: str | Target = None, - execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", + execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", + "torch"] = "auto", verbose: bool = False, pass_configs: dict[str, Any] | None = None, debug_root_path: str | None = None, @@ -322,8 +338,9 @@ def jit( # This is the new public interface Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto". target_host : Union[str, Target], optional Target host for cross-compilation. Defaults to None. - execution_backend : Literal["dlpack", "ctypes", "cython", "nvrtc"], optional - Backend for kernel execution and argument passing. Defaults to "cython". + execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional + Backend for kernel execution and argument passing. Use "auto" to pick a sensible + default per target (cuda->tvm_ffi, metal->torch, others->cython). verbose : bool, optional Enables verbose logging during compilation. Defaults to False. pass_configs : Optional[Dict[str, Any]], optional diff --git a/tilelang/jit/adapter/__init__.py b/tilelang/jit/adapter/__init__.py index 0e8fb98c..dcfdaf5b 100644 --- a/tilelang/jit/adapter/__init__.py +++ b/tilelang/jit/adapter/__init__.py @@ -1,5 +1,5 @@ from .base import BaseKernelAdapter # noqa: F401 -from .dlpack import TorchDLPackKernelAdapter # noqa: F401 +from .tvm_ffi import TVMFFIKernelAdapter # noqa: F401 from .ctypes import CtypesKernelAdapter # noqa: F401 from .cython import CythonKernelAdapter # noqa: F401 from .nvrtc import NVRTCKernelAdapter # noqa: F401 diff --git a/tilelang/jit/adapter/base.py b/tilelang/jit/adapter/base.py index 9d998bc9..6bd69cff 100644 --- a/tilelang/jit/adapter/base.py +++ b/tilelang/jit/adapter/base.py @@ -4,6 +4,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from typing import Any, Callable from tilelang.engine.param import KernelParam +import torch class BaseKernelAdapter(ABC): @@ -46,11 +47,54 @@ class BaseKernelAdapter(ABC): def _convert_torch_func(self) -> callable: pass + # --- Common helpers to align with PyTorch stream/device semantics --- + @staticmethod + def get_current_stream_functor() -> Callable[[], int]: + """Return a callable that reads Torch's current CUDA stream pointer. + + The returned lambda yields the raw CUDA stream handle of the current + PyTorch stream on the active device. It's a thunk (evaluated at call + time) so that any upstream stream guards are respected. If CUDA is + unavailable, it returns a lambda that yields 0. + """ + if torch.cuda.is_available(): + try: + torch.cuda._lazy_init() + current_device = torch._C._cuda_getDevice + get_stream = torch._C._cuda_getCurrentRawStream + return lambda: get_stream(current_device()) + except Exception: + # Fallback to Python API if internal handles are unavailable + return lambda: int(torch.cuda.current_stream().cuda_stream) + # CPU or CUDA unavailable: no stream semantics + return lambda: 0 + + @staticmethod + def get_current_device_functor() -> Callable[[], torch.device]: + """Return a callable that yields Torch's current device. + + Similar to the stream functor, we capture a callable that, when called, + fetches the current device according to PyTorch. On CPU or when CUDA is + unavailable, returns ``torch.device('cpu')``. + """ + if torch.cuda.is_available(): + try: + torch.cuda._lazy_init() + current_device = torch._C._cuda_getDevice + return lambda: torch.device("cuda", current_device()) + except Exception: + return lambda: torch.device("cuda", torch.cuda.current_device()) + # CPU fallback + return lambda: torch.device("cpu") + def __call__(self, *args: Any, **kwds: Any) -> Any: return self.func(*args, **kwds) - def get_kernel_source(self) -> str: - return self.mod.imported_modules[0].get_source() + def get_kernel_source(self, kernel_only: bool = True) -> str: + if kernel_only: + return self.mod.imports[0].inspect_source() + else: + return self.mod.inspect_source() + "\n\n" + self.mod.imports[0].inspect_source() def _post_init(self): self.func = self._convert_torch_func() diff --git a/tilelang/jit/adapter/ctypes/adapter.py b/tilelang/jit/adapter/ctypes/adapter.py index bf0aef51..e2677305 100644 --- a/tilelang/jit/adapter/ctypes/adapter.py +++ b/tilelang/jit/adapter/ctypes/adapter.py @@ -14,6 +14,7 @@ from tilelang.utils.target import determine_target from tilelang.utils.language import retrieve_func_from_module +# TODO(lei): remove ctypes adapter. class CtypesKernelAdapter(BaseKernelAdapter): """Adapter class that converts TVM/TIR functions to callable CUDA kernels using ctypes. @@ -28,9 +29,9 @@ class CtypesKernelAdapter(BaseKernelAdapter): ir_module: tvm.IRModule | None = None # The global source code of the kernel -> global means the source code of the kernel # that is not wrapped by the wrapper code - kernel_global_source: str | None = None + host_kernel_source: str | None = None + device_kernel_source: str | None = None lib: ctypes.CDLL | None = None # Compiled library handle - wrapped_source: str | None = None # Generated C++ wrapper code # Maps symbolic variables to their corresponding buffer and shape indices dynamic_symbolic_map: dict[tir.Var, tuple[int, int]] | None = None # Pass configs for the compiler @@ -47,7 +48,8 @@ class CtypesKernelAdapter(BaseKernelAdapter): func_or_mod: tir.PrimFunc | tvm.IRModule, host_mod: tvm.IRModule | None = None, device_mod: tvm.IRModule | None = None, - kernel_global_source: str | None = None, + host_kernel_source: str | None = None, + device_kernel_source: str | None = None, verbose: bool = False, pass_configs: dict[str, Any] | None = None, compile_flags: list[str] | None = None): @@ -62,7 +64,8 @@ class CtypesKernelAdapter(BaseKernelAdapter): """ self.params = params self.result_idx = self._legalize_result_idx(result_idx) - self.kernel_global_source = kernel_global_source + self.host_kernel_source = host_kernel_source + self.device_kernel_source = device_kernel_source if isinstance(func_or_mod, tir.PrimFunc): self.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) @@ -111,7 +114,8 @@ class CtypesKernelAdapter(BaseKernelAdapter): result_idx: list[int], target: str, func_or_mod: tir.PrimFunc | tvm.IRModule, - kernel_global_source: str, + host_kernel_source: str, + device_kernel_source: str, kernel_lib_path: str, verbose: bool = False, pass_configs: dict[str, Any] | None = None, @@ -119,8 +123,9 @@ class CtypesKernelAdapter(BaseKernelAdapter): adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) - adapter.kernel_global_source = kernel_global_source - adapter.wrapped_source = kernel_global_source + adapter.host_kernel_source = host_kernel_source + adapter.device_kernel_source = device_kernel_source + adapter.wrapped_source = device_kernel_source + "\n\n" + host_kernel_source adapter.pass_configs = pass_configs if isinstance(func_or_mod, tir.PrimFunc): @@ -288,7 +293,7 @@ class CtypesKernelAdapter(BaseKernelAdapter): def get_kernel_source(self, kernel_only: bool = False): """Returns the source code of the compiled kernel.""" if kernel_only: - return self.kernel_global_source + return self.device_kernel_source else: - assert self.wrapped_source is not None, "Wrapped source is not available" - return self.wrapped_source + # Wrapper only has host kernel source + return self.host_kernel_source diff --git a/tilelang/jit/adapter/cython/adapter.py b/tilelang/jit/adapter/cython/adapter.py index bc43533b..fe8fe5bd 100644 --- a/tilelang/jit/adapter/cython/adapter.py +++ b/tilelang/jit/adapter/cython/adapter.py @@ -48,9 +48,9 @@ class CythonKernelAdapter(BaseKernelAdapter): ir_module: tvm.IRModule | None = None # The global source code of the kernel -> global means the source code of the kernel # that is not wrapped by the wrapper code - kernel_global_source: str | None = None + host_kernel_source: str | None = None + device_kernel_source: str | None = None lib: ctypes.CDLL | None = None # Compiled library handle - wrapped_source: str | None = None # Generated C++ wrapper code # Maps symbolic variables to their corresponding buffer and shape indices dynamic_symbolic_map: dict[tir.Var, tuple[int, int]] | None = None # Maps pointer arguments to their corresponding (buffer_index, shape_dimension) @@ -77,7 +77,7 @@ class CythonKernelAdapter(BaseKernelAdapter): func_or_mod: tir.PrimFunc | tvm.IRModule, host_mod: tvm.IRModule | None = None, device_mod: tvm.IRModule | None = None, - kernel_global_source: str | None = None, + device_kernel_source: str | None = None, verbose: bool = False, pass_configs: dict[str, Any] | None = None, compile_flags: list[str] | None = None): @@ -92,7 +92,7 @@ class CythonKernelAdapter(BaseKernelAdapter): """ self.params = params self.result_idx = self._legalize_result_idx(result_idx) - self.kernel_global_source = kernel_global_source + self.device_kernel_source = device_kernel_source if isinstance(func_or_mod, tir.PrimFunc): self.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) @@ -121,9 +121,9 @@ class CythonKernelAdapter(BaseKernelAdapter): self.wrapper.assign_pass_configs(pass_configs) self.wrapper.assign_host_module(host_mod) self.wrapper.assign_device_module(device_mod) - self.wrapped_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True)) + self.host_kernel_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True)) - self.lib_generator.update_lib_code(self.wrapped_source) + self.lib_generator.update_lib_code(self.host_kernel_source) self.lib_generator.compile_lib() self.lib = self.lib_generator.load_lib() @@ -150,7 +150,8 @@ class CythonKernelAdapter(BaseKernelAdapter): result_idx: list[int], target: str, func_or_mod: tir.PrimFunc | tvm.IRModule, - kernel_global_source: str, + host_kernel_source: str, + device_kernel_source: str, kernel_lib_path: str, verbose: bool = False, pass_configs: dict[str, Any] | None = None, @@ -158,8 +159,8 @@ class CythonKernelAdapter(BaseKernelAdapter): adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) - adapter.kernel_global_source = kernel_global_source - adapter.wrapped_source = kernel_global_source + adapter.host_kernel_source = host_kernel_source + adapter.device_kernel_source = device_kernel_source adapter.pass_configs = pass_configs if isinstance(func_or_mod, tir.PrimFunc): @@ -382,7 +383,8 @@ class CythonKernelAdapter(BaseKernelAdapter): def get_kernel_source(self, kernel_only: bool = False): """Returns the source code of the compiled kernel.""" if kernel_only: - return self.kernel_global_source + return self.device_kernel_source else: - assert self.wrapped_source is not None, "Wrapped source is not available" - return self.wrapped_source + # Wrapper only has host kernel source + assert self.host_kernel_source is not None, "Wrapped source is not available" + return self.host_kernel_source diff --git a/tilelang/jit/adapter/dlpack.py b/tilelang/jit/adapter/dlpack.py deleted file mode 100644 index 402dfb2f..00000000 --- a/tilelang/jit/adapter/dlpack.py +++ /dev/null @@ -1,40 +0,0 @@ -"""The profiler and convert to torch utils""" -import torch -from tilelang.contrib.dlpack import to_pytorch_func -from .base import BaseKernelAdapter - - -class TorchDLPackKernelAdapter(BaseKernelAdapter): - - def _convert_torch_func(self) -> callable: - torch_func = to_pytorch_func(self.mod) - - def func(*ins: list[torch.Tensor]): - if len(ins) + len(self.result_idx) != len(self.params): - raise ValueError( - f"Expected {len(self.params)} inputs, got {len(ins) + len(self.result_idx)} with {len(ins)} inputs and {len(self.result_idx)} outputs" - ) - ins_idx = 0 - args = [] - - # use the device of the first input tensor if available - device = ins[0].device if len(ins) > 0 else torch.cuda.current_device() - - for i in range(len(self.params)): - if i in self.result_idx: - dtype = self.params[i].dtype - shape = list(map(int, self.params[i].shape)) - tensor = torch.empty(*shape, dtype=dtype, device=device) - else: - tensor = ins[ins_idx] - ins_idx += 1 - args.append(tensor) - - torch_func(*args) - - if len(self.result_idx) == 1: - return args[self.result_idx[0]] - else: - return [args[i] for i in self.result_idx] - - return func diff --git a/tilelang/jit/adapter/nvrtc/adapter.py b/tilelang/jit/adapter/nvrtc/adapter.py index 5f8a2827..4a465d33 100644 --- a/tilelang/jit/adapter/nvrtc/adapter.py +++ b/tilelang/jit/adapter/nvrtc/adapter.py @@ -34,7 +34,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter): func_or_mod: tir.PrimFunc | tvm.IRModule, host_mod: tvm.IRModule | None = None, device_mod: tvm.IRModule | None = None, - kernel_global_source: str | None = None, + device_kernel_source: str | None = None, verbose: bool = False, pass_configs: dict[str, Any] | None = None, compile_flags: list[str] | None = None): @@ -43,7 +43,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter): self.params = params self.result_idx = self._legalize_result_idx(result_idx) - self.kernel_global_source = kernel_global_source + self.device_kernel_source = device_kernel_source if isinstance(func_or_mod, tir.PrimFunc): self.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) @@ -74,10 +74,10 @@ class NVRTCKernelAdapter(BaseKernelAdapter): self.wrapper.assign_pass_configs(pass_configs) self.wrapper.assign_host_module(host_mod) self.wrapper.assign_device_module(device_mod) - self.host_func, self.function_names = self.wrapper.wrap(kernel_global_source) + self.host_func, self.function_names = self.wrapper.wrap(device_kernel_source) self.lib_generator = NVRTCLibraryGenerator(self.target, self.verbose) - self.lib_generator.update_lib_code(self.kernel_global_source) + self.lib_generator.update_lib_code(self.device_kernel_source) self.lib_generator.update_host_func(self.host_func) self.lib_generator.assign_compile_flags(compile_flags) self.lib_generator.compile_lib() @@ -97,7 +97,8 @@ class NVRTCKernelAdapter(BaseKernelAdapter): result_idx: list[int], target: str, func_or_mod: tir.PrimFunc | tvm.IRModule, - kernel_global_source: str, + host_kernel_source: str, + device_kernel_source: str, kernel_lib_path: str, verbose: bool = False, pass_configs: dict[str, Any] | None = None, @@ -105,7 +106,8 @@ class NVRTCKernelAdapter(BaseKernelAdapter): adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) - adapter.kernel_global_source = kernel_global_source + adapter.host_kernel_source = host_kernel_source + adapter.device_kernel_source = device_kernel_source if isinstance(func_or_mod, tir.PrimFunc): adapter.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) @@ -167,7 +169,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter): dynamic_symbolic_map[shape] = (i, j) return dynamic_symbolic_map - def get_kernel_source(self) -> str | None: + def get_kernel_source(self, kernel_only: bool = True) -> str | None: """Get the CUDA kernel source code. Returns @@ -175,7 +177,10 @@ class NVRTCKernelAdapter(BaseKernelAdapter): Optional[str] The kernel source code, or None if not available """ - return self.kernel_global_source + if kernel_only: + return self.device_kernel_source + else: + return self.host_func def _forward_from_prebuild_lib(self, *args, stream: int | None = None): """Low-level function to call the compiled CUDA kernel. diff --git a/tilelang/jit/adapter/tvm_ffi.py b/tilelang/jit/adapter/tvm_ffi.py new file mode 100644 index 00000000..e06e9862 --- /dev/null +++ b/tilelang/jit/adapter/tvm_ffi.py @@ -0,0 +1,321 @@ +"""Utilities to adapt TVM FFI kernels to Torch tensors. + +This adapter intentionally captures PyTorch's current CUDA stream and device +via light-weight callables so that, when the wrapped function is invoked, +the execution observes the same stream context as the active Torch code. +On non-CUDA builds, the stream/device fall back to 0/CPU semantics. +""" +from __future__ import annotations + +from typing import Callable, Any + +import torch +from tilelang import tvm +from tvm import runtime, tir +from tvm.target import Target +from tvm.relax import TensorType +from tilelang.utils.target import determine_target +from tilelang.jit.adapter.base import BaseKernelAdapter +from tilelang.utils.language import retrieve_func_from_module +from tilelang.engine.param import KernelParam + + +class TVMFFIKernelAdapter(BaseKernelAdapter): + """Adapter that runs a TVM runtime.Executable with Torch tensors. + + Notes + - We capture the "current" PyTorch CUDA stream/device as thunks (callables) + rather than materializing them at construction time. This ensures the + actual stream/device is read just-in-time when the function runs, matching + the user's current Torch context (e.g., after a stream guard/switch). + - The stream pointer returned is a raw CUDA stream handle compatible with + TVM's device API; on CPU or when CUDA is unavailable, we return 0. + """ + # Class attributes to store compiled kernel information + target: str | Target = "cuda" + ir_module: tvm.IRModule | None = None + # The global source code of the kernel -> global means the source code of the kernel + # that is not wrapped by the wrapper code + host_kernel_source: str | None = None + device_kernel_source: str | None = None + executable: tvm.runtime.Executable | None = None + # Pass configs for the compiler + pass_configs: dict[str, Any] | None = None + # host_mod + host_mod: tvm.IRModule | None = None + # device_mod + device_mod: tvm.IRModule | None = None + # rt_mod + rt_mod: tvm.runtime.Module | None = None + # Maps symbolic variables to their corresponding buffer and shape indices + dynamic_symbolic_map: dict[tir.Var, tuple[int, int, int]] | None = None + + # Stream/device functors are inherited from BaseKernelAdapter + def __init__(self, + params: list[KernelParam], + result_idx: list[int], + target: str | Target, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_mod: tvm.IRModule | None = None, + device_mod: tvm.IRModule | None = None, + rt_mod: tvm.runtime.Module | None = None, + host_kernel_source: str | None = None, + device_kernel_source: str | None = None, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None): + """Initialize the adapter with the given TIR function or module. + + Args: + params: List of tensor types for inputs/outputs + result_idx: Indices of output tensors + target: Target platform (e.g., 'cuda') + func_or_mod: TIR function or module to be compiled + verbose: Enable verbose logging + """ + self.params = params + self.result_idx = self._legalize_result_idx(result_idx) + self.host_kernel_source = host_kernel_source + self.device_kernel_source = device_kernel_source + + if isinstance(func_or_mod, tir.PrimFunc): + self.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) + else: + self.ir_module = func_or_mod + + self.target = Target.canon_target(determine_target(target)) + + self.host_mod = host_mod + self.device_mod = device_mod + self.rt_mod = rt_mod + self.verbose = verbose + self.pass_configs = pass_configs + self.compile_flags = compile_flags + self.dynamic_symbolic_map = self._process_dynamic_symbolic() + + self._post_init() + + def _process_dynamic_symbolic(self) -> dict[tir.Var, tuple[int, int]]: + """Extract information about dynamic shapes from the TIR function. + + Maps symbolic variables to their corresponding (id, buffer_index, dimension) + for runtime shape resolution. + id represents shape or stride, 0 represents shape, 1 represents stride + """ + func = self.prim_func + params = func.params + buffer_map = func.buffer_map + dynamic_symbolic_map = {} + for i, param in enumerate(params): + if param in buffer_map: + buffer = buffer_map[param] + for j, shape in enumerate(buffer.shape): + if (isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and + (shape not in params)): + dynamic_symbolic_map[shape] = (0, i, j) + for i, param in enumerate(params): + if param in buffer_map: + buffer = buffer_map[param] + for j, stride in enumerate(buffer.strides): + if (isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and + (stride not in params)): + dynamic_symbolic_map[stride] = (1, i, j) + return dynamic_symbolic_map + + def _convert_torch_func(self) -> Callable[..., Any]: + # Capture thunks that reflect Torch's current stream and device. + # These are evaluated at call time to align TVM execution with the + # caller's active PyTorch stream/device. + # current_stream_functor = self.get_current_stream_functor() + current_device_functor = self.get_current_device_functor() + + # Convert TVM types to native Python types during initialization + param_dtypes = [param.dtype for param in self.params] + # Convert TVM shape arrays to native Python lists + param_shapes = [] + + for param in self.params: + native_shape = [] + for dim in param.shape: + if isinstance(dim, tir.IntImm): + native_shape.append(int(dim)) + elif isinstance(dim, tir.Var): + native_shape.append(dim) # Keep tir.Var for dynamic dimensions + else: + native_shape.append(dim) + param_shapes.append(native_shape) + + if self.executable is None: + self.executable = runtime.Executable(self.rt_mod) + + dynamic_symbolic_map = self._process_dynamic_symbolic() + executable = self.executable + + # Prepare helpers for friendly dtype error messages + prim_func = self.prim_func + buffer_map = prim_func.buffer_map + params = prim_func.params + # Expected dtype string per parameter index (for buffers only) + expected_dtype_strs: list[str | None] = [] + # Track whether each param is a buffer (has dtype) vs scalar + is_buffer_param: list[bool] = [] + for p in params: + if p in buffer_map: + expected_dtype_strs.append(str(buffer_map[p].dtype)) + is_buffer_param.append(True) + else: + expected_dtype_strs.append(None) + is_buffer_param.append(False) + # Global function name used in error messages + global_symbol = str(prim_func.attrs.get("global_symbol", "main")) + + # Map torch dtype to TVM-style dtype string + def torch_dtype_to_tvm_str(dtype: torch.dtype) -> str: + try: + import torch as _torch + except Exception: # pragma: no cover + # Fallback, though torch should always be available here + return str(dtype) + fp8_e4m3fn = getattr(_torch, "float8_e4m3fn", None) + fp8_e4m3fnuz = getattr(_torch, "float8_e4m3fnuz", None) + fp8_e5m2 = getattr(_torch, "float8_e5m2", None) + fp8_e5m2fnuz = getattr(_torch, "float8_e5m2fnuz", None) + if fp8_e4m3fn is not None and dtype == fp8_e4m3fn: + return "float8_e4m3" + if fp8_e4m3fnuz is not None and dtype == fp8_e4m3fnuz: + return "float8_e4m3fnuz" + if fp8_e5m2 is not None and dtype == fp8_e5m2: + return "float8_e5m2" + if fp8_e5m2fnuz is not None and dtype == fp8_e5m2fnuz: + return "float8_e5m2" + # Strip torch. prefix for readability + s = str(dtype) + return s[6:] if s.startswith("torch.") else s + + def func(*inputs: torch.Tensor | Any): + # Validate input count strictly + expected_inputs = len(self.params) - len(self.result_idx) + if len(inputs) != expected_inputs: + raise ValueError( + f"Expected {expected_inputs} inputs, got {len(inputs)} (params={len(self.params)}, outputs={len(self.result_idx)})" + ) + + # Resolve the device used for outputs. Prefer the first tensor input's device + # if available, otherwise use PyTorch's current device. + out_device: torch.device | None = None + + # Stitch the full positional argument list expected by the TVM executable + ins_idx: int = 0 + tensor_list: list[torch.Tensor] = [] + + # Prepare input and output tensors + for i in range(len(self.params)): + if i in self.result_idx: + dtype = param_dtypes[i] + shape = [] + # Now working with native Python list, no FFI calls needed + for s in param_shapes[i]: + if isinstance(s, tir.Var): + for key in dynamic_symbolic_map: + if (str(s) == str(key)): + ref_id, ref_tensor_idx, ref_shape_idx = dynamic_symbolic_map[ + key] + shape.append(tensor_list[ref_tensor_idx].shape[ref_shape_idx]) + else: # Already converted to Python int during initialization + shape.append(s) + + if out_device is None: + out_device = current_device_functor() + + if len(shape) == 0: + param_name = self.params[i].name if hasattr(self.params[i], + 'name') else f'parameter_{i}' + raise ValueError( + f"Cannot create output tensor (name={param_name}) - 0-dimensional tensors are not supported. " + f"Expected shape: {shape}") + tensor = torch.empty(*shape, dtype=dtype, device=out_device) + else: + tensor = inputs[ins_idx] + # Input dtype validation with clear error message + if is_buffer_param[i]: + expected_dtype_str = expected_dtype_strs[i] + expected_torch_dtype = param_dtypes[i] + # Only check when the argument is a tensor and expected dtype is known + if isinstance( + tensor, torch.Tensor + ) and expected_dtype_str is not None and tensor.dtype != expected_torch_dtype: + param_var = params[i] + # Reconstruct TVM-like handle name A_handle for error clarity + handle_name = f"{param_var.name}_handle" + actual_dtype_str = torch_dtype_to_tvm_str(tensor.dtype) + raise RuntimeError( + f"{global_symbol}.{handle_name}.dtype is expected to be {expected_dtype_str}, but got {actual_dtype_str}" + ) + ins_idx += 1 + tensor_list.append(tensor) + + executable(*tensor_list) + + # Return outputs in the requested form + if len(self.result_idx) == 1: + return tensor_list[self.result_idx[0]] + return [tensor_list[i] for i in self.result_idx] + + return func + + @classmethod + def from_database(cls, + params: list[TensorType], + result_idx: list[int], + target: str, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_kernel_source: str, + device_kernel_source: str, + kernel_lib_path: str, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None): + adapter = cls.__new__(cls) + adapter.params = params + adapter.result_idx = adapter._legalize_result_idx(result_idx) + adapter.host_kernel_source = host_kernel_source + adapter.device_kernel_source = device_kernel_source + adapter.wrapped_source = device_kernel_source + "\n\n" + host_kernel_source + adapter.pass_configs = pass_configs + + if isinstance(func_or_mod, tir.PrimFunc): + adapter.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) + else: + adapter.ir_module = func_or_mod + + target = determine_target(target, return_object=True) + adapter.target = Target.canon_target(determine_target(target)) + + adapter.verbose = verbose + adapter.executable = runtime.load_module(kernel_lib_path) + adapter._post_init() + return adapter + + def get_host_source(self): + """Returns the source code of the host module.""" + if self.host_kernel_source is not None: + return self.host_kernel_source + return self.rt_mod.inspect_source() + + def get_device_source(self): + """Returns the source code of the device module.""" + if self.device_kernel_source is not None: + return self.device_kernel_source + return self.rt_mod.imports[0].inspect_source() + + def get_kernel_source(self, kernel_only: bool = False): + """Returns the source code of the compiled kernel.""" + if kernel_only: + return self.get_device_source() + else: + return self.get_device_source() + "\n\n" + self.get_host_source() + + @property + def prim_func(self) -> tir.PrimFunc: + """Returns the primary TIR function from the IR module.""" + return retrieve_func_from_module(self.ir_module) diff --git a/tilelang/jit/execution_backend.py b/tilelang/jit/execution_backend.py new file mode 100644 index 00000000..fe600002 --- /dev/null +++ b/tilelang/jit/execution_backend.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from collections.abc import Iterable + +from tvm.target import Target + +# Canonical names for execution backends used internally +_CANONICAL_MAP = { + "dlpack": "tvm_ffi", # historical alias +} + + +def _canon_backend(name: str | None) -> str | None: + if name is None: + return None + key = str(name).lower() + return _CANONICAL_MAP.get(key, key) + + +def _target_kind(target: Target) -> str: + # tvm.target.Target always has kind.name + return target.kind.name + + +def allowed_backends_for_target(target: Target, *, include_unavailable: bool = True) -> list[str]: + """Return allowed execution backends for a given TVM target kind. + + include_unavailable: if False, this will filter out backends that are known + to be unavailable at runtime (e.g., NVRTC without cuda-python installed). + """ + kind = _target_kind(target) + + if kind == "cuda": + allowed = ["tvm_ffi", "nvrtc", "cython", "ctypes"] + elif kind == "hip": + allowed = ["tvm_ffi", "cython", "ctypes"] + elif kind == "metal": + allowed = ["torch"] + elif kind == "c": # CPU C backend + allowed = ["cython", "ctypes", "tvm_ffi"] + else: + # Fallback: prefer portable hosts + allowed = ["cython", "ctypes", "tvm_ffi"] + + if not include_unavailable: + # Drop NVRTC if not importable + try: + from tilelang.jit.adapter.nvrtc import is_nvrtc_available # lazy + if not is_nvrtc_available and "nvrtc" in allowed: + allowed = [b for b in allowed if b != "nvrtc"] + except Exception: + # Be conservative and keep nvrtc if detection itself fails + pass + + return allowed + + +def _format_options(options: Iterable[str]) -> str: + return ", ".join(sorted(options)) + + +def resolve_execution_backend(requested: str | None, target: Target) -> str: + """Resolve an execution backend string to a concrete backend. + + - Supports the alias "dlpack" -> "tvm_ffi". + - Supports the sentinel "auto" which selects a sensible default per target. + - Validates the combination (target, backend) and raises with helpful + alternatives when invalid. + """ + req = _canon_backend(requested) + allowed_all = allowed_backends_for_target(target, include_unavailable=True) + allowed_avail = allowed_backends_for_target(target, include_unavailable=False) + + # Default selection for auto/None + if req in (None, "auto"): + kind = _target_kind(target) + if kind == "cuda": + choice = "tvm_ffi" + elif kind == "metal": + choice = "torch" + else: + choice = "cython" + # If the chosen default is not available (very rare), fall back to first available + if choice not in allowed_avail and allowed_avail: + choice = allowed_avail[0] + return choice + + # Validate against allowed + if req not in allowed_all: + raise ValueError( + f"Invalid execution backend '{requested}' for target '{_target_kind(target)}'. " + f"Allowed: {_format_options(allowed_all)}. Tip: use execution_backend='auto'.") + + # Promote to availability-aware set for nicer errors (e.g., nvrtc not installed) + if req not in allowed_avail: + raise ValueError( + f"Execution backend '{requested}' requires extra dependencies and is not available now. " + f"Try one of: {_format_options(allowed_avail)}.") + + return req diff --git a/tilelang/jit/kernel.py b/tilelang/jit/kernel.py index 6f5eb0b5..22cecf99 100644 --- a/tilelang/jit/kernel.py +++ b/tilelang/jit/kernel.py @@ -15,7 +15,7 @@ from tilelang import tvm from tilelang import env from tilelang.engine.param import CompiledArtifact, KernelParam from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter, - TorchDLPackKernelAdapter, MetalKernelAdapter) + TVMFFIKernelAdapter, MetalKernelAdapter) from tilelang.profiler import Profiler, TensorSupplyType from tilelang.utils.target import determine_target from tilelang.contrib import nvcc as tl_nvcc @@ -55,7 +55,7 @@ class JITKernel(Generic[_P, _T]): self, func: PrimFunc = None, out_idx: list[int] | int = None, - execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", + execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi", target: str | Target = "auto", target_host: str | Target = None, verbose: bool = False, @@ -72,8 +72,8 @@ class JITKernel(Generic[_P, _T]): The TileLang TIR function to compile and wrap. out_idx : Union[List[int], int], optional Index(es) of the output tensors to return (default: None). - execution_backend : Literal["dlpack", "ctypes", "cython", "nvrtc"], optional - Execution backend to use for kernel execution (default: "cython"). + execution_backend : Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional + Execution backend to use for kernel execution. target : Union[str, Target], optional Compilation target, either as a string or a TVM Target object (default: "auto"). target_host : Union[str, Target], optional @@ -102,7 +102,7 @@ class JITKernel(Generic[_P, _T]): # Validate the execution backend. assert execution_backend in [ - "dlpack", + "tvm_ffi", "ctypes", "cython", "nvrtc", @@ -143,13 +143,14 @@ class JITKernel(Generic[_P, _T]): def from_database( cls, func: PrimFunc, - kernel_global_source: str, + host_kernel_source: str, + device_kernel_source: str, kernel_lib_path: str, params: list[KernelParam], target: str | Target, target_host: str | Target, out_idx: list[int] | int, - execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"], + execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], pass_configs: dict[str, Any] | None = None, compile_flags: list[str] | None = None, ): @@ -172,7 +173,8 @@ class JITKernel(Generic[_P, _T]): params=params, result_idx=out_idx, target=target, - kernel_global_source=kernel_global_source, + host_kernel_source=host_kernel_source, + device_kernel_source=device_kernel_source, kernel_lib_path=kernel_lib_path, pass_configs=pass_configs, compile_flags=compile_flags, @@ -223,8 +225,8 @@ class JITKernel(Generic[_P, _T]): compile_flags = self.compile_flags # Compile the function with TVM, optimizing with shared memory lowering. - enable_host_codegen = execution_backend == "dlpack" - enable_device_compile = execution_backend == "dlpack" + enable_host_codegen = execution_backend == "tvm_ffi" + enable_device_compile = execution_backend == "tvm_ffi" with tvm.transform.PassContext(opt_level=3, config=pass_configs), self.target: artifact = tilelang.lower( tilelang_func, @@ -236,13 +238,23 @@ class JITKernel(Generic[_P, _T]): self.artifact = artifact # Create an adapter based on the specified execution backend. - if execution_backend == "dlpack": - # Use TorchDLPackKernelAdapter for interoperability with PyTorch via DLPack. + if execution_backend == "tvm_ffi": + # Use TVMFFIKernelAdapter for interoperability with PyTorch via DLPack. # But we need to ensure that the runtime is enabled and the runtime module is not None. - assert tvm.runtime.enabled("llvm"), "DLPack backend requires LLVM runtime." - assert (artifact.rt_mod is not None), "DLPack backend requires a runtime module." - adapter = TorchDLPackKernelAdapter( - artifact.rt_mod, params=artifact.params, result_idx=out_idx) + assert (artifact.rt_mod is not None), "tvm_ffi backend requires a runtime module." + adapter = TVMFFIKernelAdapter( + params=artifact.params, + result_idx=out_idx, + target=target, + func_or_mod=tilelang_func, + host_mod=artifact.host_mod, + device_mod=artifact.device_mod, + rt_mod=artifact.rt_mod, + device_kernel_source=artifact.kernel_source, + verbose=verbose, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) elif execution_backend == "ctypes": adapter = CtypesKernelAdapter( params=artifact.params, @@ -251,7 +263,7 @@ class JITKernel(Generic[_P, _T]): func_or_mod=tilelang_func, host_mod=artifact.host_mod, device_mod=artifact.device_mod, - kernel_global_source=artifact.kernel_source, + device_kernel_source=artifact.kernel_source, verbose=verbose, pass_configs=pass_configs, compile_flags=compile_flags, @@ -264,7 +276,7 @@ class JITKernel(Generic[_P, _T]): func_or_mod=tilelang_func, host_mod=artifact.host_mod, device_mod=artifact.device_mod, - kernel_global_source=artifact.kernel_source, + device_kernel_source=artifact.kernel_source, verbose=verbose, pass_configs=pass_configs, compile_flags=compile_flags, @@ -278,7 +290,7 @@ class JITKernel(Generic[_P, _T]): func_or_mod=tilelang_func, host_mod=artifact.host_mod, device_mod=artifact.device_mod, - kernel_global_source=artifact.kernel_source, + device_kernel_source=artifact.kernel_source, verbose=verbose, pass_configs=pass_configs, compile_flags=compile_flags, @@ -308,7 +320,8 @@ class JITKernel(Generic[_P, _T]): result_idx: list[int] | int, target: str | Target, func_or_mod: PrimFunc | tvm.runtime.Module, - kernel_global_source: str, + host_kernel_source: str, + device_kernel_source: str, kernel_lib_path: str, pass_configs: dict[str, Any] | None = None, compile_flags: list[str] | None = None) -> BaseKernelAdapter: @@ -316,15 +329,26 @@ class JITKernel(Generic[_P, _T]): execution_backend = self.execution_backend # Create an adapter based on the specified execution backend. - if execution_backend == "dlpack": - raise ValueError("DLPack backend is not supported for TileLang JIT.") + if execution_backend == "tvm_ffi": + adapter = TVMFFIKernelAdapter.from_database( + params=params, + result_idx=result_idx, + target=target, + func_or_mod=func_or_mod, + host_kernel_source=host_kernel_source, + device_kernel_source=device_kernel_source, + kernel_lib_path=kernel_lib_path, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) elif execution_backend == "ctypes": adapter = CtypesKernelAdapter.from_database( params=params, result_idx=result_idx, target=target, func_or_mod=func_or_mod, - kernel_global_source=kernel_global_source, + host_kernel_source=host_kernel_source, + device_kernel_source=device_kernel_source, kernel_lib_path=kernel_lib_path, pass_configs=pass_configs, compile_flags=compile_flags, @@ -335,7 +359,8 @@ class JITKernel(Generic[_P, _T]): result_idx=result_idx, target=target, func_or_mod=func_or_mod, - kernel_global_source=kernel_global_source, + host_kernel_source=host_kernel_source, + device_kernel_source=device_kernel_source, kernel_lib_path=kernel_lib_path, pass_configs=pass_configs, ) @@ -346,7 +371,8 @@ class JITKernel(Generic[_P, _T]): result_idx=result_idx, target=target, func_or_mod=func_or_mod, - kernel_global_source=kernel_global_source, + host_kernel_source=host_kernel_source, + device_kernel_source=device_kernel_source, kernel_lib_path=kernel_lib_path, pass_configs=pass_configs, compile_flags=compile_flags, @@ -394,7 +420,7 @@ class JITKernel(Generic[_P, _T]): return Profiler(self.params, self.out_idx, tensor_supply_type).with_default_adapter(self.adapter) - def get_kernel_source(self) -> str: + def get_kernel_source(self, kernel_only: bool = True) -> str: """ Returns the source code of the compiled kernel function. @@ -403,14 +429,17 @@ class JITKernel(Generic[_P, _T]): str The source code of the compiled kernel function. """ - if self.execution_backend in {"ctypes", "cython", "nvrtc"}: - return self.adapter.get_kernel_source() + if self.execution_backend in {"ctypes", "cython", "nvrtc", "tvm_ffi"}: + return self.adapter.get_kernel_source(kernel_only=kernel_only) return self.artifact.kernel_source def get_host_source(self) -> str: """ Returns the source code of the host function. """ + if self.execution_backend in {"ctypes", "cython", "nvrtc", "tvm_ffi"}: + return self.adapter.get_host_source() + assert self.artifact.host_mod is not None, "host_mod is not available" return str(self.artifact.host_mod) def run_once(self, func: Callable | None = None) -> None: diff --git a/tilelang/profiler/__init__.py b/tilelang/profiler/__init__.py index 3ff2baab..5af1fc2b 100644 --- a/tilelang/profiler/__init__.py +++ b/tilelang/profiler/__init__.py @@ -10,7 +10,6 @@ from tilelang.utils.tensor import ( get_tensor_supply, TensorSupplyType, torch_assert_close, - adapt_torch2tvm, ) from tilelang.engine.param import KernelParam from tilelang.jit.adapter import BaseKernelAdapter @@ -274,9 +273,8 @@ class Profiler: device = tvm.cuda(0) if target == "cuda" else tvm.rocm(0) time_evaluator = self.mod.time_evaluator( self.mod.entry_name, device, number=rep, repeat=n_repeat) - tvm_inputs = [adapt_torch2tvm(inp) for inp in ins] # Transform Latency to ms - return time_evaluator(*tvm_inputs).mean * 1e3 + return time_evaluator(*ins).mean * 1e3 else: raise ValueError(f"Unknown profiler: {profiler}") diff --git a/tilelang/utils/tensor.py b/tilelang/utils/tensor.py index 79947750..b275708c 100644 --- a/tilelang/utils/tensor.py +++ b/tilelang/utils/tensor.py @@ -1,9 +1,7 @@ """The profiler and convert to torch utils""" from enum import Enum import torch -from tvm import runtime from tvm import tir -from torch.utils.dlpack import to_dlpack import numpy as np @@ -37,23 +35,6 @@ def map_torch_type(intype: str) -> torch.dtype: return getattr(torch, intype) -def adapt_torch2tvm(arg): - float8_dtype_map = { - torch.float8_e4m3fn: "float8_e4m3", - torch.float8_e4m3fnuz: "float8_e4m3", - torch.float8_e5m2: "float8_e5m2", - torch.float8_e5m2fnuz: "float8_e5m2", - } - if isinstance(arg, torch.Tensor): - if arg.dtype in { - torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz - }: - return runtime.from_dlpack(to_dlpack(arg.view(torch.int8)))._create_view( - shape=arg.shape, dtype=float8_dtype_map[arg.dtype]) - return runtime.from_dlpack(to_dlpack(arg)) - return arg - - def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer): from tilelang.engine.param import KernelParam -- GitLab From 4c8b9adab435f3e6fa05a4da4aaaec4a8f66c2d9 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Wed, 19 Nov 2025 14:09:35 +0800 Subject: [PATCH 019/139] [Bugfix] Supply missing `T.print` for bool type (#1279) * fix for bool dtype * lint fix * fix * ci fix --- 3rdparty/tvm | 2 +- src/tl_templates/cuda/debug.h | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index f4105f89..f4affc7f 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit f4105f89a646622acc9818584d1d91e2ca3f533d +Subproject commit f4affc7f31e36e7f88c0fe1c715b03215c6a0c62 diff --git a/src/tl_templates/cuda/debug.h b/src/tl_templates/cuda/debug.h index 7dbb31ea..e8976874 100644 --- a/src/tl_templates/cuda/debug.h +++ b/src/tl_templates/cuda/debug.h @@ -29,6 +29,14 @@ __device__ void debug_print_var(const char *msg, signed char var) { threadIdx.z, var); } +// Specialization for plain char type +template <> __device__ void debug_print_var(const char *msg, char var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=char " + "value=%d\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, (int)var); +} + // Specialization for unsigned char type template <> __device__ void debug_print_var(const char *msg, @@ -58,6 +66,14 @@ __device__ void debug_print_var(const char *msg, threadIdx.z, var); } +// Specialization for bool type +template <> __device__ void debug_print_var(const char *msg, bool var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=bool " + "value=%s\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, var ? "true" : "false"); +} + // Specialization for float type template <> __device__ void debug_print_var(const char *msg, float var) { printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=float " -- GitLab From cd681e6384c72fb8fd0375e21b58791e549ce8fc Mon Sep 17 00:00:00 2001 From: Kuris <227995639+kurisu6912@users.noreply.github.com> Date: Wed, 19 Nov 2025 14:17:45 +0800 Subject: [PATCH 020/139] [Fix] Fix memory leak bug (#1281) * add typing stub for tir.ir * remove idents * minor update * [Refactor] add numpy conversion for dtype * fix lint error * remove unused np.float_ in dtype conversion * fix type in np.int_ * fix typo * minor fix * remove debug files * fix memory leak bug * fix lint error * add comments * fix lint error * remove duplicated, because tilelang doesn't dependent deprecated --- .../python/language/test_tilelang_capture.py | 40 ++++++++++++++++ tilelang/language/v2/ast.py | 39 ++++++++++++--- tilelang/language/v2/builder.py | 48 +++++++++++-------- tilelang/language/v2/utils.py | 20 -------- 4 files changed, 101 insertions(+), 46 deletions(-) create mode 100644 testing/python/language/test_tilelang_capture.py diff --git a/testing/python/language/test_tilelang_capture.py b/testing/python/language/test_tilelang_capture.py new file mode 100644 index 00000000..875fa681 --- /dev/null +++ b/testing/python/language/test_tilelang_capture.py @@ -0,0 +1,40 @@ +import tilelang.language as T +import tilelang.testing +import torch +import weakref +import gc + + +def test_tilelang_capture(): + + @tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + },) + def get_dummy_kernel(): + + @T.prim_func + def dummy_kernel(a: T.Tensor[(1,), T.float32],): + with T.Kernel(1) as _: + a[0] = 1 + + return dummy_kernel + + a = torch.randn(1, 1024) + a_weak = weakref.ref(a) + _kernel = get_dummy_kernel() + del a + torch.cuda.empty_cache() + gc.collect() + torch.cuda.empty_cache() + a_upgrade = a_weak() + assert a_upgrade is None, "A is not garbage collected" + + # use objgraph to debug + # if a_upgrade is not None: + # objgraph.show_backrefs([a_upgrade], max_depth=5) + + +if __name__ == '__main__': + tilelang.testing.main() diff --git a/tilelang/language/v2/ast.py b/tilelang/language/v2/ast.py index cf879ee5..307efdac 100644 --- a/tilelang/language/v2/ast.py +++ b/tilelang/language/v2/ast.py @@ -248,8 +248,9 @@ class BaseBuilder: class DSLMutator(ast.NodeTransformer): - def __init__(self): + def __init__(self, closure_names: list[str]): self.tmp_counter = 0 + self.closure_names = closure_names def get_tmp(self) -> str: name = f"__{self.tmp_counter}" @@ -494,9 +495,11 @@ class DSLMutator(ast.NodeTransformer): node.body = stmts + node.body node.decorator_list.clear() return quote1( - f"def {node.name}(__tb):\n" - " range = __tb.override('range')\n" - " pass\n" + f"def make_closure({', '.join(self.closure_names)}):\n" + f" def {node.name}(__tb):\n" + " range = __tb.override('range')\n" + " pass\n" + f" return {node.name}\n" f" return {node.name}", passes=[node], ) @@ -595,7 +598,29 @@ def mutate(func: Callable[_P, _T]) -> IRGenerator[_P, _T]: tree = utils.get_ast(func) filename = inspect.getsourcefile(func) or inspect.getfile(func) - tree = DSLMutator().visit(tree) - fn = utils.get_compiled_object(tree, func.__name__, filename, - utils.inspect_function_capture(func)) + nonlocals = utils.get_func_nonlocals(func) + + # DSLMutator generates a function named `make_closure` + # it accepts all names inside nonlocal, and returns the mutated function + # this is because we must separate the closure namespace form the global namespace + # if we directly inject closure variables into the global namespace, + # it generates a new `globals` dict, and the dict owns all reference to the original globalns + # which makes memory leak, because the original globalns cannot be freed + # ```py + # a = 123 + # def foo(): + # x = foo.__globals__ # OK, globals are maintained by python + # x = {**foo.__globals__, } # Not OK: globals are copied, and the original globals cannot be freed + # def bar(): x + # return bar + # ``` + tree = DSLMutator(nonlocals.keys()).visit(tree) + + make_closure = utils.get_compiled_object( + tree, + 'make_closure', + filename, + func.__globals__, # use the original globalns + ) + fn = make_closure(**nonlocals) return IRGenerator(gen=fn, source=ast.unparse(tree)) diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index 684880b7..6931c5af 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -18,6 +18,7 @@ try: except ImportError: # Python < 3.11 for Self, < 3.10 for ParamSpec from typing_extensions import ParamSpec, Self from . import dtypes as dt +from . import utils import threading import logging @@ -593,22 +594,27 @@ def get_type_hints(func): # Build eval namespaces from function globals plus captured closure variables # This lets annotations reference symbols like `n`, `h`, or dtype vars # defined in the outer scope of a nested function. - globalns = dict(getattr(func, '__globals__', {})) - localns = dict(globalns) - try: - freevars = getattr(func.__code__, 'co_freevars', ()) - cells = getattr(func, '__closure__', ()) or () - closure_bindings = { - name: cell.cell_contents for name, cell in zip(freevars, cells) if name not in localns - } - if closure_bindings: - localns.update(closure_bindings) - # Also update globals so ForwardRef eval sees them uniformly - globalns.update(closure_bindings) - except Exception: - # Be permissive: absence or access issues with closure shouldn't crash - pass - + globalns = func.__globals__ + # Here we add nonlocals into localns, to capture the parameters declared in the parent function + # ```py + # def foo(): + # n = 128 # n is nonlocal + # def bar( + # A: T.Tensor(n, T.float32) # we add nonlocal in its eval context + # ): + # for i in range(n): ... + # ``` + # + # This is incomplete and buggy + # the only bug scenario the function body doesn't use the the parameters + # but such define-no-use scenario is very rare in writing kernels + # + # ```py + # def foo(): + # n = 128 + # def bar(A: T.Tensor((n,), T.float32)): + # ... # empty function, do not use `n` + localns = utils.get_func_nonlocals(func) for name, value in annot.items(): if name == 'return': continue @@ -618,8 +624,10 @@ def get_type_hints(func): if value is None: value = type(None) if isinstance(value, str): - # Handle simple dtype aliases like T.float32 appearing as strings - # Evaluate directly only when it matches known dtypes + # if the annotation is string, is can be: (i) a T.float32 like annotations, (ii) a ForwardRef object + # typing doesn't handle (i), it will try to interpret T.float32 + # typing see: T.float32 is str('float32'), and there is no object named `flaot32` and give a NameError + # here we manually interpret it to return T.float32 object try: _, v = value.split('.', maxsplit=1) except ValueError: @@ -631,7 +639,9 @@ def get_type_hints(func): except Exception: pass value = ForwardRef(value, is_argument=True, is_class=False) - hints[name] = _eval_type(value, globalns=globalns, localns=localns) + hints[name] = _eval_type(value, globalns=globalns, localns=localns) + else: + hints[name] = value return hints diff --git a/tilelang/language/v2/utils.py b/tilelang/language/v2/utils.py index 739ecd1e..84f06145 100644 --- a/tilelang/language/v2/utils.py +++ b/tilelang/language/v2/utils.py @@ -53,26 +53,6 @@ def get_func_nonlocals(func): return nonlocal_vars -def inspect_function_capture(func: Callable) -> dict[str, Any]: - """Capture function non-locals and global variables. - - Parameters - ---------- - func : Callable - The function to inspect. - - Returns - ------- - res : Dict[str, Any] - The function variables map with non-local or global variables. - """ - captured = { - **func.__globals__, # type: ignore - **get_func_nonlocals(func), - } - return captured - - def get_ast(func: Callable): _, start = inspect.getsourcelines(func) filename = inspect.getsourcefile(func) or inspect.getfile(func) -- GitLab From 551ac60d19369df615aef578faad2048a521ed99 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Wed, 19 Nov 2025 16:27:44 +0800 Subject: [PATCH 021/139] [Enhancement] Enhance CUDA compilation by integrating pass context configuration (#1283) - Updated the `tilelang_callback_cuda_compile` function to accept a `pass_config` parameter, allowing for more flexible compilation options. - Introduced handling for fast math and PTXAS options based on the provided pass configuration. - Modified the CUDA build process in `rt_mod_cuda.cc` to utilize the current pass context, improving the integration of compilation settings. - Refactored NVCC command construction to use a dedicated function for better clarity and maintainability. --- src/target/rt_mod_cuda.cc | 6 +++++- tilelang/contrib/nvcc.py | 9 +-------- tilelang/engine/lower.py | 42 ++++++++++++++++++++++++++++----------- 3 files changed, 36 insertions(+), 21 deletions(-) diff --git a/src/target/rt_mod_cuda.cc b/src/target/rt_mod_cuda.cc index cbef0e64..a5e9b299 100644 --- a/src/target/rt_mod_cuda.cc +++ b/src/target/rt_mod_cuda.cc @@ -2,6 +2,7 @@ #include "runtime/cuda/cuda_module.h" #include "runtime/pack_args.h" #include +#include namespace tvm { namespace codegen { @@ -66,7 +67,10 @@ ffi::Module BuildTileLangCUDA(IRModule mod, Target target) { std::string ptx; if (const auto f = ffi::Function::GetGlobal("tilelang_callback_cuda_compile")) { - ptx = (*f)(code, target).cast(); + // Fetch current pass context config and pass into the compile callback + tvm::transform::PassContext pass_ctx = + tvm::transform::PassContext::Current(); + ptx = (*f)(code, target, pass_ctx->config).cast(); if (ptx[0] != '/') fmt = "cubin"; } else { diff --git a/tilelang/contrib/nvcc.py b/tilelang/contrib/nvcc.py index 202e0f3b..0d55cbf7 100644 --- a/tilelang/contrib/nvcc.py +++ b/tilelang/contrib/nvcc.py @@ -78,7 +78,7 @@ def compile_cuda(code, out_file.write(code) file_target = path_target if path_target else temp_target - cmd = ["nvcc"] + cmd = [get_nvcc_compiler()] cmd += [f"--{target_format}", "-O3"] if kernels_output_dir is not None: cmd += ["-lineinfo"] @@ -332,13 +332,6 @@ def get_cuda_version(cuda_path=None): raise RuntimeError("Cannot read cuda version file") -@tvm_ffi.register_global_func("tilelang_callback_cuda_compile", override=True) -def tilelang_callback_cuda_compile(code, target): # pylint: disable=unused-argument - """use nvcc to generate fatbin code for better optimization""" - ptx = compile_cuda(code, target_format="fatbin") - return ptx - - @tvm_ffi.register_global_func("tilelang_callback_libdevice_path", override=True) def find_libdevice_path(arch): """Utility function to find libdevice diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py index c2a14552..63391f77 100644 --- a/tilelang/engine/lower.py +++ b/tilelang/engine/lower.py @@ -11,6 +11,8 @@ import tvm_ffi from tvm.ir import CallingConv from tvm.target import Target from tilelang.contrib import hipcc, nvcc +from tilelang.transform import PassConfigKey +from tilelang.utils.deprecated import deprecated_warning from tilelang.engine.param import KernelParam, CompiledArtifact from tilelang.utils.target import determine_target from tilelang.engine.phase import ( @@ -54,7 +56,7 @@ def get_host_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]: @tvm_ffi.register_global_func("tilelang_callback_cuda_compile", override=True) -def tilelang_callback_cuda_compile(code, target): +def tilelang_callback_cuda_compile(code, target, pass_config=None): project_root = osp.join(osp.dirname(__file__), "../..") if "TL_TEMPLATE_PATH" in os.environ: tl_template_path = os.environ["TL_TEMPLATE_PATH"] @@ -69,21 +71,37 @@ def tilelang_callback_cuda_compile(code, target): target_arch = nvcc.get_target_arch(nvcc.get_target_compute_version(target)) arch = [f"-arch=sm_{target_arch}"] - format = "cubin" + compile_format = "cubin" + + # Read pass-config keys (string-valued) like in jit.adapter.libgen.compile_lib + cfg = pass_config or {} + if cfg.get(PassConfigKey.TL_DISABLE_FAST_MATH.value, False): + deprecated_warning("TL_DISABLE_FAST_MATH", "TL_ENABLE_FAST_MATH", "0.1.7") + disable_fast_math = bool(cfg.get(PassConfigKey.TL_DISABLE_FAST_MATH.value, True)) + enable_fast_math = not disable_fast_math + else: + enable_fast_math = bool(cfg.get(PassConfigKey.TL_ENABLE_FAST_MATH.value, False)) + + ptxas_usage_level = cfg.get(PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL.value, None) + verbose_ptxas_output = bool(cfg.get(PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT.value, False)) + + options = [ + "-std=c++17", + "-I" + tl_template_path, + "-I" + cutlass_path, + ] + if enable_fast_math: + options.append("--use_fast_math") + if ptxas_usage_level is not None: + options.append(f"--ptxas-options=--register-usage-level={ptxas_usage_level}") + if verbose_ptxas_output: + options.append("--ptxas-options=--verbose") - # printing out number of registers - debug_option = "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage" ptx = nvcc.compile_cuda( code, - format, + compile_format, arch, - options=[ - "-std=c++17", - debug_option, - "--use_fast_math", - "-I" + tl_template_path, - "-I" + cutlass_path, - ], + options=options, verbose=False, ) -- GitLab From 49f353935cb5006b92f6dfd96bf7f64c80c0bdd0 Mon Sep 17 00:00:00 2001 From: liu yuhao Date: Wed, 19 Nov 2025 17:21:39 +0800 Subject: [PATCH 022/139] Fix the bug in issue #1266 (#1284) Co-authored-by: cheeryBloosm --- examples/deepseek_nsa/example_tilelang_nsa_fwd.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/deepseek_nsa/example_tilelang_nsa_fwd.py b/examples/deepseek_nsa/example_tilelang_nsa_fwd.py index f8a7ebfb..0b71779b 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_fwd.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_fwd.py @@ -156,13 +156,14 @@ def main(): DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda') block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device='cuda') + block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device='cuda') for b in range(B): for t in range(SEQ_LEN): for h in range(H): i_i = torch.randperm(max(1, (t // block_size)))[:S] block_indices[b, t, h, :len(i_i)] = i_i + block_counts[b, t, h] = (block_indices[b, t, h] != SEQ_LEN).sum().item() block_indices = block_indices.sort(-1)[0] - block_counts = torch.randint(1, S + 1, (B, SEQ_LEN, H), device='cuda') out = kernel(Q, K, V, block_indices.to(torch.int32)) -- GitLab From 9e67b861c94be93d66badd06b19fbc5e415e56dd Mon Sep 17 00:00:00 2001 From: Chaofan Lin Date: Thu, 20 Nov 2025 01:30:20 +0800 Subject: [PATCH 023/139] [Language][UX] Nested loop checker in pre-lowering stage (#1288) * [Language][UX] Nested loop checker in pre-lowering stage * rename * comment * address comments --- src/transform/loop_partition.cc | 3 +- .../test_tilelang_language_nested_loop.py | 554 ++++++++++++++++++ tilelang/__init__.py | 1 + tilelang/analysis/__init__.py | 3 + tilelang/analysis/nested_loop_checker.py | 110 ++++ tilelang/engine/lower.py | 4 + tilelang/engine/phase.py | 11 + 7 files changed, 685 insertions(+), 1 deletion(-) create mode 100644 testing/python/language/test_tilelang_language_nested_loop.py create mode 100644 tilelang/analysis/__init__.py create mode 100644 tilelang/analysis/nested_loop_checker.py diff --git a/src/transform/loop_partition.cc b/src/transform/loop_partition.cc index fe1fe036..b4236c6d 100644 --- a/src/transform/loop_partition.cc +++ b/src/transform/loop_partition.cc @@ -93,7 +93,8 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer, } for (int i = 0; i < old_loop_depth; i++) { const ForNode *loop = body.as(); - ICHECK(loop != nullptr); + ICHECK(loop != nullptr) + << "No extra statements are allowed between nested parallel loops."; vmap.Set(loop->loop_var, indices[i]); loop_mins.push_back(loop->min); loop_extents.push_back(loop->extent); diff --git a/testing/python/language/test_tilelang_language_nested_loop.py b/testing/python/language/test_tilelang_language_nested_loop.py new file mode 100644 index 00000000..b572a707 --- /dev/null +++ b/testing/python/language/test_tilelang_language_nested_loop.py @@ -0,0 +1,554 @@ +import tilelang +import tilelang.language as T +import torch +import tilelang.testing +import pytest + +tilelang.testing.set_random_seed() + + +def _require_cuda_tensor(shape, dtype=torch.float32): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + try: + return torch.randn(*shape, device="cuda", dtype=dtype) + except RuntimeError as err: + pytest.skip(f"CUDA runtime unavailable: {err}") + + +""" +Nested Parallel cases: + +T.Parallel + T.Parallel + +Rule: + - continuous parallels is allowed and will be merged into one T.Parallel. + - Non-continuous (e.g. with some statements in the outer-loop) are forbidden. +""" + + +@tilelang.jit(out_idx=[1]) +def nested_continuous_parallels(length=256, block=16, dtype="float32"): + + @T.prim_func + def main( + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), + ): + with T.Kernel(1, threads=length) as _: + for i in T.Parallel(length // block): + for j in T.Parallel(block): + B[i * block + j] = A[i * block + j] + 1.0 + + return main + + +@tilelang.jit(out_idx=[1]) +def nested_triple_continuous_parallels(length=256, block1=8, block2=2, dtype="float32"): + + @T.prim_func + def main( + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), + ): + with T.Kernel(1, threads=length) as _: + for i in T.Parallel(length // block1 // block2): + for j in T.Parallel(block1): + for k in T.Parallel(block2): + B[i * block1 * block2 + j * block2 + + k] = A[i * block1 * block2 + j * block2 + k] + 1.0 + + return main + + +@tilelang.jit(out_idx=[1]) +def nested_noncontinuous_parallels(length=256, block=16, dtype="float32"): + + @T.prim_func + def main( + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), + ): + with T.Kernel(1, threads=length) as _: + for i in T.Parallel(length // block): + B[i] = 0 + for j in T.Parallel(block): + B[i * block + j] = A[i * block + j] + 1.0 + + return main + + +def test_nested_parallels(): + kernel1 = nested_continuous_parallels(length=256, block=16) + kernel2 = nested_triple_continuous_parallels(length=256, block1=8, block2=2) + data = _require_cuda_tensor((256,), torch.float32) + result1 = kernel1(data) + result2 = kernel2(data) + torch.testing.assert_close(result1, data + 1.0, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(result2, data + 1.0, atol=1e-5, rtol=1e-5) + + # This is invalid + with pytest.raises(ValueError): + nested_noncontinuous_parallels(length=256, block=16) + + +""" +Nested Pipeline cases: + +T.Pipeline + T.Pipeline + +is OK. +""" + + +def matmul_nested_pipelines(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, + out_dtype, accum_dtype, threads, order, stage, extra_pipeline_repeats): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + for _ in T.Pipelined(extra_pipeline_repeats): + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), order=order, stage=stage): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_nested_pipelines( + order, + stage, + extra_pipeline_repeats, +): + M = 1024 + N = 1024 + K = 1024 + block_M = 128 + block_N = 128 + block_K = 32 + trans_A = False + trans_B = False + in_dtype = "float16" + out_dtype = "float16" + dtypeAccum = "float32" + num_threads = 128 + program = matmul_nested_pipelines( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_threads, + order, + stage, + extra_pipeline_repeats, + ) + + kernel = tilelang.compile( + program, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) + profiler = kernel.get_profiler() + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + if in_dtype == "float32": + # Convert float32 to tfloat32 because tfloat32 mma cannot truncate + # float32 automatically, -0x1000 meas + A = ((A.view(torch.int32) - 0x1000)).view(torch.float32) + B = ((B.view(torch.int32) - 0x1000)).view(torch.float32) + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +def test_nested_pipelines(): + run_gemm_nested_pipelines(order=[0, 1, 2], stage=[0, 0, 1], extra_pipeline_repeats=3) + + +""" +Nested serial cases: + +T.serial + T.serial + +is OK. +""" + + +@tilelang.jit(out_idx=[1]) +def nested_continuous_serials(length=256, block=16, dtype="float32"): + + @T.prim_func + def main( + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), + ): + with T.Kernel(1, threads=length) as _: + for i in T.serial(length // block): + for j in T.serial(block): + B[i * block + j] = A[i * block + j] + 1.0 + + return main + + +@tilelang.jit(out_idx=[1]) +def nested_noncontinuous_serials(length=256, block=16, dtype="float32"): + + @T.prim_func + def main( + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), + ): + with T.Kernel(1, threads=length) as _: + for i in T.serial(length // block): + B[i] = 0 + for j in T.serial(block): + B[i * block + j] = A[i * block + j] + 1.0 + + return main + + +def test_nested_serials(): + kernel1 = nested_continuous_serials(length=256, block=16) + data = _require_cuda_tensor((256,), torch.float32) + result1 = kernel1(data) + torch.testing.assert_close(result1, data + 1.0, atol=1e-5, rtol=1e-5) + + # This is valid + nested_noncontinuous_serials(length=256, block=16) + + +""" +Mixed serial and Parallel loops: + +(S-P) +T.serial + T.Parallel + +(P-S) +T.Parallel + T.serial + +Rule: + - No Parallel - * - Parallel +""" + + +@tilelang.jit(out_idx=[1]) +def nested_continuous_sp(length=256, block=16, dtype="float32"): + + @T.prim_func + def main( + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), + ): + with T.Kernel(1, threads=length) as _: + for i in T.serial(length // block): + for j in T.Parallel(block): + B[i * block + j] = A[i * block + j] + 1.0 + + return main + + +@tilelang.jit(out_idx=[1]) +def nested_continuous_ps(length=256, block=16, dtype="float32"): + + @T.prim_func + def main( + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), + ): + with T.Kernel(1, threads=length) as _: + for i in T.Parallel(length // block): + for j in T.serial(block): + B[i * block + j] = A[i * block + j] + 1.0 + + return main + + +@tilelang.jit(out_idx=[1]) +def nested_continuous_psp(length=256, block1=8, block2=2, dtype="float32"): + + @T.prim_func + def main( + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), + ): + with T.Kernel(1, threads=length) as _: + for i in T.Parallel(length // block1 // block2): + for j in T.serial(block1): + for k in T.Parallel(block2): + B[i * block1 * block2 + j * block2 + + k] = A[i * block1 * block2 + j * block2 + k] + 1.0 + + return main + + +@tilelang.jit(out_idx=[1]) +def nested_continuous_sps(length=256, block1=8, block2=2, dtype="float32"): + + @T.prim_func + def main( + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), + ): + with T.Kernel(1, threads=length) as _: + for i in T.serial(length // block1 // block2): + for j in T.Parallel(block1): + for k in T.serial(block2): + B[i * block1 * block2 + j * block2 + + k] = A[i * block1 * block2 + j * block2 + k] + 1.0 + + return main + + +def test_mixed_sp(): + kernel1 = nested_continuous_sp(length=256, block=16) + kernel2 = nested_continuous_ps(length=256, block=16) + data = _require_cuda_tensor((256,), torch.float32) + result1 = kernel1(data) + result2 = kernel2(data) + torch.testing.assert_close(result1, data + 1.0, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(result2, data + 1.0, atol=1e-5, rtol=1e-5) + + # This should be invalid (Undefined behaviour) + with pytest.raises(ValueError): + nested_continuous_psp(length=256, block1=16, block2=8) + + kernel3 = nested_continuous_sps(length=256, block1=8, block2=2) + result3 = kernel3(data) + torch.testing.assert_close(result3, data + 1.0, atol=1e-5, rtol=1e-5) + + +""" +Mixed Pipelined and Parallel loops: + +(Pi-Pa) +T.Pipelined + T.Parallel + +(Pa-Pi) +T.Parallel + T.Pipelined + +Rule: + - Pi-Pa is ok where Pa-Pi is not allowed. + - For more nested cases, refer to the rule of T.Parallel. +""" + + +def matmul_nested_pipa( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + threads, + order, + stage, +): + A_shape = (M, K) + B_shape = (K, N) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), order=order, stage=stage): + for i, j in T.Parallel(block_M, block_K): + A_shared[i, j] = A[by * block_M + i, k * block_K + j] + for i, j in T.Parallel(block_K, block_N): + B_shared[i, j] = B[k * block_K + i, bx * block_N + j] + + # T.copy(A[by * block_M, k * block_K], A_shared) + # T.copy(B[k * block_K, bx * block_N], B_shared) + + T.gemm(A_shared, B_shared, C_local, False, False) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def matmul_nested_papipa( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + threads, + order, + stage, +): + A_shape = (M, K) + B_shape = (K, N) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for _ in T.Parallel(1): + for k in T.Pipelined(T.ceildiv(K, block_K), order=order, stage=stage): + for i, j in T.Parallel(block_M, block_K): + A_shared[i, j] = A[by * block_M + i, k * block_K + j] + for i, j in T.Parallel(block_K, block_N): + B_shared[i, j] = B[k * block_K + i, bx * block_N + j] + + # T.copy(A[by * block_M, k * block_K], A_shared) + # T.copy(B[k * block_K, bx * block_N], B_shared) + + T.gemm(A_shared, B_shared, C_local, False, False) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_mixed_pp( + order, + stage, +): + M = 1024 + N = 1024 + K = 1024 + block_M = 128 + block_N = 128 + block_K = 32 + in_dtype = "float16" + out_dtype = "float16" + dtypeAccum = "float32" + num_threads = 128 + + program = matmul_nested_pipa( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + dtypeAccum, + num_threads, + order, + stage, + ) + + kernel = tilelang.compile( + program, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) + profiler = kernel.get_profiler() + + def ref_program(A, B): + import torch + + if in_dtype == "float32": + # Convert float32 to tfloat32 because tfloat32 mma cannot truncate + # float32 automatically, -0x1000 meas + A = ((A.view(torch.int32) - 0x1000)).view(torch.float32) + B = ((B.view(torch.int32) - 0x1000)).view(torch.float32) + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + program1 = matmul_nested_papipa( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + dtypeAccum, + num_threads, + order, + stage, + ) + with pytest.raises(ValueError): + tilelang.compile( + program1, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) + + +def test_mixed_pp(): + run_gemm_mixed_pp(order=[0, 1, 2], stage=[0, 0, 1]) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/__init__.py b/tilelang/__init__.py index e4be0129..2eae5cdb 100644 --- a/tilelang/__init__.py +++ b/tilelang/__init__.py @@ -133,6 +133,7 @@ from .layout import ( Fragment, # noqa: F401 ) from . import ( + analysis, # noqa: F401 transform, # noqa: F401 language, # noqa: F401 engine, # noqa: F401 diff --git a/tilelang/analysis/__init__.py b/tilelang/analysis/__init__.py new file mode 100644 index 00000000..b72fc2ba --- /dev/null +++ b/tilelang/analysis/__init__.py @@ -0,0 +1,3 @@ +"""Tilelang IR analysis & visitors.""" + +from .nested_loop_checker import NestedLoopChecker # noqa: F401 diff --git a/tilelang/analysis/nested_loop_checker.py b/tilelang/analysis/nested_loop_checker.py new file mode 100644 index 00000000..4b9741c3 --- /dev/null +++ b/tilelang/analysis/nested_loop_checker.py @@ -0,0 +1,110 @@ +from tvm import tir +from tvm.tir import ( + For, + PrimFunc, + PyStmtExprVisitor, +) +from tvm.tir.transform import prim_func_pass + + +def is_pipelined_for(op: For) -> bool: + """Check if a for loop is pipelined.""" + + anno_keys = [ + "num_stages", "tl_pipeline_order", "tl_pipeline_stage", "tl_pipeline_sync", + "tl_pipeline_group" + ] + return any(key in op.annotations for key in anno_keys) + + +@tir.functor.visitor +class _NestedLoopCheckVisitor(PyStmtExprVisitor): + + def __init__(self) -> None: + super().__init__() + self.in_parallel_context = False + + def visit_for_(self, op: For) -> None: + if op.kind == tir.ForKind.PARALLEL: + child = op.body + + # Special case: continuous nested parallel loop is allowed. + if isinstance(child, tir.For) and child.kind == tir.ForKind.PARALLEL: + self.visit_stmt(child) + return + + # Otherwise + if self.in_parallel_context: + raise ValueError("Nested parallel loops are not allowed. " + "Please check your loop structure.") + self.in_parallel_context = True + self.visit_stmt(child) + self.in_parallel_context = False + return + elif is_pipelined_for(op): + if self.in_parallel_context: + raise ValueError("Pipelined loop cannot be nested inside a parallel loop. " + "Please check your loop structure.") + + self.visit_stmt(op.body) + + +def NestedLoopChecker(): + """ + User-friendly pass which identifies any invalid any nested-loop pattern. + + Nested loops is an annoying problem in tilelang or other polyhedral-style compilers. + It contains many corner cases and undefined behaviours. + + In tilelang, there are four loops: + T.serial + T.Parallel (T.vectorized) + T.Pipelined + T.Persistent + + T.Persistent is a new feature which we do not consider here. + + We define the following rules: + - (Rule 1) T.serial can be nested inside any other loop type without restriction. + - (Rule 2) Consecutive T.Parallel nested loops are not allowed. Including any TileOp (T.copy, etc.) which has + "parallel" behaviours is also forbidden. + + Examples: + for i in T.Parallel(M): + stmt + for j in T.Parallel(N): + ... + + for i in T.Parallel(M): + T.copy(A, B) # forbidden! + + **Only a special case is allowed: strict continuous Parallel loops.** Since we can fuse them into a single T.Parallel loop. + Example: + + for i in T.Parallel(M): + for j in T.Parallel(N): + ... # allowed + - (Rule 3) T.Pipelined inside a T.Parallel is forbidden. + + Examples: + for i in T.Parallel(M): + for j in T.Pipelined(K): # forbidden! + ... + + for i in T.Pipelined(K): + for j in T.Parallel(N): # allowed, ok + ... + + In summary, the problem mainly lies in the "T.Parallel". We highly recommend to use + T.Parallel to implement a tiled operator inside a kernel (e.g. T.gemm level) instead of other usages. + This guideline can help you avoid most of the issues. + + Returns: + A prim_func_pass that applies the transformation + """ + + def pass_fn(func: PrimFunc, mod, ctx): + _NestedLoopCheckVisitor().visit_stmt(func.body) + return func + + return prim_func_pass(pass_fn, opt_level=0) diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py index 63391f77..88d89dcc 100644 --- a/tilelang/engine/lower.py +++ b/tilelang/engine/lower.py @@ -16,6 +16,7 @@ from tilelang.utils.deprecated import deprecated_warning from tilelang.engine.param import KernelParam, CompiledArtifact from tilelang.utils.target import determine_target from tilelang.engine.phase import ( + PreLowerSemanticCheck, LowerAndLegalize, OptimizeForTarget, ) @@ -242,6 +243,9 @@ def lower( _is_host_call = get_host_call(is_device_c=is_cpu_device_backend(target)) _is_device_call = get_device_call(is_device_c=is_cpu_device_backend(target)) + # Before lowering, do semantic check + PreLowerSemanticCheck(mod) + # Phase 1: Lower and legalize the IR mod = LowerAndLegalize(mod, target) diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index a7cc99f8..35c16a43 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -67,6 +67,17 @@ def should_force_let_inline(pass_ctx: PassContext | None = None) -> bool: return bool(pass_ctx and pass_ctx.config.get(tilelang.PassConfigKey.TL_FORCE_LET_INLINE, False)) +def PreLowerSemanticCheck(mod: IRModule) -> None: + """ + Check whether the module is valid before lowering. If not, raise a user-friendly error + in Python side instead of letting the error dive into the complicated TVM/C++ stack. + Note: This is a validation-only pipeline of passes and does not modify or return the module. + """ + + # Check if there are any invalid nested loops. + tilelang.analysis.NestedLoopChecker()(mod) + + def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: # Bind the target device information to the module """ -- GitLab From bef7e52e32bb3280a4ad82dcdc61da9f0fc39001 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Thu, 20 Nov 2025 13:05:40 +0800 Subject: [PATCH 024/139] [Compatibility] Support CUDA 11.3 (#1290) --- src/tl_templates/cuda/atomic.h | 41 ++++++++++++++++++++++++++++++-- src/tl_templates/cuda/debug.h | 9 +++++++ src/tl_templates/cuda/gemm_mma.h | 1 - 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/src/tl_templates/cuda/atomic.h b/src/tl_templates/cuda/atomic.h index a573886b..0bbc4171 100644 --- a/src/tl_templates/cuda/atomic.h +++ b/src/tl_templates/cuda/atomic.h @@ -12,7 +12,11 @@ using cutlass::bfloat16_t; using cutlass::half_t; #define TL_DEVICE __forceinline__ __device__ - +#define TL_NOT_IMPLEMENTED() \ + { \ + printf("%s not implemented\n", __PRETTY_FUNCTION__); \ + asm volatile("brkpt;\n"); \ + } template struct normalize_atomic_type { using type = T; }; @@ -63,8 +67,12 @@ TL_DEVICE void AtomicMax(T1 &ref, T2 val, } } } else { +#if CUDART_VERSION >= 11080 cuda::atomic_ref aref(*address); aref.fetch_max(cuda_cast(val), cuda::memory_order(memory_order)); +#else + TL_NOT_IMPLEMENTED(); +#endif } } @@ -89,9 +97,13 @@ TL_DEVICE T1 AtomicMaxRet(T1 &ref, T2 val, } return static_cast(*reinterpret_cast(&old_val_ushort)); } else { +#if CUDART_VERSION >= 11080 cuda::atomic_ref aref(*address); return static_cast( aref.fetch_max(cuda_cast(val), cuda::memory_order(memory_order))); +#else + TL_NOT_IMPLEMENTED(); +#endif } } @@ -117,8 +129,13 @@ TL_DEVICE void AtomicMin(T1 &ref, T2 val, } } } else { +#if CUDART_VERSION >= 11080 cuda::atomic_ref aref(*address); - aref.fetch_min(cuda_cast(val), cuda::memory_order(memory_order)); + return static_cast( + aref.fetch_min(cuda_cast(val), cuda::memory_order(memory_order))); +#else + TL_NOT_IMPLEMENTED(); +#endif } } @@ -143,9 +160,13 @@ TL_DEVICE T1 AtomicMinRet(T1 &ref, T2 val, } return static_cast(*reinterpret_cast(&old_val_ushort)); } else { +#if CUDART_VERSION >= 11080 cuda::atomic_ref aref(*address); return static_cast( aref.fetch_min(cuda_cast(val), cuda::memory_order(memory_order))); +#else + TL_NOT_IMPLEMENTED(); +#endif } } @@ -216,8 +237,12 @@ TL_DEVICE void AtomicAdd(T1 &ref, T2 val, } } } else { +#if CUDART_VERSION >= 11080 cuda::atomic_ref aref(*address); aref.fetch_add(cuda_cast(val), cuda::memory_order(memory_order)); +#else + TL_NOT_IMPLEMENTED(); +#endif } } @@ -290,9 +315,13 @@ TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val, } } } else { +#if CUDART_VERSION >= 11080 cuda::atomic_ref aref(*address); return static_cast( aref.fetch_add(cuda_cast(val), cuda::memory_order(memory_order))); +#else + TL_NOT_IMPLEMENTED(); +#endif } } @@ -618,13 +647,21 @@ AtomicAddx4Ret(float *ref, float *val, #endif template TL_DEVICE T AtomicLoad(T &ref, int memory_order) { +#if CUDART_VERSION >= 11080 cuda::atomic_ref aref(ref); return aref.load(cuda::memory_order(memory_order)); +#else + TL_NOT_IMPLEMENTED(); +#endif } template TL_DEVICE void AtomicStore(T1 &ref, T2 value, int memory_order) { using NT1 = typename normalize_atomic_type::type; +#if CUDART_VERSION >= 11080 cuda::atomic_ref aref(ref); aref.store(cuda_cast(value), cuda::memory_order(memory_order)); +#else + TL_NOT_IMPLEMENTED(); +#endif } diff --git a/src/tl_templates/cuda/debug.h b/src/tl_templates/cuda/debug.h index e8976874..2724a814 100644 --- a/src/tl_templates/cuda/debug.h +++ b/src/tl_templates/cuda/debug.h @@ -1,6 +1,9 @@ #pragma once +#if __CUDA_ARCH_LIST__ >= 890 #include "./cuda_fp8.h" +#endif + #include "common.h" #ifndef __CUDACC_RTC__ @@ -117,6 +120,7 @@ __device__ void debug_print_var(const char *msg, double var) { threadIdx.z, var); } +#if __CUDA_ARCH_LIST__ >= 890 // Specialization for fp8_e4_t type template <> __device__ void debug_print_var(const char *msg, fp8_e4_t var) { @@ -137,6 +141,8 @@ __device__ void debug_print_var(const char *msg, fp8_e5_t var) { threadIdx.z, (float)var); } +#endif + // Template declaration for device-side debug printing (buffer only) template __device__ void debug_print_buffer_value(const char *msg, const char *buf_name, @@ -242,6 +248,7 @@ __device__ void debug_print_buffer_value(const char *msg, } // Specialization for fp8_e4_t type +#if __CUDA_ARCH_LIST__ >= 890 template <> __device__ void debug_print_buffer_value(const char *msg, const char *buf_name, @@ -263,6 +270,8 @@ __device__ void debug_print_buffer_value(const char *msg, threadIdx.z, buf_name, index, (float)var); } +#endif + // Specialization for int16 type template <> __device__ void debug_print_buffer_value(const char *msg, diff --git a/src/tl_templates/cuda/gemm_mma.h b/src/tl_templates/cuda/gemm_mma.h index 71283173..25841a3b 100644 --- a/src/tl_templates/cuda/gemm_mma.h +++ b/src/tl_templates/cuda/gemm_mma.h @@ -8,7 +8,6 @@ #include #include "common.h" -#include "cuda_fp8.h" #include "intrin.h" namespace cute::tl_mma { -- GitLab From bccb6485e4003533bb0e21391dd09478e7074562 Mon Sep 17 00:00:00 2001 From: Kuris <227995639+kurisu6912@users.noreply.github.com> Date: Thu, 20 Nov 2025 13:56:09 +0800 Subject: [PATCH 025/139] [Feat] Add support for using `T.Tensor(n * 2 + 1)` in function annotation (#1285) * [Feature] Add support for A: T.Tensor(n + 1) and A: T.Tensor(2*n) * issue fix * fix * fix * decreate nproc for debugging --------- Co-authored-by: Lei Wang --- .github/workflows/ci.yml | 2 +- .../test_tilelang_example_deepseek_v32.py | 1 + src/transform/arg_binder.cc | 76 ++++++++++++++++--- src/transform/arg_binder.h | 1 + .../python/jit/test_tilelang_jit_callback.py | 2 + .../python/jit/test_tilelang_jit_tvm_ffi.py | 62 --------------- .../language/test_tilelang_language_annot.py | 71 +++++++++++++++++ 7 files changed, 142 insertions(+), 73 deletions(-) create mode 100644 testing/python/language/test_tilelang_language_annot.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f9fe3286..ee796602 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -352,7 +352,7 @@ jobs: uv run --no-project -m -- pytest --verbose --color=yes --durations=0 --showlocals --cache-clear ) - "${PYTEST[@]}" --maxfail=3 --numprocesses=4 \ + "${PYTEST[@]}" --maxfail=3 --numprocesses=1 \ ../examples # NVIDIA CUDA tests diff --git a/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py b/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py index e10141b5..2dd27048 100644 --- a/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py +++ b/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py @@ -1,4 +1,5 @@ # ruff: noqa +import tilelang import tilelang.testing import topk_selector diff --git a/src/transform/arg_binder.cc b/src/transform/arg_binder.cc index 6a0909b8..361cfe90 100644 --- a/src/transform/arg_binder.cc +++ b/src/transform/arg_binder.cc @@ -29,8 +29,14 @@ #include #include +#include #include "tir/transforms/ir_utils.h" +#include "tvm/arith/int_solver.h" +#include "tvm/ffi/cast.h" +#include "tvm/ffi/container/array.h" +#include "tvm/tir/stmt.h" +#include "tvm/tir/stmt_functor.h" namespace tvm { namespace tl { @@ -51,6 +57,26 @@ void BinderAddAssert(arith::Analyzer *ana, PrimExpr cond, } } +std::vector ArgBinder::getUndefVars(const std::vector &args) { + std::unordered_set visit; + std::vector res; + for (const auto &arg : args) { + PostOrderVisit(arg, [&](ObjectRef r) { + if (auto var = r.as()) { + if (!visit.count(var)) { + visit.insert(var); + } + auto it = def_map_->find(var); + if (it == def_map_->end()) { + // res.push_back(var); + res.push_back(ffi::GetRef(var)); + } + } + }); + } + return res; +} + bool ArgBinder::BindNullable(const PrimExpr &arg, const PrimExpr &value, const std::string &arg_name, bool with_lets, const PrimExpr &nullable_guard) { @@ -60,20 +86,23 @@ bool ArgBinder::BindNullable(const PrimExpr &arg, const PrimExpr &value, // is_null || basic return Or(nullable_guard, basic); }; - ICHECK_EQ(arg.dtype(), value.dtype()) << "arg " << arg << " value " << value; + auto BindVar = [&](const VarNode *v, PrimExpr value) { + auto v_arg = ffi::GetRef(v); + defs_.emplace_back(v_arg); + if (with_lets) { + (*def_map_)[v] = value; + init_nest_.emplace_back(LetStmt(v_arg, value, Evaluate(0))); + } else { + (*def_map_)[v] = value; + } + }; + // 1. simple binding var = value if (const VarNode *v = arg.as()) { auto it = def_map_->find(v); if (it == def_map_->end()) { + BindVar(v, value); // First time binding: identical behavior as Bind_ - Var v_arg = Downcast(arg); - defs_.emplace_back(v_arg); - if (with_lets) { - (*def_map_)[v] = arg; - init_nest_.emplace_back(LetStmt(v_arg, value, Evaluate(0))); - } else { - (*def_map_)[v] = value; - } return true; } else { // Second or later binding: add is_null short-circuit @@ -81,7 +110,34 @@ bool ArgBinder::BindNullable(const PrimExpr &arg, const PrimExpr &value, BinderAddAssert(&analyzer_, cond, arg_name, &asserts_); } } else { - // For non-Var expressions, also add is_null short-circuit + // 2. complex binding expr = value + // get undefined variables + auto undefs = ffi::Array(getUndefVars({arg})); + if (!undefs.empty()) { + // if value is not integer, such as float, we are unable to solve it + if (!value.dtype().is_int() && !value.dtype().is_uint()) { + LOG(FATAL) << "Unable to solve non-integer variables " << undefs + << " from equation `" << value << "`"; + } + arith::IntConstraints constraints(undefs, {}, {arg == value}); + auto sol = arith::SolveLinearEquations(constraints); + if (!sol->dst->variables.empty()) { + LOG(FATAL) << "TVM is unable to solve variables " << undefs + << " from equation " << constraints; + } + for (const auto &v : undefs) { + auto value_opt = sol->src_to_dst.Get(v); + ICHECK(value_opt->defined()) + << "Unable to solve variable `" << v << "` from expression `" + << (arg == value) << "`"; + auto value = ffi::GetRef(sol->src_to_dst.Get(v)->get()); + BindVar(v.as(), value); + } + } + // we must add the assert again + // because the solved expression may contain floordiv (e.g. 3 * m == n + // ==> m = n // 3) we re-compute the constraint to verify the solution + // is correct PrimExpr cond = MakeGuarded(arg == value); BinderAddAssert(&analyzer_, cond, arg_name, &asserts_); } diff --git a/src/transform/arg_binder.h b/src/transform/arg_binder.h index cf9f8466..793ada11 100644 --- a/src/transform/arg_binder.h +++ b/src/transform/arg_binder.h @@ -159,6 +159,7 @@ public: const PrimExpr &nullable_guard); private: + std::vector getUndefVars(const std::vector &arg); // Internal bind function bool Bind_(const PrimExpr &arg, const PrimExpr &value, const std::string &arg_name, bool with_lets); diff --git a/testing/python/jit/test_tilelang_jit_callback.py b/testing/python/jit/test_tilelang_jit_callback.py index d5aa00a4..e987368d 100644 --- a/testing/python/jit/test_tilelang_jit_callback.py +++ b/testing/python/jit/test_tilelang_jit_callback.py @@ -91,7 +91,9 @@ def run_gemm( code = f"// {stramp}\n" + code return code + tilelang.disable_cache() matmul_kernel = tilelang.compile(program, out_idx=-1) + tilelang.enable_cache() kernel_source = matmul_kernel.get_kernel_source() diff --git a/testing/python/jit/test_tilelang_jit_tvm_ffi.py b/testing/python/jit/test_tilelang_jit_tvm_ffi.py index cd5d9c75..f7bde6af 100644 --- a/testing/python/jit/test_tilelang_jit_tvm_ffi.py +++ b/testing/python/jit/test_tilelang_jit_tvm_ffi.py @@ -52,68 +52,6 @@ def matmul( return main -def run_gemm( - M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128, -): - program = matmul( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - num_stages, - num_threads, - ) - - stramp = "&*(XS)" - - @tvm.register_global_func("tilelang_callback_cuda_postproc", override=True) - def tilelang_callback_cuda_postproc(code, _): - code = f"// {stramp}\n" + code - return code - - matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="tvm_ffi") - - kernel_source = matmul_kernel.get_kernel_source() - - assert stramp in kernel_source, f"Expected {stramp} in the kernel source" - - -def test_gemm_f16f16f16_nn(): - run_gemm( - 512, - 1024, - 768, - False, - False, - "float16", - "float16", - "float16", - 128, - 256, - 32, - 2, - ) - - def matmu_jit_kernel( M, N, diff --git a/testing/python/language/test_tilelang_language_annot.py b/testing/python/language/test_tilelang_language_annot.py new file mode 100644 index 00000000..7425bf5c --- /dev/null +++ b/testing/python/language/test_tilelang_language_annot.py @@ -0,0 +1,71 @@ +import tilelang +import tilelang.language as T +import tilelang.testing +import torch + + +def test_tensor_annot_mul(): + + @tilelang.jit + def example_tensor_annot(): + n = T.symbolic('n') + + @T.prim_func + def kernel(A: T.Tensor((n * 4,), T.int32),): + with T.Kernel(1) as _: + for i in range(n * 4): + A[i] = 0 + + return kernel + + ker = example_tensor_annot() + A = torch.arange(16, dtype=torch.int32, device='cuda') + ker(A) + expected = torch.zeros(16, dtype=torch.int32, device='cuda') + assert torch.equal(A, expected) + + +def test_tensor_annot_add(): + + @tilelang.jit + def example_tensor_annot(): + n = T.symbolic('n') + + @T.prim_func + def kernel(A: T.Tensor((n + 1,), T.int32),): + with T.Kernel(1) as _: + for i in range(n + 1): + A[i] = 0 + + return kernel + + ker = example_tensor_annot() + A = torch.arange(16, dtype=torch.int32, device='cuda') + ker(A) + expected = torch.zeros(16, dtype=torch.int32, device='cuda') + assert torch.equal(A, expected) + + +def test_tensor_annot_mul_add(): + + @tilelang.jit + def example_tensor_annot(): + n = T.symbolic('n') + + @T.prim_func + def kernel(A: T.Tensor((n * 3 + 1,), T.int32),): + with T.Kernel(1) as _: + for i in range(n * 3 + 1): + A[i] = 0 + + return kernel + + ker = example_tensor_annot() + A = torch.arange(16, dtype=torch.int32, device='cuda') + ker(A) + expected = torch.zeros(16, dtype=torch.int32, device='cuda') + assert torch.equal(A, expected) + + +if __name__ == '__main__': + tilelang.testing.main() -- GitLab From dd7fdb8ee93cd134fd62636ab65122d7b03173a1 Mon Sep 17 00:00:00 2001 From: Kuris <227995639+kurisu6912@users.noreply.github.com> Date: Thu, 20 Nov 2025 17:33:35 +0800 Subject: [PATCH 026/139] [Feat] add support for passing reference in T.Var annotation (#1291) --- .../test_tilelang_language_frontend_v2.py | 34 ++++++++++ tilelang/language/v2/builder.py | 63 ++++++++++--------- 2 files changed, 67 insertions(+), 30 deletions(-) diff --git a/testing/python/language/test_tilelang_language_frontend_v2.py b/testing/python/language/test_tilelang_language_frontend_v2.py index 1d9a20fe..41657dd7 100644 --- a/testing/python/language/test_tilelang_language_frontend_v2.py +++ b/testing/python/language/test_tilelang_language_frontend_v2.py @@ -361,5 +361,39 @@ def test_while_loop(): assert A[0].item() == sum(range(10)), f"Expected {sum(range(10))}, but got {A[0].item()}" +def test_var_macro(): + try: + + @T.macro + def macro_with_var(x: T.Var): + x = 1 # noqa: F841 + + @T.prim_func + def prim_call_macro(): + with T.Kernel(1): + x = T.alloc_var(T.int32) + macro_with_var(x) + + assert 'x[0] = 1' in prim_call_macro.script() + finally: + pass + + try: + + @T.macro + def macro_with_var(x: T.Var): + x = 1 # noqa: F841 + + @T.prim_func + def prim_call_macro(): + with T.Kernel(1): + x = 1 + macro_with_var(x) + + raise RuntimeError("Expect to report an error, x should not be passed as T.Var") + except ValueError: + pass + + if __name__ == '__main__': tilelang.testing.main() diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index 6931c5af..e693f850 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -140,6 +140,7 @@ class Builder(BaseBuilder): self.frames: list[AnyFrame] = [] self.ir_builder = IRBuilder() self.name_inside_frame: dict[str, AnyFrame] = {} + self.arg_annotations = {} @classmethod def current(cls) -> Self: @@ -155,16 +156,17 @@ class Builder(BaseBuilder): yield @contextmanager - def macro(self, name=None): + def macro(self, name=None, annotations=None): if self.find_frame_idx(BoolOpFrame) is not None: raise RuntimeError( f"Macro `{name}` is used inside boolean expressions, " "please use `if` to replace `M and M`, `M or M`, `M if xxx else M` constructs") - save = self.name_inside_frame + save = self.name_inside_frame, self.arg_annotations self.name_inside_frame = {} + self.arg_annotations = annotations or {} with self.with_frame(MacroFrame()): yield - self.name_inside_frame = save + self.name_inside_frame, self.arg_annotations = save def get(self): return self.ir_builder.get() @@ -313,32 +315,18 @@ class Builder(BaseBuilder): self.check_continue_break() locals = self.get_parent_locals() orig_value = locals.get(name, None) - # annotation like tl.float32 - # temporarily disable annotation based var declaration, for better pull request separation - # if callable(annot): - # annot_val = annot() - # if isinstance(annot_val, tir.Var): - # orig_value = tir.alloc_buffer((1,), dtype=annot_val.dtype, scope='local.var') - # IRBuilder.name(name, orig_value) - # if isinstance(value, EllipsisType) or value is self.empty: - # return orig_value - # elif isinstance(value, (int, float, IntImm, FloatImm)): - # tir.block_attr( - # {'tl.local_var_init': { - # orig_value.data: tvm.runtime.convert(value) - # }}) - # return orig_value # if orig_value is a local.var, we use buffer_store to modify it immutably - # however, if rvalue is also a local.var, this is a new binding, + # however, if rvalue is not a PrimExpr, such as buffer, # we should not use buffer_store, and bind it instead # ```py # a = tl.alloc_var('float32') # bind var `a` # a = tl.alloc_var('float32') # bind a new var `a_1` + # a = tl.alloc_shared((1,), T.float32) # bind a to new buffer # b = a # get value of var `b = a_1[0]`` # c = tl.alloc_var('float32') # bind var `c` # c = a # get and assign `c[0] = a_1[0]` # ``` - if is_var(orig_value) and not is_var(value): + if is_var(orig_value) and isinstance(value, (int, float, PrimExpr)): tir.buffer_store(orig_value, value, 0) return orig_value res = self.bind_immutable(name, value) @@ -486,22 +474,34 @@ class Builder(BaseBuilder): ) return self.unwrap_value(value) - def arg(self, name, value): - if self.find_frame_idx(MacroFrame) is not None: - if isinstance(value, (PrimExpr, int, float)): - return self.bind(name, value) - else: - return value + def macro_arg(self, name, value): + if self.arg_annotations.get(name, None) is Var: + is_var = isinstance(value, tvm.tir.BufferLoad) and value.buffer.scope() == 'local.var' + if not is_var: + raise ValueError( + f'Argument `{name}` is expected to be a variable allocated by `T.alloc_var`, but got {value}({type(value)})' + ) + return value.buffer + elif isinstance(value, (PrimExpr, int, float)): + return self.bind(name, value) + else: + return value + + def prim_func_arg(self, name, value): if isinstance(value, (Buffer, Var)): return tir.arg(name, value) elif value is self.empty: raise ValueError(f'Argument `{name}` is not annotated') - # elif isinstance(value, Hashable): - # return value else: raise TypeError( f"Unsupported argument type: {value}({type(value)}) for argument `{name}`.") + def arg(self, name, value): + if self.find_frame_idx(MacroFrame) is not None: + return self.macro_arg(name, value) + else: + return self.prim_func_arg(name, value) + def override(self, name: str): from tilelang.language import serial if name == 'range': @@ -533,6 +533,7 @@ class Macro(Generic[_P, _T]): name: str orig_func: Callable[_P, _T] ir_gen: IRGenerator[_P, _T] + annotations: dict[str, Any] @property def source(self) -> str: @@ -540,7 +541,7 @@ class Macro(Generic[_P, _T]): def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: builder = Builder.current() - with builder.macro(self.name): + with builder.macro(self.name, self.annotations): res = self.ir_gen.gen(builder)(*args, **kwargs) return res @@ -578,7 +579,9 @@ def macro(func: Callable[_P, _T] = None) -> Macro[_P, _T]: """ def impl(func: Callable[_P, _T]) -> Macro[_P, _T]: - return Macro(name=func.__name__, orig_func=func, ir_gen=mutate(func)) + annotations = get_type_hints(func) + return Macro( + name=func.__name__, orig_func=func, ir_gen=mutate(func), annotations=annotations) return impl(func) if func is not None else impl -- GitLab From d4b6d0945e7a45db3883c13ed8d7049b568e0e94 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Thu, 20 Nov 2025 20:01:38 +0800 Subject: [PATCH 027/139] [Enhancement] Shared Memory Size Can be Dynamic (#1294) * bugfix * lint fix * test * lint fix * increate procs * recover --- .github/workflows/ci.yml | 2 +- 3rdparty/tvm | 2 +- src/tl_templates/cuda/atomic.h | 3 +- .../test_tilelang_language_atomic_add.py | 7 ++- ..._tilelang_runtime_dynamic_shared_memory.py | 52 +++++++++++++++++++ 5 files changed, 58 insertions(+), 8 deletions(-) create mode 100644 testing/python/runtime/test_tilelang_runtime_dynamic_shared_memory.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ee796602..f9fe3286 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -352,7 +352,7 @@ jobs: uv run --no-project -m -- pytest --verbose --color=yes --durations=0 --showlocals --cache-clear ) - "${PYTEST[@]}" --maxfail=3 --numprocesses=1 \ + "${PYTEST[@]}" --maxfail=3 --numprocesses=4 \ ../examples # NVIDIA CUDA tests diff --git a/3rdparty/tvm b/3rdparty/tvm index f4affc7f..713e6ade 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit f4affc7f31e36e7f88c0fe1c715b03215c6a0c62 +Subproject commit 713e6ade56eaa72cc85d58d9228dd9f34cc2d03e diff --git a/src/tl_templates/cuda/atomic.h b/src/tl_templates/cuda/atomic.h index 0bbc4171..f724882e 100644 --- a/src/tl_templates/cuda/atomic.h +++ b/src/tl_templates/cuda/atomic.h @@ -131,8 +131,7 @@ TL_DEVICE void AtomicMin(T1 &ref, T2 val, } else { #if CUDART_VERSION >= 11080 cuda::atomic_ref aref(*address); - return static_cast( - aref.fetch_min(cuda_cast(val), cuda::memory_order(memory_order))); + aref.fetch_min(cuda_cast(val), cuda::memory_order(memory_order)); #else TL_NOT_IMPLEMENTED(); #endif diff --git a/testing/python/language/test_tilelang_language_atomic_add.py b/testing/python/language/test_tilelang_language_atomic_add.py index 132e002a..2472c20f 100644 --- a/testing/python/language/test_tilelang_language_atomic_add.py +++ b/testing/python/language/test_tilelang_language_atomic_add.py @@ -374,10 +374,9 @@ def test_atomic_return_prev(): run_atomic_return_prev(32, 32, 8, 8) -# TODO(lei): test failed and this is experimental -# CC @dyq -# def test_tile_atomic_add(): -# run_tile_atomic_add(8, 128, 128, 32, 32) +def test_tile_atomic_add(): + run_tile_atomic_add(8, 128, 128, 32, 32) + if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/runtime/test_tilelang_runtime_dynamic_shared_memory.py b/testing/python/runtime/test_tilelang_runtime_dynamic_shared_memory.py new file mode 100644 index 00000000..7a42b23b --- /dev/null +++ b/testing/python/runtime/test_tilelang_runtime_dynamic_shared_memory.py @@ -0,0 +1,52 @@ +import pytest +import torch + +import tilelang +import tilelang.language as T +import tilelang.testing + + +@tilelang.jit +def dynamic_smem_kernel(): + # Symbolic length to drive dynamic shared memory allocation + length = T.symbolic("len", dtype="int32") # noqa: F821 + + @T.prim_func + def main(global_tensor: T.Tensor[(length,), "int32"]): # noqa: F821 + # Launch a simple kernel that copies from global memory into shared memory + # using a dynamically-sized allocation. No writes back to global_tensor. + with T.Kernel(1, threads=32) as _: + buffer_shared = T.alloc_shared((length,), dtype="int32") # noqa: F821 + T.copy(buffer_shared, global_tensor) + + return main + + +def _require_cuda_tensor(shape, dtype): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + try: + return torch.randint(0, 100, shape, dtype=dtype, device="cuda") + except RuntimeError as err: + pytest.skip(f"CUDA runtime unavailable: {err}") + + +def _run_and_check(kernel, n): + a = _require_cuda_tensor((n,), torch.int32) + kernel(a) + torch.cuda.synchronize() + + +def test_dynamic_shared_memory_varies_across_calls(): + kernel = dynamic_smem_kernel() + + # Run with different dynamic shared memory sizes across invocations + _run_and_check(kernel, 100) + _run_and_check(kernel, 200) + # Repeat sizes to exercise attribute caching path + _run_and_check(kernel, 200) + _run_and_check(kernel, 100) + + +if __name__ == "__main__": + tilelang.testing.main() -- GitLab From 2426090fdbd9e3e5e6987efd5f37cd0519efee8b Mon Sep 17 00:00:00 2001 From: Kuris <227995639+kurisu6912@users.noreply.github.com> Date: Fri, 21 Nov 2025 17:04:52 +0800 Subject: [PATCH 028/139] [Fix] Remove unused let_bindings_ in CodeGenC to fix #1300 (#1305) * [Feat] add missing support of uint32x2 * [Feat] Add `T.Ref` annotation and tests * fix lint error * minor update for error message on twice decl * Remove unused let_bindings_ in CodeGenC to fix #1300 --- 3rdparty/tvm | 2 +- .../python/language/test_tilelang_intimm.py | 28 ++++++++++++++++ .../test_tilelang_language_frontend_v2.py | 32 +++++++++++++++++++ tilelang/language/__init__.py | 1 + tilelang/language/proxy.py | 10 +++++- tilelang/language/v2/builder.py | 8 +++-- tilelang/language/v2/dtypes.py | 28 ++++++++++++++++ 7 files changed, 105 insertions(+), 4 deletions(-) create mode 100644 testing/python/language/test_tilelang_intimm.py diff --git a/3rdparty/tvm b/3rdparty/tvm index 713e6ade..bc31e7ad 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 713e6ade56eaa72cc85d58d9228dd9f34cc2d03e +Subproject commit bc31e7ad9f9fafd7659dfabafe359fd55a0ffc1e diff --git a/testing/python/language/test_tilelang_intimm.py b/testing/python/language/test_tilelang_intimm.py new file mode 100644 index 00000000..58fea31d --- /dev/null +++ b/testing/python/language/test_tilelang_intimm.py @@ -0,0 +1,28 @@ +import tilelang +import tilelang.testing +import tilelang.language as T + + +def test_tilelang_intimm(): + T.int32(0x7fffffff) + T.int32(-0x7fffffff - 1) + T.uint32(0xffffffff) + T.int64(0x7fffffffffffffff) + T.int64(-0x7fffffffffffffff - 1) + T.uint64(0xffffffffffffffff) + + a = T.int32() + a & 0x7fffffff + + a = T.uint32() + a & 0xffffffff + + a = T.int64() + a & 0x7fffffffffffffff + + a = T.uint64() + a & T.uint64(0xffffffffffffffff) + + +if __name__ == '__main__': + tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_language_frontend_v2.py b/testing/python/language/test_tilelang_language_frontend_v2.py index 41657dd7..2608e251 100644 --- a/testing/python/language/test_tilelang_language_frontend_v2.py +++ b/testing/python/language/test_tilelang_language_frontend_v2.py @@ -394,6 +394,38 @@ def test_var_macro(): except ValueError: pass + try: + + @T.macro + def macro_with_var(x: T.Ref): + x = 1 # noqa: F841 + + @T.prim_func + def prim_call_macro(): + with T.Kernel(1): + x = T.alloc_var(T.int32) + macro_with_var(x) + + assert 'x[0] = 1' in prim_call_macro.script() + finally: + pass + + try: + + @T.macro + def macro_with_var(x: T.Ref): + x = 1 # noqa: F841 + + @T.prim_func + def prim_call_macro(): + with T.Kernel(1): + x = 1 + macro_with_var(x) + + raise RuntimeError("Expect to report an error, x should not be passed as T.Var") + except ValueError: + pass + if __name__ == '__main__': tilelang.testing.main() diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 43c721bb..95488bdf 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -22,6 +22,7 @@ from .proxy import ( FragmentBuffer, # noqa: F401 SharedBuffer, # noqa: F401 LocalBuffer, # noqa: F401 + Ref, # noqa: F401 ) from .loop import serial, Parallel, Persistent, Pipelined # noqa: F401 from .frame import has_let_value, get_let_value # noqa: F401 diff --git a/tilelang/language/proxy.py b/tilelang/language/proxy.py index e2f65e83..9e209a1b 100644 --- a/tilelang/language/proxy.py +++ b/tilelang/language/proxy.py @@ -1,7 +1,7 @@ """The language interface for tl programs.""" from __future__ import annotations -from typing import Any, SupportsIndex, TYPE_CHECKING +from typing import Any, SupportsIndex, TYPE_CHECKING, Generic, TypeVar from collections.abc import Sequence from typing_extensions import Self @@ -263,6 +263,11 @@ if TYPE_CHECKING: class LocalBuffer(BaseTensor): ... + + _T = TypeVar('_T') + + class Ref(Generic[_T], tir.Var): + ... else: Tensor = TensorProxy() # pylint: disable=invalid-name StridedTensor = StridedTensorProxy() # pylint: disable=invalid-name @@ -270,6 +275,9 @@ else: SharedBuffer = SharedBufferProxy() # pylint: disable=invalid-name LocalBuffer = LocalBufferProxy() # pylint: disable=invalid-name + class Ref: + ... + def ptr(dtype: str | None = None, storage_scope: str = "global", diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index e693f850..643994a4 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -335,7 +335,7 @@ class Builder(BaseBuilder): assert frame is not None, f"Variable `{name}` is not defined inside any control flow." if name in self.name_inside_frame and self.name_inside_frame[name] in self.frames: logger.warning( - f'Variable `{name}` shadows another declared value, Are you forgetting to allocate it as a var?', + f'Variable `{name}` is declared twice, are you looking for a T.alloc_var?', stack_info=True, stacklevel=2, ) @@ -475,7 +475,11 @@ class Builder(BaseBuilder): return self.unwrap_value(value) def macro_arg(self, name, value): - if self.arg_annotations.get(name, None) is Var: + from tilelang.language.proxy import Ref + annot_value = self.arg_annotations.get(name, None) + if annot_value is Var or annot_value is Ref: + if annot_value is Var: + logger.warning('Use `T.Var` as macro annotations is deprecated, please use `T.Ref`') is_var = isinstance(value, tvm.tir.BufferLoad) and value.buffer.scope() == 'local.var' if not is_var: raise ValueError( diff --git a/tilelang/language/v2/dtypes.py b/tilelang/language/v2/dtypes.py index 0702635a..75cf83dd 100644 --- a/tilelang/language/v2/dtypes.py +++ b/tilelang/language/v2/dtypes.py @@ -87,8 +87,12 @@ _STR_TO_TVM_DTYPE_CALL = { 'float8_e8m0fnu': 'Float8E8M0FNU' } +int_ = int + def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var: + if isinstance(expr, int_): + return tvm.tir.const(expr, dtype=self) if self in _STR_TO_TVM_DTYPE_CALL: attr = _STR_TO_TVM_DTYPE_CALL[self] call = getattr(tb_ffi, attr, None) @@ -151,6 +155,10 @@ if TYPE_CHECKING: class int16(dtype): ... class int32(dtype): ... class int64(dtype): ... + class int8x2(dtype): ... + class int16x2(dtype): ... + class int32x2(dtype): ... + class int64x2(dtype): ... class int8x4(dtype): ... class int16x4(dtype): ... class int32x4(dtype): ... @@ -175,6 +183,10 @@ if TYPE_CHECKING: class uint16(dtype): ... class uint32(dtype): ... class uint64(dtype): ... + class uint8x2(dtype): ... + class uint16x2(dtype): ... + class uint32x2(dtype): ... + class uint64x2(dtype): ... class uint8x4(dtype): ... class uint16x4(dtype): ... class uint32x4(dtype): ... @@ -308,6 +320,10 @@ else: int16 = dtype('int16') int32 = dtype('int32') int64 = dtype('int64') + int8x2 = dtype('int8x2') + int16x2 = dtype('int16x2') + int32x2 = dtype('int32x2') + int64x2 = dtype('int64x2') int8x4 = dtype('int8x4') int16x4 = dtype('int16x4') int32x4 = dtype('int32x4') @@ -332,6 +348,10 @@ else: uint16 = dtype('uint16') uint32 = dtype('uint32') uint64 = dtype('uint64') + uint8x2 = dtype('uint8x2') + uint16x2 = dtype('uint16x2') + uint32x2 = dtype('uint32x2') + uint64x2 = dtype('uint64x2') uint8x4 = dtype('uint8x4') uint16x4 = dtype('uint16x4') uint32x4 = dtype('uint32x4') @@ -464,6 +484,10 @@ _all_dtypes = { 'int16', 'int32', 'int64', + 'int8x2', + 'int16x2', + 'int32x2', + 'int64x2', 'int8x4', 'int16x4', 'int32x4', @@ -488,6 +512,10 @@ _all_dtypes = { 'uint16', 'uint32', 'uint64', + 'uint8x2', + 'uint16x2', + 'uint32x2', + 'uint64x2', 'uint8x4', 'uint16x4', 'uint32x4', -- GitLab From 17bbc0ca3d929411dfbd3908bc70085c15a56f07 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Fri, 21 Nov 2025 17:37:39 +0800 Subject: [PATCH 029/139] [Bugfix] Fallback to the old AtomicAdd implementation for legacy architectures (#1306) --- src/tl_templates/cuda/atomic.h | 59 ++++++++++++++++++++++++++++++---- 1 file changed, 53 insertions(+), 6 deletions(-) diff --git a/src/tl_templates/cuda/atomic.h b/src/tl_templates/cuda/atomic.h index f724882e..05421080 100644 --- a/src/tl_templates/cuda/atomic.h +++ b/src/tl_templates/cuda/atomic.h @@ -169,6 +169,7 @@ TL_DEVICE T1 AtomicMinRet(T1 &ref, T2 val, } } +#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 890)) template TL_DEVICE void AtomicAdd(T1 &ref, T2 val, int memory_order = int(cuda::memory_order_relaxed)) { @@ -236,14 +237,18 @@ TL_DEVICE void AtomicAdd(T1 &ref, T2 val, } } } else { -#if CUDART_VERSION >= 11080 - cuda::atomic_ref aref(*address); - aref.fetch_add(cuda_cast(val), cuda::memory_order(memory_order)); -#else - TL_NOT_IMPLEMENTED(); -#endif + atomicAdd(reinterpret_cast(address), cuda_cast(val)); } } +#else +template +TL_DEVICE void AtomicAdd(T1 &ref, T2 val, + int memory_order = int(cuda::memory_order_relaxed)) { + using NT1 = typename normalize_atomic_type::type; + (void)memory_order; + atomicAdd(reinterpret_cast(&ref), cuda_cast(val)); +} +#endif template TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val, @@ -643,6 +648,48 @@ AtomicAddx4Ret(float *ref, float *val, return ret_val; } } +#else +TL_DEVICE void AtomicAddx2(float *ref, float *val, + int memory_order = int(cuda::memory_order_relaxed)) { + (void)memory_order; + float2 add_val = *reinterpret_cast(val); + atomicAdd(ref + 0, add_val.x); + atomicAdd(ref + 1, add_val.y); +} + +TL_DEVICE float2 +AtomicAddx2Ret(float *ref, float *val, + int memory_order = int(cuda::memory_order_relaxed)) { + (void)memory_order; + float2 add_val = *reinterpret_cast(val); + float2 ret; + ret.x = atomicAdd(ref + 0, add_val.x); + ret.y = atomicAdd(ref + 1, add_val.y); + return ret; +} + +TL_DEVICE void AtomicAddx4(float *ref, float *val, + int memory_order = int(cuda::memory_order_relaxed)) { + (void)memory_order; + float4 add_val = *reinterpret_cast(val); + atomicAdd(ref + 0, add_val.x); + atomicAdd(ref + 1, add_val.y); + atomicAdd(ref + 2, add_val.z); + atomicAdd(ref + 3, add_val.w); +} + +TL_DEVICE float4 +AtomicAddx4Ret(float *ref, float *val, + int memory_order = int(cuda::memory_order_relaxed)) { + (void)memory_order; + float4 add_val = *reinterpret_cast(val); + float4 ret; + ret.x = atomicAdd(ref + 0, add_val.x); + ret.y = atomicAdd(ref + 1, add_val.y); + ret.z = atomicAdd(ref + 2, add_val.z); + ret.w = atomicAdd(ref + 3, add_val.w); + return ret; +} #endif template TL_DEVICE T AtomicLoad(T &ref, int memory_order) { -- GitLab From bf90a5f58c1ce9a3f20144368d72b02ed5fbeae6 Mon Sep 17 00:00:00 2001 From: Kuris <227995639+kurisu6912@users.noreply.github.com> Date: Fri, 21 Nov 2025 20:27:14 +0800 Subject: [PATCH 030/139] [Fix] Fix frame scope error in T.macro (#1308) * [Fix] Fix #1307 by adding macro inside function * fix lint error * add comments and fix lint error * Remove debug print from enter_frame method Removed debug print statement from enter_frame method. --------- Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> --- .../test_tilelang_language_frontend_v2.py | 26 +++++++++++++++++++ tilelang/language/v2/builder.py | 22 ++++++++++++++-- 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/testing/python/language/test_tilelang_language_frontend_v2.py b/testing/python/language/test_tilelang_language_frontend_v2.py index 2608e251..349f3caf 100644 --- a/testing/python/language/test_tilelang_language_frontend_v2.py +++ b/testing/python/language/test_tilelang_language_frontend_v2.py @@ -427,5 +427,31 @@ def test_var_macro(): pass +def frame_inside_macro(): + + @tilelang.jit + def get_sample_kernel(): + + @T.macro + def transform(x): + return x + 1 + + @T.prim_func + def sample_kernel( + num_blocks: T.int32, + idx_out: T.Tensor[(32,), T.int32], + ): + with T.Kernel(num_blocks, threads=32) as block_idx: # noqa: F841 + fragment = T.alloc_fragment(32, 'int32') + T.copy(idx_out, fragment) + + for i in T.Parallel(32): + idx_out[i] = transform(fragment[i]) + + return sample_kernel + + kernel = get_sample_kernel() # noqa: F841 + + if __name__ == '__main__': tilelang.testing.main() diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index 643994a4..c54b0701 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -80,6 +80,10 @@ class MacroFrame(Frame): ... +class ExitedMacroFrame(Frame): + ... + + class BoolOpFrame(Frame): ... @@ -164,8 +168,22 @@ class Builder(BaseBuilder): save = self.name_inside_frame, self.arg_annotations self.name_inside_frame = {} self.arg_annotations = annotations or {} - with self.with_frame(MacroFrame()): - yield + pos = len(self.frames) + # here we add a ExitedMacroFrame to preserve the frame stack inside macro + # because macro may bind some variable, and return it + # + # ```py + # @T.macro + # def foo(x): + # y = x + 1 + # return y + # @T.prim_func + # def bar(): + # c = foo(1) # macro generates let y = x + 1 + # d = c # d = c should lay inside frame of `let y = x + 1` + self.frames.append(MacroFrame()) + yield + self.frames[pos] = ExitedMacroFrame() self.name_inside_frame, self.arg_annotations = save def get(self): -- GitLab From 0d101c110f74ebf2ef8c11a5ece9dfb314b48baa Mon Sep 17 00:00:00 2001 From: Yunqian Fan Date: Fri, 21 Nov 2025 21:20:18 +0800 Subject: [PATCH 031/139] [WIP] support more dtypes for tcgen05 (#1229) support ld with pack for fp32 dtype add dump add tempalte expand remove unused dtype and change to rebased apis --- .../example_tilelang_gemm_fp8_sm100.py | 126 +++ src/op/copy.cc | 14 +- src/op/gemm_py.cc | 2 + src/op/tcgen5_meta.h | 38 +- src/tl_templates/cuda/copy_sm100.h | 35 +- src/tl_templates/cuda/gemm_sm100.h | 76 +- src/tl_templates/cuda/tcgen_05_ld.h | 755 +++++++++++++++++- tilelang/intrinsics/mma_macro_generator.py | 3 + .../intrinsics/tcgen05_macro_generator.py | 9 +- tilelang/jit/adapter/wrapper.py | 1 + tilelang/tileop/gemm/gemm_tcgen05.py | 5 +- 11 files changed, 976 insertions(+), 88 deletions(-) create mode 100644 examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py new file mode 100644 index 00000000..4628a997 --- /dev/null +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py @@ -0,0 +1,126 @@ +import torch +import tilelang +import tilelang.language as T +from tilelang.utils.tensor import map_torch_type + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + mbar = T.alloc_barrier(1) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm_v2( + A_shared, + B_shared, + C_tmem, + trans_A, + trans_B, + mbar=mbar, + wg_wait=-1, + clear_accum=(k == 0), + ) + T.mbarrier_wait_parity(mbar, k % 2) + + T.copy(C_tmem, C_local) + T.copy(C_local, C_shared) + + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +def calc_diff(x, y): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +M, N, K = 4096, 4096, 8192 +block_M, block_N, block_K = 64, 256, 32 +trans_A, trans_B = False, True +num_stages = 2 +threads = 256 +for tvm_fp8_dtype in ["float8_e4m3", "float8_e5m2"]: + for tvm_acc_dtype in ["float16", "float32"]: # , torch.float16]: + torch_fp8_dtype = map_torch_type(tvm_fp8_dtype) + torch_acc_dtype = map_torch_type(tvm_acc_dtype) + print(f"running {tvm_fp8_dtype} -> {tvm_acc_dtype}") + in_dtype, out_dtype, accum_dtype = tvm_fp8_dtype, tvm_acc_dtype, tvm_acc_dtype + + func = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, + ) + jit_kernel = tilelang.compile( + func, + out_idx=[2], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT: True, + }, + ) + # jit_kernel.export_ptx("./dump.ptx") + # jit_kernel.export_sources("./dump.cu") + + a = torch.randn(M, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype) + b = torch.randn(N, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype) + + c = jit_kernel(a, b) + ref_c = (a.to(torch.half) @ b.T.to(torch.half)).float() + c = c.float() + diff = calc_diff(c, ref_c) + # assert diff < 1e-3, f"{diff}" + print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] diff = {diff}") + + profiler = jit_kernel.get_profiler() + latency = profiler.do_bench() + print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Latency: {latency} ms") + print( + f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS" + ) diff --git a/src/op/copy.cc b/src/op/copy.cc index 5d352904..8ffef5ea 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -1117,6 +1117,11 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, bool is_ld = false; // tcgen05.ld (tensor memory -> register) bool is_st = false; // tcgen05.st (register -> tensor memory) bool is_cp = false; // tcgen05.cp (shared memory -> tensor memory) + bool src_needs_pack = + 16 == src->dtype.bits(); // if needs .pack::16b when is_ld + bool dst_needs_unpack = + 16 == dst->dtype.bits(); // if needs .unpack::16b when is_st + if (src.scope() == "shared.tmem" && dst.scope() == "local.fragment") { is_ld = true; } else if (src.scope() == "local.fragment" && dst.scope() == "shared.tmem") { @@ -1124,9 +1129,8 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, } else if (src.scope() == "shared.dyn" && dst.scope() == "shared.tmem") { is_cp = true; } else { - ICHECK(0) << "Unsupported tensor memory copy: " - << "src scope = " << src.scope() - << ", dst scope = " << dst.scope(); + ICHECK(0) << "Unsupported tensor memory copy: " << "src scope = " + << src.scope() << ", dst scope = " << dst.scope(); } // Currently tcgen05.cp is not supported // TODO (mzw) Support tcgen05.cp @@ -1246,8 +1250,10 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, : relative_wg_idx * (num_chunks_each_wg * meta.width); have_succeeded = true; Array args; + const char *bool_str = src_needs_pack ? "true" : "false"; args.push_back(StringImm(meta.intrinsics_name + "<" + - std::to_string(num_chunks_each_wg) + ">")); + std::to_string(num_chunks_each_wg) + ", " + + bool_str + ">")); args.push_back( BufferLoad(src, {(int)logical_row_min, (int)logical_col_min})); // Will be translated later diff --git a/src/op/gemm_py.cc b/src/op/gemm_py.cc index ac506ee0..6097998c 100644 --- a/src/op/gemm_py.cc +++ b/src/op/gemm_py.cc @@ -428,6 +428,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { result.push_back(Integer(meta.atom_m)); result.push_back(Integer(meta.atom_n)); result.push_back(Integer(meta.atom_k)); + result.push_back(Integer(meta.enable_ws)); + result.push_back(Integer(meta.enable_2cta)); } return result; }); diff --git a/src/op/tcgen5_meta.h b/src/op/tcgen5_meta.h index bb63c8dc..350a2bc8 100644 --- a/src/op/tcgen5_meta.h +++ b/src/op/tcgen5_meta.h @@ -15,16 +15,19 @@ using runtime::DataType; struct TCGEN5MMAMeta { int atom_m, atom_n, atom_k; + bool enable_ws, enable_2cta; }; inline std::pair GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { // TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA. #define FAIL \ - return { false, TCGEN5MMAMeta{0, 0, 0} } -#define SUCCESS(atom_m, atom_n, atom_k) \ return { \ - true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \ + false, TCGEN5MMAMeta { 0, 0, 0, false, false } \ + } +#define SUCCESS(atom_m, atom_n, atom_k, use_ws, use_2cta) \ + return { \ + true, TCGEN5MMAMeta { atom_m, atom_n, atom_k, use_ws, use_2cta } \ } std::vector ws_valid_atom_ns = {256, 128, 64}; if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) && @@ -34,39 +37,52 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { if (M % 128 == 0) { for (int atom_n = 256; atom_n >= 16; atom_n -= 16) if (N % atom_n == 0) - SUCCESS(128, atom_n, 16); + SUCCESS(128, atom_n, 16, false, false); FAIL; } else if (M % 64 == 0) { for (int atom_n : ws_valid_atom_ns) if (N % atom_n == 0) - SUCCESS(64, atom_n, 16); + SUCCESS(64, atom_n, 16, false, false); FAIL; } else if (M % 32 == 0) { for (int atom_n : ws_valid_atom_ns) if (N % atom_n == 0) - SUCCESS(32, atom_n, 16); + SUCCESS(32, atom_n, 16, false, false); FAIL; } else { FAIL; } - } else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) && - (c_dtype.is_float() && c_dtype.bits() == 32)) { + } else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e4m3() || + ab_dtype.is_float8_e5m2() || ab_dtype.is_float8_e5m2fnuz() || + ab_dtype.is_float6_e2m3fn() || ab_dtype.is_float6_e3m2fn() || + ab_dtype.is_float4_e2m1fn()) && + ((c_dtype.is_float() && c_dtype.bits() == 32) || + (c_dtype.is_float16() && c_dtype.bits() == 16))) { if (K % 32 != 0) FAIL; if (M % 128 == 0) { + for (int atom_n : ws_valid_atom_ns) + if (N % atom_n == 0) + SUCCESS(128, atom_n, 32, true, false); for (int atom_n = 256; atom_n >= 16; atom_n -= 16) if (N % atom_n == 0) - SUCCESS(128, atom_n, 32); + SUCCESS(128, atom_n, 32, false, true); + for (int atom_n = 256; atom_n >= 8; atom_n -= 8) + if (N % atom_n == 0) + SUCCESS(128, atom_n, 32, false, false); FAIL; } else if (M % 64 == 0) { for (int atom_n : ws_valid_atom_ns) if (N % atom_n == 0) - SUCCESS(64, atom_n, 32); + SUCCESS(64, atom_n, 32, true, false); + for (int atom_n = 256; atom_n >= 8; atom_n -= 8) + if (N % atom_n == 0) + SUCCESS(128, atom_n, 32, false, false); FAIL; } else if (M % 32 == 0) { for (int atom_n : ws_valid_atom_ns) if (N % atom_n == 0) - SUCCESS(32, atom_n, 32); + SUCCESS(32, atom_n, 32, true, false); FAIL; } else { FAIL; diff --git a/src/tl_templates/cuda/copy_sm100.h b/src/tl_templates/cuda/copy_sm100.h index c4047c34..aa898bcc 100644 --- a/src/tl_templates/cuda/copy_sm100.h +++ b/src/tl_templates/cuda/copy_sm100.h @@ -51,6 +51,21 @@ __device__ __forceinline__ void st_global_256(fp8_e4_32_t *ptr, : : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w)); } +__device__ __forceinline__ ulonglong4 ld_global_256(const fp8_e5_32_t *ptr) { + ulonglong4 ret; + asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];" + : "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w) + : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ void st_global_256(fp8_e5_32_t *ptr, + fp8_e5_32_t &val8) { + ulonglong4 &val = *((ulonglong4 *)&val8); + asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};" + : + : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w)); +} __device__ __forceinline__ unsigned long long pack_bfloat16x4(const bfloat16_t x, const bfloat16_t y, const bfloat16_t z, @@ -95,38 +110,38 @@ __device__ __forceinline__ void tcgen05_ld_core(uint32_t const &tmem_start_col, } } -template +template __device__ __forceinline__ void tcgen05_ld_32dp32bNx(uint32_t const &tmem_start_col, uint32_t const &tmem_col_offset, dst_t *dst_ptr) { - tcgen05_ld_core(tmem_start_col + tmem_col_offset, - dst_ptr); + tcgen05_ld_core, 7, N>( + tmem_start_col + tmem_col_offset, dst_ptr); tl::fence_view_async_tmem_load(); } -template +template __device__ __forceinline__ void tcgen05_ld_32dp64bNx(uint32_t const &tmem_start_col, uint32_t const &tmem_col_offset, dst_t *dst_ptr) { - tcgen05_ld_core(tmem_start_col + tmem_col_offset, - dst_ptr); + tcgen05_ld_core, 7, N>( + tmem_start_col + tmem_col_offset, dst_ptr); tl::fence_view_async_tmem_load(); } -template +template __device__ __forceinline__ void tcgen05_ld_32dp128bNx(uint32_t const &tmem_start_col, uint32_t const &tmem_col_offset, dst_t *dst_ptr) { - tcgen05_ld_core( + tcgen05_ld_core, 6, N>( tmem_start_col + tmem_col_offset, dst_ptr); tl::fence_view_async_tmem_load(); } -template +template __device__ __forceinline__ void tcgen05_ld_32dp256bNx(uint32_t const &tmem_start_col, uint32_t const &tmem_col_offset, dst_t *dst_ptr) { - tcgen05_ld_core( + tcgen05_ld_core, 5, N>( tmem_start_col + tmem_col_offset, dst_ptr); tl::fence_view_async_tmem_load(); } diff --git a/src/tl_templates/cuda/gemm_sm100.h b/src/tl_templates/cuda/gemm_sm100.h index 856d37dd..6c68c2c2 100644 --- a/src/tl_templates/cuda/gemm_sm100.h +++ b/src/tl_templates/cuda/gemm_sm100.h @@ -243,46 +243,96 @@ struct DispatchInstruction -struct DispatchInstruction> { - using MMA = MMA_Traits, - Int, integral_constant, + using MMA = + MMA_Traits, Int, integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; +template +struct DispatchInstruction> { + using MMA = MMA_Traits, Int, + integral_constant, integral_constant, integral_constant, integral_constant>; }; template -struct DispatchInstruction> { +struct DispatchInstruction> { using MMA = - MMA_Traits, - Int, integral_constant, + MMA_Traits, Int, integral_constant, integral_constant, integral_constant, integral_constant>; }; +template +struct DispatchInstruction> { + using MMA = MMA_Traits, Int, + integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; template -struct DispatchInstruction> { - using MMA = MMA_Traits, - Int, integral_constant, + using MMA = + MMA_Traits, Int, integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; +template +struct DispatchInstruction> { + using MMA = MMA_Traits, Int, + integral_constant, integral_constant, integral_constant, integral_constant>; }; template -struct DispatchInstruction> { using MMA = - MMA_Traits, - Int, integral_constant, + MMA_Traits, Int, integral_constant, integral_constant, integral_constant, integral_constant>; }; +template +struct DispatchInstruction> { + using MMA = MMA_Traits, Int, + integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; template class tmem_ld_32dp32bNx; + +template <> class tmem_ld_32dp32bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { @@ -180,9 +182,180 @@ public: } } }; +template <> class tmem_ld_32dp32bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, + "N must be a power of 2 and lies between 1 ~ 128"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.pack::16b.x1.b32" + "{%0}," + "[%1];\n" + : "=r"(dst_ptr[0]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.pack::16b.x2.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.pack::16b.x4.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.pack::16b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.pack::16b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.pack::16b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.pack::16b.x64.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 128) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.pack::16b.x128.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; // 16 data path lanes, 64-bit pattern, repeated N times -class tmem_ld_16dp64bNx { +template class tmem_ld_16dp64bNx; +template <> class tmem_ld_16dp64bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { @@ -352,39 +525,43 @@ public: } } }; - -// 16 data path lanes, 128-bit pattern, repeated N times -class tmem_ld_16dp128bNx { +template <> class tmem_ld_16dp64bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { - static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 64, - "N must be a power of 2 and lies between 1 ~ 64"); + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, + "N must be a power of 2 and lies between 1 ~ 128"); if constexpr (N == 1) { - asm volatile("tcgen05.ld.sync.aligned.16x128b.x1.b32" + asm volatile("tcgen05.ld.sync.aligned.16x64b.pack::16b.x1.b32" + "{%0}," + "[%1];\n" + : "=r"(dst_ptr[0]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x64b.pack::16b.x2.b32" "{%0, %1}," "[%2];\n" : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) : "r"(src_addr)); - } else if constexpr (N == 2) { - asm volatile("tcgen05.ld.sync.aligned.16x128b.x2.b32" + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.16x64b.pack::16b.x4.b32" "{%0, %1, %2, %3}," "[%4];\n" : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), "=r"(dst_ptr[3]) : "r"(src_addr)); - } else if constexpr (N == 4) { - asm volatile("tcgen05.ld.sync.aligned.16x128b.x4.b32" + } else if constexpr (N == 8) { + asm volatile("tcgen05.ld.sync.aligned.16x64b.pack::16b.x8.b32" "{%0, %1, %2, %3, %4, %5, %6, %7}," "[%8];\n" : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) : "r"(src_addr)); - } else if constexpr (N == 8) { + } else if constexpr (N == 16) { asm volatile( - "tcgen05.ld.sync.aligned.16x128b.x8.b32" + "tcgen05.ld.sync.aligned.16x64b.pack::16b.x16.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " "%14, %15}," "[%16];\n" @@ -395,9 +572,9 @@ public: "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), "=r"(dst_ptr[15]) : "r"(src_addr)); - } else if constexpr (N == 16) { + } else if constexpr (N == 32) { asm volatile( - "tcgen05.ld.sync.aligned.16x128b.x16.b32" + "tcgen05.ld.sync.aligned.16x64b.pack::16b.x32.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " "%26, %27, %28, %29, %30, %31}," @@ -414,9 +591,9 @@ public: "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) : "r"(src_addr)); - } else if constexpr (N == 32) { + } else if constexpr (N == 64) { asm volatile( - "tcgen05.ld.sync.aligned.16x128b.x32.b32" + "tcgen05.ld.sync.aligned.16x64b.pack::16b.x64.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " "%28, " @@ -449,9 +626,9 @@ public: "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), "=r"(dst_ptr[63]) : "r"(src_addr)); - } else if constexpr (N == 64) { + } else if constexpr (N == 128) { asm volatile( - "tcgen05.ld.sync.aligned.16x128b.x64.b32" + "tcgen05.ld.sync.aligned.16x64b.pack::16b.x128.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " "%28, " @@ -519,32 +696,39 @@ public: } }; -// 16 data path lanes, 256-bit pattern, repeated N times -class tmem_ld_16dp256bNx { +// 16 data path lanes, 128-bit pattern, repeated N times +template class tmem_ld_16dp128bNx; +template <> class tmem_ld_16dp128bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { - static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 32, - "N must be a power of 2 and lies between 1 ~ 32"); + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 64, + "N must be a power of 2 and lies between 1 ~ 64"); if constexpr (N == 1) { - asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32" + asm volatile("tcgen05.ld.sync.aligned.16x128b.x1.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.x2.b32" "{%0, %1, %2, %3}," "[%4];\n" : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), "=r"(dst_ptr[3]) : "r"(src_addr)); - } else if constexpr (N == 2) { - asm volatile("tcgen05.ld.sync.aligned.16x256b.x2.b32" + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.x4.b32" "{%0, %1, %2, %3, %4, %5, %6, %7}," "[%8];\n" : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) : "r"(src_addr)); - } else if constexpr (N == 4) { + } else if constexpr (N == 8) { asm volatile( - "tcgen05.ld.sync.aligned.16x256b.x4.b32" + "tcgen05.ld.sync.aligned.16x128b.x8.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " "%14, %15}," "[%16];\n" @@ -555,9 +739,9 @@ public: "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), "=r"(dst_ptr[15]) : "r"(src_addr)); - } else if constexpr (N == 8) { + } else if constexpr (N == 16) { asm volatile( - "tcgen05.ld.sync.aligned.16x256b.x8.b32" + "tcgen05.ld.sync.aligned.16x128b.x16.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " "%26, %27, %28, %29, %30, %31}," @@ -574,9 +758,9 @@ public: "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) : "r"(src_addr)); - } else if constexpr (N == 16) { + } else if constexpr (N == 32) { asm volatile( - "tcgen05.ld.sync.aligned.16x256b.x16.b32" + "tcgen05.ld.sync.aligned.16x128b.x32.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " "%28, " @@ -609,9 +793,492 @@ public: "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), "=r"(dst_ptr[63]) : "r"(src_addr)); - } else if constexpr (N == 32) { + } else if constexpr (N == 64) { asm volatile( - "tcgen05.ld.sync.aligned.16x256b.x32.b32" + "tcgen05.ld.sync.aligned.16x128b.x64.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; +template <> class tmem_ld_16dp128bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 64, + "N must be a power of 2 and lies between 1 ~ 64"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.pack::16b.x1.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.pack::16b.x2.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.pack::16b.x4.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.pack::16b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.pack::16b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.pack::16b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.pack::16b.x64.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; + +// 16 data path lanes, 256-bit pattern, repeated N times +template class tmem_ld_16dp256bNx; +template <> class tmem_ld_16dp256bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 32, + "N must be a power of 2 and lies between 1 ~ 32"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.x2.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x4.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; +template <> class tmem_ld_16dp256bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 32, + "N must be a power of 2 and lies between 1 ~ 32"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.pack::16b.x1.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.pack::16b.x2.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.pack::16b.x4.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.pack::16b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.pack::16b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.pack::16b.x32.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " "%28, " @@ -681,32 +1348,32 @@ public: // 32 data path lanes, 64-bit pattern, repeated N times // (conducted with 2x16dp64bNx) -class tmem_ld_32dp64bNx { +template class tmem_ld_32dp64bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { - tmem_ld_16dp64bNx::copy(src_addr, dst_ptr); - tmem_ld_16dp64bNx::copy(src_addr + (16 << 16), dst_ptr + N); + tmem_ld_16dp64bNx::copy(src_addr, dst_ptr); + tmem_ld_16dp64bNx::copy(src_addr + (16 << 16), dst_ptr + N); } }; // 32 data path lanes, 128-bit pattern, repeated N times -class tmem_ld_32dp128bNx { +template class tmem_ld_32dp128bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { - tmem_ld_16dp128bNx::copy(src_addr, dst_ptr); - tmem_ld_16dp128bNx::copy(src_addr + (16 << 16), dst_ptr + N * 2); + tmem_ld_16dp128bNx::copy(src_addr, dst_ptr); + tmem_ld_16dp128bNx::copy(src_addr + (16 << 16), dst_ptr + N * 2); } }; // 32 data path lanes, 256-bit pattern, repeated N times -class tmem_ld_32dp256bNx { +template class tmem_ld_32dp256bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { - tmem_ld_16dp256bNx::copy(src_addr, dst_ptr); - tmem_ld_16dp256bNx::copy(src_addr + (16 << 16), dst_ptr + N * 4); + tmem_ld_16dp256bNx::copy(src_addr, dst_ptr); + tmem_ld_16dp256bNx::copy(src_addr + (16 << 16), dst_ptr + N * 4); } }; diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index 8c546c63..bbfeb157 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -45,7 +45,10 @@ class TensorCoreIntrinEmitter: "int8": "int8", "int32": "int32", "float8_e4m3": "e4m3", + "float8_e4m3fn": "e4m3", + "float8_e4m3fnuz": "e4m3", "float8_e5m2": "e5m2", + "float8_e5m2fnuz": "e5m2", } # Represent the thread binding in the form of (tx, warp_n, warp_m) diff --git a/tilelang/intrinsics/tcgen05_macro_generator.py b/tilelang/intrinsics/tcgen05_macro_generator.py index e53ff7cb..966f4dc4 100644 --- a/tilelang/intrinsics/tcgen05_macro_generator.py +++ b/tilelang/intrinsics/tcgen05_macro_generator.py @@ -169,12 +169,11 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): accum_dtype_in_bits = DataType(accum_dtype).bits meta = self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim) - if len(meta) != 3: + if len(meta) != 5: raise ValueError( f"Unsupported TCGEN5MMA configuration for desc generation: M={m_dim}, N={n_dim}, " f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}") - atom_m, atom_n, atom_k = (int(x) for x in meta) - enable_ws = atom_m != 128 + atom_m, atom_n, atom_k, enable_ws, enable_2cta = (int(x) for x in meta) # by default, we utilize non-swizzle layout offset a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * @@ -382,10 +381,10 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): k = int(self.chunk) meta = self.get_tcgen5_mma_meta(m, n, k) - if len(meta) != 3: + if len(meta) != 5: raise ValueError(f"Unsupported TCGEN5MMA configuration: M={m}, N={n}, K={k}, " f"A dtype={self.a_dtype}, accum dtype={self.accum_dtype}") - atom_m, atom_n, _ = (int(x) for x in meta) + atom_m, atom_n, _, _, _ = (int(x) for x in meta) if m % atom_m != 0 or n % atom_n != 0: raise ValueError( diff --git a/tilelang/jit/adapter/wrapper.py b/tilelang/jit/adapter/wrapper.py index 48b8e908..75607976 100644 --- a/tilelang/jit/adapter/wrapper.py +++ b/tilelang/jit/adapter/wrapper.py @@ -144,6 +144,7 @@ class TLCUDASourceWrapper: "float16": "half_t", "bfloat16": "bfloat16_t", "float8_e4m3": "fp8_e4_t", + "float8_e4m3fn": "fp8_e4_t", "float8_e5m2": "fp8_e5_t", "float64": "double", "int64": "int64_t", diff --git a/tilelang/tileop/gemm/gemm_tcgen05.py b/tilelang/tileop/gemm/gemm_tcgen05.py index 52c192e5..1de9fe87 100644 --- a/tilelang/tileop/gemm/gemm_tcgen05.py +++ b/tilelang/tileop/gemm/gemm_tcgen05.py @@ -85,6 +85,9 @@ class GemmTCGEN5(GemmBase): raise ValueError(f"TCGEN5MMA currently only supports gemm_ss, got " f"A scope {self.A.scope()}, B scope {self.B.scope()}") + atom_m, atom_n, atom_k, enable_ws, enable_2cta = mma_emitter.get_tcgen5_mma_meta( + self.M, self.N, self.K) + if self.A.scope() not in {"shared", "shared.dyn", "shared.tmem"}: raise ValueError(f"Unsupported A scope for TCGEN5MMA: {self.A.scope()}") if self.B.scope() not in {"shared", "shared.dyn"}: @@ -103,7 +106,7 @@ class GemmTCGEN5(GemmBase): raise ValueError("TCGEN5MMA expects 2D coordinates for C buffer access") accum_dtype = str(self.C.dtype) - if accum_dtype != "float32": + if accum_dtype not in ["float32", 'float16']: raise ValueError(f"Unsupported accumulator dtype for TCGEN5MMA: {accum_dtype}") A_shared = self.ARegion -- GitLab From 470eb74cac8e1ea4f99547de5ea5cb24feabb2c9 Mon Sep 17 00:00:00 2001 From: LJC00118 <77378439+LJC00118@users.noreply.github.com> Date: Sat, 22 Nov 2025 12:03:23 +0800 Subject: [PATCH 032/139] Improve memory access safety and `T.assume` handling (#1292) * Improve memory access safety and T.assume handling * Improve memory access safety and T.assume handling * bugfix * lint fix * bugfix * bugfix * refactor legalize safe memory access pass --------- Co-authored-by: Lei Wang --- src/transform/legalize_safe_memory_access.cc | 168 ++++++------------- src/transform/simplify.cc | 10 ++ 2 files changed, 58 insertions(+), 120 deletions(-) diff --git a/src/transform/legalize_safe_memory_access.cc b/src/transform/legalize_safe_memory_access.cc index 68a0cdbb..1a9da919 100644 --- a/src/transform/legalize_safe_memory_access.cc +++ b/src/transform/legalize_safe_memory_access.cc @@ -24,32 +24,6 @@ namespace tl { using namespace tir; using arith::IRMutatorWithAnalyzer; -// Helper class to find leaf For nodes in a given IR -class LeafForFinder : public StmtVisitor { -public: - std::vector leaf_for_nodes; - -private: - void VisitStmt_(const ForNode *op) final { - has_child_for_ = false; - bool parent_has_child_for = parent_has_child_for_; - parent_has_child_for_ = false; - - StmtVisitor::VisitStmt(op->body); - - if (!has_child_for_) { - leaf_for_nodes.push_back(tvm::ffi::GetRef(op)); - } - - parent_has_child_for_ = parent_has_child_for; - parent_has_child_for_ = true; - } - -private: - bool has_child_for_ = false; - bool parent_has_child_for_ = false; -}; - // GlobalMemChecker for a BufferLoad/BufferStore node: // 1. Identify BufferLoad and BufferStore nodes. // 2. Check if the buffer is in global scope. @@ -109,13 +83,16 @@ struct GlobalMemChecker : public StmtExprVisitor { PrimExpr index = indices[i]; PrimExpr shape_dim = buffer->shape[i]; - bool has_variable = false; + bool is_index_constant = true; PostOrderVisit(index, [&](const ObjectRef &obj) { if (const VarNode *v = obj.as()) { - has_variable = true; + is_index_constant = false; + } + if (const BufferLoadNode *v = obj.as()) { + is_index_constant = false; } }); - if (!has_variable) { + if (is_index_constant) { // If index is a constant, we can skip the check continue; } @@ -145,18 +122,31 @@ private: bool recursively_collect_conds_; }; -class SafeMemorysRewriter : public StmtExprMutator { - arith::Analyzer *analyzer_; - +class SafeMemorysRewriter : public IRMutatorWithAnalyzer { public: - explicit SafeMemorysRewriter(Map annotated_safe_value_map, - arith::Analyzer *analyzer) - : annotated_safe_value_map_(std::move(annotated_safe_value_map)), - analyzer_(analyzer) {} + // Static method to substitute and transform the given PrimFunc + static PrimFunc Substitute(PrimFunc f) { + arith::Analyzer analyzer; + // Create an instance of the legalizer with the analyzer + SafeMemorysRewriter substituter(&analyzer); + // Get a mutable copy of the function node + PrimFuncNode *fptr = f.CopyOnWrite(); + for (const auto &[_, buffer] : f->buffer_map) { + substituter.buffer_data_to_buffer_.Set(buffer->data, buffer); + } + // Apply the legalizer to the function body + fptr->body = substituter.VisitStmt(f->body); + return f; + } private: + // Constructor initializing the base class with the analyzer + SafeMemorysRewriter(arith::Analyzer *analyzer) + : arith::IRMutatorWithAnalyzer(analyzer) {} + // Constructor initializing the base class with the analyzer + PrimExpr VisitExpr_(const BufferLoadNode *op) final { - auto load = Downcast(StmtExprMutator::VisitExpr_(op)); + auto load = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); // For Load/Store, we only check the current node, not its children. // Since rewriter will recursively visit children. @@ -181,7 +171,7 @@ private: Stmt VisitStmt_(const BufferStoreNode *op) final { // Check if the buffer is in global scope - auto store = Downcast(StmtExprMutator::VisitStmt_(op)); + auto store = Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/false); checker(store); @@ -253,6 +243,25 @@ private: return evaluate; } + Stmt VisitStmt_(const BlockNode *op) final { + for (auto buffer : op->alloc_buffers) { + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + if (op->annotations.count(attr::kSafeValueMap)) { + auto map = op->annotations.Get(attr::kSafeValueMap) + ->as>() + .value(); + for (const auto &[var, safe_value] : map) { + ICHECK(buffer_data_to_buffer_.count(var)) + << "buffer " << var << " is not found in the block " + << buffer_data_to_buffer_; + auto buffer = buffer_data_to_buffer_[var]; + annotated_safe_value_map_.Set(buffer, safe_value); + } + } + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + bool IsLocalBuffer(const Buffer &buffer) { String scope = buffer.scope(); return scope == "local" || scope == "local.fragment" || @@ -276,87 +285,6 @@ private: return make_zero(buffer->dtype); } - Map annotated_safe_value_map_; -}; - -// Class to legalize safe memory access by transforming them appropriately -class SafeMemoryLegalizer : IRMutatorWithAnalyzer { -public: - // Static method to substitute and transform the given PrimFunc - static PrimFunc Substitute(PrimFunc f) { - arith::Analyzer analyzer; - // Create an instance of the legalizer with the analyzer - SafeMemoryLegalizer substituter(&analyzer); - // Get a mutable copy of the function node - PrimFuncNode *fptr = f.CopyOnWrite(); - for (const auto &[_, buffer] : f->buffer_map) { - substituter.buffer_data_to_buffer_.Set(buffer->data, buffer); - } - // Apply the legalizer to the function body - fptr->body = substituter.VisitStmt(f->body); - return f; - } - -private: - // Constructor initializing the base class with the analyzer - SafeMemoryLegalizer(arith::Analyzer *analyzer) - : arith::IRMutatorWithAnalyzer(analyzer) {} - - // Override the VisitStmt_ method to handle ForNode (loop statements) - Stmt VisitStmt_(const ForNode *op) final { - // Visit and potentially modify the loop node - For for_node = Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); - auto has_inner_loop = HasInnerLoop(for_node->body); - if (!has_inner_loop) { - SafeMemorysRewriter rewriter(annotated_safe_value_map_, analyzer_); - for_node.CopyOnWrite()->body = rewriter(for_node->body); - // // Detect Buffer Load Node in the loop body, collect the indices and - // buffer size - - // // Run the checker on the loop body - // GlobalMemChecker checker(analyzer_); - // checker(for_node->body); - // Array conditions = checker.GetConditions(); - // auto body = for_node->body; - // // Note that we might have duplicate conditions - // // Which will be optimized by simplify pass - // // Replace the loop body with the new body - // for (auto cond : conditions) { - // body = IfThenElse(cond, body); - // } - // for_node.CopyOnWrite()->body = body; - return std::move(for_node); - } - - // Visit a For Node - return IRMutatorWithAnalyzer::VisitStmt_(op); - } - - Stmt VisitStmt_(const BlockNode *op) final { - for (auto buffer : op->alloc_buffers) { - buffer_data_to_buffer_.Set(buffer->data, buffer); - } - if (op->annotations.count(attr::kSafeValueMap)) { - auto map = op->annotations.Get(attr::kSafeValueMap) - ->as>() - .value(); - for (const auto &[var, safe_value] : map) { - ICHECK(buffer_data_to_buffer_.count(var)) - << "buffer " << var << " is not found in the block " - << buffer_data_to_buffer_; - auto buffer = buffer_data_to_buffer_[var]; - annotated_safe_value_map_.Set(buffer, safe_value); - } - } - return IRMutatorWithAnalyzer::VisitStmt_(op); - } - - static bool HasInnerLoop(const Stmt &stmt) { - LeafForFinder finder; - finder(stmt); - return !finder.leaf_for_nodes.empty(); - } - Map buffer_data_to_buffer_; Map annotated_safe_value_map_; }; @@ -371,7 +299,7 @@ tvm::transform::Pass LegalizeSafeMemoryAccess() { if (disable_safe_memory_legalize) { return f; } - return SafeMemoryLegalizer::Substitute(std::move(f)); + return SafeMemorysRewriter::Substitute(std::move(f)); }; // Create and return a PrimFunc pass with the transformation function return CreatePrimFuncPass(pass_func, 0, "tl.LegalizeSafeMemoryAccess", {}); diff --git a/src/transform/simplify.cc b/src/transform/simplify.cc index 5a83f0df..c10d5687 100644 --- a/src/transform/simplify.cc +++ b/src/transform/simplify.cc @@ -465,6 +465,16 @@ private: return std::move(store); } + Stmt VisitStmt_(const AttrStmtNode *op) override { + if (op->attr_key == "tl.assume") { + PrimExpr condition = this->VisitExpr(Downcast(op->node)); + auto n = CopyOnWrite(op); + n->node = std::move(condition); + return Parent::VisitStmt_(n.get()); + } + return Parent::VisitStmt_(op); + } + private: bool ArrayDeepEqual(const Array &lhs, const Array &rhs) { if (lhs.size() != rhs.size()) { -- GitLab From 721baedb7821c9be2950d45dad05a736a3590dfd Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Sat, 22 Nov 2025 19:24:45 +0800 Subject: [PATCH 033/139] [Bugfix] Fix autotune cache (#1315) --- tilelang/autotuner/param.py | 198 ++++++++++++++++++++++++++++-------- 1 file changed, 153 insertions(+), 45 deletions(-) diff --git a/tilelang/autotuner/param.py b/tilelang/autotuner/param.py index 3e401cc5..4c8d9a94 100644 --- a/tilelang/autotuner/param.py +++ b/tilelang/autotuner/param.py @@ -13,18 +13,25 @@ from pathlib import Path from tilelang.jit import JITKernel import cloudpickle import os -import shutil from tilelang.engine.param import KernelParam from tilelang import logger import json import hashlib +import uuid +from tilelang import env +from tvm.runtime import Executable BEST_CONFIG_PATH = "best_config.json" FUNCTION_PATH = "function.pkl" LATENCY_PATH = "latency.json" -KERNEL_PATH = "kernel.cu" -WRAPPED_KERNEL_PATH = "wrapped_kernel.cu" + +# Align file names with cache/kernel_cache.py +DEVICE_KERNEL_PATH = "device_kernel.cu" +HOST_KERNEL_PATH = "host_kernel.cu" +EXECUTABLE_PATH = "executable.so" KERNEL_LIB_PATH = "kernel_lib.so" +KERNEL_CUBIN_PATH = "kernel.cubin" +KERNEL_PY_PATH = "kernel.py" PARAMS_PATH = "params.pkl" @@ -143,6 +150,31 @@ class AutotuneResult: func: Callable | None = None kernel: Callable | None = None + @staticmethod + def _load_binary(path: str): + with open(path, "rb") as file: + binary = file.read() + return binary + + @staticmethod + def _safe_write_file(path: str, mode: str, operation: Callable[[Any], None]): + # Random a temporary file within the same FS as the cache directory + tmp_dir = env.TILELANG_TMP_DIR + os.makedirs(tmp_dir, exist_ok=True) + temp_path = os.path.join(tmp_dir, f"{os.getpid()}_{uuid.uuid4()}") + with open(temp_path, mode) as temp_file: + operation(temp_file) + # Use atomic POSIX replace, so other processes cannot see a partial write + os.replace(temp_path, path) + + @staticmethod + def _safe_write_executable(executable: Executable, path: str): + tmp_dir = env.TILELANG_TMP_DIR + os.makedirs(tmp_dir, exist_ok=True) + temp_path = os.path.join(tmp_dir, f"{os.getpid()}_{uuid.uuid4()}.so") + executable.export_library(temp_path) + os.replace(temp_path, path) + def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel, verbose: bool = False): """ Persists a compiled kernel to disk cache. @@ -161,34 +193,68 @@ class AutotuneResult: """ os.makedirs(cache_path, exist_ok=True) # Ensure directory exists - # Save kernel source code + # Save device kernel source code try: - kernel_path = os.path.join(cache_path, KERNEL_PATH) + device_kernel_path = os.path.join(cache_path, DEVICE_KERNEL_PATH) if verbose: - logger.debug(f"Saving kernel source code to file: {kernel_path}") + logger.debug(f"Saving kernel source code to file: {device_kernel_path}") if kernel.kernel_source is not None: - with open(kernel_path, "w") as f: - f.write(kernel.kernel_source) + self._safe_write_file(device_kernel_path, "w", + lambda f: f.write(kernel.kernel_source)) except Exception as e: logger.error(f"Error saving kernel source code to disk: {e}") - # Save wrapped kernel source code + # Save host kernel source code (wrapped) try: - wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH) + host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH) if verbose: - logger.debug(f"Saving wrapped kernel source code to file: {wrapped_kernel_path}") - with open(wrapped_kernel_path, "w") as f: - f.write(kernel.get_kernel_source()) + logger.debug(f"Saving wrapped kernel source code to file: {host_kernel_path}") + # Match kernel_cache behavior: use host source for tvm_ffi, otherwise wrapped kernel + if kernel.execution_backend == "tvm_ffi": + self._safe_write_file(host_kernel_path, "w", + lambda f: f.write(kernel.adapter.get_host_source())) + else: + self._safe_write_file(host_kernel_path, "w", + lambda f: f.write(kernel.adapter.get_kernel_source())) except Exception as e: logger.error(f"Error saving wrapped kernel source code to disk: {e}") - # Save kernel library + # Save kernel library (backend-specific) try: - kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH) - src_lib_path = kernel.adapter.libpath - if verbose: - logger.debug(f"Saving kernel library to file: {kernel_lib_path}") - shutil.copy(src_lib_path, kernel_lib_path) + if kernel.execution_backend == "nvrtc": + kernel_lib_file = KERNEL_CUBIN_PATH + elif kernel.execution_backend == "tvm_ffi": + kernel_lib_file = EXECUTABLE_PATH + else: + kernel_lib_file = KERNEL_LIB_PATH + + kernel_lib_path = os.path.join(cache_path, kernel_lib_file) + + if kernel.execution_backend == "nvrtc": + # Save cubin and python helper file + src_lib_path = kernel.adapter.libpath + kernel_py_path = os.path.join(cache_path, KERNEL_PY_PATH) + py_src_path = src_lib_path.replace(".cubin", ".py") + if verbose: + logger.debug(f"Saving kernel nvrtc python code to file: {kernel_py_path}") + self._safe_write_file(kernel_py_path, "wb", + lambda f: f.write(self._load_binary(py_src_path))) + if verbose: + logger.debug(f"Saving kernel library to file: {kernel_lib_path}") + self._safe_write_file(kernel_lib_path, "wb", + lambda f: f.write(self._load_binary(src_lib_path))) + elif kernel.execution_backend == "tvm_ffi": + executable = kernel.adapter.executable + if verbose: + logger.debug(f"Saving kernel executable to file: {kernel_lib_path}") + self._safe_write_executable(executable, kernel_lib_path) + else: + src_lib_path = kernel.adapter.libpath + if verbose: + logger.debug(f"Saving kernel library to file: {kernel_lib_path}") + self._safe_write_file(kernel_lib_path, "wb", + lambda f: f.write(self._load_binary(src_lib_path))) + except Exception as e: logger.error(f"Error saving kernel library to disk: {e}") @@ -197,8 +263,7 @@ class AutotuneResult: params_path = os.path.join(cache_path, PARAMS_PATH) if verbose: logger.debug(f"Saving kernel parameters to disk: {params_path}") - with open(params_path, "wb") as f: - cloudpickle.dump(kernel.params, f) + self._safe_write_file(params_path, "wb", lambda f: cloudpickle.dump(kernel.params, f)) except Exception as e: logger.error(f"Error saving kernel parameters to disk: {e}") @@ -210,6 +275,7 @@ class AutotuneResult: out_idx: list[int] | int | None = None, execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi", pass_configs: dict = None, + compile_flags: list[str] | str | None = None, func: Callable = None, verbose: bool = False, ) -> JITKernel: @@ -233,23 +299,46 @@ class AutotuneResult: if not os.path.exists(cache_path): return None - kernel_global_source: str | None = None + # Resolve backend to pick correct file names + if execution_backend == "nvrtc": + kernel_lib_file = KERNEL_CUBIN_PATH + elif execution_backend == "tvm_ffi": + kernel_lib_file = EXECUTABLE_PATH + else: + kernel_lib_file = KERNEL_LIB_PATH + + device_kernel_path = os.path.join(cache_path, DEVICE_KERNEL_PATH) + host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH) + kernel_lib_path = os.path.join(cache_path, kernel_lib_file) + params_path = os.path.join(cache_path, PARAMS_PATH) + + if not all([os.path.exists(file) for file in (kernel_lib_path, params_path)]): + return None + + device_kernel_source: str | None = None + host_kernel_source: str | None = None kernel_params: list[KernelParam] | None = None + # Load optional device kernel source try: - wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH) if verbose: - logger.debug(f"Loading wrapped kernel source code from file: {wrapped_kernel_path}") - with open(wrapped_kernel_path) as f: - kernel_global_source = f.read() + logger.debug(f"Loading kernel source code from file: {device_kernel_path}") + with open(device_kernel_path) as f: + device_kernel_source = f.read() except Exception as e: - logger.error(f"Error loading wrapped kernel source code from disk: {e}") + logger.error(f"Error loading kernel source code from disk: {e}") - kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH) + # Load optional host kernel source + try: + if verbose: + logger.debug(f"Loading wrapped kernel source code from file: {host_kernel_path}") + with open(host_kernel_path) as f: + host_kernel_source = f.read() + except Exception as e: + logger.error(f"Error loading host kernel source code from disk: {e}") # Load kernel parameters try: - params_path = os.path.join(cache_path, PARAMS_PATH) if verbose: logger.debug(f"Loading kernel parameters from file: {params_path}") with open(params_path, "rb") as f: @@ -257,10 +346,11 @@ class AutotuneResult: except Exception as e: logger.error(f"Error loading kernel parameters from disk: {e}") - if kernel_global_source and kernel_params: + if host_kernel_source and device_kernel_source and kernel_params: return JITKernel.from_database( func=func, - kernel_global_source=kernel_global_source, + host_kernel_source=host_kernel_source, + device_kernel_source=device_kernel_source, kernel_lib_path=kernel_lib_path, params=kernel_params, target=target, @@ -268,6 +358,7 @@ class AutotuneResult: out_idx=out_idx, execution_backend=execution_backend, pass_configs=pass_configs, + compile_flags=compile_flags, ) else: return None @@ -276,26 +367,29 @@ class AutotuneResult: if not os.path.exists(path): os.makedirs(path) - # save best config + # save best config (atomic) if verbose: logger.debug(f"Saving best config to file: {path / BEST_CONFIG_PATH}") - with open(path / BEST_CONFIG_PATH, "w") as f: - json.dump(self.config, f) + self._safe_write_file( + str(path / BEST_CONFIG_PATH), "w", lambda f: json.dump(self.config, f)) - # save function + # save function (atomic) if verbose: logger.debug(f"Saving function to file: {path / FUNCTION_PATH}") - with open(path / FUNCTION_PATH, "wb") as f: - cloudpickle.dump(self.func, f) + self._safe_write_file( + str(path / FUNCTION_PATH), "wb", lambda f: cloudpickle.dump(self.func, f)) - # save ref latency + # save ref latency (atomic) if verbose: logger.debug(f"Saving latency to file: {path / LATENCY_PATH}") - with open(path / LATENCY_PATH, "w") as f: - json.dump({ + self._safe_write_file( + str(path / LATENCY_PATH), + "w", + lambda f: json.dump({ "latency": self.latency, "ref_latency": self.ref_latency, - }, f) + }, f), + ) # save kernel self._save_kernel_to_disk(path, self.kernel) @@ -306,6 +400,13 @@ class AutotuneResult: return None verbose = compile_args.verbose + # Normalize target and resolve execution backend for loading + from tilelang.utils.target import determine_target as _determine_target + from tilelang.jit.execution_backend import resolve_execution_backend + norm_target = Target(_determine_target(compile_args.target)) if isinstance( + compile_args.target, str) else compile_args.target + requested_backend = compile_args.execution_backend + resolved_backend = resolve_execution_backend(requested_backend, norm_target) # load best config if verbose: logger.debug(f"Loading best config from file: {path / BEST_CONFIG_PATH}") @@ -325,10 +426,17 @@ class AutotuneResult: latency = json.load(f) latency, ref_latency = latency["latency"], latency["ref_latency"] - kernel = cls._load_kernel_from_disk(cls, path, compile_args.target, - compile_args.target_host, compile_args.out_idx, - compile_args.execution_backend, - compile_args.pass_configs, func) + kernel = cls._load_kernel_from_disk( + cls, + path, + norm_target, + compile_args.target_host, + compile_args.out_idx, + resolved_backend, + compile_args.pass_configs, + None, # compile_flags not tracked here + func, + ) if kernel is None: return None kernel.update_tuner_result( -- GitLab From 9f7bac4c1c21d259c59f44114554256b39c3610b Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Sun, 23 Nov 2025 14:01:02 +0800 Subject: [PATCH 034/139] [Refactor] Backup Analyzer to get the appropriate arith informations (#1311) * [Refactor] Update Vectorization Functions to Accept Analyzer Parameter - Modified `VectorizeLoop` and related functions to accept an `arith::Analyzer` parameter, enhancing their capability to perform analysis during vectorization. - Updated multiple instances in `copy.cc`, `fill.cc`, `parallel.cc`, and layout inference files to utilize the new analyzer parameter for improved performance and correctness. - Ensured consistency across vectorization logic by integrating the analyzer into existing workflows, facilitating better optimization opportunities. * [Fix] Corrected PostOrderVisit call in loop_vectorize.cc - Updated the PostOrderVisit function to analyze the body of the loop node instead of the node itself, ensuring proper handling of nested loops during vectorization analysis. * fix * lint fix * fix --- 3rdparty/tvm | 2 +- src/op/copy.cc | 4 +- src/op/fill.cc | 6 +- src/op/parallel.cc | 3 +- src/transform/layout_inference.cc | 12 ++- src/transform/legalize_vectorized_loop.cc | 2 +- src/transform/loop_vectorize.cc | 99 +++++++++++++++-------- src/transform/loop_vectorize.h | 5 ++ 8 files changed, 87 insertions(+), 46 deletions(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index bc31e7ad..cd2b2b60 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit bc31e7ad9f9fafd7659dfabafe359fd55a0ffc1e +Subproject commit cd2b2b6013d155b5822300b0a0740fa65320dd9e diff --git a/src/op/copy.cc b/src/op/copy.cc index 8ffef5ea..c2dd06fc 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -852,7 +852,7 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, auto par_op = ParallelOp(transformed_loop); if (is_cpu_target) { - vectorized_thread_loop = VectorizeLoop(transformed_loop); + vectorized_thread_loop = VectorizeLoop(transformed_loop, analyzer); } else { std::vector levels = {InferLevel::kCommon, InferLevel::kStrict, InferLevel::kFree}; @@ -865,7 +865,7 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, auto thread_var = T.thread_var; auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout); - vectorized_thread_loop = VectorizeLoop(thread_loop); + vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer); } if (par_op->GetPredicate(T.thread_var).defined()) { diff --git a/src/op/fill.cc b/src/op/fill.cc index 83b0842d..93b3bca0 100644 --- a/src/op/fill.cc +++ b/src/op/fill.cc @@ -207,7 +207,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { InferLevel::kFree); auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, par_op->GetLoopLayout()); - auto vectorized_thread_loop = VectorizeLoop(thread_loop); + auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer); if (par_op->GetPredicate(T.thread_var).defined()) { return IfThenElse(par_op->GetPredicate(T.thread_var).value(), vectorized_thread_loop); @@ -215,7 +215,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return vectorized_thread_loop; } else if (dst.scope() == "local") { auto init_loop = MakeSIMTLoop(analyzer); - auto vectorized_thread_loop = VectorizeLoop(init_loop); + auto vectorized_thread_loop = VectorizeLoop(init_loop, analyzer); return vectorized_thread_loop; } else if (dst.scope() == "shared.dyn" || dst.scope() == "shared" || dst.scope() == "global") { @@ -225,7 +225,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { InferLevel::kFree); auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, par_op->GetLoopLayout()); - auto vectorized_thread_loop = VectorizeLoop(thread_loop); + auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer); if (par_op->GetPredicate(T.thread_var).defined()) { return IfThenElse(par_op->GetPredicate(T.thread_var).value(), vectorized_thread_loop); diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 81777aa5..0d09cc12 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -452,8 +452,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, // As the pass will do post processing to the layout auto maybe_remapped_root_ = IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map); - int vector_size = GetVectorizeSize(maybe_remapped_root_); - + int vector_size = GetVectorizeSize(maybe_remapped_root_, T.analyzer); DLOG(INFO) << "[PlanLoopPartition] vector_size = " << vector_size << '\n'; PrimExpr loop_total_size = 1; diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index bd726b3d..be98b284 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -12,6 +12,7 @@ #include #include +#include #include #include "../layout/utils.h" @@ -85,6 +86,7 @@ public: auto &next = infer_list_[cur_infer_id]; auto iter_var = thread_var_vec_[cur_infer_id]; auto thread_bounds = thread_bounds_vec_[cur_infer_id]; + arith::Analyzer *cur_analyzer = analyzer_vec_[cur_infer_id].get(); auto buffer_oob = buffer_oob_vec_[cur_infer_id]; // Double-check that 'next' is valid ICHECK(next.defined()) << "infer_list_[" << cur_infer_id @@ -108,7 +110,7 @@ public: // Run InferLayout auto updates = next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map, - &analyzer_, buffer_oob}, + cur_analyzer, buffer_oob}, level); // Process the returned updates for (const auto &[buffer, layout] : updates) { @@ -266,6 +268,9 @@ public: ICHECK_EQ(thread_bounds_vec_.size(), infer_list_.size()) << "Size mismatch: thread_bounds_vec_ and infer_list_ must match in " "length."; + ICHECK_EQ(analyzer_vec_.size(), infer_list_.size()) + << "Size mismatch: analyzer_vec_ and infer_list_ must match in " + "length."; ICHECK_EQ(buffer_oob_vec_.size(), infer_list_.size()) << "Size mismatch: buffer_oob_vec_ and infer_list_ must match in " "length."; @@ -452,6 +457,7 @@ private: } else { thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1)); } + analyzer_vec_.push_back(analyzer_.Clone()); // Compute buffer oob for each buffer in the op if (const auto *copy = p.as()) { @@ -542,6 +548,7 @@ private: } else { thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1)); } + analyzer_vec_.push_back(analyzer_.Clone()); buffer_oob_vec_.push_back(false); } else { IRVisitorWithAnalyzer::VisitStmt(op->body); @@ -683,6 +690,7 @@ private: IterVarType::kDataPar); std::vector thread_var_vec_; std::vector thread_bounds_vec_; + std::vector> analyzer_vec_; std::vector buffer_oob_vec_; Target target_; LayoutMap annotated_layout_map_; @@ -1024,7 +1032,7 @@ private: }); if ((has_non_local || has_cast_operations) && !has_reducer) { - for_node = VectorizeLoop(for_node); + for_node = VectorizeLoop(for_node, analyzer_); } if (result_.predicate_map.count(root) && parallel_loop) { diff --git a/src/transform/legalize_vectorized_loop.cc b/src/transform/legalize_vectorized_loop.cc index aa461784..4fd4ab91 100644 --- a/src/transform/legalize_vectorized_loop.cc +++ b/src/transform/legalize_vectorized_loop.cc @@ -73,7 +73,7 @@ private: // Change the loop kind from vectorized to serial for_node.CopyOnWrite()->kind = ForKind::kSerial; // Apply vectorization transformation to the loop - return VectorizeLoop(for_node); + return VectorizeLoop(for_node, analyzer_); } }; diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index 45283d90..e8a18b00 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -45,7 +45,7 @@ struct VectorizePlanResult { PrimExpr condition; }; -class VectorizeFindGlobalAccess : public arith::IRVisitorWithAnalyzer { +class VectorizeFindGlobalAccess : public StmtExprVisitor { public: VectorizeFindGlobalAccess() = default; @@ -60,19 +60,20 @@ private: void VisitStmt_(const BufferStoreNode *node) final { if (node->buffer.scope() == "global") has_global_access_ = true; - return arith::IRVisitorWithAnalyzer::VisitStmt_(node); + return StmtExprVisitor::VisitStmt_(node); } void VisitExpr_(const BufferLoadNode *node) final { if (node->buffer.scope() == "global") has_global_access_ = true; - return arith::IRVisitorWithAnalyzer::VisitExpr_(node); + return StmtExprVisitor::VisitExpr_(node); } }; -class VectorizePlanner : public arith::IRVisitorWithAnalyzer { +class VectorizePlanner : public arith::IRMutatorWithAnalyzer { public: - VectorizePlanner() = default; + explicit VectorizePlanner(arith::Analyzer *analyzer) + : arith::IRMutatorWithAnalyzer(analyzer) {} int Plan(const For &node) { tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current(); @@ -92,21 +93,31 @@ public: } private: - void VisitStmt_(const ForNode *node) final { + Stmt VisitStmt_(const ForNode *node) final { inner_for_ = node; - auto extent_ptr = as_const_int(analyzer_.Simplify(node->extent)); - // Here I disable dynamic shape completely, - // In order to do it, the Planner should accept an analyzer with - // arithmetic info outside to prove the dividiblity of vector size - if (!extent_ptr) { - vector_size_ = 1; - return; + bool contains_nested_for = false; + // Must analysis vectorization on the innermost loop + PostOrderVisit(Downcast(node->body), [&](const ObjectRef &obj) { + if (obj.as()) { + contains_nested_for = true; + } + }); + + if (!contains_nested_for) { + auto extent_ptr = as_const_int(analyzer_->Simplify(node->extent)); + // Here I disable dynamic shape completely, + // In order to do it, the Planner should accept an analyzer with + // arithmetic info outside to prove the dividiblity of vector size + if (!extent_ptr) { + vector_size_ = 1; + return ffi::GetRef(node); + } + vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr); } - vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr); - arith::IRVisitorWithAnalyzer::VisitStmt_(node); + return arith::IRMutatorWithAnalyzer::VisitStmt_(node); } - void VisitExpr_(const BufferLoadNode *node) final { + PrimExpr VisitExpr_(const BufferLoadNode *node) final { if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" || node->buffer.scope() == "shared.dyn") has_nonlocal_memory_access_ = true; @@ -115,43 +126,44 @@ private: // constant buffer that tl hack to use as local register. auto boundary_check = node->buffer->shape[0].as(); if (boundary_check && boundary_check->value == 1) { - return arith::IRVisitorWithAnalyzer::VisitExpr_(node); + return arith::IRMutatorWithAnalyzer::VisitExpr_(node); } } UpdateVectorSize(node->indices, node->buffer); + return arith::IRMutatorWithAnalyzer::VisitExpr_(node); } - void VisitStmt_(const BufferStoreNode *node) final { + Stmt VisitStmt_(const BufferStoreNode *node) final { if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" || node->buffer.scope() == "shared.dyn") has_nonlocal_memory_access_ = true; UpdateVectorSize(node->indices, node->buffer); - return arith::IRVisitorWithAnalyzer::VisitExpr(node->value); + return arith::IRMutatorWithAnalyzer::VisitStmt_(node); } - void VisitStmt_(const IfThenElseNode *node) final { + Stmt VisitStmt_(const IfThenElseNode *node) final { CheckConditionVectorized(node->condition); - return arith::IRVisitorWithAnalyzer::VisitStmt_(node); + return arith::IRMutatorWithAnalyzer::VisitStmt_(node); } - void VisitExpr_(const CallNode *node) final { + PrimExpr VisitExpr_(const CallNode *node) final { if (node->op == builtin::if_then_else()) { CheckConditionVectorized(node->args[0]); } else if (node->op == builtin::call_extern()) { // do not vectorize extern calls vector_size_ = 1; } - return arith::IRVisitorWithAnalyzer::VisitExpr_(node); + return arith::IRMutatorWithAnalyzer::VisitExpr_(node); } void CheckConditionVectorized(const PrimExpr &cond) { // TODO: perform some checks here } - void VisitExpr_(const CastNode *node) final { + PrimExpr VisitExpr_(const CastNode *node) final { vector_size_ = arith::ZeroAwareGCD( vector_load_bits_max_ / node->dtype.bits(), vector_size_); - return arith::IRVisitorWithAnalyzer::VisitExpr_(node); + return arith::IRMutatorWithAnalyzer::VisitExpr_(node); } void UpdateVectorSize(const Array indices, const Buffer &buffer) { @@ -171,19 +183,16 @@ private: for (int i = 0; i < indices.size(); ++i) { elem_offset += indices[i] * strides[i]; } - // 2. If element offset is independent with loop_var, ignore it - if (CanProveIndependent(elem_offset, inner_for_->loop_var, &analyzer_)) { + if (CanProveIndependent(elem_offset, inner_for_->loop_var, analyzer_)) { return; } - // 3. Tight vectorize bound vector_size_ = arith::ZeroAwareGCD(vector_size_, vector_load_bits_max_ / buffer->dtype.bits()); - // 4. Try to vectorize buffer load while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var, - inner_for_->extent, vector_size_, &analyzer_)) { + inner_for_->extent, vector_size_, analyzer_)) { vector_size_ /= 2; } } @@ -235,7 +244,14 @@ private: const int vector_size_; }; -int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); } +int GetVectorizeSize(const For &loop) { + arith::Analyzer analyzer; + return VectorizePlanner(&analyzer).Plan(loop); +} + +int GetVectorizeSize(const For &loop, arith::Analyzer *analyzer) { + return VectorizePlanner(analyzer).Plan(loop); +} bool CanProveIndependent(const PrimExpr &expr, Var var, arith::Analyzer *analyzer) { @@ -274,10 +290,10 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var, if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_size_for_iter), 0)) return false; - + auto simplified_expr = analyzer->Simplify(Substitute(expr, {{var, zero}})); // The base offset must be divisible - if (!analyzer->CanProveEqual( - FloorMod(Substitute(expr, {{var, zero}}), target_size_for_expr), 0)) { + if (!analyzer->CanProveEqual(FloorMod(simplified_expr, target_size_for_expr), + zero)) { return false; } @@ -308,7 +324,20 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var, For VectorizeLoop(const For &loop, int vectorize_hint) { if (vectorize_hint <= 0) { - VectorizePlanner planner; + arith::Analyzer analyzer; + VectorizePlanner planner(&analyzer); + vectorize_hint = planner.Plan(loop); + } + if (vectorize_hint == 1) + return loop; + auto rewriter = VectorizeRewriter(vectorize_hint); + return Downcast(rewriter(loop)); +} + +For VectorizeLoop(const For &loop, arith::Analyzer *analyzer, + int vectorize_hint) { + if (vectorize_hint <= 0) { + VectorizePlanner planner(analyzer); vectorize_hint = planner.Plan(loop); } if (vectorize_hint == 1) diff --git a/src/transform/loop_vectorize.h b/src/transform/loop_vectorize.h index 4ab20c66..a63c4b45 100644 --- a/src/transform/loop_vectorize.h +++ b/src/transform/loop_vectorize.h @@ -35,8 +35,13 @@ using namespace tir; int GetVectorizeSize(const For &loop); +int GetVectorizeSize(const For &loop, arith::Analyzer *analyzer); + For VectorizeLoop(const For &loop, int vectorize_hint = -1); +For VectorizeLoop(const For &loop, arith::Analyzer *analyzer, + int vectorize_hint = -1); + // Can prove expr is independent with var, i.e. the value of expr doesn't change // when var changes bool CanProveIndependent(const PrimExpr &expr, Var var, -- GitLab From ca98cc391790d160cffcb0b997c2380c276b8e2e Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Mon, 24 Nov 2025 16:17:13 +0800 Subject: [PATCH 035/139] Revert "[WIP] support more dtypes for tcgen05 (#1229)" (#1323) This reverts commit 0d101c110f74ebf2ef8c11a5ece9dfb314b48baa. Co-authored-by: Zhiwen Mo --- .../example_tilelang_gemm_fp8_sm100.py | 126 --- src/op/copy.cc | 14 +- src/op/gemm_py.cc | 2 - src/op/tcgen5_meta.h | 38 +- src/tl_templates/cuda/copy_sm100.h | 35 +- src/tl_templates/cuda/gemm_sm100.h | 76 +- src/tl_templates/cuda/tcgen_05_ld.h | 753 +----------------- tilelang/intrinsics/mma_macro_generator.py | 3 - .../intrinsics/tcgen05_macro_generator.py | 9 +- tilelang/jit/adapter/wrapper.py | 1 - tilelang/tileop/gemm/gemm_tcgen05.py | 5 +- 11 files changed, 87 insertions(+), 975 deletions(-) delete mode 100644 examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py deleted file mode 100644 index 4628a997..00000000 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py +++ /dev/null @@ -1,126 +0,0 @@ -import torch -import tilelang -import tilelang.language as T -from tilelang.utils.tensor import map_torch_type - - -def matmul( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - accum_dtype, - num_stages, - threads, -): - A_shape = (K, M) if trans_A else (M, K) - B_shape = (N, K) if trans_B else (K, N) - A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) - B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, in_dtype) - C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) - mbar = T.alloc_barrier(1) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - C_shared = T.alloc_shared((block_M, block_N), out_dtype) - - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[bx * block_N, k * block_K], B_shared) - T.gemm_v2( - A_shared, - B_shared, - C_tmem, - trans_A, - trans_B, - mbar=mbar, - wg_wait=-1, - clear_accum=(k == 0), - ) - T.mbarrier_wait_parity(mbar, k % 2) - - T.copy(C_tmem, C_local) - T.copy(C_local, C_shared) - - T.copy(C_shared, C[by * block_M, bx * block_N]) - - return main - - -def calc_diff(x, y): - x, y = x.double(), y.double() - denominator = (x * x + y * y).sum() - sim = 2 * (x * y).sum() / denominator - return 1 - sim - - -M, N, K = 4096, 4096, 8192 -block_M, block_N, block_K = 64, 256, 32 -trans_A, trans_B = False, True -num_stages = 2 -threads = 256 -for tvm_fp8_dtype in ["float8_e4m3", "float8_e5m2"]: - for tvm_acc_dtype in ["float16", "float32"]: # , torch.float16]: - torch_fp8_dtype = map_torch_type(tvm_fp8_dtype) - torch_acc_dtype = map_torch_type(tvm_acc_dtype) - print(f"running {tvm_fp8_dtype} -> {tvm_acc_dtype}") - in_dtype, out_dtype, accum_dtype = tvm_fp8_dtype, tvm_acc_dtype, tvm_acc_dtype - - func = matmul( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - accum_dtype, - num_stages, - threads, - ) - jit_kernel = tilelang.compile( - func, - out_idx=[2], - target="cuda", - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT: True, - }, - ) - # jit_kernel.export_ptx("./dump.ptx") - # jit_kernel.export_sources("./dump.cu") - - a = torch.randn(M, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype) - b = torch.randn(N, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype) - - c = jit_kernel(a, b) - ref_c = (a.to(torch.half) @ b.T.to(torch.half)).float() - c = c.float() - diff = calc_diff(c, ref_c) - # assert diff < 1e-3, f"{diff}" - print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] diff = {diff}") - - profiler = jit_kernel.get_profiler() - latency = profiler.do_bench() - print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Latency: {latency} ms") - print( - f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS" - ) diff --git a/src/op/copy.cc b/src/op/copy.cc index c2dd06fc..2584abce 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -1117,11 +1117,6 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, bool is_ld = false; // tcgen05.ld (tensor memory -> register) bool is_st = false; // tcgen05.st (register -> tensor memory) bool is_cp = false; // tcgen05.cp (shared memory -> tensor memory) - bool src_needs_pack = - 16 == src->dtype.bits(); // if needs .pack::16b when is_ld - bool dst_needs_unpack = - 16 == dst->dtype.bits(); // if needs .unpack::16b when is_st - if (src.scope() == "shared.tmem" && dst.scope() == "local.fragment") { is_ld = true; } else if (src.scope() == "local.fragment" && dst.scope() == "shared.tmem") { @@ -1129,8 +1124,9 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, } else if (src.scope() == "shared.dyn" && dst.scope() == "shared.tmem") { is_cp = true; } else { - ICHECK(0) << "Unsupported tensor memory copy: " << "src scope = " - << src.scope() << ", dst scope = " << dst.scope(); + ICHECK(0) << "Unsupported tensor memory copy: " + << "src scope = " << src.scope() + << ", dst scope = " << dst.scope(); } // Currently tcgen05.cp is not supported // TODO (mzw) Support tcgen05.cp @@ -1250,10 +1246,8 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, : relative_wg_idx * (num_chunks_each_wg * meta.width); have_succeeded = true; Array args; - const char *bool_str = src_needs_pack ? "true" : "false"; args.push_back(StringImm(meta.intrinsics_name + "<" + - std::to_string(num_chunks_each_wg) + ", " + - bool_str + ">")); + std::to_string(num_chunks_each_wg) + ">")); args.push_back( BufferLoad(src, {(int)logical_row_min, (int)logical_col_min})); // Will be translated later diff --git a/src/op/gemm_py.cc b/src/op/gemm_py.cc index 6097998c..ac506ee0 100644 --- a/src/op/gemm_py.cc +++ b/src/op/gemm_py.cc @@ -428,8 +428,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { result.push_back(Integer(meta.atom_m)); result.push_back(Integer(meta.atom_n)); result.push_back(Integer(meta.atom_k)); - result.push_back(Integer(meta.enable_ws)); - result.push_back(Integer(meta.enable_2cta)); } return result; }); diff --git a/src/op/tcgen5_meta.h b/src/op/tcgen5_meta.h index 350a2bc8..bb63c8dc 100644 --- a/src/op/tcgen5_meta.h +++ b/src/op/tcgen5_meta.h @@ -15,19 +15,16 @@ using runtime::DataType; struct TCGEN5MMAMeta { int atom_m, atom_n, atom_k; - bool enable_ws, enable_2cta; }; inline std::pair GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { // TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA. #define FAIL \ + return { false, TCGEN5MMAMeta{0, 0, 0} } +#define SUCCESS(atom_m, atom_n, atom_k) \ return { \ - false, TCGEN5MMAMeta { 0, 0, 0, false, false } \ - } -#define SUCCESS(atom_m, atom_n, atom_k, use_ws, use_2cta) \ - return { \ - true, TCGEN5MMAMeta { atom_m, atom_n, atom_k, use_ws, use_2cta } \ + true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \ } std::vector ws_valid_atom_ns = {256, 128, 64}; if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) && @@ -37,52 +34,39 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { if (M % 128 == 0) { for (int atom_n = 256; atom_n >= 16; atom_n -= 16) if (N % atom_n == 0) - SUCCESS(128, atom_n, 16, false, false); + SUCCESS(128, atom_n, 16); FAIL; } else if (M % 64 == 0) { for (int atom_n : ws_valid_atom_ns) if (N % atom_n == 0) - SUCCESS(64, atom_n, 16, false, false); + SUCCESS(64, atom_n, 16); FAIL; } else if (M % 32 == 0) { for (int atom_n : ws_valid_atom_ns) if (N % atom_n == 0) - SUCCESS(32, atom_n, 16, false, false); + SUCCESS(32, atom_n, 16); FAIL; } else { FAIL; } - } else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e4m3() || - ab_dtype.is_float8_e5m2() || ab_dtype.is_float8_e5m2fnuz() || - ab_dtype.is_float6_e2m3fn() || ab_dtype.is_float6_e3m2fn() || - ab_dtype.is_float4_e2m1fn()) && - ((c_dtype.is_float() && c_dtype.bits() == 32) || - (c_dtype.is_float16() && c_dtype.bits() == 16))) { + } else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) && + (c_dtype.is_float() && c_dtype.bits() == 32)) { if (K % 32 != 0) FAIL; if (M % 128 == 0) { - for (int atom_n : ws_valid_atom_ns) - if (N % atom_n == 0) - SUCCESS(128, atom_n, 32, true, false); for (int atom_n = 256; atom_n >= 16; atom_n -= 16) if (N % atom_n == 0) - SUCCESS(128, atom_n, 32, false, true); - for (int atom_n = 256; atom_n >= 8; atom_n -= 8) - if (N % atom_n == 0) - SUCCESS(128, atom_n, 32, false, false); + SUCCESS(128, atom_n, 32); FAIL; } else if (M % 64 == 0) { for (int atom_n : ws_valid_atom_ns) if (N % atom_n == 0) - SUCCESS(64, atom_n, 32, true, false); - for (int atom_n = 256; atom_n >= 8; atom_n -= 8) - if (N % atom_n == 0) - SUCCESS(128, atom_n, 32, false, false); + SUCCESS(64, atom_n, 32); FAIL; } else if (M % 32 == 0) { for (int atom_n : ws_valid_atom_ns) if (N % atom_n == 0) - SUCCESS(32, atom_n, 32, true, false); + SUCCESS(32, atom_n, 32); FAIL; } else { FAIL; diff --git a/src/tl_templates/cuda/copy_sm100.h b/src/tl_templates/cuda/copy_sm100.h index aa898bcc..c4047c34 100644 --- a/src/tl_templates/cuda/copy_sm100.h +++ b/src/tl_templates/cuda/copy_sm100.h @@ -51,21 +51,6 @@ __device__ __forceinline__ void st_global_256(fp8_e4_32_t *ptr, : : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w)); } -__device__ __forceinline__ ulonglong4 ld_global_256(const fp8_e5_32_t *ptr) { - ulonglong4 ret; - asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];" - : "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w) - : "l"(ptr)); - return ret; -} - -__device__ __forceinline__ void st_global_256(fp8_e5_32_t *ptr, - fp8_e5_32_t &val8) { - ulonglong4 &val = *((ulonglong4 *)&val8); - asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};" - : - : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w)); -} __device__ __forceinline__ unsigned long long pack_bfloat16x4(const bfloat16_t x, const bfloat16_t y, const bfloat16_t z, @@ -110,38 +95,38 @@ __device__ __forceinline__ void tcgen05_ld_core(uint32_t const &tmem_start_col, } } -template +template __device__ __forceinline__ void tcgen05_ld_32dp32bNx(uint32_t const &tmem_start_col, uint32_t const &tmem_col_offset, dst_t *dst_ptr) { - tcgen05_ld_core, 7, N>( - tmem_start_col + tmem_col_offset, dst_ptr); + tcgen05_ld_core(tmem_start_col + tmem_col_offset, + dst_ptr); tl::fence_view_async_tmem_load(); } -template +template __device__ __forceinline__ void tcgen05_ld_32dp64bNx(uint32_t const &tmem_start_col, uint32_t const &tmem_col_offset, dst_t *dst_ptr) { - tcgen05_ld_core, 7, N>( - tmem_start_col + tmem_col_offset, dst_ptr); + tcgen05_ld_core(tmem_start_col + tmem_col_offset, + dst_ptr); tl::fence_view_async_tmem_load(); } -template +template __device__ __forceinline__ void tcgen05_ld_32dp128bNx(uint32_t const &tmem_start_col, uint32_t const &tmem_col_offset, dst_t *dst_ptr) { - tcgen05_ld_core, 6, N>( + tcgen05_ld_core( tmem_start_col + tmem_col_offset, dst_ptr); tl::fence_view_async_tmem_load(); } -template +template __device__ __forceinline__ void tcgen05_ld_32dp256bNx(uint32_t const &tmem_start_col, uint32_t const &tmem_col_offset, dst_t *dst_ptr) { - tcgen05_ld_core, 5, N>( + tcgen05_ld_core( tmem_start_col + tmem_col_offset, dst_ptr); tl::fence_view_async_tmem_load(); } diff --git a/src/tl_templates/cuda/gemm_sm100.h b/src/tl_templates/cuda/gemm_sm100.h index 6c68c2c2..856d37dd 100644 --- a/src/tl_templates/cuda/gemm_sm100.h +++ b/src/tl_templates/cuda/gemm_sm100.h @@ -243,96 +243,46 @@ struct DispatchInstruction -struct DispatchInstruction> { - using MMA = - MMA_Traits, Int, integral_constant, - integral_constant, - integral_constant, - integral_constant>; -}; -template -struct DispatchInstruction> { - using MMA = MMA_Traits, Int, - integral_constant, + using MMA = MMA_Traits, + Int, integral_constant, integral_constant, integral_constant, integral_constant>; }; template -struct DispatchInstruction> { +struct DispatchInstruction> { using MMA = - MMA_Traits, Int, integral_constant, + MMA_Traits, + Int, integral_constant, integral_constant, integral_constant, integral_constant>; }; -template -struct DispatchInstruction> { - using MMA = MMA_Traits, Int, - integral_constant, - integral_constant, - integral_constant, - integral_constant>; -}; template -struct DispatchInstruction> { - using MMA = - MMA_Traits, Int, integral_constant, - integral_constant, - integral_constant, - integral_constant>; -}; -template -struct DispatchInstruction> { - using MMA = MMA_Traits, Int, - integral_constant, + using MMA = MMA_Traits, + Int, integral_constant, integral_constant, integral_constant, integral_constant>; }; template -struct DispatchInstruction> { using MMA = - MMA_Traits, Int, integral_constant, + MMA_Traits, + Int, integral_constant, integral_constant, integral_constant, integral_constant>; }; -template -struct DispatchInstruction> { - using MMA = MMA_Traits, Int, - integral_constant, - integral_constant, - integral_constant, - integral_constant>; -}; template class tmem_ld_32dp32bNx; - -template <> class tmem_ld_32dp32bNx { +class tmem_ld_32dp32bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { @@ -182,180 +180,9 @@ public: } } }; -template <> class tmem_ld_32dp32bNx { -public: - template - static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { - static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, - "N must be a power of 2 and lies between 1 ~ 128"); - - if constexpr (N == 1) { - asm volatile("tcgen05.ld.sync.aligned.32x32b.pack::16b.x1.b32" - "{%0}," - "[%1];\n" - : "=r"(dst_ptr[0]) - : "r"(src_addr)); - } else if constexpr (N == 2) { - asm volatile("tcgen05.ld.sync.aligned.32x32b.pack::16b.x2.b32" - "{%0, %1}," - "[%2];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) - : "r"(src_addr)); - } else if constexpr (N == 4) { - asm volatile("tcgen05.ld.sync.aligned.32x32b.pack::16b.x4.b32" - "{%0, %1, %2, %3}," - "[%4];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]) - : "r"(src_addr)); - } else if constexpr (N == 8) { - asm volatile("tcgen05.ld.sync.aligned.32x32b.pack::16b.x8.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "[%8];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) - : "r"(src_addr)); - } else if constexpr (N == 16) { - asm volatile( - "tcgen05.ld.sync.aligned.32x32b.pack::16b.x16.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " - "%14, %15}," - "[%16];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]) - : "r"(src_addr)); - } else if constexpr (N == 32) { - asm volatile( - "tcgen05.ld.sync.aligned.32x32b.pack::16b.x32.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " - "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " - "%26, %27, %28, %29, %30, %31}," - "[%32];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), - "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), - "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), - "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), - "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), - "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) - : "r"(src_addr)); - } else if constexpr (N == 64) { - asm volatile( - "tcgen05.ld.sync.aligned.32x32b.pack::16b.x64.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " - "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " - "%28, " - "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " - "%42, " - "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " - "%56, " - "%57, %58, %59, %60, %61, %62, %63}," - "[%64];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), - "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), - "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), - "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), - "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), - "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), - "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), - "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), - "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), - "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), - "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), - "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), - "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), - "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), - "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), - "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), - "=r"(dst_ptr[63]) - : "r"(src_addr)); - } else if constexpr (N == 128) { - asm volatile( - "tcgen05.ld.sync.aligned.32x32b.pack::16b.x128.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " - "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " - "%28, " - "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " - "%42, " - "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " - "%56, " - "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " - "%70, " - "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " - "%84, " - "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " - "%98, " - "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " - "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " - "%121, %122, %123, %124, %125, %126, %127}," - "[%128];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), - "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), - "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), - "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), - "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), - "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), - "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), - "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), - "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), - "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), - "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), - "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), - "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), - "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), - "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), - "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), - "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), - "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), - "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), - "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), - "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), - "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), - "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), - "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), - "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), - "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), - "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), - "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), - "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), - "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), - "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), - "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), - "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), - "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), - "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), - "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), - "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), - "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) - : "r"(src_addr)); - } else { - asm volatile("trap"); - } - } -}; // 16 data path lanes, 64-bit pattern, repeated N times -template class tmem_ld_16dp64bNx; -template <> class tmem_ld_16dp64bNx { +class tmem_ld_16dp64bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { @@ -525,43 +352,39 @@ public: } } }; -template <> class tmem_ld_16dp64bNx { + +// 16 data path lanes, 128-bit pattern, repeated N times +class tmem_ld_16dp128bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { - static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, - "N must be a power of 2 and lies between 1 ~ 128"); + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 64, + "N must be a power of 2 and lies between 1 ~ 64"); if constexpr (N == 1) { - asm volatile("tcgen05.ld.sync.aligned.16x64b.pack::16b.x1.b32" - "{%0}," - "[%1];\n" - : "=r"(dst_ptr[0]) - : "r"(src_addr)); - } else if constexpr (N == 2) { - asm volatile("tcgen05.ld.sync.aligned.16x64b.pack::16b.x2.b32" + asm volatile("tcgen05.ld.sync.aligned.16x128b.x1.b32" "{%0, %1}," "[%2];\n" : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) : "r"(src_addr)); - } else if constexpr (N == 4) { - asm volatile("tcgen05.ld.sync.aligned.16x64b.pack::16b.x4.b32" + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.x2.b32" "{%0, %1, %2, %3}," "[%4];\n" : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), "=r"(dst_ptr[3]) : "r"(src_addr)); - } else if constexpr (N == 8) { - asm volatile("tcgen05.ld.sync.aligned.16x64b.pack::16b.x8.b32" + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.x4.b32" "{%0, %1, %2, %3, %4, %5, %6, %7}," "[%8];\n" : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) : "r"(src_addr)); - } else if constexpr (N == 16) { + } else if constexpr (N == 8) { asm volatile( - "tcgen05.ld.sync.aligned.16x64b.pack::16b.x16.b32" + "tcgen05.ld.sync.aligned.16x128b.x8.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " "%14, %15}," "[%16];\n" @@ -572,9 +395,9 @@ public: "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), "=r"(dst_ptr[15]) : "r"(src_addr)); - } else if constexpr (N == 32) { + } else if constexpr (N == 16) { asm volatile( - "tcgen05.ld.sync.aligned.16x64b.pack::16b.x32.b32" + "tcgen05.ld.sync.aligned.16x128b.x16.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " "%26, %27, %28, %29, %30, %31}," @@ -591,9 +414,9 @@ public: "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) : "r"(src_addr)); - } else if constexpr (N == 64) { + } else if constexpr (N == 32) { asm volatile( - "tcgen05.ld.sync.aligned.16x64b.pack::16b.x64.b32" + "tcgen05.ld.sync.aligned.16x128b.x32.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " "%28, " @@ -626,9 +449,9 @@ public: "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), "=r"(dst_ptr[63]) : "r"(src_addr)); - } else if constexpr (N == 128) { + } else if constexpr (N == 64) { asm volatile( - "tcgen05.ld.sync.aligned.16x64b.pack::16b.x128.b32" + "tcgen05.ld.sync.aligned.16x128b.x64.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " "%28, " @@ -696,39 +519,32 @@ public: } }; -// 16 data path lanes, 128-bit pattern, repeated N times -template class tmem_ld_16dp128bNx; -template <> class tmem_ld_16dp128bNx { +// 16 data path lanes, 256-bit pattern, repeated N times +class tmem_ld_16dp256bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { - static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 64, - "N must be a power of 2 and lies between 1 ~ 64"); + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 32, + "N must be a power of 2 and lies between 1 ~ 32"); if constexpr (N == 1) { - asm volatile("tcgen05.ld.sync.aligned.16x128b.x1.b32" - "{%0, %1}," - "[%2];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) - : "r"(src_addr)); - } else if constexpr (N == 2) { - asm volatile("tcgen05.ld.sync.aligned.16x128b.x2.b32" + asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32" "{%0, %1, %2, %3}," "[%4];\n" : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), "=r"(dst_ptr[3]) : "r"(src_addr)); - } else if constexpr (N == 4) { - asm volatile("tcgen05.ld.sync.aligned.16x128b.x4.b32" + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.x2.b32" "{%0, %1, %2, %3, %4, %5, %6, %7}," "[%8];\n" : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) : "r"(src_addr)); - } else if constexpr (N == 8) { + } else if constexpr (N == 4) { asm volatile( - "tcgen05.ld.sync.aligned.16x128b.x8.b32" + "tcgen05.ld.sync.aligned.16x256b.x4.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " "%14, %15}," "[%16];\n" @@ -739,9 +555,9 @@ public: "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), "=r"(dst_ptr[15]) : "r"(src_addr)); - } else if constexpr (N == 16) { + } else if constexpr (N == 8) { asm volatile( - "tcgen05.ld.sync.aligned.16x128b.x16.b32" + "tcgen05.ld.sync.aligned.16x256b.x8.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " "%26, %27, %28, %29, %30, %31}," @@ -758,9 +574,9 @@ public: "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) : "r"(src_addr)); - } else if constexpr (N == 32) { + } else if constexpr (N == 16) { asm volatile( - "tcgen05.ld.sync.aligned.16x128b.x32.b32" + "tcgen05.ld.sync.aligned.16x256b.x16.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " "%28, " @@ -793,332 +609,7 @@ public: "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), "=r"(dst_ptr[63]) : "r"(src_addr)); - } else if constexpr (N == 64) { - asm volatile( - "tcgen05.ld.sync.aligned.16x128b.x64.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " - "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " - "%28, " - "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " - "%42, " - "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " - "%56, " - "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " - "%70, " - "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " - "%84, " - "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " - "%98, " - "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " - "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " - "%121, %122, %123, %124, %125, %126, %127}," - "[%128];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), - "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), - "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), - "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), - "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), - "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), - "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), - "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), - "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), - "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), - "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), - "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), - "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), - "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), - "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), - "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), - "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), - "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), - "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), - "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), - "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), - "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), - "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), - "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), - "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), - "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), - "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), - "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), - "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), - "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), - "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), - "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), - "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), - "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), - "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), - "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), - "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), - "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) - : "r"(src_addr)); - } else { - asm volatile("trap"); - } - } -}; -template <> class tmem_ld_16dp128bNx { -public: - template - static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { - static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 64, - "N must be a power of 2 and lies between 1 ~ 64"); - - if constexpr (N == 1) { - asm volatile("tcgen05.ld.sync.aligned.16x128b.pack::16b.x1.b32" - "{%0, %1}," - "[%2];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) - : "r"(src_addr)); - } else if constexpr (N == 2) { - asm volatile("tcgen05.ld.sync.aligned.16x128b.pack::16b.x2.b32" - "{%0, %1, %2, %3}," - "[%4];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]) - : "r"(src_addr)); - } else if constexpr (N == 4) { - asm volatile("tcgen05.ld.sync.aligned.16x128b.pack::16b.x4.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "[%8];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) - : "r"(src_addr)); - } else if constexpr (N == 8) { - asm volatile( - "tcgen05.ld.sync.aligned.16x128b.pack::16b.x8.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " - "%14, %15}," - "[%16];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]) - : "r"(src_addr)); - } else if constexpr (N == 16) { - asm volatile( - "tcgen05.ld.sync.aligned.16x128b.pack::16b.x16.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " - "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " - "%26, %27, %28, %29, %30, %31}," - "[%32];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), - "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), - "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), - "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), - "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), - "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) - : "r"(src_addr)); - } else if constexpr (N == 32) { - asm volatile( - "tcgen05.ld.sync.aligned.16x128b.pack::16b.x32.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " - "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " - "%28, " - "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " - "%42, " - "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " - "%56, " - "%57, %58, %59, %60, %61, %62, %63}," - "[%64];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), - "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), - "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), - "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), - "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), - "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), - "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), - "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), - "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), - "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), - "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), - "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), - "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), - "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), - "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), - "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), - "=r"(dst_ptr[63]) - : "r"(src_addr)); - } else if constexpr (N == 64) { - asm volatile( - "tcgen05.ld.sync.aligned.16x128b.pack::16b.x64.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " - "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " - "%28, " - "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " - "%42, " - "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " - "%56, " - "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " - "%70, " - "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " - "%84, " - "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " - "%98, " - "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " - "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " - "%121, %122, %123, %124, %125, %126, %127}," - "[%128];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), - "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), - "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), - "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), - "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), - "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), - "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), - "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), - "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), - "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), - "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), - "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), - "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), - "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), - "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), - "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), - "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), - "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), - "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), - "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), - "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), - "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), - "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), - "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), - "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), - "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), - "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), - "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), - "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), - "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), - "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), - "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), - "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), - "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), - "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), - "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), - "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), - "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) - : "r"(src_addr)); - } else { - asm volatile("trap"); - } - } -}; - -// 16 data path lanes, 256-bit pattern, repeated N times -template class tmem_ld_16dp256bNx; -template <> class tmem_ld_16dp256bNx { -public: - template - static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { - static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 32, - "N must be a power of 2 and lies between 1 ~ 32"); - - if constexpr (N == 1) { - asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32" - "{%0, %1, %2, %3}," - "[%4];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]) - : "r"(src_addr)); - } else if constexpr (N == 2) { - asm volatile("tcgen05.ld.sync.aligned.16x256b.x2.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "[%8];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) - : "r"(src_addr)); - } else if constexpr (N == 4) { - asm volatile( - "tcgen05.ld.sync.aligned.16x256b.x4.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " - "%14, %15}," - "[%16];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]) - : "r"(src_addr)); - } else if constexpr (N == 8) { - asm volatile( - "tcgen05.ld.sync.aligned.16x256b.x8.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " - "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " - "%26, %27, %28, %29, %30, %31}," - "[%32];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), - "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), - "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), - "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), - "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), - "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) - : "r"(src_addr)); - } else if constexpr (N == 16) { - asm volatile( - "tcgen05.ld.sync.aligned.16x256b.x16.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " - "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " - "%28, " - "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " - "%42, " - "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " - "%56, " - "%57, %58, %59, %60, %61, %62, %63}," - "[%64];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), - "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), - "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), - "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), - "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), - "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), - "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), - "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), - "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), - "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), - "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), - "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), - "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), - "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), - "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), - "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), - "=r"(dst_ptr[63]) - : "r"(src_addr)); - } else if constexpr (N == 32) { + } else if constexpr (N == 32) { asm volatile( "tcgen05.ld.sync.aligned.16x256b.x32.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " @@ -1187,193 +678,35 @@ public: } } }; -template <> class tmem_ld_16dp256bNx { -public: - template - static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { - static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 32, - "N must be a power of 2 and lies between 1 ~ 32"); - - if constexpr (N == 1) { - asm volatile("tcgen05.ld.sync.aligned.16x256b.pack::16b.x1.b32" - "{%0, %1, %2, %3}," - "[%4];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]) - : "r"(src_addr)); - } else if constexpr (N == 2) { - asm volatile("tcgen05.ld.sync.aligned.16x256b.pack::16b.x2.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "[%8];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) - : "r"(src_addr)); - } else if constexpr (N == 4) { - asm volatile( - "tcgen05.ld.sync.aligned.16x256b.pack::16b.x4.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " - "%14, %15}," - "[%16];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]) - : "r"(src_addr)); - } else if constexpr (N == 8) { - asm volatile( - "tcgen05.ld.sync.aligned.16x256b.pack::16b.x8.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " - "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " - "%26, %27, %28, %29, %30, %31}," - "[%32];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), - "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), - "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), - "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), - "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), - "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) - : "r"(src_addr)); - } else if constexpr (N == 16) { - asm volatile( - "tcgen05.ld.sync.aligned.16x256b.pack::16b.x16.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " - "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " - "%28, " - "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " - "%42, " - "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " - "%56, " - "%57, %58, %59, %60, %61, %62, %63}," - "[%64];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), - "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), - "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), - "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), - "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), - "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), - "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), - "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), - "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), - "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), - "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), - "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), - "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), - "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), - "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), - "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), - "=r"(dst_ptr[63]) - : "r"(src_addr)); - } else if constexpr (N == 32) { - asm volatile( - "tcgen05.ld.sync.aligned.16x256b.pack::16b.x32.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " - "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " - "%28, " - "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " - "%42, " - "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " - "%56, " - "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " - "%70, " - "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " - "%84, " - "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " - "%98, " - "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " - "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " - "%121, %122, %123, %124, %125, %126, %127}," - "[%128];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), - "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), - "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), - "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), - "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), - "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), - "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), - "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), - "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), - "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), - "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), - "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), - "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), - "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), - "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), - "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), - "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), - "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), - "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), - "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), - "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), - "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), - "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), - "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), - "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), - "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), - "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), - "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), - "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), - "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), - "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), - "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), - "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), - "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), - "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), - "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), - "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), - "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) - : "r"(src_addr)); - } else { - asm volatile("trap"); - } - } -}; // 32 data path lanes, 64-bit pattern, repeated N times // (conducted with 2x16dp64bNx) -template class tmem_ld_32dp64bNx { +class tmem_ld_32dp64bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { - tmem_ld_16dp64bNx::copy(src_addr, dst_ptr); - tmem_ld_16dp64bNx::copy(src_addr + (16 << 16), dst_ptr + N); + tmem_ld_16dp64bNx::copy(src_addr, dst_ptr); + tmem_ld_16dp64bNx::copy(src_addr + (16 << 16), dst_ptr + N); } }; // 32 data path lanes, 128-bit pattern, repeated N times -template class tmem_ld_32dp128bNx { +class tmem_ld_32dp128bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { - tmem_ld_16dp128bNx::copy(src_addr, dst_ptr); - tmem_ld_16dp128bNx::copy(src_addr + (16 << 16), dst_ptr + N * 2); + tmem_ld_16dp128bNx::copy(src_addr, dst_ptr); + tmem_ld_16dp128bNx::copy(src_addr + (16 << 16), dst_ptr + N * 2); } }; // 32 data path lanes, 256-bit pattern, repeated N times -template class tmem_ld_32dp256bNx { +class tmem_ld_32dp256bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { - tmem_ld_16dp256bNx::copy(src_addr, dst_ptr); - tmem_ld_16dp256bNx::copy(src_addr + (16 << 16), dst_ptr + N * 4); + tmem_ld_16dp256bNx::copy(src_addr, dst_ptr); + tmem_ld_16dp256bNx::copy(src_addr + (16 << 16), dst_ptr + N * 4); } }; diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index bbfeb157..8c546c63 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -45,10 +45,7 @@ class TensorCoreIntrinEmitter: "int8": "int8", "int32": "int32", "float8_e4m3": "e4m3", - "float8_e4m3fn": "e4m3", - "float8_e4m3fnuz": "e4m3", "float8_e5m2": "e5m2", - "float8_e5m2fnuz": "e5m2", } # Represent the thread binding in the form of (tx, warp_n, warp_m) diff --git a/tilelang/intrinsics/tcgen05_macro_generator.py b/tilelang/intrinsics/tcgen05_macro_generator.py index 966f4dc4..e53ff7cb 100644 --- a/tilelang/intrinsics/tcgen05_macro_generator.py +++ b/tilelang/intrinsics/tcgen05_macro_generator.py @@ -169,11 +169,12 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): accum_dtype_in_bits = DataType(accum_dtype).bits meta = self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim) - if len(meta) != 5: + if len(meta) != 3: raise ValueError( f"Unsupported TCGEN5MMA configuration for desc generation: M={m_dim}, N={n_dim}, " f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}") - atom_m, atom_n, atom_k, enable_ws, enable_2cta = (int(x) for x in meta) + atom_m, atom_n, atom_k = (int(x) for x in meta) + enable_ws = atom_m != 128 # by default, we utilize non-swizzle layout offset a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * @@ -381,10 +382,10 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): k = int(self.chunk) meta = self.get_tcgen5_mma_meta(m, n, k) - if len(meta) != 5: + if len(meta) != 3: raise ValueError(f"Unsupported TCGEN5MMA configuration: M={m}, N={n}, K={k}, " f"A dtype={self.a_dtype}, accum dtype={self.accum_dtype}") - atom_m, atom_n, _, _, _ = (int(x) for x in meta) + atom_m, atom_n, _ = (int(x) for x in meta) if m % atom_m != 0 or n % atom_n != 0: raise ValueError( diff --git a/tilelang/jit/adapter/wrapper.py b/tilelang/jit/adapter/wrapper.py index 75607976..48b8e908 100644 --- a/tilelang/jit/adapter/wrapper.py +++ b/tilelang/jit/adapter/wrapper.py @@ -144,7 +144,6 @@ class TLCUDASourceWrapper: "float16": "half_t", "bfloat16": "bfloat16_t", "float8_e4m3": "fp8_e4_t", - "float8_e4m3fn": "fp8_e4_t", "float8_e5m2": "fp8_e5_t", "float64": "double", "int64": "int64_t", diff --git a/tilelang/tileop/gemm/gemm_tcgen05.py b/tilelang/tileop/gemm/gemm_tcgen05.py index 1de9fe87..52c192e5 100644 --- a/tilelang/tileop/gemm/gemm_tcgen05.py +++ b/tilelang/tileop/gemm/gemm_tcgen05.py @@ -85,9 +85,6 @@ class GemmTCGEN5(GemmBase): raise ValueError(f"TCGEN5MMA currently only supports gemm_ss, got " f"A scope {self.A.scope()}, B scope {self.B.scope()}") - atom_m, atom_n, atom_k, enable_ws, enable_2cta = mma_emitter.get_tcgen5_mma_meta( - self.M, self.N, self.K) - if self.A.scope() not in {"shared", "shared.dyn", "shared.tmem"}: raise ValueError(f"Unsupported A scope for TCGEN5MMA: {self.A.scope()}") if self.B.scope() not in {"shared", "shared.dyn"}: @@ -106,7 +103,7 @@ class GemmTCGEN5(GemmBase): raise ValueError("TCGEN5MMA expects 2D coordinates for C buffer access") accum_dtype = str(self.C.dtype) - if accum_dtype not in ["float32", 'float16']: + if accum_dtype != "float32": raise ValueError(f"Unsupported accumulator dtype for TCGEN5MMA: {accum_dtype}") A_shared = self.ARegion -- GitLab From fddcbbd665d2fc8eed0f629fbcb2521798068d66 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 24 Nov 2025 16:48:45 +0800 Subject: [PATCH 036/139] [CI]: Bump actions/checkout from 5 to 6 (#1319) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/ci.yml | 4 ++-- .github/workflows/dist.yml | 4 ++-- .github/workflows/pr-perfbench-bot.yml | 2 +- .github/workflows/publish-docs.yml | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f9fe3286..c33a25b6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -40,7 +40,7 @@ jobs: timeout-minutes: 30 steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 submodules: recursive @@ -104,7 +104,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 submodules: recursive diff --git a/.github/workflows/dist.yml b/.github/workflows/dist.yml index 0ba3fbc3..ed63914c 100644 --- a/.github/workflows/dist.yml +++ b/.github/workflows/dist.yml @@ -52,7 +52,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 1 submodules: recursive @@ -122,7 +122,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 1 submodules: recursive diff --git a/.github/workflows/pr-perfbench-bot.yml b/.github/workflows/pr-perfbench-bot.yml index 37da4e3c..e6954bcc 100644 --- a/.github/workflows/pr-perfbench-bot.yml +++ b/.github/workflows/pr-perfbench-bot.yml @@ -33,7 +33,7 @@ jobs: runs-on: [self-hosted, nvidia] steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: ref: refs/pull/${{ github.event.issue.number }}/merge fetch-depth: 0 diff --git a/.github/workflows/publish-docs.yml b/.github/workflows/publish-docs.yml index 95330310..2197015b 100644 --- a/.github/workflows/publish-docs.yml +++ b/.github/workflows/publish-docs.yml @@ -25,7 +25,7 @@ jobs: runs-on: [self-hosted, nvidia] steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 submodules: recursive -- GitLab From 2a70fd3f9e93dee4e776a9891377340d8170cc5e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 24 Nov 2025 16:49:18 +0800 Subject: [PATCH 037/139] [CI]: Bump pypa/cibuildwheel from 3.2 to 3.3 (#1318) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/dist.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/dist.yml b/.github/workflows/dist.yml index ed63914c..ff230af4 100644 --- a/.github/workflows/dist.yml +++ b/.github/workflows/dist.yml @@ -160,7 +160,7 @@ jobs: fi - name: Build wheels - uses: pypa/cibuildwheel@v3.2 + uses: pypa/cibuildwheel@v3.3 with: package-dir: . output-dir: wheelhouse -- GitLab From 01d207fa1494a5c46b2cc44d0682ce0544271418 Mon Sep 17 00:00:00 2001 From: Chaofan Lin Date: Mon, 24 Nov 2025 18:32:00 +0800 Subject: [PATCH 038/139] [Installation] Fix building using customized TVM path (#1326) --- cmake/load_tvm.cmake | 5 ++++- docs/get_started/Installation.md | 9 +++++---- tilelang/env.py | 6 +++--- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/cmake/load_tvm.cmake b/cmake/load_tvm.cmake index f013c3ba..cb21be95 100644 --- a/cmake/load_tvm.cmake +++ b/cmake/load_tvm.cmake @@ -3,12 +3,15 @@ set(TVM_BUILD_FROM_SOURCE TRUE) set(TVM_SOURCE ${CMAKE_SOURCE_DIR}/3rdparty/tvm) -if(DEFINED $ENV{TVM_ROOT}) +if(DEFINED ENV{TVM_ROOT}) if(EXISTS $ENV{TVM_ROOT}/cmake/config.cmake) set(TVM_SOURCE $ENV{TVM_ROOT}) + message(STATUS "Using TVM_ROOT from environment variable: ${TVM_SOURCE}") endif() endif() +message(STATUS "Using TVM source: ${TVM_SOURCE}") + set(TVM_INCLUDES ${TVM_SOURCE}/include ${TVM_SOURCE}/src diff --git a/docs/get_started/Installation.md b/docs/get_started/Installation.md index be0d794e..585a0029 100644 --- a/docs/get_started/Installation.md +++ b/docs/get_started/Installation.md @@ -93,14 +93,16 @@ Some useful CMake options you can toggle while configuring: (using-existing-tvm)= -### Building with Existing TVM Installation +### Building with Customized TVM Path -If you already have a compatible TVM installation, use the `TVM_ROOT` environment variable to specify the location of your existing TVM repository when building tilelang: +If you already have a TVM codebase, use the `TVM_ROOT` environment variable to specify the location of your existing TVM repository when building tilelang: ```bash TVM_ROOT= pip install . -v ``` +> **Note**: This will still rebuild the TVM-related libraries (stored in `TL_LIBS`). And this method often leads to some path issues. Check `env.py` to see some environment variables which are not set properly. + (install-using-docker)= ## Install Using Docker @@ -197,8 +199,7 @@ Set `NO_TOOLCHAIN_VERSION=ON` to disable this. ### Run-time environment variables - -TODO +Please refer to the `env.py` file for a full list of supported run-time environment variables. ## Other Tips diff --git a/tilelang/env.py b/tilelang/env.py index b98bbf98..39d9e722 100644 --- a/tilelang/env.py +++ b/tilelang/env.py @@ -314,9 +314,9 @@ else: if tvm_path not in sys.path: prepend_pythonpath(tvm_path) env.TVM_IMPORT_PYTHON_PATH = tvm_path - - if os.environ.get("TVM_LIBRARY_PATH") is None: - os.environ['TVM_LIBRARY_PATH'] = env.TVM_LIBRARY_PATH = os.pathsep.join(TL_LIBS) +# By default, the built TVM-related libraries are stored in TL_LIBS. +if os.environ.get("TVM_LIBRARY_PATH") is None: + os.environ['TVM_LIBRARY_PATH'] = env.TVM_LIBRARY_PATH = os.pathsep.join(TL_LIBS) # Initialize CUTLASS paths if os.environ.get("TL_CUTLASS_PATH", None) is None: -- GitLab From 6c2162a9fdcd1e754faea9944da033c3199b08c1 Mon Sep 17 00:00:00 2001 From: Yichen Yan Date: Mon, 24 Nov 2025 19:07:51 +0800 Subject: [PATCH 039/139] [Release] Allow developer with write permission to trigger wheel release (#1322) --- .github/workflows/dist.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/dist.yml b/.github/workflows/dist.yml index ff230af4..73c08936 100644 --- a/.github/workflows/dist.yml +++ b/.github/workflows/dist.yml @@ -1,5 +1,6 @@ name: Dist on: + workflow_dispatch: schedule: # gemini said this is 6:00 china time - cron: "0 22 * * *" -- GitLab From caa6dd3f02885960a75f299f73a94f67e0817477 Mon Sep 17 00:00:00 2001 From: Tong WU <109033598+Rachmanino@users.noreply.github.com> Date: Mon, 24 Nov 2025 19:38:14 +0800 Subject: [PATCH 040/139] [Feat] Support warp reduce (#1316) * [Feat] Support warp reduce * lint * add test * lint --- src/op/builtin.cc | 25 ++++++ src/op/builtin.h | 25 ++++++ src/target/codegen_cuda.cc | 10 +++ src/tl_templates/cuda/reduce.h | 31 +++++++ .../test_tilelang_language_warp_reduce.py | 83 +++++++++++++++++++ tilelang/language/__init__.py | 5 ++ tilelang/language/reduce.py | 80 ++++++++++++++++++ 7 files changed, 259 insertions(+) create mode 100644 testing/python/language/test_tilelang_language_warp_reduce.py diff --git a/src/op/builtin.cc b/src/op/builtin.cc index e7e86f2f..ced86cfa 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -341,5 +341,30 @@ TIR_DEFINE_TL_BUILTIN(tcgen05_mma_arrive) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(warp_reduce_sum) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(warp_reduce_max) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(warp_reduce_min) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(warp_reduce_bitand) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(warp_reduce_bitor) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + } // namespace tl } // namespace tvm diff --git a/src/op/builtin.h b/src/op/builtin.h index f5c7d9ed..7ae638f1 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -571,6 +571,31 @@ TVM_DLL const Op &device_assert(); */ TVM_DLL const Op &device_assert_with_msg(); +/*! + * \brief tilelang intrinsic for warp reduction sum. + */ +TVM_DLL const Op &warp_reduce_sum(); + +/*! + * \brief tilelang intrinsic for warp reduction max. + */ +TVM_DLL const Op &warp_reduce_max(); + +/*! + * \brief tilelang intrinsic for warp reduction min. + */ +TVM_DLL const Op &warp_reduce_min(); + +/*! + * \brief tilelang intrinsic for warp reduction bitand. + */ +TVM_DLL const Op &warp_reduce_bitand(); + +/*! + * \brief tilelang intrinsic for warp reduction bitor. + */ +TVM_DLL const Op &warp_reduce_bitor(); + } // namespace tl } // namespace tvm diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index dda96925..99512b8b 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -2609,6 +2609,16 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { std::string func_name = math_func(op->dtype, "fdiv", rounding_mode); os << func_name << "(" << PrintExpr(op->args[0]) << ", " << PrintExpr(op->args[1]) << ")"; + } else if (op->op.same_as(tl::warp_reduce_sum())) { + os << "tl::warp_reduce_sum(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::warp_reduce_max())) { + os << "tl::warp_reduce_max(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::warp_reduce_min())) { + os << "tl::warp_reduce_min(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::warp_reduce_bitand())) { + os << "tl::warp_reduce_bitand(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::warp_reduce_bitor())) { + os << "tl::warp_reduce_bitor(" << PrintExpr(op->args[0]) << ")"; } else { CodeGenC::VisitExpr_(op, os); } diff --git a/src/tl_templates/cuda/reduce.h b/src/tl_templates/cuda/reduce.h index a083c711..45824264 100644 --- a/src/tl_templates/cuda/reduce.h +++ b/src/tl_templates/cuda/reduce.h @@ -250,4 +250,35 @@ template struct CumSum2D { } }; +template +TL_DEVICE T warp_reduce(T value, ReduceOp op) { + constexpr uint32_t mask = 0xffffffff; + value = op(value, __shfl_xor_sync(mask, value, 16)); + value = op(value, __shfl_xor_sync(mask, value, 8)); + value = op(value, __shfl_xor_sync(mask, value, 4)); + value = op(value, __shfl_xor_sync(mask, value, 2)); + value = op(value, __shfl_xor_sync(mask, value, 1)); + return value; +} + +template TL_DEVICE T warp_reduce_sum(T value) { + return warp_reduce(value, SumOp()); +} + +template TL_DEVICE T warp_reduce_max(T value) { + return warp_reduce(value, MaxOp()); +} + +template TL_DEVICE T warp_reduce_min(T value) { + return warp_reduce(value, MinOp()); +} + +template TL_DEVICE T warp_reduce_bitand(T value) { + return warp_reduce(value, BitAndOp()); +} + +template TL_DEVICE T warp_reduce_bitor(T value) { + return warp_reduce(value, BitOrOp()); +} + } // namespace tl diff --git a/testing/python/language/test_tilelang_language_warp_reduce.py b/testing/python/language/test_tilelang_language_warp_reduce.py new file mode 100644 index 00000000..681b2347 --- /dev/null +++ b/testing/python/language/test_tilelang_language_warp_reduce.py @@ -0,0 +1,83 @@ +import torch + +import tilelang +import tilelang.testing +import tilelang.language as T + + +@tilelang.jit +def get_kernel(reduce_op: str, dtype: str): + + assert reduce_op in ["sum", "max", "min", "bitand", "bitor"] + + @T.prim_func + def main(x: T.Tensor((32), dtype)): + with T.Kernel(1, threads=32): + tx = T.get_thread_binding(0) + local_val = T.alloc_local([1], dtype) + local_val[0] = x[tx] + reduced_val = T.alloc_local([1], dtype) + if reduce_op == "sum": + reduced_val[0] = T.warp_reduce_sum(local_val[0]) + elif reduce_op == "max": + reduced_val[0] = T.warp_reduce_max(local_val[0]) + elif reduce_op == "min": + reduced_val[0] = T.warp_reduce_min(local_val[0]) + elif reduce_op == "bitand": + reduced_val[0] = T.warp_reduce_bitand(local_val[0]) + elif reduce_op == "bitor": + reduced_val[0] = T.warp_reduce_bitor(local_val[0]) + x[tx] = reduced_val[0] + + return main + + +def test_warp_reduce_sum(): + a = torch.randn((32,), dtype=torch.float32, device='cuda') + kernel = get_kernel('sum', 'float32') + ref = torch.full_like(a, a.sum()) + kernel(a) + torch.testing.assert_close(a, ref) + + +def test_warp_reduce_max(): + a = torch.randn((32,), dtype=torch.float32, device='cuda') + kernel = get_kernel("max", 'float32') + print(kernel.get_kernel_source()) + ref = torch.full_like(a, a.max()) + kernel(a) + torch.testing.assert_close(a, ref) + + +def test_warp_reduce_min(): + a = torch.randn((32,), dtype=torch.float32, device='cuda') + kernel = get_kernel("min", 'float32') + ref = torch.full_like(a, a.min()) + kernel(a) + torch.testing.assert_close(a, ref) + + +def test_warp_reduce_bitand(): + a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device='cuda') + kernel = get_kernel("bitand", 'int32') + ref_val = a[0] + for i in range(1, a.shape[0]): + ref_val = ref_val & a[i] + ref = torch.full_like(a, ref_val) + kernel(a) + torch.testing.assert_close(a, ref) + + +def test_warp_reduce_bitor(): + a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device='cuda') + kernel = get_kernel("bitor", 'int32') + ref_val = a[0] + for i in range(1, a.shape[0]): + ref_val = ref_val | a[i] + ref = torch.full_like(a, ref_val) + kernel(a) + torch.testing.assert_close(a, ref) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 95488bdf..75d8d0b4 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -65,6 +65,11 @@ from .reduce import ( reduce_bitxor, # noqa: F401 cumsum, # noqa: F401 finalize_reducer, # noqa: F401 + warp_reduce_sum, # noqa: F401 + warp_reduce_max, # noqa: F401 + warp_reduce_min, # noqa: F401 + warp_reduce_bitand, # noqa: F401 + warp_reduce_bitor, # noqa: F401 ) from .print import print, device_assert # noqa: F401 from .customize import ( diff --git a/tilelang/language/reduce.py b/tilelang/language/reduce.py index 09289559..23bb6d05 100644 --- a/tilelang/language/reduce.py +++ b/tilelang/language/reduce.py @@ -325,3 +325,83 @@ def finalize_reducer(reducer: tir.Buffer): tir.op.Op.get("tl.finalize_reducer"), reducer.access_ptr("w"), ) + + +def warp_reduce_sum(value: tir.PrimExpr): + """Perform warp reduction sum on a register value. + + This function reduces a value across all threads in a warp using shuffle operations. + Each thread provides a register `value`, and after the reduction, all threads + will have the sum of all values across the warp. + + Args: + value (tir.PrimExpr): The input register value to reduce + + Returns: + tir.PrimExpr: The reduced sum value (same on all threads in the warp) + """ + return tir.call_intrin(value.dtype, tir.op.Op.get("tl.warp_reduce_sum"), value) + + +def warp_reduce_max(value: tir.PrimExpr): + """Perform warp reduction max on a register value. + + This function reduces a value across all threads in a warp using shuffle operations. + Each thread provides a register `value`, and after the reduction, all threads + will have the max of all values across the warp. + + Args: + value (tir.PrimExpr): The input register value to reduce + + Returns: + tir.PrimExpr: The reduced max value (same on all threads in the warp) + """ + return tir.call_intrin(value.dtype, tir.op.Op.get("tl.warp_reduce_max"), value) + + +def warp_reduce_min(value: tir.PrimExpr): + """Perform warp reduction min on a register value. + + This function reduces a value across all threads in a warp using shuffle operations. + Each thread provides a register `value`, and after the reduction, all threads + will have the min of all values across the warp. + + Args: + value (tir.PrimExpr): The input register value to reduce + + Returns: + tir.PrimExpr: The reduced min value (same on all threads in the warp) + """ + return tir.call_intrin(value.dtype, tir.op.Op.get("tl.warp_reduce_min"), value) + + +def warp_reduce_bitand(value: tir.PrimExpr): + """Perform warp reduction bitwise-and on a register value. + + This function reduces a value across all threads in a warp using shuffle operations. + Each thread provides a register `value`, and after the reduction, all threads + will have the bitwise-and of all values across the warp. + + Args: + value (tir.PrimExpr): The input register value to reduce + + Returns: + tir.PrimExpr: The reduced bitwise-and value (same on all threads in the warp) + """ + return tir.call_intrin(value.dtype, tir.op.Op.get("tl.warp_reduce_bitand"), value) + + +def warp_reduce_bitor(value: tir.PrimExpr): + """Perform warp reduction bitwise-or on a register value. + + This function reduces a value across all threads in a warp using shuffle operations. + Each thread provides a register `value`, and after the reduction, all threads + will have the bitwise-or of all values across the warp. + + Args: + value (tir.PrimExpr): The input register value to reduce + + Returns: + tir.PrimExpr: The reduced bitwise-or value (same on all threads in the warp) + """ + return tir.call_intrin(value.dtype, tir.op.Op.get("tl.warp_reduce_bitor"), value) -- GitLab From c30df2a1c58bc6296e2a6027b4ebacf9f1b82202 Mon Sep 17 00:00:00 2001 From: Wenhao Xie Date: Tue, 25 Nov 2025 01:08:35 +0800 Subject: [PATCH 041/139] [Enhancement] Support more dtype in `T.print` (#1329) * [Enhancement] Support more dtype in `T.print` * upd * upd --- src/tl_templates/cuda/debug.h | 353 +++++------------- .../python/debug/test_tilelang_debug_print.py | 21 +- 2 files changed, 107 insertions(+), 267 deletions(-) diff --git a/src/tl_templates/cuda/debug.h b/src/tl_templates/cuda/debug.h index 2724a814..020cb1f1 100644 --- a/src/tl_templates/cuda/debug.h +++ b/src/tl_templates/cuda/debug.h @@ -5,282 +5,107 @@ #endif #include "common.h" - #ifndef __CUDACC_RTC__ +#include #include #endif -// Template declaration for device-side debug printing (variable only) -template __device__ void debug_print_var(const char *msg, T var); - -// Overload for pointer type (supports any cv-qualified T*) -template __device__ void debug_print_var(const char *msg, T *var) { - printf( - "msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=pointer " - "value=%p\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, var); -} - -// Specialization for signed char type -template <> -__device__ void debug_print_var(const char *msg, signed char var) { - printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=signed " - "char " - "value=%d\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, var); -} - -// Specialization for plain char type -template <> __device__ void debug_print_var(const char *msg, char var) { - printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=char " - "value=%d\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, (int)var); -} - -// Specialization for unsigned char type -template <> -__device__ void debug_print_var(const char *msg, - unsigned char var) { - printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): " - "dtype=unsigned char " - "value=%d\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, var); -} - -// Specialization for integer type -template <> __device__ void debug_print_var(const char *msg, int var) { - printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=int " - "value=%d\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, var); -} - -// Specialization for unsigned integer type -template <> -__device__ void debug_print_var(const char *msg, - unsigned int var) { - printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=int " - "value=%u\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, var); -} - -// Specialization for bool type -template <> __device__ void debug_print_var(const char *msg, bool var) { - printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=bool " - "value=%s\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, var ? "true" : "false"); -} - -// Specialization for float type -template <> __device__ void debug_print_var(const char *msg, float var) { - printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=float " - "value=%f\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, var); -} - -// Specialization for half type -template <> __device__ void debug_print_var(const char *msg, half var) { - printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=half " - "value=%f\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, (float)var); -} - -// Specialization for half_t type -template <> -__device__ void debug_print_var(const char *msg, half_t var) { - printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=half_t " - "value=%f\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, (float)var); -} +template struct PrintTraits { + static __device__ void print_var(const char *msg, T val) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): " + "dtype=unknown value=%p\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, (const void *)&val); + } -// Specialization for bfloat16_t type -template <> -__device__ void debug_print_var(const char *msg, bfloat16_t var) { - printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): " - "dtype=bfloat16_t value=%f\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, (float)var); -} + static __device__ void print_buffer(const char *msg, const char *buf_name, + int index, T val) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=unknown value=%p\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, buf_name, index, (const void *)&val); + } +}; + +#define DEFINE_PRINT_TRAIT(TYPE, NAME, FORMAT, CAST_TYPE) \ + template <> struct PrintTraits { \ + static __device__ void print_var(const char *msg, TYPE val) { \ + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): " \ + "dtype=" NAME " value=" FORMAT "\n", \ + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, \ + threadIdx.y, threadIdx.z, (CAST_TYPE)val); \ + } \ + static __device__ void print_buffer(const char *msg, const char *buf_name, \ + int index, TYPE val) { \ + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): " \ + "buffer=%s, index=%d, dtype=" NAME " value=" FORMAT "\n", \ + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, \ + threadIdx.y, threadIdx.z, buf_name, index, (CAST_TYPE)val); \ + } \ + } -// Specialization for double type -template <> -__device__ void debug_print_var(const char *msg, double var) { - printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=double " - "value=%lf\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, var); -} +DEFINE_PRINT_TRAIT(char, "char", "%d", int); +DEFINE_PRINT_TRAIT(signed char, "signed char", "%d", int); +DEFINE_PRINT_TRAIT(unsigned char, "unsigned char", "%u", unsigned int); +DEFINE_PRINT_TRAIT(short, "short", "%d", int); +DEFINE_PRINT_TRAIT(unsigned short, "unsigned short", "%u", unsigned int); +DEFINE_PRINT_TRAIT(int, "int", "%d", int); +DEFINE_PRINT_TRAIT(unsigned int, "uint", "%u", unsigned int); +DEFINE_PRINT_TRAIT(long, "long", "%ld", long); +DEFINE_PRINT_TRAIT(unsigned long, "ulong", "%lu", unsigned long); +DEFINE_PRINT_TRAIT(long long, "long long", "%lld", long long); + +DEFINE_PRINT_TRAIT(float, "float", "%f", float); +DEFINE_PRINT_TRAIT(double, "double", "%lf", double); +DEFINE_PRINT_TRAIT(half, "half", "%f", float); +DEFINE_PRINT_TRAIT(half_t, "half_t", "%f", float); +DEFINE_PRINT_TRAIT(bfloat16_t, "bfloat16_t", "%f", float); #if __CUDA_ARCH_LIST__ >= 890 -// Specialization for fp8_e4_t type -template <> -__device__ void debug_print_var(const char *msg, fp8_e4_t var) { - printf( - "msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=fp8_e4_t " - "value=%f\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, (float)var); -} - -// Specialization for fp8_e5_t type -template <> -__device__ void debug_print_var(const char *msg, fp8_e5_t var) { - printf( - "msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=fp8_e5_t " - "value=%f\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, (float)var); -} - +DEFINE_PRINT_TRAIT(fp8_e4_t, "fp8_e4_t", "%f", float); +DEFINE_PRINT_TRAIT(fp8_e5_t, "fp8_e5_t", "%f", float); #endif -// Template declaration for device-side debug printing (buffer only) -template -__device__ void debug_print_buffer_value(const char *msg, const char *buf_name, - int index, T var); - -// Specialization for signed char type -template <> -__device__ void -debug_print_buffer_value(const char *msg, const char *buf_name, - int index, signed char var) { - printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " - "index=%d, dtype=signed char value=%d\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, buf_name, index, var); -} - -// Specialization for unsigned char type -template <> -__device__ void -debug_print_buffer_value(const char *msg, const char *buf_name, - int index, unsigned char var) { - printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " - "index=%d, dtype=char value=%d\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, buf_name, index, var); -} - -// Specialization for integer type -template <> -__device__ void debug_print_buffer_value(const char *msg, - const char *buf_name, int index, - int var) { - printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " - "index=%d, dtype=int value=%d\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, buf_name, index, var); -} - -// Specialization for unsigned integer type -template <> -__device__ void -debug_print_buffer_value(const char *msg, const char *buf_name, - int index, unsigned int var) { - printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " - "index=%d, dtype=int value=%u\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, buf_name, index, var); -} - -// Specialization for float type -template <> -__device__ void debug_print_buffer_value(const char *msg, - const char *buf_name, int index, - float var) { - printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " - "index=%d, dtype=float value=%f\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, buf_name, index, var); -} - -// Specialization for half type -template <> -__device__ void debug_print_buffer_value(const char *msg, - const char *buf_name, int index, - half var) { - printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " - "index=%d, dtype=half value=%f\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, buf_name, index, (float)var); -} - -// Specialization for half_t type -template <> -__device__ void debug_print_buffer_value(const char *msg, - const char *buf_name, - int index, half_t var) { - printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " - "index=%d, dtype=half_t value=%f\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, buf_name, index, (float)var); -} - -// Specialization for bfloat16_t type -template <> -__device__ void -debug_print_buffer_value(const char *msg, const char *buf_name, - int index, bfloat16_t var) { - printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " - "index=%d, dtype=bfloat16_t value=%f\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, buf_name, index, (float)var); -} - -// Specialization for double type -template <> -__device__ void debug_print_buffer_value(const char *msg, - const char *buf_name, - int index, double var) { - printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " - "index=%d, dtype=double value=%lf\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, buf_name, index, var); -} - -// Specialization for fp8_e4_t type -#if __CUDA_ARCH_LIST__ >= 890 -template <> -__device__ void debug_print_buffer_value(const char *msg, - const char *buf_name, - int index, fp8_e4_t var) { - printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " - "index=%d, dtype=fp8_e4_t value=%f\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, buf_name, index, (float)var); -} +template <> struct PrintTraits { + static __device__ void print_var(const char *msg, bool val) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=bool " + "value=%s\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, val ? "true" : "false"); + } + static __device__ void print_buffer(const char *msg, const char *buf_name, + int index, bool val) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=bool value=%s\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, buf_name, index, val ? "true" : "false"); + } +}; + +template struct PrintTraits { + static __device__ void print_var(const char *msg, T *val) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): " + "dtype=pointer value=%p\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, (void *)val); + } + static __device__ void print_buffer(const char *msg, const char *buf_name, + int index, T *val) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=pointer value=%p\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, buf_name, index, (void *)val); + } +}; -// Specialization for fp8_e5_t type -template <> -__device__ void debug_print_buffer_value(const char *msg, - const char *buf_name, - int index, fp8_e5_t var) { - printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " - "index=%d, dtype=fp8_e5_t value=%f\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, buf_name, index, (float)var); +template __device__ void debug_print_var(const char *msg, T var) { + PrintTraits::print_var(msg, var); } -#endif - -// Specialization for int16 type -template <> -__device__ void debug_print_buffer_value(const char *msg, - const char *buf_name, - int index, int16_t var) { - printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " - "index=%d, dtype=int16_t value=%d\n", - msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, - threadIdx.z, buf_name, index, (int32_t)var); +template +__device__ void debug_print_buffer_value(const char *msg, const char *buf_name, + int index, T var) { + PrintTraits::print_buffer(msg, buf_name, index, var); } TL_DEVICE void device_assert(bool cond) { assert(cond); } @@ -290,4 +115,4 @@ TL_DEVICE void device_assert_with_msg(bool cond, const char *msg) { printf("Device assert failed: %s\n", msg); assert(0); } -} +} \ No newline at end of file diff --git a/testing/python/debug/test_tilelang_debug_print.py b/testing/python/debug/test_tilelang_debug_print.py index fcfae4ed..a1aa42ed 100644 --- a/testing/python/debug/test_tilelang_debug_print.py +++ b/testing/python/debug/test_tilelang_debug_print.py @@ -19,9 +19,24 @@ def debug_print_buffer(M=16, N=16, dtype="float16"): def test_debug_print_buffer(): - debug_print_buffer(16, 16, dtype="float") - debug_print_buffer(16, 16, dtype="float16") - debug_print_buffer(16, 16, dtype="uint8") + debug_print_buffer(dtype='bool') + debug_print_buffer(dtype='int8') + debug_print_buffer(dtype='int16') + debug_print_buffer(dtype='int32') + debug_print_buffer(dtype='int64') + debug_print_buffer(dtype='uint8') + debug_print_buffer(dtype='uint16') + debug_print_buffer(dtype='uint32') + debug_print_buffer(dtype='uint64') + debug_print_buffer(dtype='float16') + debug_print_buffer(dtype='float32') + debug_print_buffer(dtype='float64') + debug_print_buffer(dtype='bfloat16') + debug_print_buffer(dtype='float8_e4m3') + debug_print_buffer(dtype='float8_e4m3fn') + debug_print_buffer(dtype='float8_e4m3fnuz') + debug_print_buffer(dtype='float8_e5m2') + debug_print_buffer(dtype='float8_e5m2fnuz') def debug_print_buffer_conditional(M=16, N=16): -- GitLab From 9dda774affbc13bbb142d5f59c91a6cb8aa88d39 Mon Sep 17 00:00:00 2001 From: Chaofan Lin Date: Tue, 25 Nov 2025 01:36:17 +0800 Subject: [PATCH 042/139] [BugFix] Use BufferRegion in tl.cumsum to infer buffer shape (#1321) * [BugFix] Use BufferRegion in tl.cumsum to infer buffer shape * remove debug lines * remove rubbish * Fix decorator syntax for atomic_different_memory_orders_program --------- Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> --- src/op/reduce.cc | 91 +++++++++++++++++-- src/op/reduce.h | 8 +- .../python/issue/test_tilelang_issue_1001.py | 33 +++++++ .../test_tilelang_language_atomic_add.py | 2 +- tilelang/analysis/__init__.py | 1 + tilelang/analysis/ast_printer.py | 23 +++++ tilelang/engine/phase.py | 3 + tilelang/language/reduce.py | 8 +- 8 files changed, 155 insertions(+), 14 deletions(-) create mode 100644 testing/python/issue/test_tilelang_issue_1001.py create mode 100644 tilelang/analysis/ast_printer.py diff --git a/src/op/reduce.cc b/src/op/reduce.cc index 05dad48f..b6dbe865 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -16,6 +16,7 @@ #include "../transform/loop_partition.h" #include "region.h" #include "tir/transforms/ir_utils.h" +#include "tvm/tir/stmt.h" namespace tvm { namespace tl { @@ -57,12 +58,65 @@ static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg, RegionOp region(call->args, vmap); return BufferRegion(region->GetBuffer(), region->GetRanges()); } + // builtin.tvm_access_ptr(...) — map var to Buffer and take full region + if (call->op.same_as(builtin::tvm_access_ptr())) { + Var var = Downcast(call->args[1]); + Buffer buf = vmap[var]; + Array ranges; + for (PrimExpr extent : buf->shape) { + ranges.push_back(Range(IntImm(extent->dtype, 0), extent)); + } + return BufferRegion(buf, ranges); + } } LOG(FATAL) << "Unsupported argument for BufferRegion in reduce: " << arg; throw; // Unreachable } +// Build a tvm_access_ptr(handle) to the start of the 2D tile within a +// BufferRegion. Offset is computed from all but the last two dimensions; extent +// is the product of the last two extents. rw_mask: 1=read, 2=write, +// 3=readwrite. +static PrimExpr MakeAccessPtrFromRegion(const BufferRegion ®ion, + int rw_mask) { + Buffer buf = region->buffer; + int ndim = static_cast(buf->shape.size()); + ICHECK(ndim == 1 || ndim == 2) << "Cumsum expects buffers with 1 or 2 dims"; + + PrimExpr offset, extent; + if (ndim == 1) { + // Simple 1D region: offset and extent come from the single axis. + auto axis = region->region[0]; + offset = axis->min; + extent = axis->extent; + } else { + // Compute row-major strides for ndim >= 2 + std::vector strides(ndim); + PrimExpr one = make_const(buf->shape[0].dtype(), 1); + PrimExpr cur = one; + for (int i = ndim - 1; i >= 0; --i) { + strides[i] = cur; + cur = cur * buf->shape[i]; + } + // Offset: sum_{i in [0..ndim-3]} min_i * stride_i + offset = make_const(buf->shape[0].dtype(), 0); + for (int i = 0; i < ndim - 2; ++i) { + offset = offset + region->region[i]->min * strides[i]; + } + + // Extent: last two extents product (elements) + extent = + region->region[ndim - 2]->extent * region->region[ndim - 1]->extent; + } + + // ptype and return handle + PrimExpr ptype = tir::TypeAnnotation(buf->dtype); + Array acc_args{ptype, buf->data, offset, extent, + IntImm(DataType::Int(32), rw_mask)}; + return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args); +} + ReduceOp::ReduceOp(Array args, BufferMap vmap) { ObjectPtr node = tvm::ffi::make_object(); // Accept BufferRegion/BufferLoad/tl.region for src/dst @@ -231,6 +285,7 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { auto dst_scope = this->dst.scope(); if (src_scope == "local.fragment" && dst_scope == "local.fragment") { + Buffer src_buffer = get_buffer(this->src); Buffer dst_buffer = get_buffer(this->dst); Fragment src_layout = T.layout_map[this->src].as().value(); @@ -518,6 +573,16 @@ TIR_REGISTER_TL_OP(ReduceOp, reduce) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +// Normalize "Buffer" to BufferRegion. Use the shape of the buffer as the +// ranges. +static BufferRegion ConvertBufferToBufferRegion(const Buffer &buf) { + Array ranges; + for (PrimExpr extent : buf->shape) { + ranges.push_back(Range(IntImm(extent->dtype, 0), extent)); + } + return BufferRegion(buf, ranges); +} + CumSumOp::CumSumOp(Array args, BufferMap vmap) { /// CumSum constructor arguments: /// - src: input buffer @@ -526,11 +591,19 @@ CumSumOp::CumSumOp(Array args, BufferMap vmap) { /// - reverse: whether to cumsum in reverse order CHECK_EQ(args.size(), 4); ObjectPtr node = tvm::ffi::make_object(); - node->src = vmap[GetVarFromAccessPtr(args[0])]; - node->dst = vmap[GetVarFromAccessPtr(args[1])]; + // node->src = vmap[GetVarFromAccessPtr(args[0])]; + // node->dst = vmap[GetVarFromAccessPtr(args[1])]; + node->srcRegion_ = NormalizeToBufferRegion(args[0], vmap); + node->dstRegion_ = NormalizeToBufferRegion(args[1], vmap); + node->src = node->srcRegion_->buffer; + node->dst = node->dstRegion_->buffer; node->dim = args[2].as().value()->value; node->reverse = args[3].as().value(); - CHECK_LT(node->dim, static_cast(node->src->shape.size())); + CHECK_LT(node->dim, static_cast(node->src->shape.size())) + << "The dim of cumsum should be less than the number of dimensions. Got " + "dim=" + << node->dim << ", but src has " << node->src->shape.size() << " dims."; + data_ = std::move(node); } @@ -546,18 +619,22 @@ Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { auto threads = T.thread_bounds->extent; Array args; int ndim = static_cast(src->shape.size()); + + // Build access pointers from regions locally + PrimExpr srcPtr = MakeAccessPtrFromRegion(srcRegion_, 1); + PrimExpr dstPtr = MakeAccessPtrFromRegion(dstRegion_, 2); + if (ndim == 1) { ICHECK_EQ(dim, 0) << "Cumulative sum over a 1D buffer only supports dim " "= 0."; ss << "tl::CumSum1D<" << threads << ", " << (reverse ? "true" : "false") << ">::run"; - args = {StringImm(ss.str()), src.access_ptr(1), dst.access_ptr(3), - src->shape[0]}; + args = {StringImm(ss.str()), srcPtr, dstPtr, src->shape[0]}; } else if (ndim == 2) { ss << "tl::CumSum2D<" << threads << ", " << dim << ", " << (reverse ? "true" : "false") << ">::run"; - args = {StringImm(ss.str()), src.access_ptr(1), dst.access_ptr(3), - src->shape[0], src->shape[1]}; + args = {StringImm(ss.str()), srcPtr, dstPtr, src->shape[0], + src->shape[1]}; } else { LOG(FATAL) << "CumSum currently supports only 1D or 2D buffers, got " << ndim << "D."; diff --git a/src/op/reduce.h b/src/op/reduce.h index 3b124a4d..eb0599eb 100644 --- a/src/op/reduce.h +++ b/src/op/reduce.h @@ -133,8 +133,10 @@ public: class CumSumOpNode : public TileOperatorNode { public: tir::Buffer src, dst; ///< Source and destination buffers - int dim; ///< Dimension along which to compute cumulative sum - bool reverse; ///< Whether to compute in reverse order + // Optional: keep the original regions used to construct this op + BufferRegion srcRegion_, dstRegion_; + int dim; ///< Dimension along which to compute cumulative sum + bool reverse; ///< Whether to compute in reverse order TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.CumSumOp", CumSumOpNode, TileOperatorNode); @@ -143,6 +145,8 @@ public: refl::ObjectDef() .def_ro("src", &CumSumOpNode::src) .def_ro("dst", &CumSumOpNode::dst) + .def_ro("srcRegion", &CumSumOpNode::srcRegion_) + .def_ro("dstRegion", &CumSumOpNode::dstRegion_) .def_ro("dim", &CumSumOpNode::dim) .def_ro("reverse", &CumSumOpNode::reverse); } diff --git a/testing/python/issue/test_tilelang_issue_1001.py b/testing/python/issue/test_tilelang_issue_1001.py new file mode 100644 index 00000000..77d8cc1f --- /dev/null +++ b/testing/python/issue/test_tilelang_issue_1001.py @@ -0,0 +1,33 @@ +import torch +import tilelang +import tilelang.testing +from tilelang import language as T + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + },) +def _cumsum_view_infer_layout(hidden): + num_tokens = T.dynamic('num_tokens') + + @T.prim_func + def buggy_kernel(x: T.Tensor[(num_tokens, hidden), 'float']): + with T.Kernel(num_tokens, threads=128) as pid: + smem = T.alloc_shared((hidden,), dtype='float') + T.copy(x[pid, :], smem) + T.cumsum(T.view(smem, (1, hidden)), dim=1) + + return buggy_kernel + + +def test_cumsum_view_infer_layout(): + hidden = 128 + x = torch.randn(1, hidden, device='cuda', dtype=torch.float) + kernel = _cumsum_view_infer_layout(hidden) + kernel(x) + + +if __name__ == '__main__': + tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_language_atomic_add.py b/testing/python/language/test_tilelang_language_atomic_add.py index 2472c20f..b157966a 100644 --- a/testing/python/language/test_tilelang_language_atomic_add.py +++ b/testing/python/language/test_tilelang_language_atomic_add.py @@ -260,7 +260,7 @@ def test_atomic_addx2(): run_atomic_addx2(32, 64, 8, 16) -@tilelang.jit(debug_root_path="./testing/python/language") +@tilelang.jit def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype="float"): @T.prim_func diff --git a/tilelang/analysis/__init__.py b/tilelang/analysis/__init__.py index b72fc2ba..6e5ee5d6 100644 --- a/tilelang/analysis/__init__.py +++ b/tilelang/analysis/__init__.py @@ -1,3 +1,4 @@ """Tilelang IR analysis & visitors.""" +from .ast_printer import ASTPrinter # noqa: F401 from .nested_loop_checker import NestedLoopChecker # noqa: F401 diff --git a/tilelang/analysis/ast_printer.py b/tilelang/analysis/ast_printer.py new file mode 100644 index 00000000..c54ec5cf --- /dev/null +++ b/tilelang/analysis/ast_printer.py @@ -0,0 +1,23 @@ +from tvm import tir +from tvm.tir import PrimFunc +from tvm.tir.transform import prim_func_pass +from tvm.tir.stmt_functor import ir_transform + + +def ASTPrinter(): + """ + Print the AST of a given tilelang module for debugging. + """ + + def pre_visit(statement: tir.Stmt) -> None: + """ + Pre-order visitor to print all visited statements. + """ + + print(f"Visiting statement: {type(statement)}") + + def pass_fn(func: PrimFunc, mod, ctx) -> PrimFunc: + new_body = ir_transform(func.body, pre_visit, None) + return func.with_body(new_body) + + return prim_func_pass(pass_fn, opt_level=0) diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 35c16a43..f686ba1f 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -74,6 +74,9 @@ def PreLowerSemanticCheck(mod: IRModule) -> None: Note: This is a validation-only pipeline of passes and does not modify or return the module. """ + # Debug + # tilelang.analysis.ASTPrinter()(mod) + # Check if there are any invalid nested loops. tilelang.analysis.NestedLoopChecker()(mod) diff --git a/tilelang/language/reduce.py b/tilelang/language/reduce.py index 23bb6d05..9d84e0b2 100644 --- a/tilelang/language/reduce.py +++ b/tilelang/language/reduce.py @@ -246,8 +246,8 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) - tir.call_intrin( "handle", tir.op.Op.get("tl.cumsum"), - cumsum_smem.access_ptr("r"), - cumsum_smem.access_ptr("w"), + buffer_to_tile_region(cumsum_smem, "r"), + buffer_to_tile_region(cumsum_smem, "w"), dim, reverse, ) @@ -300,8 +300,8 @@ def cumsum(src: tir.Buffer, dst: tir.Buffer | None = None, dim: int = 0, reverse return tir.call_intrin( "handle", tir.op.Op.get("tl.cumsum"), - src.access_ptr("r"), - dst.access_ptr("w"), + buffer_to_tile_region(src, "r"), + buffer_to_tile_region(dst, "w"), dim, reverse, ) -- GitLab From b02068546bd4f83beb3adea8771e91caa5022b35 Mon Sep 17 00:00:00 2001 From: Kuris <227995639+kurisu6912@users.noreply.github.com> Date: Tue, 25 Nov 2025 11:25:04 +0800 Subject: [PATCH 043/139] [Fix] fix wrong uint narrowing bug in tvm in #1310 (#1320) --- 3rdparty/tvm | 2 +- tilelang/language/allocate.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index cd2b2b60..3354ada7 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit cd2b2b6013d155b5822300b0a0740fa65320dd9e +Subproject commit 3354ada79dd428e383102020814fa9c37638e752 diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index f0784e86..da1ca837 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -22,6 +22,7 @@ from tvm.tir import PrimExpr from tvm.script.parser.tir import block_attr from tvm.tir.buffer import Buffer from tvm.tir.expr import FloatImm, IntImm +from .v2.dtypes import dtype as tl_dtype def alloc_shared(shape, dtype, scope="shared.dyn"): @@ -135,7 +136,7 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None): buffer = T.alloc_buffer([1], dtype, scope=parsed_scope) if parsed_init is not None: if isinstance(parsed_init, (int, float, IntImm, FloatImm)): - block_attr({"tl.local_var_init": {buffer.data: parsed_init}}) + block_attr({"tl.local_var_init": {buffer.data: tl_dtype(dtype)(parsed_init)}}) else: T.buffer_store(buffer, parsed_init, 0) return buffer -- GitLab From 71b73e185aa2b72f3fabdae7382f9b0451034389 Mon Sep 17 00:00:00 2001 From: Kuris <227995639+kurisu6912@users.noreply.github.com> Date: Tue, 25 Nov 2025 12:32:48 +0800 Subject: [PATCH 044/139] [Refactor] Disable strided buffer load inside tvm (#1301) (#1332) --- 3rdparty/tvm | 2 +- .../test_tilelang_language_frontend_v2.py | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 3354ada7..e3af4000 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 3354ada79dd428e383102020814fa9c37638e752 +Subproject commit e3af400013551755a8df668ba77b530735931ade diff --git a/testing/python/language/test_tilelang_language_frontend_v2.py b/testing/python/language/test_tilelang_language_frontend_v2.py index 349f3caf..299a4127 100644 --- a/testing/python/language/test_tilelang_language_frontend_v2.py +++ b/testing/python/language/test_tilelang_language_frontend_v2.py @@ -427,7 +427,7 @@ def test_var_macro(): pass -def frame_inside_macro(): +def test_frame_inside_macro(): @tilelang.jit def get_sample_kernel(): @@ -453,5 +453,18 @@ def frame_inside_macro(): kernel = get_sample_kernel() # noqa: F841 +def test_buffer_slice_step(): + try: + + @T.prim_func + def prim_buffer_slice_step(A: T.Buffer((10,), T.int32), B: T.Buffer((5,), T.int32)): + with T.Kernel(1): + B[0:5:2] = A[0:10:2] + + raise AssertionError("Expect to report an error, buffer slice with step is not supported") + except RuntimeError: + pass + + if __name__ == '__main__': tilelang.testing.main() -- GitLab From 2f34840fc40ee74c9ab8f3b019983398e5610315 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Tue, 25 Nov 2025 12:35:08 +0800 Subject: [PATCH 045/139] [Refactor] Moving `NormalizeToBufferRegion` and `MakeAccessPtrFromRegion` to utils (#1333) * Refactor GEMM and Reduce operations by moving NormalizeToBufferRegion and MakeAccessPtrFromRegion to utils.{h,cc} for better code organization and reuse. * lint fix --- src/op/gemm.cc | 97 ++++-------------------------------------- src/op/gemm_py.cc | 88 ++------------------------------------ src/op/reduce.cc | 95 ++--------------------------------------- src/op/utils.cc | 105 ++++++++++++++++++++++++++++++++++++++++++++++ src/op/utils.h | 35 ++++++++++++++++ 5 files changed, 155 insertions(+), 265 deletions(-) create mode 100644 src/op/utils.cc create mode 100644 src/op/utils.h diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 48e6cdf6..cece1e6f 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -14,6 +14,7 @@ #include "../target/utils.h" #include "region.h" #include "tcgen5_meta.h" +#include "utils.h" namespace tvm { namespace tl { @@ -48,92 +49,9 @@ using namespace tir; * fails with an ICHECK (runtime assertion). No other validation is * performed here. */ -// Normalize a GEMM argument (BufferRegion/BufferLoad/tvm_access_ptr/tl.region) -// to BufferRegion -static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg, - const BufferMap &vmap) { - // Case 1: Already a BufferRegion - if (arg->IsInstance()) { - return Downcast(arg); - } - - // Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else - // extent=1) - if (const auto *load = arg.as()) { - Array ranges; - for (const PrimExpr &index : load->indices) { - if (const auto *ramp = index.as()) { - ICHECK(ramp->stride.as()) << "Ramp stride must be IntImm"; - ICHECK_EQ(ramp->stride.as()->value, 1) - << "Only stride-1 Ramp is supported in GEMM region conversion"; - ICHECK(ramp->lanes.as()) - << "Scalable vector lanes not supported in GEMM region conversion"; - ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes)); - } else { - ranges.push_back(Range::FromMinExtent(index, 1)); - } - } - return BufferRegion(load->buffer, ranges); - } +// NormalizeToBufferRegion moved to src/op/utils.{h,cc} - // Case 3: Call nodes - if (const auto *call = arg.as()) { - // tl.region(...) — reconstruct via RegionOp - if (call->op.same_as(RegionOp::Get())) { - RegionOp region(call->args, vmap); - return BufferRegion(region->GetBuffer(), region->GetRanges()); - } - // builtin.tvm_access_ptr(...) — map var to Buffer and take full region - if (call->op.same_as(builtin::tvm_access_ptr())) { - Var var = Downcast(call->args[1]); - Buffer buf = vmap[var]; - Array ranges; - for (PrimExpr extent : buf->shape) { - ranges.push_back(Range(IntImm(extent->dtype, 0), extent)); - } - return BufferRegion(buf, ranges); - } - } - - LOG(FATAL) << "Unsupported GEMM argument for BufferRegion: " << arg; - throw; // Unreachable, keeps compiler happy -} - -// Build a tvm_access_ptr(handle) to the start of the 2D tile within a -// BufferRegion. Offset is computed from all but the last two dimensions; extent -// is the product of the last two extents. rw_mask: 1=read, 2=write, -// 3=readwrite. -static PrimExpr MakeAccessPtrFromRegion(const BufferRegion ®ion, - int rw_mask) { - Buffer buf = region->buffer; - int ndim = static_cast(buf->shape.size()); - ICHECK(ndim >= 2) << "GEMM expects buffers with at least 2 dims"; - - // Compute row-major strides - std::vector strides(ndim); - PrimExpr one = make_const(buf->shape[0].dtype(), 1); - PrimExpr cur = one; - for (int i = ndim - 1; i >= 0; --i) { - strides[i] = cur; - cur = cur * buf->shape[i]; - } - - // Offset: sum_{i in [0..ndim-3]} min_i * stride_i - PrimExpr offset = make_const(buf->shape[0].dtype(), 0); - for (int i = 0; i < ndim - 2; ++i) { - offset = offset + region->region[i]->min * strides[i]; - } - - // Extent: last two extents product (elements) - PrimExpr extent = - region->region[ndim - 2]->extent * region->region[ndim - 1]->extent; - - // ptype and return handle - PrimExpr ptype = tir::TypeAnnotation(buf->dtype); - Array acc_args{ptype, buf->data, offset, extent, - IntImm(DataType::Int(32), rw_mask)}; - return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args); -} +// MakeAccessPtrFromRegion moved to src/op/utils.{h,cc} Gemm::Gemm(Array args, BufferMap vmap) { ObjectPtr node = tvm::ffi::make_object(); @@ -535,9 +453,12 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst); // Build access pointers from regions locally - PrimExpr Aptr = MakeAccessPtrFromRegion(aRegion_, /*r*/ 1); - PrimExpr Bptr = MakeAccessPtrFromRegion(bRegion_, /*r*/ 1); - PrimExpr Cptr = MakeAccessPtrFromRegion(cRegion_, /*rw*/ 3); + PrimExpr Aptr = + MakeAccessPtrFromRegion(aRegion_, /*r*/ 1, /*require_2d*/ true); + PrimExpr Bptr = + MakeAccessPtrFromRegion(bRegion_, /*r*/ 1, /*require_2d*/ true); + PrimExpr Cptr = + MakeAccessPtrFromRegion(cRegion_, /*rw*/ 3, /*require_2d*/ true); std::stringstream ss; std::string op_name; diff --git a/src/op/gemm_py.cc b/src/op/gemm_py.cc index ac506ee0..a6ddef64 100644 --- a/src/op/gemm_py.cc +++ b/src/op/gemm_py.cc @@ -14,98 +14,16 @@ #include "../target/utils.h" #include "region.h" #include "tcgen5_meta.h" +#include "utils.h" namespace tvm { namespace tl { using namespace tir; -// Normalize a GEMM argument (BufferRegion/BufferLoad/tvm_access_ptr/tl.region) -// to BufferRegion -static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg, - const BufferMap &vmap) { - // Case 1: Already a BufferRegion - if (arg->IsInstance()) { - return Downcast(arg); - } - - // Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else - // extent=1) - if (const auto *load = arg.as()) { - Array ranges; - for (const PrimExpr &index : load->indices) { - if (const auto *ramp = index.as()) { - ICHECK(ramp->stride.as()) << "Ramp stride must be IntImm"; - ICHECK_EQ(ramp->stride.as()->value, 1) - << "Only stride-1 Ramp is supported in GEMM region conversion"; - ICHECK(ramp->lanes.as()) - << "Scalable vector lanes not supported in GEMM region conversion"; - ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes)); - } else { - ranges.push_back(Range::FromMinExtent(index, 1)); - } - } - return BufferRegion(load->buffer, ranges); - } - - // Case 3: Call nodes - if (const auto *call = arg.as()) { - // tl.region(...) — reconstruct via RegionOp - if (call->op.same_as(RegionOp::Get())) { - RegionOp region(call->args, vmap); - return BufferRegion(region->GetBuffer(), region->GetRanges()); - } - // builtin.tvm_access_ptr(...) — map var to Buffer and take full region - if (call->op.same_as(builtin::tvm_access_ptr())) { - Var var = Downcast(call->args[1]); - Buffer buf = vmap.at(var); - Array ranges; - for (PrimExpr extent : buf->shape) { - ranges.push_back(Range(IntImm(extent->dtype, 0), extent)); - } - return BufferRegion(buf, ranges); - } - } +// NormalizeToBufferRegion moved to src/op/utils.{h,cc} - LOG(FATAL) << "Unsupported GEMM argument for BufferRegion: " << arg; - throw; // Unreachable, keeps compiler happy -} - -// Build a tvm_access_ptr(handle) to the start of the 2D tile within a -// BufferRegion. Offset is computed from all but the last two dimensions; extent -// is the product of the last two extents. rw_mask: 1=read, 2=write, -// 3=readwrite. -static PrimExpr MakeAccessPtrFromRegion(const BufferRegion ®ion, - int rw_mask) { - Buffer buf = region->buffer; - int ndim = static_cast(buf->shape.size()); - ICHECK(ndim >= 2) << "GEMM expects buffers with at least 2 dims"; - - // Compute row-major strides - std::vector strides(ndim); - PrimExpr one = make_const(buf->shape[0].dtype(), 1); - PrimExpr cur = one; - for (int i = ndim - 1; i >= 0; --i) { - strides[i] = cur; - cur = cur * buf->shape[i]; - } - - // Offset: sum_{i in [0..ndim-3]} min_i * stride_i - PrimExpr offset = make_const(buf->shape[0].dtype(), 0); - for (int i = 0; i < ndim - 2; ++i) { - offset = offset + region->region[i]->min * strides[i]; - } - - // Extent: last two extents product (elements) - PrimExpr extent = - region->region[ndim - 2]->extent * region->region[ndim - 1]->extent; - - // ptype and return handle - PrimExpr ptype = tir::TypeAnnotation(buf->dtype); - Array acc_args{ptype, buf->data, offset, extent, - IntImm(DataType::Int(32), rw_mask)}; - return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args); -} +// MakeAccessPtrFromRegion moved to src/op/utils.{h,cc} /** * @brief Construct a Gemm operator from serialized TL arguments and a buffer diff --git a/src/op/reduce.cc b/src/op/reduce.cc index b6dbe865..c326f5ac 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -17,105 +17,16 @@ #include "region.h" #include "tir/transforms/ir_utils.h" #include "tvm/tir/stmt.h" +#include "utils.h" namespace tvm { namespace tl { using namespace tir; -// Normalize an argument (BufferRegion/BufferLoad/tl.region) -// to BufferRegion so Reduce can uniformly consume regions. -static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg, - const BufferMap &vmap) { - // Case 1: Already a BufferRegion - if (arg->IsInstance()) { - return Downcast(arg); - } - - // Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else - // extent=1) - if (const auto *load = arg.as()) { - Array ranges; - for (const PrimExpr &index : load->indices) { - if (const auto *ramp = index.as()) { - ICHECK(ramp->stride.as()) << "Ramp stride must be IntImm"; - ICHECK_EQ(ramp->stride.as()->value, 1) - << "Only stride-1 Ramp is supported in region conversion"; - ICHECK(ramp->lanes.as()) - << "Scalable vector lanes not supported in region conversion"; - ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes)); - } else { - ranges.push_back(Range::FromMinExtent(index, 1)); - } - } - return BufferRegion(load->buffer, ranges); - } - - // Case 3: Call nodes (only tl.region) - if (const auto *call = arg.as()) { - // tl.region(...) — reconstruct via RegionOp - if (call->op.same_as(RegionOp::Get())) { - RegionOp region(call->args, vmap); - return BufferRegion(region->GetBuffer(), region->GetRanges()); - } - // builtin.tvm_access_ptr(...) — map var to Buffer and take full region - if (call->op.same_as(builtin::tvm_access_ptr())) { - Var var = Downcast(call->args[1]); - Buffer buf = vmap[var]; - Array ranges; - for (PrimExpr extent : buf->shape) { - ranges.push_back(Range(IntImm(extent->dtype, 0), extent)); - } - return BufferRegion(buf, ranges); - } - } - - LOG(FATAL) << "Unsupported argument for BufferRegion in reduce: " << arg; - throw; // Unreachable -} - -// Build a tvm_access_ptr(handle) to the start of the 2D tile within a -// BufferRegion. Offset is computed from all but the last two dimensions; extent -// is the product of the last two extents. rw_mask: 1=read, 2=write, -// 3=readwrite. -static PrimExpr MakeAccessPtrFromRegion(const BufferRegion ®ion, - int rw_mask) { - Buffer buf = region->buffer; - int ndim = static_cast(buf->shape.size()); - ICHECK(ndim == 1 || ndim == 2) << "Cumsum expects buffers with 1 or 2 dims"; - - PrimExpr offset, extent; - if (ndim == 1) { - // Simple 1D region: offset and extent come from the single axis. - auto axis = region->region[0]; - offset = axis->min; - extent = axis->extent; - } else { - // Compute row-major strides for ndim >= 2 - std::vector strides(ndim); - PrimExpr one = make_const(buf->shape[0].dtype(), 1); - PrimExpr cur = one; - for (int i = ndim - 1; i >= 0; --i) { - strides[i] = cur; - cur = cur * buf->shape[i]; - } - // Offset: sum_{i in [0..ndim-3]} min_i * stride_i - offset = make_const(buf->shape[0].dtype(), 0); - for (int i = 0; i < ndim - 2; ++i) { - offset = offset + region->region[i]->min * strides[i]; - } +// NormalizeToBufferRegion moved to src/op/utils.{h,cc} - // Extent: last two extents product (elements) - extent = - region->region[ndim - 2]->extent * region->region[ndim - 1]->extent; - } - - // ptype and return handle - PrimExpr ptype = tir::TypeAnnotation(buf->dtype); - Array acc_args{ptype, buf->data, offset, extent, - IntImm(DataType::Int(32), rw_mask)}; - return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args); -} +// MakeAccessPtrFromRegion moved to src/op/utils.{h,cc} ReduceOp::ReduceOp(Array args, BufferMap vmap) { ObjectPtr node = tvm::ffi::make_object(); diff --git a/src/op/utils.cc b/src/op/utils.cc new file mode 100644 index 00000000..59960b57 --- /dev/null +++ b/src/op/utils.cc @@ -0,0 +1,105 @@ +/*! + * \file tl/op/utils.cc + * \brief Common utilities implementation for TL ops. + */ + +#include "utils.h" + +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +BufferRegion NormalizeToBufferRegion(const PrimExpr &arg, + const BufferMap &vmap) { + // Case 1: Already a BufferRegion + if (arg->IsInstance()) { + return Downcast(arg); + } + + // Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else + // extent=1) + if (const auto *load = arg.as()) { + Array ranges; + for (const PrimExpr &index : load->indices) { + if (const auto *ramp = index.as()) { + ICHECK(ramp->stride.as()) << "Ramp stride must be IntImm"; + ICHECK_EQ(ramp->stride.as()->value, 1) + << "Only stride-1 Ramp is supported in region conversion"; + ICHECK(ramp->lanes.as()) + << "Scalable vector lanes not supported in region conversion"; + ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes)); + } else { + ranges.push_back(Range::FromMinExtent(index, 1)); + } + } + return BufferRegion(load->buffer, ranges); + } + + // Case 3: Call nodes + if (const auto *call = arg.as()) { + // tl.region(...) — reconstruct via RegionOp + if (call->op.same_as(RegionOp::Get())) { + RegionOp region(call->args, vmap); + return BufferRegion(region->GetBuffer(), region->GetRanges()); + } + // builtin.tvm_access_ptr(...) — map var to Buffer and take full region + if (call->op.same_as(builtin::tvm_access_ptr())) { + Var var = Downcast(call->args[1]); + Buffer buf = vmap.at(var); + Array ranges; + for (PrimExpr extent : buf->shape) { + ranges.push_back(Range(IntImm(extent->dtype, 0), extent)); + } + return BufferRegion(buf, ranges); + } + } + + LOG(FATAL) << "Unsupported argument for BufferRegion: " << arg; + throw; // Unreachable +} + +PrimExpr MakeAccessPtrFromRegion(const BufferRegion ®ion, int rw_mask, + bool require_2d) { + Buffer buf = region->buffer; + int ndim = static_cast(buf->shape.size()); + if (require_2d) { + ICHECK(ndim >= 2) << "Expect buffers with at least 2 dims"; + } + + PrimExpr offset, extent; + if (ndim == 1) { + // 1D: straightforward + auto axis = region->region[0]; + offset = axis->min; + extent = axis->extent; + } else { + // Compute row-major strides + std::vector strides(ndim); + PrimExpr one = make_const(buf->shape[0].dtype(), 1); + PrimExpr cur = one; + for (int i = ndim - 1; i >= 0; --i) { + strides[i] = cur; + cur = cur * buf->shape[i]; + } + // Offset: sum_{i in [0..ndim-3]} min_i * stride_i + offset = make_const(buf->shape[0].dtype(), 0); + for (int i = 0; i < ndim - 2; ++i) { + offset = offset + region->region[i]->min * strides[i]; + } + // Extent: last two extents product (elements) + extent = + region->region[ndim - 2]->extent * region->region[ndim - 1]->extent; + } + + // ptype and return handle + PrimExpr ptype = tir::TypeAnnotation(buf->dtype); + Array acc_args{ptype, buf->data, offset, extent, + IntImm(DataType::Int(32), rw_mask)}; + return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args); +} + +} // namespace tl +} // namespace tvm diff --git a/src/op/utils.h b/src/op/utils.h new file mode 100644 index 00000000..9e7880ac --- /dev/null +++ b/src/op/utils.h @@ -0,0 +1,35 @@ +/*! + * \file tl/op/utils.h + * \brief Common utilities for TL ops. + */ + +#ifndef TVM_TL_OP_UTILS_H_ +#define TVM_TL_OP_UTILS_H_ + +#include "./operator.h" +#include "region.h" +#include +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +// Normalize an argument (BufferRegion/BufferLoad/tl.region/tvm_access_ptr) +// to BufferRegion so ops can uniformly consume regions. +TVM_DLL BufferRegion NormalizeToBufferRegion(const PrimExpr &arg, + const BufferMap &vmap); + +// Build a tvm_access_ptr(handle) from a BufferRegion. +// - If `require_2d` is true, checks buffer ndim >= 2. +// - For 1D regions (when allowed), offset=min, extent=extent. +// - For ndim >= 2, offset sums all but last two dims using row-major strides, +// extent is product of the last two extents. +TVM_DLL PrimExpr MakeAccessPtrFromRegion(const BufferRegion ®ion, + int rw_mask, bool require_2d = false); + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_UTILS_H_ -- GitLab From 2ae4f1b7877a828da7d01cf88a2a45ad37850bfd Mon Sep 17 00:00:00 2001 From: Kuris <227995639+kurisu6912@users.noreply.github.com> Date: Tue, 25 Nov 2025 14:07:52 +0800 Subject: [PATCH 046/139] [Fix] Fix bug copying from or to local buffer (#1304) (#1324) * [Fix] fix copy from or to local buffer (#1304) * fix lint error * minor fix testing script --- src/op/copy.cc | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/op/copy.cc b/src/op/copy.cc index 2584abce..82c903f8 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -851,8 +851,13 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, For vectorized_thread_loop; auto par_op = ParallelOp(transformed_loop); - if (is_cpu_target) { - vectorized_thread_loop = VectorizeLoop(transformed_loop, analyzer); + if (is_cpu_target || dst.scope() == "local" || src.scope() == "local") { + if (src.scope() == "local" && dst.scope() != "local") { + LOG(WARNING) << "Copy from local buffer `" << src->name << "` to " + << dst.scope() << " buffer `" << dst->name + << "` may cause conflicted write."; + } + vectorized_thread_loop = VectorizeLoop(transformed_loop); } else { std::vector levels = {InferLevel::kCommon, InferLevel::kStrict, InferLevel::kFree}; -- GitLab From e2b10c580b32cd31f384917d0ce31b7610f4e5e4 Mon Sep 17 00:00:00 2001 From: Chaofan Lin Date: Tue, 25 Nov 2025 20:22:15 +0800 Subject: [PATCH 047/139] [Language][UX] Semantic check for parallel fragment access (#1338) --- src/transform/layout_inference.cc | 8 +- .../test_tilelang_fragment_loop_checker.py | 162 ++++++++++++++++++ .../test_tilelang_nested_loop_checker.py} | 0 tilelang/analysis/__init__.py | 1 + tilelang/analysis/fragment_loop_checker.py | 100 +++++++++++ tilelang/analysis/nested_loop_checker.py | 6 +- tilelang/engine/phase.py | 3 + 7 files changed, 277 insertions(+), 3 deletions(-) create mode 100644 testing/python/analysis/test_tilelang_fragment_loop_checker.py rename testing/python/{language/test_tilelang_language_nested_loop.py => analysis/test_tilelang_nested_loop_checker.py} (100%) create mode 100644 tilelang/analysis/fragment_loop_checker.py diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index be98b284..873f70d0 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -821,7 +821,13 @@ private: int64_t frag_reg_num = 1; for (auto i : frag.value()->OutputShape()) { auto pci = as_const_int(i); - ICHECK(pci != nullptr); + ICHECK(pci != nullptr) + << "Can not use non-constant range to " + "iterate over a fragment/local " + "buffer. Non-constant shape expr is: " + << i + << ". This is possibly because you use symbolic shape when " + "accessing a fragment/local buffer."; frag_reg_num *= *pci; } reg_num += frag_reg_num; diff --git a/testing/python/analysis/test_tilelang_fragment_loop_checker.py b/testing/python/analysis/test_tilelang_fragment_loop_checker.py new file mode 100644 index 00000000..9073aebc --- /dev/null +++ b/testing/python/analysis/test_tilelang_fragment_loop_checker.py @@ -0,0 +1,162 @@ +import tilelang +import tilelang.language as T +import pytest + + +@tilelang.jit +def simple_invalid_loop(dtype: str = "bfloat16", + accum_dtype: str = "float32", + num_threads: int = 128): + A = T.dynamic("A") + + @T.prim_func + def main( + data: T.Tensor((128, A), dtype), # type: ignore + ): + with T.Kernel(128, threads=num_threads) as (tid,): + data_frag = T.alloc_fragment([128], accum_dtype) + + for i in T.Parallel(128): + if i < A: + data_frag[i] = data[tid, i] + + for i in T.Parallel(A): + data_frag[i] = 0 + + return main + + +@tilelang.jit +def nested_invalid_loop(dtype: str = "bfloat16", + accum_dtype: str = "float32", + num_threads: int = 128): + A = T.dynamic("A") + + @T.prim_func + def main( + data: T.Tensor((128, A), dtype), # type: ignore + ): + with T.Kernel(128, threads=num_threads) as (tid,): + data_frag = T.alloc_fragment([128], accum_dtype) + + for i in T.Parallel(128): + if i < A: + data_frag[i] = data[tid, i] + + for i in T.Parallel(A // 64): + for j in T.Parallel(64): + data_frag[i * 64 + j] = 0 + + return main + + +@tilelang.jit +def invalid_loop_with_complex_dataflow(dtype: str = "bfloat16", + accum_dtype: str = "float32", + num_threads: int = 128): + A = T.dynamic("A") + + @T.prim_func + def main( + data: T.Tensor((128, A), dtype), # type: ignore + ): + with T.Kernel(128, threads=num_threads) as (tid,): + data_frag = T.alloc_fragment([128], accum_dtype) + + for i in T.Parallel(128): + if i < A: + data_frag[i] = data[tid, i] + + for i in T.Parallel(A): + data_frag[64 // 2 + i % 64] = 0 + + return main + + +@tilelang.jit +def valid_loop_not_use_loop_var(dtype: str = "bfloat16", + accum_dtype: str = "float32", + num_threads: int = 128): + A = T.dynamic("A") + + @T.prim_func + def main( + data: T.Tensor((128, A), dtype), # type: ignore + ): + with T.Kernel(128, threads=num_threads) as (tid,): + data_frag = T.alloc_fragment([128], accum_dtype) + + for i in T.Parallel(128): + if i < A: + data_frag[i] = data[tid, i] + + for i in T.Parallel(A): # noqa: B007 + for j in T.Parallel(64): + data_frag[j] = 0 # This is valid because we don't use i + + return main + + +@tilelang.jit +def valid_loop_not_frag(dtype: str = "bfloat16", + accum_dtype: str = "float32", + num_threads: int = 128): + A = T.dynamic("A") + + @T.prim_func + def main( + data: T.Tensor((128, A), dtype), # type: ignore + ): + with T.Kernel(128, threads=num_threads) as (tid,): + data_shared = T.alloc_shared([128], accum_dtype) + + for i in T.Parallel(128): + if i < A: + data_shared[i] = data[tid, i] + + for i in T.Parallel(A): + data_shared[i] = 0 # Valid because this is shared memory + + return main + + +@tilelang.jit +def valid_loop_serial(dtype: str = "bfloat16", + accum_dtype: str = "float32", + num_threads: int = 128): + A = T.dynamic("A") + + @T.prim_func + def main( + data: T.Tensor((128, A), dtype), # type: ignore + ): + with T.Kernel(128, threads=num_threads) as (tid,): + data_shared = T.alloc_shared([128], accum_dtype) + + for i in T.Parallel(128): + if i < A: + data_shared[i] = data[tid, i] + + for i in T.serial(A): + data_shared[i] = 0 # Valid because this is serial + + return main + + +def test_invalid_loop(): + with pytest.raises(ValueError): + simple_invalid_loop() + with pytest.raises(ValueError): + nested_invalid_loop() + with pytest.raises(ValueError): + invalid_loop_with_complex_dataflow() + + +def test_valid_loop(): + valid_loop_not_use_loop_var() + valid_loop_not_frag() + valid_loop_serial() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_language_nested_loop.py b/testing/python/analysis/test_tilelang_nested_loop_checker.py similarity index 100% rename from testing/python/language/test_tilelang_language_nested_loop.py rename to testing/python/analysis/test_tilelang_nested_loop_checker.py diff --git a/tilelang/analysis/__init__.py b/tilelang/analysis/__init__.py index 6e5ee5d6..33ccded6 100644 --- a/tilelang/analysis/__init__.py +++ b/tilelang/analysis/__init__.py @@ -2,3 +2,4 @@ from .ast_printer import ASTPrinter # noqa: F401 from .nested_loop_checker import NestedLoopChecker # noqa: F401 +from .fragment_loop_checker import FragmentLoopChecker # noqa: F401 diff --git a/tilelang/analysis/fragment_loop_checker.py b/tilelang/analysis/fragment_loop_checker.py new file mode 100644 index 00000000..3186b23e --- /dev/null +++ b/tilelang/analysis/fragment_loop_checker.py @@ -0,0 +1,100 @@ +from __future__ import annotations +from tvm import tir +from tvm.tir import (PyStmtExprVisitor, BufferStore, For, Var, PrimFunc, BufferLoad, IntImm) +from tvm.tir.transform import prim_func_pass +from tvm.tir.stmt_functor import post_order_visit + + +@tir.functor.visitor +class _LoopVarUseAnalyzer(PyStmtExprVisitor): + """Analyze whether a loop variable is used in the given expr.""" + + def __init__(self, var: Var) -> None: + super().__init__() + self.var = var + self.used = False + + def visit_var_(self, op: Var) -> None: + if op == self.var: + self.used = True + # Don't recursively visit children to avoid infinite recursion + + +def collect_local_buffer_accesses(statement) -> list[BufferLoad | BufferStore]: + """ + Collect local buffer accesses in the loop body. + + Args: + statement: The TIR statement to analyze + + Returns: + Tuple of buffer accesses in the loop body. + """ + + buffer_accesses = [] + + def visit_buffer_access(node): + if isinstance(node, (BufferLoad, BufferStore)) and node.buffer.scope().startswith("local"): + buffer_accesses.append(node) + + post_order_visit(statement, visit_buffer_access) + + return buffer_accesses + + +@tir.functor.visitor +class _FragmentLoopCheckVisitor(PyStmtExprVisitor): + + def __init__(self) -> None: + super().__init__() + + def visit_for_(self, op: For) -> None: + if op.kind == tir.ForKind.PARALLEL: + # Fuse consecutive parallel loops + # Other nested cases are all invalid in TileLang. + loops = [op] + child = op.body + while isinstance(child, For) and child.kind == tir.ForKind.PARALLEL: + loops.append(child) + child = child.body + + loops_with_symbolic_ranges = [] + for loop in loops: + if not (isinstance(loop.min, IntImm) and isinstance(loop.extent, IntImm)): + loops_with_symbolic_ranges.append(loop) + + if len(loops_with_symbolic_ranges) > 0: + buffer_accesses = collect_local_buffer_accesses(child) + for loop in loops_with_symbolic_ranges: + for buffer_access in buffer_accesses: + indices = buffer_access.indices + analyzer = _LoopVarUseAnalyzer(loop.loop_var) + for index in indices: + analyzer.visit_expr(index) + if analyzer.used: + raise ValueError( + "[Tilelang Semantic Check] " + f"Loop variable {loop.loop_var} in a T.Parallel loop with symbolic range (min={loop.min}, extent={loop.extent}) is used to index " + "a local/fragment buffer, which is not allowed in Tilelang.") + + return + + self.visit_stmt(op.body) + + +def FragmentLoopChecker(): + """ + When using T.Parallel over a local/fragment buffer, there are several restrictions: + to ensure that the parallelization is valid. + + 1. The range of loop can not be symbolic. + + Returns: + A prim_func_pass that applies the transformation + """ + + def pass_fn(func: PrimFunc, mod, ctx): + _FragmentLoopCheckVisitor().visit_stmt(func.body) + return func + + return prim_func_pass(pass_fn, opt_level=0) diff --git a/tilelang/analysis/nested_loop_checker.py b/tilelang/analysis/nested_loop_checker.py index 4b9741c3..7a0d94da 100644 --- a/tilelang/analysis/nested_loop_checker.py +++ b/tilelang/analysis/nested_loop_checker.py @@ -35,7 +35,8 @@ class _NestedLoopCheckVisitor(PyStmtExprVisitor): # Otherwise if self.in_parallel_context: - raise ValueError("Nested parallel loops are not allowed. " + raise ValueError("[Tilelang Semantic Check] " + "Nested parallel loops are not allowed. " "Please check your loop structure.") self.in_parallel_context = True self.visit_stmt(child) @@ -43,7 +44,8 @@ class _NestedLoopCheckVisitor(PyStmtExprVisitor): return elif is_pipelined_for(op): if self.in_parallel_context: - raise ValueError("Pipelined loop cannot be nested inside a parallel loop. " + raise ValueError("[Tilelang Semantic Check] " + "Pipelined loop cannot be nested inside a parallel loop. " "Please check your loop structure.") self.visit_stmt(op.body) diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index f686ba1f..17d6e4aa 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -80,6 +80,9 @@ def PreLowerSemanticCheck(mod: IRModule) -> None: # Check if there are any invalid nested loops. tilelang.analysis.NestedLoopChecker()(mod) + # Check if there are any invalid symbolic T.Parallel + fragment access. + tilelang.analysis.FragmentLoopChecker()(mod) + def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: # Bind the target device information to the module -- GitLab From f810f9767a53b140557daf5486e326c723b40a6a Mon Sep 17 00:00:00 2001 From: LJC00118 <77378439+LJC00118@users.noreply.github.com> Date: Wed, 26 Nov 2025 12:57:48 +0800 Subject: [PATCH 048/139] Add unit tests for T.assume (#1341) * Add test for T.assume * Add unit test for T.assume * Add unit test for T.assume * Add unit tests for T.assume * Remove debug print for kernel source Remove print statement for kernel source in tests. * Update test_tilelang_language_assume.py --------- Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> --- .../language/test_tilelang_language_assume.py | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 testing/python/language/test_tilelang_language_assume.py diff --git a/testing/python/language/test_tilelang_language_assume.py b/testing/python/language/test_tilelang_language_assume.py new file mode 100644 index 00000000..9c75a5ac --- /dev/null +++ b/testing/python/language/test_tilelang_language_assume.py @@ -0,0 +1,89 @@ +import tilelang +import tilelang.language as T +import tilelang.testing + + +def test_assume_remove_boundary_check(): + + @tilelang.jit + def kernel_with_assume(): + N = T.dynamic('N') + + @T.prim_func + def main(A: T.Tensor((N,), "float32"), l: T.int32, r: T.int32): + with T.Kernel(1, threads=32) as _: + for i in T.serial(r - l + 1): + T.assume(l + i >= 0 and l + i < N) + A[l + i] = 0 + + return main + + jit_kernel = kernel_with_assume() + source = jit_kernel.get_kernel_source() + + assert ("if (" not in source) + + +def test_assume_enable_vectorization(): + + @tilelang.jit + def kernel_vectorize(M): + N = T.dynamic('N') + vectorize_size = 4 + + @T.prim_func + def main( + A: T.Tensor((M, N), "float32"), + B: T.Tensor((M, N), "float32"), + ): + with T.Kernel(1, threads=32) as _: + tid = T.get_thread_binding() + + base_idx = tid * 4 + T.assume(N % vectorize_size == 0) + + for i in T.vectorized(vectorize_size): + T.assume(base_idx + i < N) + B[tid, base_idx + i] = A[tid, base_idx + i] + + return main + + jit_kernel = kernel_vectorize(128) + source = jit_kernel.get_kernel_source() + + assert ("float4" in source) and ("if (" not in source) + + +def test_assume_complex_indexing(): + + @tilelang.jit + def kernel_complex(): + M = T.dynamic('M') + N = T.dynamic('N') + + @T.prim_func + def main( + A: T.Tensor((M, N), "float32"), + B: T.Tensor((M, N), "float32"), + ): + with T.Kernel(1, threads=32) as _: + tid = T.get_thread_binding() + for j in T.serial(N): + i_src = T.min(j + 233, tid + 2) + j_src = j * T.ceildiv(j, i_src) * j - 1 + + T.assume(i_src >= 0 and i_src < M) + T.assume(j_src >= 0 and j_src < N) + + B[tid, j] = A[i_src, j_src] + + return main + + jit_kernel = kernel_complex() + source = jit_kernel.get_kernel_source() + + assert ("if (" not in source) + + +if __name__ == '__main__': + tilelang.testing.main() -- GitLab From fac0400680aa267efe01c663d0b92544c22471b5 Mon Sep 17 00:00:00 2001 From: ConvolutedDog Date: Wed, 26 Nov 2025 14:02:09 +0800 Subject: [PATCH 049/139] [Feat] Extend LegalizeNegativeIndex to support buffer store stmts (#1339) This commit enhances the LegalizeNegativeIndex transformation pass to handle both buffer load and store operations with negative indices and adds some test cases. --- src/support/ffi_aliases.h | 1 + src/transform/legalize_negative_index.cc | 214 +++++------ ...elang_transform_legalize_negative_index.py | 342 ++++++++++++++++++ 3 files changed, 453 insertions(+), 104 deletions(-) create mode 100644 testing/python/transform/test_tilelang_transform_legalize_negative_index.py diff --git a/src/support/ffi_aliases.h b/src/support/ffi_aliases.h index cbc6fb02..7dbe0b39 100644 --- a/src/support/ffi_aliases.h +++ b/src/support/ffi_aliases.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include diff --git a/src/transform/legalize_negative_index.cc b/src/transform/legalize_negative_index.cc index b502a6fb..f0df555e 100644 --- a/src/transform/legalize_negative_index.cc +++ b/src/transform/legalize_negative_index.cc @@ -1,6 +1,6 @@ /*! * \file legalize_negative_index.cc - * \brief Legalize negative indices in buffer load expressions. + * \brief Legalize negative indices in buffer load/store expressions. */ #include @@ -10,6 +10,7 @@ #include #include +#include #include #include "arith/ir_mutator_with_analyzer.h" @@ -23,47 +24,42 @@ using arith::IRVisitorWithAnalyzer; enum class IndexSignState { kNonNegative, kNegative, kUnknown }; +using BufferAccessVariant = + std::variant; +using LoadStore2StateMap = + std::unordered_map>; + class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer { public: - explicit NegativeIndexAnalyzer( - std::unordered_map> - *result) + explicit NegativeIndexAnalyzer(LoadStore2StateMap *result) : result_(result) {} - void VisitExpr_(const BufferLoadNode *op) final { - auto load = tvm::ffi::GetRef(op); +private: + std::vector ProcessIdx(const ffi::Array &indices, + ffi::String buffer_name) { std::vector states; - states.reserve(op->indices.size()); - bool needs_record = false; + states.reserve(indices.size()); - for (size_t i = 0; i < op->indices.size(); ++i) { - PrimExpr simplified = analyzer_.Simplify(op->indices[i]); + for (size_t i = 0; i < indices.size(); ++i) { + PrimExpr simplified = analyzer_.Simplify(indices[i]); + IndexSignState state = IndexSignState::kUnknown; // Handle scalar indices with the standard analyzer if (simplified.dtype().lanes() == 1) { - if (analyzer_.CanProve(simplified >= 0)) { - states.push_back(IndexSignState::kNonNegative); - continue; - } - if (analyzer_.CanProve(simplified < 0)) { - states.push_back(IndexSignState::kNegative); - needs_record = true; - continue; - } - states.push_back(IndexSignState::kUnknown); - needs_record = true; - DLOG(WARNING) - << "LegalizeNegativeIndex: cannot prove non-negative index " - << simplified << " for buffer " << load->buffer->name << " (axis " - << i << ")."; - continue; + if (analyzer_.CanProve(simplified >= 0)) + state = IndexSignState::kNonNegative; + else if (analyzer_.CanProve(simplified < 0)) + state = IndexSignState::kNegative; + else + DLOG(WARNING) + << "LegalizeNegativeIndex: cannot prove non-negative index " + << simplified << " for buffer " << buffer_name << " (axis " << i + << ", index " + indices[i]->Script() + ")."; } - // Vector indices: try to reason about non-negativity/negativity // Common patterns are Ramp(base, stride, lanes) and Broadcast(value, // lanes). - IndexSignState vec_state = IndexSignState::kUnknown; - if (const auto *ramp = simplified.as()) { + else if (const auto *ramp = simplified.as()) { // Compute a safe lower/upper bound for the vector lanes // lower_bound = base_min + min(0, stride_min) * (lanes - 1) // upper_bound = base_max + max(0, stride_max) * (lanes - 1) @@ -85,118 +81,129 @@ public: if (s_max > 0) upper += s_max * (lanes - 1); - if (lower >= 0) { - vec_state = IndexSignState::kNonNegative; - } else if (upper < 0) { - vec_state = IndexSignState::kNegative; - } else { - vec_state = IndexSignState::kUnknown; - } - } else if (const auto *bc = simplified.as()) { - auto v = analyzer_.Simplify(bc->value); - if (analyzer_.CanProve(v >= 0)) { - vec_state = IndexSignState::kNonNegative; - } else if (analyzer_.CanProve(v < 0)) { - vec_state = IndexSignState::kNegative; - } else { + if (lower >= 0) + state = IndexSignState::kNonNegative; + else if (upper < 0) + state = IndexSignState::kNegative; + else + DLOG(WARNING) + << "LegalizeNegativeIndex: cannot prove non-negative index " + << simplified << " for buffer " << buffer_name << " (axis " << i + << ", index " + indices[i]->Script() + ")."; + } else if (const auto *broadcast = simplified.as()) { + auto v = analyzer_.Simplify(broadcast->value); + if (analyzer_.CanProve(v >= 0)) + state = IndexSignState::kNonNegative; + else if (analyzer_.CanProve(v < 0)) + state = IndexSignState::kNegative; + else { // Try const bound if proof unavailable auto vb = analyzer_.const_int_bound(v); - if (vb->min_value >= 0) { - vec_state = IndexSignState::kNonNegative; - } else if (vb->max_value < 0) { - vec_state = IndexSignState::kNegative; - } else { - vec_state = IndexSignState::kUnknown; - } + if (vb->min_value >= 0) + state = IndexSignState::kNonNegative; + else if (vb->max_value < 0) + state = IndexSignState::kNegative; + else + DLOG(WARNING) + << "LegalizeNegativeIndex: cannot prove non-negative index " + << simplified << " for buffer " << buffer_name << " (axis " << i + << ", index " + indices[i]->Script() + ")."; } } + states.push_back(state); + } - if (vec_state == IndexSignState::kNonNegative) { - states.push_back(IndexSignState::kNonNegative); - continue; - } - if (vec_state == IndexSignState::kNegative) { - states.push_back(IndexSignState::kNegative); - needs_record = true; - continue; - } + return std::move(states); + } - states.push_back(IndexSignState::kUnknown); - needs_record = true; - DLOG(WARNING) << "LegalizeNegativeIndex: cannot prove non-negative index " - << simplified << " for buffer " << load->buffer->name - << " (axis " << i << ")."; - } + bool NeedRecord(const std::vector &states) { + return std::any_of(states.begin(), states.end(), + [](const IndexSignState &state) { + return state == IndexSignState::kUnknown || + state == IndexSignState::kNegative; + }); + } + + void VisitExpr_(const BufferLoadNode *op) final { + std::vector states = + ProcessIdx(op->indices, op->buffer->name); - if (needs_record) { + if (NeedRecord(states)) (*result_)[op] = std::move(states); - } IRVisitorWithAnalyzer::VisitExpr_(op); } + void VisitStmt_(const BufferStoreNode *op) final { + std::vector states = + ProcessIdx(op->indices, op->buffer->name); + + if (NeedRecord(states)) + (*result_)[op] = std::move(states); + + IRVisitorWithAnalyzer::VisitStmt_(op); + } + private: - std::unordered_map> - *result_; + LoadStore2StateMap *result_; }; class NegativeIndexRewriter : public arith::IRMutatorWithAnalyzer { public: - static PrimFunc - Apply(PrimFunc func, - const std::unordered_map> &states) { + static PrimFunc Apply(PrimFunc func, const LoadStore2StateMap &states) { arith::Analyzer analyzer; NegativeIndexRewriter rewriter(&analyzer, states); - if (!func->body.defined()) { - return func; - } PrimFuncNode *func_node = func.CopyOnWrite(); func_node->body = rewriter.VisitStmt(func_node->body); return func; } private: - NegativeIndexRewriter( - arith::Analyzer *analyzer, - const std::unordered_map> &states) + NegativeIndexRewriter(arith::Analyzer *analyzer, + const LoadStore2StateMap &states) : arith::IRMutatorWithAnalyzer(analyzer), states_(states) {} + ffi::Array UpdateIdx(const ffi::Array &indices, + const ffi::Array &buffer_shape, + const std::vector &state_vec) { + ICHECK_EQ(state_vec.size(), indices.size()) + << "State vector size mismatch for buffer load/store indices (" + << indices << ")"; + ffi::Array new_indices = indices; + for (size_t i = 0; i < indices.size(); ++i) { + if (state_vec[i] != IndexSignState::kNegative) + continue; + new_indices.Set(i, analyzer_->Simplify(buffer_shape[i] + indices[i])); + } + return new_indices; + } + PrimExpr VisitExpr_(const BufferLoadNode *op) final { BufferLoad load = Downcast(arith::IRMutatorWithAnalyzer::VisitExpr_(op)); auto it = states_.find(op); - if (it == states_.end()) { + if (it == states_.end()) return load; - } - auto indices = load->indices; - bool changed = false; - - const auto &state_vector = it->second; - ICHECK_EQ(state_vector.size(), indices.size()) - << "State vector size mismatch for buffer load " << load->buffer->name; + auto indices = UpdateIdx(load->indices, load->buffer->shape, it->second); + return BufferLoad(load->buffer, indices, load->predicate); + } - for (size_t i = 0; i < indices.size(); ++i) { - if (state_vector[i] != IndexSignState::kNegative) { - continue; - } - PrimExpr extent = load->buffer->shape[i]; - indices.Set(i, analyzer_->Simplify(extent + indices[i])); - changed = true; - } + Stmt VisitStmt_(const BufferStoreNode *op) final { + BufferStore store = + Downcast(arith::IRMutatorWithAnalyzer::VisitStmt_(op)); - if (!changed) { - return load; - } + auto it = states_.find(op); + if (it == states_.end()) + return store; - return BufferLoad(load->buffer, indices); + auto indices = UpdateIdx(store->indices, store->buffer->shape, it->second); + return BufferStore(store->buffer, store->value, indices, store->predicate); } - const std::unordered_map> - &states_; +private: + const LoadStore2StateMap &states_; }; PrimFunc LegalizeNegativeIndex(PrimFunc func) { @@ -204,8 +211,7 @@ PrimFunc LegalizeNegativeIndex(PrimFunc func) { return func; } - std::unordered_map> - states; + LoadStore2StateMap states; NegativeIndexAnalyzer analyzer(&states); analyzer(func->body); if (states.empty()) { diff --git a/testing/python/transform/test_tilelang_transform_legalize_negative_index.py b/testing/python/transform/test_tilelang_transform_legalize_negative_index.py new file mode 100644 index 00000000..c5dd065a --- /dev/null +++ b/testing/python/transform/test_tilelang_transform_legalize_negative_index.py @@ -0,0 +1,342 @@ +from tilelang import tvm as tvm +import tilelang as tl +import tilelang.language as T +import tilelang.testing + + +def _check(original, expected): + """Helper function to verify structural equality after transformations""" + func = original + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + mod = tl.transform.LegalizeNegativeIndex()(mod) + expected = tvm.IRModule.from_expr(expected.with_attr("global_symbol", "main")) + tvm.ir.assert_structural_equal(mod["main"], expected["main"], True) + + +def test_buffer_load_negative_index_legalized(): + """ + Test that negative indices are legalized by adding buffer extent. + """ + + @T.prim_func + def before(A: T.Tensor((1024,), "float32")): + value = A[-1] + B = T.alloc_buffer((1,), "float32") + B[0] = value + + @T.prim_func + def after(A: T.Tensor((1024,), "float32")): + value = A[1023] # A[-1] becomes A[1023] + B = T.alloc_buffer((1,), "float32") + B[0] = value + + _check(before, after) + + +def test_buffer_load_mixed_negative_positive_indices(): + """ + Test mixed negative and positive indices - only negative ones are legalized. + """ + + @T.prim_func + def before(A: T.Tensor((1024, 512), "float32")): + value = A[-1, 10] + B = T.alloc_buffer((1,), "float32") + B[0] = value + + @T.prim_func + def after(A: T.Tensor((1024, 512), "float32")): + value = A[1023, 10] # A[-1, 10] becomes A[1023, 10] + B = T.alloc_buffer((1,), "float32") + B[0] = value + + _check(before, after) + + +def test_buffer_load_multiple_negative_indices(): + """ + Test multiple negative indices in different dimensions. + """ + + @T.prim_func + def before(A: T.Tensor((1024, 512, 256), "float32")): + value = A[-1, -2, -3] + B = T.alloc_buffer((1,), "float32") + B[0] = value + + @T.prim_func + def after(A: T.Tensor((1024, 512, 256), "float32")): + value = A[1023, 510, 253] # -1+1024=1023, -2+512=510, -3+256=253 + B = T.alloc_buffer((1,), "float32") + B[0] = value + + _check(before, after) + + +def test_buffer_load_negative_index_in_expression(): + """ + Test negative index as part of a larger expression. + """ + + @T.prim_func + def before(A: T.Tensor((1024,), "float32")): + B = T.alloc_buffer((1024,), "float32") + for i in T.serial(1, 1024): + value = A[-i] + B[-i] = value + + @T.prim_func + def after(A: T.Tensor((1024,), "float32")): + B = T.alloc_buffer((1024,), "float32") + for i in T.serial(1, 1024): + value = A[1024 - i] + B[1024 - i] = value + + _check(before, after) + + +def test_buffer_load_non_negative_index_unchanged(): + """ + Test that non-negative indices remain unchanged. + """ + + @T.prim_func + def before(A: T.Tensor((1024,), "float32")): + value = A[0] + B = T.alloc_buffer((1,), "float32") + B[0] = value + + @T.prim_func + def after(A: T.Tensor((1024,), "float32")): + # No changes expected for non-negative indices + value = A[0] + B = T.alloc_buffer((1,), "float32") + B[0] = value + + _check(before, after) + + +def test_buffer_load_unknown_sign_index_warning(): + """ + Test that indices with unknown sign trigger warnings but are processed. + This test mainly checks that the pass doesn't crash on unknown signs. + """ + + @T.prim_func + def before(A: T.Tensor((1024,), "float32")): + i = T.Var("i", "int32") + value = A[i] + B = T.alloc_buffer((1,), "float32") + B[0] = value + + @T.prim_func + def after(A: T.Tensor((1024,), "float32")): + i = T.Var("i", "int32") + # Unknown sign indices should remain unchanged + value = A[i] + B = T.alloc_buffer((1,), "float32") + B[0] = value + + _check(before, after) + + +def test_buffer_load_vector_index_negative_broadcast(): + """ + Test negative indices in vectorized operations (broadcast case). + """ + + @T.prim_func + def before(A: T.Tensor((1024,), "float32")): + vec = T.Broadcast(-1, 4) + value = A[vec] + B = T.alloc_buffer((4,), "float32") + B[T.Ramp(0, 1, 4)] = value + + @T.prim_func + def after(A: T.Tensor((1024,), "float32")): + # vec is unused and can be delimed by Simplify. + vec = T.Broadcast(-1, 4) # noqa: F841 + value = A[T.Broadcast(1023, 4)] + B = T.alloc_buffer((4,), "float32") + B[T.Ramp(0, 1, 4)] = value + + _check(before, after) + + +def test_buffer_load_vector_index_negative_ramp(): + """ + Test negative indices in vectorized operations (ramp case). + """ + + @T.prim_func + def before(A: T.Tensor((1024,), "float32")): + vec = T.Ramp(-4, 1, 4) # indices: [-4, -3, -2, -1] + value = A[vec] + B = T.alloc_buffer((4,), "float32") + B[T.Ramp(0, 1, 4)] = value + + @T.prim_func + def after(A: T.Tensor((1024,), "float32")): + # vec is unused and can be delimed by Simplify. + vec = T.Ramp(-4, 1, 4) # noqa: F841 + value = A[T.Ramp(1020, 1, 4)] + B = T.alloc_buffer((4,), "float32") + B[T.Ramp(0, 1, 4)] = value + + _check(before, after) + + +def test_buffer_load_nested_buffer_loads(): + """ + Test legalization with nested buffer load expressions. + """ + + @T.prim_func + def before(A: T.Tensor((1024, 512), "float32")): + inner_val = A[-1, 10] + outer_val = A[inner_val.astype("int32"), -2] + B = T.alloc_buffer((1,), "float32") + B[0] = outer_val + + @T.prim_func + def after(A: T.Tensor((1024, 512), "float32")): + inner_val = A[1023, 10] + outer_val = A[inner_val.astype("int32"), 510] + B = T.alloc_buffer((1,), "float32") + B[0] = outer_val + + _check(before, after) + + +def test_buffer_store_negative_index(): + """ + Test negative indices in buffer store operations are legalized. + """ + + @T.prim_func + def before(A: T.Tensor((1024,), "float32")): + A[-1] = 42.0 + + @T.prim_func + def after(A: T.Tensor((1024,), "float32")): + A[1023] = 42.0 + + _check(before, after) + + +def test_buffer_store_mixed_negative_positive_indices(): + """ + Test mixed negative and positive indices in buffer store. + """ + + @T.prim_func + def before(A: T.Tensor((1024, 512), "float32")): + A[-1, 10] = 42.0 + + @T.prim_func + def after(A: T.Tensor((1024, 512), "float32")): + A[1023, 10] = 42.0 + + _check(before, after) + + +def test_buffer_store_multiple_negative_indices(): + """ + Test multiple negative indices in different dimensions for buffer store. + """ + + @T.prim_func + def before(A: T.Tensor((1024, 512, 256), "float32")): + A[-1, -2, -3] = 42.0 + + @T.prim_func + def after(A: T.Tensor((1024, 512, 256), "float32")): + A[1023, 510, 253] = 42.0 # -1+1024=1023, -2+512=510, -3+256=253 + + _check(before, after) + + +def test_buffer_store_negative_index_in_expression(): + """ + Test negative index as part of a larger expression in buffer store. + """ + + @T.prim_func + def before(A: T.Tensor((1024,), "float32")): + for i in T.serial(1, 1024): + A[-i] = i * 2.0 + + @T.prim_func + def after(A: T.Tensor((1024,), "float32")): + for i in T.serial(1, 1024): + A[1024 - i] = i * 2.0 + + _check(before, after) + + +def test_buffer_store_vector_index_negative_broadcast(): + """ + Test negative indices in vectorized store operations (broadcast case). + """ + + @T.prim_func + def before(A: T.Tensor((1024,), "float32")): + vec = T.Broadcast(-1, 4) + values = T.Broadcast(42.0, 4) + A[vec] = values + + @T.prim_func + def after(A: T.Tensor((1024,), "float32")): + # vec is unused and can be delimed by Simplify. + vec = T.Broadcast(-1, 4) # noqa: F841 + values = T.Broadcast(42.0, 4) + A[T.Broadcast(1023, 4)] = values + + _check(before, after) + + +def test_buffer_store_vector_index_negative_ramp(): + """ + Test negative indices in vectorized store operations (ramp case). + """ + + @T.prim_func + def before(A: T.Tensor((1024,), "float32")): + vec = T.Ramp(-4, 1, 4) # indices: [-4, -3, -2, -1] + values = T.Ramp(0.0, 1.0, 4) # values: [0.0, 1.0, 2.0, 3.0] + A[vec] = values + + @T.prim_func + def after(A: T.Tensor((1024,), "float32")): + # vec is unused and can be delimed by Simplify. + vec = T.Ramp(-4, 1, 4) # noqa: F841 + values = T.Ramp(0.0, 1.0, 4) + A[T.Ramp(1020, 1, 4)] = values + + _check(before, after) + + +def test_buffer_store_nested_in_condition(): + """ + Test negative index buffer store within conditional statements. + """ + + @T.prim_func + def before(A: T.Tensor((1024,), "float32"), flag: T.int32): + if flag > 0: + A[-1] = 42.0 + else: + A[-2] = 24.0 + + @T.prim_func + def after(A: T.Tensor((1024,), "float32"), flag: T.int32): + if flag > 0: + A[1023] = 42.0 + else: + A[1022] = 24.0 + + _check(before, after) + + +if __name__ == "__main__": + tilelang.testing.main() -- GitLab From f5d9da46788674b326ace0714c47ad36f39c1de8 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Wed, 26 Nov 2025 15:18:50 +0800 Subject: [PATCH 050/139] [Refactor] Phaseout vmap for Tile Operators (#1334) * Refactor GEMM and Reduce operations by moving NormalizeToBufferRegion and MakeAccessPtrFromRegion to utils.{h,cc} for better code organization and reuse. * lint fix * Refactor region handling by removing the RegionOp and updating NormalizeToBufferRegion to only accept BufferLoad and BufferRegion. This change improves code organization and simplifies the handling of memory regions across various operations. * fix * Refactor memory region handling by introducing `tl.region` calls across various operations, including GEMM and fill functions. This change enhances the consistency of region management and improves code organization by utilizing utility functions for buffer region conversions. * fix * fix * test fix * lint fix * Refactor GEMM operations to improve memory region handling by replacing `mbarPtr_` with `mbarRegion_` and updating related logic in both C++ and Python implementations. This change enhances the clarity and consistency of buffer region management. * fix * lint fix * fix * fix * test fix * lint fix * lint fix * minor fix * fix --------- Co-authored-by: Zhiwen Mo --- .../deepseek_mla/test_example_mla_decode.py | 1 - examples/gemv/example_gemv.py | 21 +-- examples/gemv/test_example_gemv.py | 4 +- src/op/atomic_add.cc | 27 ++-- src/op/atomic_add.h | 2 +- src/op/copy.cc | 127 +++++++++--------- src/op/copy.h | 38 +++--- src/op/fill.cc | 54 +------- src/op/fill.h | 2 +- src/op/finalize_reducer.cc | 11 +- src/op/finalize_reducer.h | 2 +- src/op/gemm.cc | 28 ++-- src/op/gemm.h | 4 +- src/op/gemm_py.cc | 22 ++- src/op/gemm_py.h | 9 +- src/op/gemm_sp.cc | 16 ++- src/op/gemm_sp.h | 7 +- src/op/operator.cc | 11 +- src/op/operator.h | 13 +- src/op/reduce.cc | 15 +-- src/op/reduce.h | 4 +- src/op/region.cc | 99 +++++--------- src/op/region.h | 99 +++++--------- src/op/utils.cc | 21 +-- src/op/utils.h | 6 +- src/transform/layout_inference.cc | 21 ++- src/transform/layout_reducer.cc | 34 ++++- src/transform/lower_tile_op.cc | 3 +- .../python/issue/test_tilelang_issue_830.py | 10 ++ tilelang/intrinsics/mfma_macro_generator.py | 40 +++++- tilelang/intrinsics/mma_macro_generator.py | 41 +++++- .../intrinsics/mma_sm70_macro_generator.py | 6 +- tilelang/language/atomic.py | 25 +--- tilelang/language/copy.py | 31 +---- tilelang/language/experimental/gemm_sp.py | 18 +-- tilelang/language/fill.py | 24 +--- tilelang/language/gemm.py | 39 +++--- tilelang/language/reduce.py | 28 ++-- tilelang/language/utils.py | 85 ++---------- tilelang/tileop/gemm/gemm_base.py | 4 + tilelang/tileop/gemm/gemm_tcgen05.py | 11 +- tilelang/utils/__init__.py | 1 + tilelang/utils/language.py | 73 ++++++---- 43 files changed, 535 insertions(+), 602 deletions(-) diff --git a/examples/deepseek_mla/test_example_mla_decode.py b/examples/deepseek_mla/test_example_mla_decode.py index 66a750f7..a269ea57 100644 --- a/examples/deepseek_mla/test_example_mla_decode.py +++ b/examples/deepseek_mla/test_example_mla_decode.py @@ -1,5 +1,4 @@ import tilelang.testing - import example_mla_decode diff --git a/examples/gemv/example_gemv.py b/examples/gemv/example_gemv.py index 4e43dcd9..58e0114b 100644 --- a/examples/gemv/example_gemv.py +++ b/examples/gemv/example_gemv.py @@ -334,14 +334,14 @@ def get_autotuned_kernel( return main -def check_correctness_and_bench(kernel, N, K, bench_ref=True): +def check_correctness_and_bench(kernel, N, K, do_bench=True): profiler = kernel.get_profiler() profiler.assert_allclose(lambda x, y: x @ y.T, atol=1e-2, rtol=1e-2) - if bench_ref: + if do_bench: latency = profiler.do_bench(lambda x, y: x @ y.T, warmup=50) print(f"Torch Latency: {latency} ms") - latency = profiler.do_bench(kernel, warmup=50) - print(f"TileLang Latency: {latency} ms\n") + latency = profiler.do_bench(kernel, warmup=50) + print(f"TileLang Latency: {latency} ms\n") def main(do_bench: bool = True): @@ -350,12 +350,13 @@ def main(do_bench: bool = True): parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K") args, _ = parser.parse_known_args() N, K = args.n, args.k - check_correctness_and_bench(naive_gemv(N, K, 128, 128), N, K) - check_correctness_and_bench(naive_splitk_gemv(N, K, 32, 32), N, K) - check_correctness_and_bench(splitk_gemv(N, K, 32, 32, 32), N, K) - check_correctness_and_bench(splitk_gemv_vectorized(N, K, 2, 32), N, K) - check_correctness_and_bench(splitk_gemv_vectorized_tvm(N, K, 2, 32), N, K) - check_correctness_and_bench(gemv_alloc_reducer(N, K, block_M=128, block_N=128), N, K) + check_correctness_and_bench(naive_gemv(N, K, 128, 128), N, K, do_bench=do_bench) + check_correctness_and_bench(naive_splitk_gemv(N, K, 32, 32), N, K, do_bench=do_bench) + check_correctness_and_bench(splitk_gemv(N, K, 32, 32, 32), N, K, do_bench=do_bench) + check_correctness_and_bench(splitk_gemv_vectorized(N, K, 2, 32), N, K, do_bench=do_bench) + check_correctness_and_bench(splitk_gemv_vectorized_tvm(N, K, 2, 32), N, K, do_bench=do_bench) + check_correctness_and_bench( + gemv_alloc_reducer(N, K, block_M=128, block_N=128), N, K, do_bench=do_bench) print("Test passed!") diff --git a/examples/gemv/test_example_gemv.py b/examples/gemv/test_example_gemv.py index 3881ca76..323337a7 100644 --- a/examples/gemv/test_example_gemv.py +++ b/examples/gemv/test_example_gemv.py @@ -1,5 +1,3 @@ -import tilelang.testing - import example_gemv @@ -8,4 +6,4 @@ def test_example_gemv(): if __name__ == "__main__": - tilelang.testing.main() + test_example_gemv() diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc index 57e0d8b7..1a49b770 100644 --- a/src/op/atomic_add.cc +++ b/src/op/atomic_add.cc @@ -5,7 +5,7 @@ */ #include "./atomic_add.h" -#include "./region.h" +#include "utils.h" #include #include #include @@ -26,32 +26,27 @@ using namespace tir; * @brief Construct an AtomicAdd operator from call arguments and a buffer map. * * Builds the internal AtomicAddNode, extracts the source and destination - * regions and their backing Buffers from the first two call-style expressions - * in `args` (via RegionOp), and stores them along with their ranges. If a third - * argument is provided, it is interpreted as an integer immediate and stored as - * the node's coalesced width. + * regions and their backing Buffers from the first two region-style expressions + * in `args` (BufferLoad/BufferRegion), and stores them along with their + * ranges. If a third argument is provided, it is interpreted as an integer + * immediate and stored as the node's coalesced width. * * @param args Call-style PrimExprs where: * - args[0] is the source region call, * - args[1] is the destination region call, * - args[2] (optional) is an IntImm specifying coalesced width. - * @param vmap Mapping from buffers used by RegionOp to concrete Buffer objects. - * * Notes: - * - The constructor checks that args[0] and args[1] are CallNodes. + * - The constructor checks that args[0] and args[1] are region-compatible. * - The constructed node is stored in this->data_. */ -AtomicAdd::AtomicAdd(Array args, BufferMap vmap) { +AtomicAdd::AtomicAdd(Array args) { ObjectPtr node = tvm::ffi::make_object(); Array rgs[2]; Buffer bf[2]; for (int i = 0; i < 2; i++) { - auto expr = args[i]; - auto call = expr.as(); - ICHECK(call); - auto region = RegionOp(call->args, vmap); - rgs[i] = region->GetRanges(); - bf[i] = region->GetBuffer(); + auto region = NormalizeToBufferRegion(args[i]); + rgs[i] = region->region; + bf[i] = region->buffer; } std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); @@ -552,4 +547,4 @@ TIR_REGISTER_TL_OP(AtomicAdd, atomicadd) TVM_FFI_STATIC_INIT_BLOCK() { AtomicAddNode::RegisterReflection(); } } // namespace tl -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/src/op/atomic_add.h b/src/op/atomic_add.h index f3aaacdb..c6beb70e 100644 --- a/src/op/atomic_add.h +++ b/src/op/atomic_add.h @@ -65,7 +65,7 @@ class AtomicAdd : public TileOperator { public: TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AtomicAdd, TileOperator, AtomicAddNode); - TVM_DLL AtomicAdd(Array args, BufferMap vmap); + TVM_DLL AtomicAdd(Array args); static const Op &Get(); }; diff --git a/src/op/copy.cc b/src/op/copy.cc index 82c903f8..9b93fea1 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -16,7 +16,7 @@ #include "../transform/common/loop_parallel_transform_utils.h" #include "../transform/loop_partition.h" #include "../transform/loop_vectorize.h" -#include "region.h" +#include "utils.h" #include "../target/cuda.h" #include "../target/utils.h" @@ -110,36 +110,32 @@ template static Array ReverseArray(Array array) { /*! * \brief Construct a Copy operator node from call arguments and a buffer map. * - * This constructor parses the first two entries of `args` as Call nodes - * describing source and destination Regions (via RegionOp), extracts their - * Buffers and Ranges, and stores them on the newly created CopyNode. It also + * This constructor parses the first two entries of `args` as regions + * (BufferLoad/BufferRegion), extracts their Buffers and Ranges, and stores + * them on the newly created CopyNode. It also * reads optional arguments: * - args[2] (IntImm): coalesced width (stored only if > 0), * - args[3] (Bool): disable TMA lowering flag, * - args[4] (IntImm): eviction policy. * * Preconditions: - * - `args` must contain at least two Call-compatible PrimExpr entries - * describing regions; an ICHECK will fail if they are not CallNodes. + * - `args` must contain at least two region-compatible PrimExpr entries + * (BufferLoad/BufferRegion); ICHECK will fail otherwise. * * @param args Array of PrimExpr where: * - args[0] is the source Region call, * - args[1] is the destination Region call, * - optional args[2..4] are coalesced width, disable_tma, and eviction * policy. - * @param vmap BufferMap used to resolve RegionOp buffers and ranges. */ -Copy::Copy(Array args, BufferMap vmap) { +Copy::Copy(Array args) { ObjectPtr node = tvm::ffi::make_object(); Array rgs[2]; Buffer bf[2]; for (int i = 0; i < 2; i++) { - auto expr = args[i]; - auto call = expr.as(); - ICHECK(call); - auto region = RegionOp(call->args, vmap); - rgs[i] = region->GetRanges(); - bf[i] = region->GetBuffer(); + auto region = NormalizeToBufferRegion(args[i]); + rgs[i] = region->region; + bf[i] = region->buffer; } std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); @@ -250,6 +246,7 @@ PrimExpr CopyNode::MakePredicate(arith::Analyzer *analyzer, const Array &ivs, Array extents, int src_dst) const { Array ranges = src_dst == 0 ? src_range : dst_range; + Array cond_list; ICHECK(extents.size() == ranges.size()) << extents << " " << ranges; size_t idx = 0; @@ -302,7 +299,6 @@ For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { for (const auto &iv : loop_vars) analyzer->Bind(iv->var, iv->dom); - ICHECK(loop_vars.size() <= src_range.size()) << "loop_vars.size() = " << loop_vars.size() << ", src_range.size() = " << src_range.size() << ", src = " << src->name @@ -1729,20 +1725,21 @@ Array TMADesc::EncodeCallArgs() const { * GPU intrinsics. * * @param args Array of PrimExpr TL-call arguments (see list above). - * @param vmap Mapping from original buffer variables to actual Buffer objects. */ -Conv2DIm2ColOp::Conv2DIm2ColOp(Array args, BufferMap vmap) { +Conv2DIm2ColOp::Conv2DIm2ColOp(Array args) { ObjectPtr node = tvm::ffi::make_object(); - node->src = vmap[GetVarFromAccessPtr(args[0])]; - node->dst = vmap[GetVarFromAccessPtr(args[1])]; - node->nhw_step = args[2]; - node->c_step = args[3]; - node->kernel = args[4].as().value()->value; - node->stride = args[5].as().value()->value; - node->dilation = args[6].as().value()->value; - node->padding = args[7].as().value()->value; - node->eviction_policy = args[8].as().value()->value; + node->srcRegion_ = NormalizeToBufferRegion(args[0]); + node->dstRegion_ = NormalizeToBufferRegion(args[1]); + node->src_ = node->srcRegion_->buffer; + node->dst_ = node->dstRegion_->buffer; + node->nhw_step_ = args[2]; + node->c_step_ = args[3]; + node->kernel_ = args[4].as().value()->value; + node->stride_ = args[5].as().value()->value; + node->dilation_ = args[6].as().value()->value; + node->padding_ = args[7].as().value()->value; + node->eviction_policy_ = args[8].as().value()->value; data_ = std::move(node); } @@ -1793,24 +1790,24 @@ TileOperator Conv2DIm2ColOpNode::Clone() const { Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ICHECK(TargetIsHopper(T.target)); - ICHECK(src.scope() == "global" && - (dst.scope() == "shared.dyn" || dst.scope() == "shared")); - ICHECK(src->shape.size() == 4); - ICHECK(dst->shape.size() == 2); - ICHECK(src->dtype == dst->dtype); + ICHECK(src_.scope() == "global" && + (dst_.scope() == "shared.dyn" || dst_.scope() == "shared")); + ICHECK(src_->shape.size() == 4); + ICHECK(dst_->shape.size() == 2); + ICHECK(src_->dtype == dst_->dtype); Layout shared_layout; - if (T.layout_map.count(dst)) { - shared_layout = T.layout_map[dst]; + if (T.layout_map.count(dst_)) { + shared_layout = T.layout_map[dst_]; } TMAIm2ColDesc desc; - desc.rank = src->shape.size(); - desc.data_type = to_CUtensorMapDataType(src->dtype); - desc.global_addr = src->data; - desc.global_shape = ReverseArray(src->shape); + desc.rank = src_->shape.size(); + desc.data_type = to_CUtensorMapDataType(src_->dtype); + desc.global_addr = src_->data; + desc.global_shape = ReverseArray(src_->shape); - if (!src->strides.empty()) { - desc.global_stride = ReverseArray(src->strides); + if (!src_->strides.empty()) { + desc.global_stride = ReverseArray(src_->strides); } else { // Create stride from shape PrimExpr stride = 1; @@ -1824,13 +1821,13 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T, ICHECK(is_one(desc.global_stride[0])) << desc.global_stride; // Make global stride in bytes desc.global_stride = desc.global_stride.Map([&](PrimExpr e) { - return cast(DataType::Int(64), e) * src->dtype.bytes(); + return cast(DataType::Int(64), e) * src_->dtype.bytes(); }); - desc.elem_stride = {1, stride, stride, 1}; - desc.lower_corner = {-padding, -padding}; - desc.upper_corner = {-padding, -padding}; - desc.smem_box_pixel = Downcast(dst->shape[0])->value; - desc.smem_box_channel = Downcast(dst->shape[1])->value; + desc.elem_stride = {1, stride_, stride_, 1}; + desc.lower_corner = {-padding_, -padding_}; + desc.upper_corner = {-padding_, -padding_}; + desc.smem_box_pixel = Downcast(dst_->shape[0])->value; + desc.smem_box_channel = Downcast(dst_->shape[1])->value; desc.l2_promotion = static_cast(CU_TENSOR_MAP_L2_PROMOTION_L2_128B); desc.oob_fill = static_cast(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); desc.interleave = static_cast(CU_TENSOR_MAP_INTERLEAVE_NONE); @@ -1844,15 +1841,15 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T, if (StructuralEqual()(shared_layout, makeQuarterBankSwizzleLayout(*stride, *continuous, - dst->dtype.bits()))) { + dst_->dtype.bits()))) { desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_32B); } else if (StructuralEqual()(shared_layout, makeHalfBankSwizzleLayout( *stride, *continuous, - dst->dtype.bits()))) { + dst_->dtype.bits()))) { desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_64B); } else if (StructuralEqual()(shared_layout, makeFullBankSwizzleLayout( *stride, *continuous, - dst->dtype.bits()))) { + dst_->dtype.bits()))) { desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_128B); } else { ICHECK(0) << "Cannot detect TMA layout."; @@ -1871,43 +1868,43 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T, << "Currently can only support divisible channel case"; global_coords.push_back( - FloorMod(c_step * desc.smem_box_channel, desc.global_shape[0])); + FloorMod(c_step_ * desc.smem_box_channel, desc.global_shape[0])); image_offset.push_back( - dilation * - FloorMod(FloorDiv(c_step * desc.smem_box_channel, desc.global_shape[0]), - kernel)); - image_offset.push_back(dilation * FloorDiv(c_step * desc.smem_box_channel, - desc.global_shape[0] * kernel)); + dilation_ * + FloorMod(FloorDiv(c_step_ * desc.smem_box_channel, desc.global_shape[0]), + kernel_)); + image_offset.push_back(dilation_ * FloorDiv(c_step_ * desc.smem_box_channel, + desc.global_shape[0] * kernel_)); PrimExpr h_dim = - FloorDiv(src->shape[1] + 2 * padding - (kernel - 1) * dilation - 1, - stride) + + FloorDiv(src_->shape[1] + 2 * padding_ - (kernel_ - 1) * dilation_ - 1, + stride_) + 1; PrimExpr w_dim = - FloorDiv(src->shape[2] + 2 * padding - (kernel - 1) * dilation - 1, - stride) + + FloorDiv(src_->shape[2] + 2 * padding_ - (kernel_ - 1) * dilation_ - 1, + stride_) + 1; global_coords.push_back( - stride * FloorMod(nhw_step * desc.smem_box_pixel, w_dim) - padding); + stride_ * FloorMod(nhw_step_ * desc.smem_box_pixel, w_dim) - padding_); global_coords.push_back( - stride * - FloorMod(FloorDiv(nhw_step * desc.smem_box_pixel, w_dim), h_dim) - - padding); + stride_ * + FloorMod(FloorDiv(nhw_step_ * desc.smem_box_pixel, w_dim), h_dim) - + padding_); global_coords.push_back( - FloorDiv(nhw_step * desc.smem_box_pixel, w_dim * h_dim)); + FloorDiv(nhw_step_ * desc.smem_box_pixel, w_dim * h_dim)); Array args; args.reserve(desc.rank * 2 + 2); args.push_back(create_desc); args.push_back(0); // mbar placeholder - auto dst_buffer = T.buffer_remap.count(dst) ? T.buffer_remap[dst] : dst; + auto dst_buffer = T.buffer_remap.count(dst_) ? T.buffer_remap[dst_] : dst_; auto shared_addr = dst_buffer.access_ptr(2); args.push_back(shared_addr); for (auto coord : global_coords) args.push_back(coord); for (auto offset : image_offset) args.push_back(offset); - args.push_back(this->eviction_policy); + args.push_back(this->eviction_policy_); Stmt tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), Evaluate(Call(DataType::Handle(), tma_load_im2col(), args))); diff --git a/src/op/copy.h b/src/op/copy.h index ef46b9ed..b08f5768 100644 --- a/src/op/copy.h +++ b/src/op/copy.h @@ -280,7 +280,7 @@ public: * \param args Expression arguments for the copy. * \param vmap Buffer variable mapping. */ - TVM_DLL Copy(Array args, BufferMap vmap); + TVM_DLL Copy(Array args); /*! * \brief Get the TVM Op handle corresponding to this Copy op. @@ -296,14 +296,16 @@ public: */ class Conv2DIm2ColOpNode : public TileOperatorNode { public: - Buffer src, dst; // Source (input feature map) and destination (im2col matrix) - int stride; // Stride for convolution - int padding; // Padding amount - int dilation; // Dilation factor - int kernel; // Kernel size - int eviction_policy; // Cache eviction policy - PrimExpr nhw_step; // Step size in NHW dimensions - PrimExpr c_step; // Step size in channel dimension + BufferRegion srcRegion_, dstRegion_; + Buffer src_, + dst_; // Source (input feature map) and destination (im2col matrix) + int stride_; // Stride for convolution + int padding_; // Padding amount + int dilation_; // Dilation factor + int kernel_; // Kernel size + int eviction_policy_; // Cache eviction policy + PrimExpr nhw_step_; // Step size in NHW dimensions + PrimExpr c_step_; // Step size in channel dimension TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Conv2DIm2Col", Conv2DIm2ColOpNode, TileOperatorNode); @@ -311,13 +313,15 @@ public: static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("src", &Conv2DIm2ColOpNode::src) - .def_ro("dst", &Conv2DIm2ColOpNode::dst) - .def_ro("stride", &Conv2DIm2ColOpNode::stride) - .def_ro("padding", &Conv2DIm2ColOpNode::padding) - .def_ro("dilation", &Conv2DIm2ColOpNode::dilation) - .def_ro("kernel", &Conv2DIm2ColOpNode::kernel) - .def_ro("eviction_policy", &Conv2DIm2ColOpNode::eviction_policy); + .def_ro("srcRegion", &Conv2DIm2ColOpNode::srcRegion_) + .def_ro("dstRegion", &Conv2DIm2ColOpNode::dstRegion_) + .def_ro("src", &Conv2DIm2ColOpNode::src_) + .def_ro("dst", &Conv2DIm2ColOpNode::dst_) + .def_ro("stride", &Conv2DIm2ColOpNode::stride_) + .def_ro("padding", &Conv2DIm2ColOpNode::padding_) + .def_ro("dilation", &Conv2DIm2ColOpNode::dilation_) + .def_ro("kernel", &Conv2DIm2ColOpNode::kernel_) + .def_ro("eviction_policy", &Conv2DIm2ColOpNode::eviction_policy_); } /*! @@ -342,7 +346,7 @@ class Conv2DIm2ColOp : public TileOperator { public: TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Conv2DIm2ColOp, TileOperator, Conv2DIm2ColOpNode); - TVM_DLL Conv2DIm2ColOp(Array args, BufferMap vmap); + TVM_DLL Conv2DIm2ColOp(Array args); static const Op &Get(); }; diff --git a/src/op/fill.cc b/src/op/fill.cc index 93b3bca0..5a773768 100644 --- a/src/op/fill.cc +++ b/src/op/fill.cc @@ -17,7 +17,7 @@ #include "../transform/loop_partition.h" #include "../transform/loop_vectorize.h" #include "builtin.h" -#include "region.h" +#include "utils.h" namespace tvm { namespace tl { @@ -52,62 +52,18 @@ using namespace tir; * value]. * - args[0]: destination access (BufferLoad or pointer expression). * - args[1]: value to fill (scalar or vector). - * @param vmap Mapping from buffer variables to Buffer objects; used to resolve - * the destination when args[0] is not a BufferLoad. * * Notes: * - The constructor enforces constraints (e.g., stride == 1 ramps, constant * lanes) and will terminate (via CHECK/ICHECK) if inputs are unsupported or out * of bounds. */ -Fill::Fill(Array args, BufferMap vmap) { +Fill::Fill(Array args) { ObjectPtr node = tvm::ffi::make_object(); - // Case 1: Region descriptor call (tl.region) - if (const auto *call = args[0].as()) { - if (call->op.same_as(RegionOp::Get())) { - auto region = RegionOp(call->args, vmap); - node->dst = region->GetBuffer(); - node->region = region->GetRanges(); - } else if (call->op.same_as(builtin::tvm_access_ptr())) { - node->dst = vmap[GetVarFromAccessPtr(args[0])]; - for (int i = 0; i < node->dst->shape.size(); i++) { - node->region.push_back(Range(0, node->dst->shape[i])); - } - } else { - ICHECK(false) << "Unsupported call op in tl.fill: " - << Downcast(call->op)->name; - } - - // Case 2: Explicit BufferRegion (legacy path) - } else if (args[0]->IsInstance()) { - auto region = Downcast(args[0]); - node->dst = region->buffer; - node->region = region->region; - - // Case 3: Vector/scalar region expressed via BufferLoad indices - } else if (args[0]->IsInstance()) { - auto buffer_load = Downcast(args[0]); - for (const auto &index : buffer_load->indices) { - if (const auto *ramp = index.as()) { - CHECK(ramp->stride.as()->value == 1) - << "Only stride 1 ramps are supported"; - const auto *lanes = ramp->lanes.as(); - CHECK(lanes) - << "Scalable vectors not supported in BufferRegion conversion"; - node->region.push_back(Range::FromMinExtent(ramp->base, ramp->lanes)); - } else { - node->region.push_back(Range::FromMinExtent(index, 1)); - } - } - node->dst = buffer_load->buffer; - // Case 4: Access pointer, fill the full buffer - } else { - node->dst = vmap[GetVarFromAccessPtr(args[0])]; - for (int i = 0; i < node->dst->shape.size(); i++) { - node->region.push_back(Range(0, node->dst->shape[i])); - } - } + BufferRegion region = NormalizeToBufferRegion(args[0]); + node->dst = region->buffer; + node->region = region->region; if (args[1]->dtype != node->dst->dtype) { node->value = Cast(node->dst->dtype, args[1]); diff --git a/src/op/fill.h b/src/op/fill.h index 8f1dd900..c10a5cfb 100644 --- a/src/op/fill.h +++ b/src/op/fill.h @@ -45,7 +45,7 @@ private: class Fill : public TileOperator { public: TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fill, TileOperator, FillNode); - TVM_DLL Fill(Array args, BufferMap vmap); + TVM_DLL Fill(Array args); static const Op &Get(); }; diff --git a/src/op/finalize_reducer.cc b/src/op/finalize_reducer.cc index 84b18897..effc4baf 100644 --- a/src/op/finalize_reducer.cc +++ b/src/op/finalize_reducer.cc @@ -12,6 +12,7 @@ #include #include "../target/utils.h" +#include "utils.h" namespace tvm { namespace tl { @@ -29,12 +30,14 @@ using namespace tir; * @param args TL operator arguments: expects at least two elements where * `args[0]` is an access pointer identifying the reducer variable * and `args[1]` is an integer encoding a `ReducerOpType` (e.g., Sum/Max/Min). - * @param vmap Mapping from variables to Buffers used to look up the reducer - * Buffer. */ -FinalizeReducerOp::FinalizeReducerOp(Array args, BufferMap vmap) { +FinalizeReducerOp::FinalizeReducerOp(Array args) { auto node = tvm::ffi::make_object(); - node->reducer = vmap[GetVarFromAccessPtr(args[0])]; + // Normalize any supported region expression + // (BufferRegion/BufferLoad/tl.region) to a BufferRegion, then take the + // underlying Buffer as reducer. + auto region = NormalizeToBufferRegion(args[0]); + node->reducer = region->buffer; node->op = (ReducerOpType)*as_const_int(args[1]); data_ = std::move(node); } diff --git a/src/op/finalize_reducer.h b/src/op/finalize_reducer.h index ef49ee19..99e1e7cb 100644 --- a/src/op/finalize_reducer.h +++ b/src/op/finalize_reducer.h @@ -48,7 +48,7 @@ class FinalizeReducerOp : public TileOperator { public: TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FinalizeReducerOp, TileOperator, FinalizeReducerOpNode); - TVM_DLL FinalizeReducerOp(Array args, BufferMap vmap); + TVM_DLL FinalizeReducerOp(Array args); static const Op &Get(); }; diff --git a/src/op/gemm.cc b/src/op/gemm.cc index cece1e6f..5a98cba6 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -12,7 +12,6 @@ #include #include "../target/utils.h" -#include "region.h" #include "tcgen5_meta.h" #include "utils.h" @@ -42,8 +41,6 @@ using namespace tir; * M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool), * stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int), * (optional) kPack (Int), (optional) wg_wait (Int)] - * @param vmap Mapping from access pointer vars to Buffer objects used to - * resolve the Buffer corresponding to each pointer argument. * * @note If `kPack` is provided it must be 1; otherwise the constructor * fails with an ICHECK (runtime assertion). No other validation is @@ -53,12 +50,12 @@ using namespace tir; // MakeAccessPtrFromRegion moved to src/op/utils.{h,cc} -Gemm::Gemm(Array args, BufferMap vmap) { +Gemm::Gemm(Array args) { ObjectPtr node = tvm::ffi::make_object(); - node->aRegion_ = NormalizeToBufferRegion(args[0], vmap); - node->bRegion_ = NormalizeToBufferRegion(args[1], vmap); - node->cRegion_ = NormalizeToBufferRegion(args[2], vmap); + node->aRegion_ = NormalizeToBufferRegion(args[0]); + node->bRegion_ = NormalizeToBufferRegion(args[1]); + node->cRegion_ = NormalizeToBufferRegion(args[2]); node->a_ = node->aRegion_->buffer; node->b_ = node->bRegion_->buffer; @@ -83,11 +80,14 @@ Gemm::Gemm(Array args, BufferMap vmap) { if (args.size() > 15) { node->wgWait_ = args[15].as().value()->value; } - node->mbarPtr_ = args[16]; - if (node->mbarPtr_.as()) { - node->mbar_ = vmap[GetVarFromAccessPtr(node->mbarPtr_)]; - } else { - node->mbar_ = std::nullopt; + if (args.size() > 16) { + if (const auto *load = args[16].as()) { + node->mbarRegion_ = + NormalizeToBufferRegion(Downcast(args[16])); + node->mbar_ = node->mbarRegion_->buffer; + } else { + node->mbar_ = std::nullopt; + } } node->cCoords_ = Array( {args[17].as().value(), args[18].as().value()}); @@ -500,11 +500,13 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { auto C_buffer = T.buffer_remap.count(c_) ? T.buffer_remap[c_] : c_; Array new_args; + auto mbarPtr = + MakeAccessPtrFromRegion(mbarRegion_, /*rw*/ 3, /*require_2d*/ true); new_args.push_back(StringImm(ss.str())); new_args.push_back(Aptr); new_args.push_back(Bptr); new_args.push_back(BufferLoad(C_buffer, cCoords_)); - new_args.push_back(mbarPtr_); + new_args.push_back(mbarPtr); new_args.push_back(clearAccum_); auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args); diff --git a/src/op/gemm.h b/src/op/gemm.h index 1c976055..3ec58bec 100644 --- a/src/op/gemm.h +++ b/src/op/gemm.h @@ -97,7 +97,7 @@ public: // only will be enabled under cdna mfma instructions int kPack_ = 1; int wgWait_ = 0; - PrimExpr mbarPtr_; + BufferRegion mbarRegion_; std::optional mbar_; // mbar is optional, only used for TCGEN5MMA Array cCoords_; mutable GemmWarpPolicy policy_; @@ -144,7 +144,7 @@ private: class Gemm : public TileOperator { public: TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Gemm, TileOperator, GemmNode); - TVM_DLL Gemm(Array args, BufferMap vmap); + TVM_DLL Gemm(Array args); static const Op &Get(); }; diff --git a/src/op/gemm_py.cc b/src/op/gemm_py.cc index a6ddef64..511a4283 100644 --- a/src/op/gemm_py.cc +++ b/src/op/gemm_py.cc @@ -12,7 +12,6 @@ #include #include "../target/utils.h" -#include "region.h" #include "tcgen5_meta.h" #include "utils.h" @@ -46,19 +45,17 @@ using namespace tir; * M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool), * stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int), * (optional) kPack (Int), (optional) wg_wait (Int)] - * @param vmap Mapping from access pointer vars to Buffer objects used to - * resolve the Buffer corresponding to each pointer argument. * * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor * fails with an ICHECK (runtime assertion). No other validation is * performed here. */ -GemmPy::GemmPy(Array args, BufferMap vmap) { +GemmPy::GemmPy(Array args) { ObjectPtr node = tvm::ffi::make_object(); - node->aRegion_ = NormalizeToBufferRegion(args[0], vmap); - node->bRegion_ = NormalizeToBufferRegion(args[1], vmap); - node->cRegion_ = NormalizeToBufferRegion(args[2], vmap); + node->aRegion_ = NormalizeToBufferRegion(args[0]); + node->bRegion_ = NormalizeToBufferRegion(args[1]); + node->cRegion_ = NormalizeToBufferRegion(args[2]); node->a_ = node->aRegion_->buffer; node->b_ = node->bRegion_->buffer; @@ -83,11 +80,12 @@ GemmPy::GemmPy(Array args, BufferMap vmap) { if (args.size() > 15) { node->wgWait_ = args[15].as().value()->value; } - node->mbarPtr_ = args[16]; - if (node->mbarPtr_.as()) { - node->mbar_ = vmap[GetVarFromAccessPtr(node->mbarPtr_)]; - } else { - node->mbar_ = std::nullopt; + if (args.size() > 16) { + if (const auto *load = args[16].as()) { + node->mbarRegion_ = + NormalizeToBufferRegion(Downcast(args[16])); + node->mbar_ = node->mbarRegion_->buffer; + } } node->cCoords_ = Array( {args[17].as().value(), args[18].as().value()}); diff --git a/src/op/gemm_py.h b/src/op/gemm_py.h index 0678588e..2fe47be8 100644 --- a/src/op/gemm_py.h +++ b/src/op/gemm_py.h @@ -29,8 +29,8 @@ public: int strideA_, strideB_; int offsetA_, offsetB_; PrimExpr clearAccum_ = const_false(); - PrimExpr mbarPtr_; - std::optional mbar_; // mbar is optional, only used for TCGEN5MMA + BufferRegion mbarRegion_; + tir::Buffer mbar_; // mbar is optional, only used for TCGEN5MMA Array cCoords_; // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack // only will be enabled under cdna mfma instructions @@ -59,7 +59,8 @@ public: .def_ro("offsetA", &GemmPyNode::offsetA_) .def_ro("offsetB", &GemmPyNode::offsetB_) .def_ro("clearAccum", &GemmPyNode::clearAccum_) - .def_ro("mbarPtr", &GemmPyNode::mbarPtr_) + .def_ro("mbarRegion", &GemmPyNode::mbarRegion_) + .def_ro("mbar", &GemmPyNode::mbar_) .def_ro("cCoords", &GemmPyNode::cCoords_) .def_ro("kPack", &GemmPyNode::kPack_) .def_ro("wgWait", &GemmPyNode::wgWait_) @@ -82,7 +83,7 @@ private: class GemmPy : public TileOperator { public: TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmPy, TileOperator, GemmPyNode); - TVM_DLL GemmPy(Array args, BufferMap vmap); + TVM_DLL GemmPy(Array args); static const Op &Get(); }; diff --git a/src/op/gemm_sp.cc b/src/op/gemm_sp.cc index 52a119e0..df923d0e 100644 --- a/src/op/gemm_sp.cc +++ b/src/op/gemm_sp.cc @@ -14,6 +14,7 @@ #include "../target/utils.h" #include "builtin.h" #include "gemm.h" +#include "utils.h" namespace tvm { namespace tl { @@ -79,16 +80,19 @@ std::pair GemmSPWarpPolicyNode::computeWarpPartition(int M, int N, * The populated GemmSPNode is stored in the instance's internal data_ pointer. * * @param args Positional TL call arguments in the above order. - * @param vmap BufferMap mapping access pointers (from args) to Buffer objects. * * @note An ICHECK failure is raised if a provided kPack is not 1 or 2. */ -GemmSP::GemmSP(Array args, BufferMap vmap) { +GemmSP::GemmSP(Array args) { ObjectPtr node = tvm::ffi::make_object(); - node->a_ = vmap[GetVarFromAccessPtr(args[0])]; - node->e_ = vmap[GetVarFromAccessPtr(args[1])]; - node->b_ = vmap[GetVarFromAccessPtr(args[2])]; - node->c_ = vmap[GetVarFromAccessPtr(args[3])]; + node->aRegion_ = NormalizeToBufferRegion(args[0]); + node->eRegion_ = NormalizeToBufferRegion(args[1]); + node->bRegion_ = NormalizeToBufferRegion(args[2]); + node->cRegion_ = NormalizeToBufferRegion(args[3]); + node->a_ = node->aRegion_->buffer; + node->e_ = node->eRegion_->buffer; + node->b_ = node->bRegion_->buffer; + node->c_ = node->cRegion_->buffer; node->transA_ = args[4].as().value(); node->transB_ = args[5].as().value(); node->m_ = args[6].as().value()->value; diff --git a/src/op/gemm_sp.h b/src/op/gemm_sp.h index 1eb535a5..aae5b27b 100644 --- a/src/op/gemm_sp.h +++ b/src/op/gemm_sp.h @@ -53,6 +53,7 @@ public: class GemmSPNode : public TileOperatorNode { public: + BufferRegion aRegion_, bRegion_, cRegion_, eRegion_; tir::Buffer a_, b_, c_, e_; bool transA_, transB_; int m_, n_, k_; @@ -75,6 +76,10 @@ public: namespace refl = tvm::ffi::reflection; refl::ObjectDef() .def_ro("policy", &GemmSPNode::policy_) + .def_ro("aRegion", &GemmSPNode::aRegion_) + .def_ro("bRegion", &GemmSPNode::bRegion_) + .def_ro("cRegion", &GemmSPNode::cRegion_) + .def_ro("eRegion", &GemmSPNode::eRegion_) .def_ro("a", &GemmSPNode::a_) .def_ro("b", &GemmSPNode::b_) .def_ro("c", &GemmSPNode::c_) @@ -96,7 +101,7 @@ private: class GemmSP : public TileOperator { public: TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSP, TileOperator, GemmSPNode); - TVM_DLL GemmSP(Array args, BufferMap vmap); + TVM_DLL GemmSP(Array args); static const Op &Get(); }; diff --git a/src/op/operator.cc b/src/op/operator.cc index b751559c..302ee3e3 100644 --- a/src/op/operator.cc +++ b/src/op/operator.cc @@ -24,16 +24,14 @@ using namespace tir; * * @param call The TIR Call whose operator and arguments will be used to build * the TileOperator. - * @param vmap Buffer mapping passed through to the builder to resolve buffer - * references. * @return TileOperator The constructed TileOperator, or a default (empty) * TileOperator if no builder exists. */ -TileOperator ParseOperator(Call call, BufferMap vmap) { +TileOperator ParseOperator(Call call) { auto op_map = Op::GetAttrMap("TLOpBuilder"); Op op = call->op.as().value(); if (op_map.count(op)) { - auto tile_op = op_map[op](call->args, vmap); + auto tile_op = op_map[op](call->args); ICHECK(tile_op.defined()); return tile_op; } @@ -48,14 +46,13 @@ TileOperator ParseOperator(Call call, BufferMap vmap) { * Otherwise returns a default-constructed (empty) TileOperator. * * @param stmt TIR statement to inspect; expected to be an Evaluate of a Call. - * @param vmap Mapping of buffer variables used when building the operator. * @return TileOperator Parsed operator on success, or a default (empty) * TileOperator if `stmt` is not an Evaluate(Call). */ -TileOperator ParseOperator(Stmt stmt, BufferMap vmap) { +TileOperator ParseOperator(Stmt stmt) { if (stmt.as() && stmt.as()->value.as()) { auto call = stmt.as()->value.as(); - return ParseOperator(tvm::ffi::GetRef(call), vmap); + return ParseOperator(tvm::ffi::GetRef(call)); } return TileOperator(); } diff --git a/src/op/operator.h b/src/op/operator.h index 628b83b2..0d9f859a 100644 --- a/src/op/operator.h +++ b/src/op/operator.h @@ -72,11 +72,10 @@ public: Var GetVarFromAccessPtr(const PrimExpr &expr); -TileOperator ParseOperator(Call call, BufferMap vmap); -TileOperator ParseOperator(Stmt stmt, BufferMap vmap); +TileOperator ParseOperator(Call call); +TileOperator ParseOperator(Stmt stmt); -using OpBuilderFunc = - ffi::TypedFunction, BufferMap)>; +using OpBuilderFunc = ffi::TypedFunction)>; #define TIR_REGISTER_TL_OP(Entry, OpName) \ const Op &Entry::Get() { \ @@ -85,10 +84,8 @@ using OpBuilderFunc = } \ TVM_REGISTER_OP("tl." #OpName) \ .set_attr("TScriptPrinterName", #OpName) \ - .set_attr("TLOpBuilder", \ - [](Array args, BufferMap vmap) { \ - return Entry(args, vmap); \ - }) + .set_attr( \ + "TLOpBuilder", [](Array args) { return Entry(args); }) } // namespace tl } // namespace tvm diff --git a/src/op/reduce.cc b/src/op/reduce.cc index c326f5ac..caf9198a 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -14,7 +14,6 @@ #include "../op/parallel.h" #include "../target/utils.h" #include "../transform/loop_partition.h" -#include "region.h" #include "tir/transforms/ir_utils.h" #include "tvm/tir/stmt.h" #include "utils.h" @@ -28,11 +27,11 @@ using namespace tir; // MakeAccessPtrFromRegion moved to src/op/utils.{h,cc} -ReduceOp::ReduceOp(Array args, BufferMap vmap) { +ReduceOp::ReduceOp(Array args) { ObjectPtr node = tvm::ffi::make_object(); - // Accept BufferRegion/BufferLoad/tl.region for src/dst - node->srcRegion_ = NormalizeToBufferRegion(args[0], vmap); - node->dstRegion_ = NormalizeToBufferRegion(args[1], vmap); + // Accept BufferRegion/BufferLoad for src/dst + node->srcRegion_ = NormalizeToBufferRegion(args[0]); + node->dstRegion_ = NormalizeToBufferRegion(args[1]); node->src = node->srcRegion_->buffer; node->dst = node->dstRegion_->buffer; std::string reduce_type = args[2].as().value()->value; @@ -494,7 +493,7 @@ static BufferRegion ConvertBufferToBufferRegion(const Buffer &buf) { return BufferRegion(buf, ranges); } -CumSumOp::CumSumOp(Array args, BufferMap vmap) { +CumSumOp::CumSumOp(Array args) { /// CumSum constructor arguments: /// - src: input buffer /// - dst: output buffer @@ -504,8 +503,8 @@ CumSumOp::CumSumOp(Array args, BufferMap vmap) { ObjectPtr node = tvm::ffi::make_object(); // node->src = vmap[GetVarFromAccessPtr(args[0])]; // node->dst = vmap[GetVarFromAccessPtr(args[1])]; - node->srcRegion_ = NormalizeToBufferRegion(args[0], vmap); - node->dstRegion_ = NormalizeToBufferRegion(args[1], vmap); + node->srcRegion_ = NormalizeToBufferRegion(args[0]); + node->dstRegion_ = NormalizeToBufferRegion(args[1]); node->src = node->srcRegion_->buffer; node->dst = node->dstRegion_->buffer; node->dim = args[2].as().value()->value; diff --git a/src/op/reduce.h b/src/op/reduce.h index eb0599eb..cab3835e 100644 --- a/src/op/reduce.h +++ b/src/op/reduce.h @@ -125,7 +125,7 @@ class ReduceOp : public TileOperator { public: TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ReduceOp, TileOperator, ReduceOpNode); - TVM_DLL ReduceOp(Array args, BufferMap vmap); + TVM_DLL ReduceOp(Array args); static const Op &Get(); }; @@ -163,7 +163,7 @@ class CumSumOp : public TileOperator { public: TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CumSumOp, TileOperator, CumSumOpNode); - TVM_DLL CumSumOp(Array args, BufferMap vmap); + TVM_DLL CumSumOp(Array args); static const Op &Get(); }; diff --git a/src/op/region.cc b/src/op/region.cc index e4984af1..2a1f2745 100644 --- a/src/op/region.cc +++ b/src/op/region.cc @@ -1,7 +1,14 @@ /*! * \file tl/op/region.cc - * \brief Define region operator. + * \brief Define region operator (bridge to carry BufferRegion via Call args). * + * Notes: + * - BufferLoad/Ramp cannot represent a general PrimExpr as a vector lane + * count. Dynamic extents like (H1 - H0) cannot be encoded as + * Ramp(lanes = H1 - H0), and lowering BufferRegion to BufferLoad loses the + * explicit extent information. + * - tl.region carries both mins and extents in Call args and lets the backend + * reconstruct a BufferRegion faithfully. */ #include "region.h" @@ -11,27 +18,7 @@ namespace tvm { namespace tl { using namespace tir; -/** - * @brief Construct a RegionOp from TL operator arguments. - * - * Parses the TL `region` operator call arguments to populate the RegionOpNode: - * - Expects args[0] to be a `BufferLoad` whose `indices` are the per-dimension - * minima. - * - args[1] must be a constant integer used as the access mask. - * - args[2 + i] provides the extent for dimension `i`. - * - * The constructor validates that the number of load indices equals `args.size() - * - 2` and will abort via ICHECK on mismatch or if args[0] is not a - * `BufferLoad`. - * - * Parameters: - * - args: TL operator call arguments in the form - * [BufferLoad(min_i...), access_mask, extent_0, extent_1, ..., - * extent_{n-1}] where n = number of dimensions. - * - vmap: BufferMap passed through by the caller (not documented here as a - * generic utility). - */ -RegionOp::RegionOp(Array args, BufferMap vmap) { +RegionOp::RegionOp(Array args) { size_t n = args.size(); size_t ndim = n - 2; auto load = args[0].as(); @@ -39,10 +26,24 @@ RegionOp::RegionOp(Array args, BufferMap vmap) { ICHECK(load->indices.size() == ndim) << "load->indices.size() = " << load->indices << " ndim = " << ndim; Array ranges; + // Rebuild per-axis ranges from mins (BufferLoad indices) and provided extents for (size_t i = 0; i < ndim; i++) { - PrimExpr min = load->indices[i]; + PrimExpr index = load->indices[i]; PrimExpr extent = args[2 + i]; - ranges.push_back(Range::FromMinExtent(min, extent)); + if (const auto *ramp = index.as()) { + const auto *stride_imm = ramp->stride.as(); + ICHECK(stride_imm && stride_imm->value == 1) + << "RegionOp expects stride-1 Ramp for index"; + if (const auto *lanes_imm = ramp->lanes.as()) { + if (const auto *ext_imm = extent.as()) { + ICHECK_EQ(lanes_imm->value, ext_imm->value) + << "Ramp lanes and provided extent must match"; + } + } + ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes)); + } else { + ranges.push_back(Range::FromMinExtent(index, extent)); + } } ObjectPtr node = tvm::ffi::make_object(); node->buffer_ = load->buffer; @@ -51,26 +52,11 @@ RegionOp::RegionOp(Array args, BufferMap vmap) { data_ = std::move(node); } -/** - * @brief Create a copy of this RegionOpNode and return it as a TileOperator. - * - * @return TileOperator A new TileOperator that owns a copied RegionOpNode. - */ TileOperator RegionOpNode::Clone() const { auto op = tvm::ffi::make_object(*this); return RegionOp(op); } -/** - * @brief Check whether the region spans the entire underlying buffer. - * - * Returns true if for every dimension the range minimum is zero and the - * range extent is structurally equal to the corresponding buffer shape - * dimension. Otherwise returns false. - * - * @return true if the region covers the full buffer in all dimensions; false - * otherwise. - */ bool RegionOpNode::IsFullRegion() const { for (size_t i = 0; i < ranges_.size(); i++) { if (!is_zero(ranges_[i]->min)) @@ -81,39 +67,26 @@ bool RegionOpNode::IsFullRegion() const { return true; } -/** - * @brief Lower the region operator to a TIR statement. - * - * Lowers this RegionOpNode into a TIR Stmt by delegating to the operator's - * evaluation path (currently `Evaluate(0)`). - * - * @param T Lowering context (provides buffers, producers/consumers and other - * environment required for lowering). - * @param analyzer Optional arithmetic analyzer used for simplification during - * lowering. - * @return Stmt The lowered TIR statement representing this region operation. - */ Stmt RegionOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return Evaluate(0); } -/** - * @brief Infers data layout for the region operator. - * - * This operator does not provide any layout inference; the function always - * returns an empty LayoutMap regardless of the provided arguments or inference - * level. - * - * @param T Layout inference arguments (ignored). - * @param level Inference granularity level (ignored). - * @return LayoutMap Empty map indicating no inferred layouts. - */ LayoutMap RegionOpNode::InferLayout(const LayoutInferArgs &T, InferLevel level) const { return {}; } -TIR_REGISTER_TL_OP(RegionOp, region) +const Op &RegionOp::Get() { + static const Op &op = Op::Get("tl.region"); + return op; +} + +TVM_REGISTER_OP("tl.region") + .set_attr("TScriptPrinterName", "region") + .set_attr("TLOpBuilder", + [](Array args) { + return RegionOp(args); + }) .set_num_inputs(-1) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); diff --git a/src/op/region.h b/src/op/region.h index e5c478bf..24399f7a 100644 --- a/src/op/region.h +++ b/src/op/region.h @@ -1,74 +1,36 @@ /*! - * \file tl/op/op.h - * \brief Tile library operations. + * \file tl/op/region.h + * \brief Tile memory region descriptor op (bridge to carry BufferRegion via + * Call args). * + * Why tl.region instead of passing BufferRegion directly? + * + * - While TIR can represent a BufferRegion, when a BufferRegion is passed as a + * call argument through call_intrin/FFI, the Python->C++ conversion lowers it + * to a BufferLoad(indices). To encode an interval inside indices, the FFI + * typically uses Ramp(base, stride, lanes) to represent a contiguous slice. + * - Ramp(lanes) may only be a constant or vscale*k (scalable vector). A general + * PrimExpr (e.g., H1 - H0) is not allowed as lanes, so dynamic extents would + * make the lowered BufferLoad invalid. + * - Moreover, BufferLoad only carries indices, not per-axis extents. Downstream + * tile operators (e.g., tl.copy, tl.reduce) that require both min and extent + * cannot losslessly recover dynamic extents from a BufferLoad alone. + * + * tl.region is a small transport-only op that solves this: + * - The frontend packs buffer + mins (from BufferLoad.indices) + extents into + * Call args, allowing dynamic extents to be expressed explicitly. + * - The backend (NormalizeToBufferRegion) reconstructs a BufferRegion from the + * tl.region call without losing information. + * - The op itself carries no semantics in Lower/InferLayout and is only used as + * a bridge for argument passing. */ #ifndef TVM_TL_OP_REGION_H_ #define TVM_TL_OP_REGION_H_ #include "./operator.h" -#include -#include -#include #include -/** - * Tile operator representing a memory region (buffer + ranges) used by TL - * passes. - * - * Encapsulates the target tir::Buffer, the region extents as an Array, - * and an access mask that indicates permitted or intended accesses for lowering - * and layout inference. - */ - -/** - * Lower this RegionOp into a TIR statement representing the region access. - * - * @param T Lowering-time arguments (e.g., loop/build context and value - * mappings). - * @param analyzer Arithmetic analyzer used to simplify and reason about - * expressions. - * @return A tir::Stmt that implements the region access/mutation described by - * this operator. - */ - -/** - * Infer the layout mapping for this region operator. - * - * Produces a LayoutMap describing how loop/axis indices map to buffer axes for - * layout-aware scheduling and subsequent operators. - * - * @param T Layout inference arguments (e.g., input layouts and shapes). - * @param level The inference detail level to use. - * @return A LayoutMap describing inferred mappings for the operator. - */ - -/** - * Return true when this RegionOp represents the full buffer region (i.e., - * ranges cover the entire buffer extent). - */ - -/** - * Create a shallow copy of this operator as a TileOperator handle. - * - * @return A TileOperator that references a cloned RegionOpNode. - */ - -/** - * Construct a RegionOp from argument expressions and a buffer map. - * - * @param args Positional expressions used to instantiate the operator - * (semantics depend on how RegionOp is invoked in TL pipelines). - * @param vmap Mapping from Buffer to replacement Buffer or buffer metadata used - * during creation. - */ - -/** - * Return the global Op registration for RegionOp. - * - * @return Reference to the registered tvm::Op describing the RegionOp. - */ namespace tvm { namespace tl { @@ -80,6 +42,12 @@ public: Array ranges_; int access_mask_; + /*! + * access_mask_ encodes the intended access type when the region is used as + * an argument to tile operators: 1=read, 2=write, 3=read-write. The mask is + * transport metadata only and does not affect lowering. + */ + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.RegionOp", RegionOpNode, TileOperatorNode); @@ -107,8 +75,13 @@ class RegionOp : public TileOperator { public: TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(RegionOp, TileOperator, RegionOpNode); - TVM_DLL RegionOp(Array args, BufferMap vmap); - + /*! + * Build a RegionOp from call arguments: + * - args[0]: BufferLoad whose indices are per-axis minima. + * - args[1]: Integer access mask (1=r, 2=w, 3=rw). + * - args[2 + i]: Extent of axis i (supports dynamic PrimExpr). + */ + TVM_DLL RegionOp(Array args); static const Op &Get(); }; diff --git a/src/op/utils.cc b/src/op/utils.cc index 59960b57..7e56ae8c 100644 --- a/src/op/utils.cc +++ b/src/op/utils.cc @@ -12,8 +12,7 @@ namespace tl { using namespace tir; -BufferRegion NormalizeToBufferRegion(const PrimExpr &arg, - const BufferMap &vmap) { +BufferRegion NormalizeToBufferRegion(const PrimExpr &arg) { // Case 1: Already a BufferRegion if (arg->IsInstance()) { return Downcast(arg); @@ -38,23 +37,15 @@ BufferRegion NormalizeToBufferRegion(const PrimExpr &arg, return BufferRegion(load->buffer, ranges); } - // Case 3: Call nodes + // Case 3: tl.region(...) — reconstruct via RegionOp (bridge) if (const auto *call = arg.as()) { - // tl.region(...) — reconstruct via RegionOp if (call->op.same_as(RegionOp::Get())) { - RegionOp region(call->args, vmap); + RegionOp region(call->args); return BufferRegion(region->GetBuffer(), region->GetRanges()); } - // builtin.tvm_access_ptr(...) — map var to Buffer and take full region - if (call->op.same_as(builtin::tvm_access_ptr())) { - Var var = Downcast(call->args[1]); - Buffer buf = vmap.at(var); - Array ranges; - for (PrimExpr extent : buf->shape) { - ranges.push_back(Range(IntImm(extent->dtype, 0), extent)); - } - return BufferRegion(buf, ranges); - } + LOG(FATAL) << "Unsupported argument for BufferRegion (expect " + "BufferLoad/BufferRegion/tl.region): " + << arg; } LOG(FATAL) << "Unsupported argument for BufferRegion: " << arg; diff --git a/src/op/utils.h b/src/op/utils.h index 9e7880ac..d386b1a5 100644 --- a/src/op/utils.h +++ b/src/op/utils.h @@ -16,10 +16,10 @@ namespace tl { using namespace tir; -// Normalize an argument (BufferRegion/BufferLoad/tl.region/tvm_access_ptr) +// Normalize an argument (BufferRegion/BufferLoad/tl.region) // to BufferRegion so ops can uniformly consume regions. -TVM_DLL BufferRegion NormalizeToBufferRegion(const PrimExpr &arg, - const BufferMap &vmap); +// Note: tvm_access_ptr is no longer supported here. +TVM_DLL BufferRegion NormalizeToBufferRegion(const PrimExpr &arg); // Build a tvm_access_ptr(handle) from a BufferRegion. // - If `require_2d` is true, checks buffer ndim >= 2. diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index 873f70d0..f5ccc42b 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -437,11 +437,13 @@ private: if (op->op.as()) return; - auto p = ParseOperator(tvm::ffi::GetRef(op), GetBufferMap()); + auto p = ParseOperator(tvm::ffi::GetRef(op)); if (p.defined()) { for (const auto &arg : op->args) { if (auto buffer = getBufferFromAccessPtr(arg)) { addToUseList(buffer.value()); + } else if (auto buffer = getBufferFromRegion(arg)) { + addToUseList(buffer.value()); } } // Compute thread_var_ and thread_bounds_ @@ -495,6 +497,9 @@ private: } Optional getBufferFromAccessPtr(const PrimExpr &expr) { + if (auto bl = expr.as()) { + return bl->buffer; + } auto call = expr.as(); if (!call) { return std::nullopt; @@ -514,8 +519,18 @@ private: } } return std::nullopt; - } else if (call->op.same_as(RegionOp::Get())) { - return call->args[0].as()->buffer; + } + return std::nullopt; + } + + Optional getBufferFromRegion(const PrimExpr &expr) { + if (auto call = expr.as()) { + if (call->op.same_as(RegionOp::Get())) { + if (auto bl = call->args[0].as()) { + return bl->buffer; + } + return std::nullopt; + } } return std::nullopt; } diff --git a/src/transform/layout_reducer.cc b/src/transform/layout_reducer.cc index a3c69c43..660fc6fd 100644 --- a/src/transform/layout_reducer.cc +++ b/src/transform/layout_reducer.cc @@ -277,7 +277,7 @@ private: if (op->op.same_as(Fill::Get())) { ICHECK(!op->args.empty()); if (auto arg0_call = op->args[0].as()) { - // Case 1: tl.region(...) — extract buffer var from its first arg + // tl.region(...) — extract buffer var from its first arg if (arg0_call.value()->op.same_as(RegionOp::Get())) { ICHECK(!arg0_call.value()->args.empty()); if (auto bl = arg0_call.value()->args[0].as()) { @@ -285,15 +285,14 @@ private: if (reducer_info_map_.count(var)) { ICHECK(inside_reducer_range_.count(var) == 0) << "T.fill on reducer must be enclosed with a " - "T.finalize_reducer " - "before next."; + "T.finalize_reducer before next."; inside_reducer_range_.Set(var, reducer_info_map_.Get(var).value()); } } } - // Case 2: builtin.tvm_access_ptr(...) — existing path - else if (arg0_call.value()->op.same_as(builtin::tvm_access_ptr())) { + // builtin.tvm_access_ptr(...) — existing path (legacy) + if (arg0_call.value()->op.same_as(builtin::tvm_access_ptr())) { ICHECK(arg0_call.value()->args.size() > 1); if (auto var = arg0_call.value()->args[1].as(); var && reducer_info_map_.count(var.value())) { @@ -305,10 +304,33 @@ private: var.value(), reducer_info_map_.Get(var.value()).value()); } } + } else if (auto bl = op->args[0].as()) { + Var var = bl->buffer->data; + if (reducer_info_map_.count(var)) { + ICHECK(inside_reducer_range_.count(var) == 0) + << "T.fill on reducer must be enclosed with a T.finalize_reducer " + "before next."; + inside_reducer_range_.Set(var, reducer_info_map_.Get(var).value()); + } } } else if (op->op.same_as(FinalizeReducerOp::Get())) { ICHECK(op->args.size() == 1); - auto var = GetVarFromAccessPtr(op->args[0]); + Var var; + if (auto bl = op->args[0].as()) { + var = bl->buffer->data; + } else if (auto reg_call = op->args[0].as()) { + if (reg_call.value()->op.same_as(RegionOp::Get())) { + if (auto bl2 = reg_call.value()->args[0].as()) { + var = bl2->buffer->data; + } else { + LOG(FATAL) << "tl.region expects BufferLoad as first arg"; + } + } else { + var = GetVarFromAccessPtr(op->args[0]); + } + } else { + var = GetVarFromAccessPtr(op->args[0]); + } ICHECK(inside_reducer_range_.count(var) == 1) << "T.finalize_reducer must have a pairing T.fill ahead of it, " "enclosing a reduction range."; diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index 4c0ccfaf..4392f319 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -606,8 +606,7 @@ private: if (call && call->op.as()) return Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); - auto tile_op = - ParseOperator(tvm::ffi::GetRef(op), buffer_data_to_buffer_); + auto tile_op = ParseOperator(tvm::ffi::GetRef(op)); if (!tile_op.defined()) return IRMutatorWithAnalyzer::VisitStmt_(op); AddWorkspaceCallback callback = [this](int num_elem, DataType dtype) { diff --git a/testing/python/issue/test_tilelang_issue_830.py b/testing/python/issue/test_tilelang_issue_830.py index ab593712..950b8583 100644 --- a/testing/python/issue/test_tilelang_issue_830.py +++ b/testing/python/issue/test_tilelang_issue_830.py @@ -17,7 +17,15 @@ def _empty_kernel(): return empty_kernel +@tilelang.testing.requires_cuda def test_empty_kernel_lowering(): + # Ensure a valid CUDA runtime context is current on this thread for the + # target device before using driver API calls. Without this, calls like + # cuModuleLoadData can fail with CUDA_ERROR_INVALID_CONTEXT, especially + # for kernels that don't touch any device memory or streams beforehand + # (e.g., "empty" kernels) and therefore haven't triggered context + # creation implicitly. + torch.cuda.set_device(0) kernel = _empty_kernel() kernel() @@ -59,7 +67,9 @@ def _empty_kernel_with_binding_variants(use_tuple_binding: bool = False): return kernel_with_tuple_kernel_binding if use_tuple_binding else kernel_with_scalar_kernel_binding +@tilelang.testing.requires_cuda def test_empty_kernel_with_binding_variants(): + torch.cuda.set_device(0) kernel = _empty_kernel_with_binding_variants() kernel() diff --git a/tilelang/intrinsics/mfma_macro_generator.py b/tilelang/intrinsics/mfma_macro_generator.py index 84e4c21b..02c0b039 100644 --- a/tilelang/intrinsics/mfma_macro_generator.py +++ b/tilelang/intrinsics/mfma_macro_generator.py @@ -2,14 +2,15 @@ from __future__ import annotations from tilelang import tvm as tvm import tilelang.language as T from tvm import DataType -from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion +from tvm import tir +from tvm.ir import Range +from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion, BufferLoad from tvm.runtime import convert -from .utils import ( - mfma_store_index_map,) +from .utils import (mfma_store_index_map) from typing import Literal, Callable from tilelang.utils import is_fragment -from tilelang.utils.language import to_buffer_region +from tilelang.utils.language import get_buffer_region_from_load from .mfma_layout import ( shared_16x4_to_local_64x1_layout_A, shared_4x16_to_local_64x1_layout_B, @@ -268,7 +269,7 @@ class MatrixCoreIntrinEmitter: _, reverse_index_map = self.get_ldmatrix_index_map(is_b=False) # legalize shared buffer to region - A_region = to_buffer_region(A_shared_buf) + A_region = self._legalize_to_buffer_region(A_shared_buf) A_buf = A_region.buffer A_base0 = A_region.region[-2].min A_base1 = A_region.region[-1].min @@ -314,7 +315,7 @@ class MatrixCoreIntrinEmitter: _, reverse_index_map = self.get_ldmatrix_index_map(is_b=True) # legalize shared buffer to region - B_region = to_buffer_region(B_shared_buf) + B_region = self._legalize_to_buffer_region(B_shared_buf) B_buf = B_region.buffer B_base0 = B_region.region[-2].min B_base1 = B_region.region[-1].min @@ -655,6 +656,33 @@ class MatrixCoreIntrinEmitter: forward_index_fn=forward_index, ) + @staticmethod + def _legalize_to_buffer_region(obj: Buffer | BufferLoad | BufferRegion) -> BufferRegion: + """ + Convert Buffer/BufferRegion/BufferLoad to a BufferRegion. + + - Buffer -> full-region BufferRegion covering entire shape + - BufferRegion -> returned as-is + - BufferLoad -> best-effort convert via get_buffer_region_from_load; + if scalar, fall back to 1-sized ranges at given indices + """ + if isinstance(obj, BufferRegion): + return obj + if isinstance(obj, Buffer): + mins = [tir.IntImm("int32", 0) for _ in obj.shape] + ranges = [Range.from_min_extent(m, e) for m, e in zip(mins, obj.shape)] + return BufferRegion(obj, ranges) + if isinstance(obj, BufferLoad): + region = get_buffer_region_from_load(obj) + if region is not None: + return region + # Fallback: scalar load -> 1-sized ranges at indices + mins = [idx for idx in obj.indices] + ones = [tir.IntImm("int32", 1) for _ in obj.indices] + ranges = [Range.from_min_extent(m, e) for m, e in zip(mins, ones)] + return BufferRegion(obj.buffer, ranges) + raise ValueError(f"Unsupported argument type for BufferRegion: {type(obj)}") + class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index 8c546c63..aab2a49e 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -3,14 +3,16 @@ import tilelang.language as T from typing import Literal, Callable from tilelang.common import TransformKind from tvm import DataType -from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion +from tvm import tir +from tvm.ir import Range +from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion, BufferLoad from tilelang import tvm as tvm from tvm.runtime import convert from .utils import ( mma_store_index_map, get_ldmatrix_offset, ) -from tilelang.utils import is_fragment, to_buffer_region +from tilelang.utils import is_fragment, get_buffer_region_from_load from tilelang.intrinsics.mma_layout import ( shared_16x8_to_mma_32x4_layout_sr_a, shared_16x8_to_mma_32x4_layout_sr_b, @@ -243,7 +245,7 @@ class TensorCoreIntrinEmitter: thread_binding = self.get_thread_binding() # legalize shared buffer to region - A_region = to_buffer_region(A_shared_buf) + A_region = self._legalize_to_buffer_region(A_shared_buf) A_buf = A_region.buffer A_base0 = A_region.region[-2].min A_base1 = A_region.region[-1].min @@ -294,7 +296,7 @@ class TensorCoreIntrinEmitter: thread_binding = self.get_thread_binding() # legalize shared buffer to region - A_region = to_buffer_region(A_shared_buf) + A_region = self._legalize_to_buffer_region(A_shared_buf) A_buf = A_region.buffer A_base0 = A_region.region[-2].min A_base1 = A_region.region[-1].min @@ -360,7 +362,7 @@ class TensorCoreIntrinEmitter: thread_binding = self.get_thread_binding() # legalize shared buffer to region - B_region = to_buffer_region(B_shared_buf) + B_region = self._legalize_to_buffer_region(B_shared_buf) B_buf = B_region.buffer B_base0 = B_region.region[-2].min B_base1 = B_region.region[-1].min @@ -397,7 +399,7 @@ class TensorCoreIntrinEmitter: thread_binding = self.get_thread_binding() # legalize shared buffer to region - B_region = to_buffer_region(B_shared_buf) + B_region = self._legalize_to_buffer_region(B_shared_buf) B_buf = B_region.buffer B_base0 = B_region.region[-2].min B_base1 = B_region.region[-1].min @@ -798,6 +800,33 @@ class TensorCoreIntrinEmitter: forward_index_fn=forward_index, ) + @staticmethod + def _legalize_to_buffer_region(obj: Buffer | BufferLoad | BufferRegion) -> BufferRegion: + """ + Convert Buffer/BufferRegion/BufferLoad to a BufferRegion. + + - Buffer -> full-region BufferRegion covering entire shape + - BufferRegion -> returned as-is + - BufferLoad -> best-effort convert via get_buffer_region_from_load; + if scalar, fall back to 1-sized ranges at given indices + """ + if isinstance(obj, BufferRegion): + return obj + if isinstance(obj, Buffer): + mins = [tir.IntImm("int32", 0) for _ in obj.shape] + ranges = [Range.from_min_extent(m, e) for m, e in zip(mins, obj.shape)] + return BufferRegion(obj, ranges) + if isinstance(obj, BufferLoad): + region = get_buffer_region_from_load(obj) + if region is not None: + return region + # Fallback: scalar load -> 1-sized ranges at indices + mins = [idx for idx in obj.indices] + ones = [tir.IntImm("int32", 1) for _ in obj.indices] + ranges = [Range.from_min_extent(m, e) for m, e in zip(mins, ones)] + return BufferRegion(obj.buffer, ranges) + raise ValueError(f"Unsupported argument type for BufferRegion: {type(obj)}") + class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): """ diff --git a/tilelang/intrinsics/mma_sm70_macro_generator.py b/tilelang/intrinsics/mma_sm70_macro_generator.py index b20a6a90..78248081 100644 --- a/tilelang/intrinsics/mma_sm70_macro_generator.py +++ b/tilelang/intrinsics/mma_sm70_macro_generator.py @@ -5,7 +5,7 @@ from tvm import DataType from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion from tilelang import tvm as tvm from tvm.runtime import convert -from tilelang.utils import is_fragment, to_buffer_region +from tilelang.utils import is_fragment from tilelang.intrinsics.mma_sm70_layout import ( shared_16x4_to_mma_a_32x4_layout, shared_4x16_to_mma_b_32x4_layout, @@ -207,7 +207,7 @@ class TensorCoreIntrinEmitter: mma_load_layout = mma_load_a_32x4_to_shared_16x4_layout # legalize shared buffer to region - A_region = to_buffer_region(A_shared_buf) + A_region = self._legalize_to_buffer_region(A_shared_buf) A_buf = A_region.buffer A_base0 = A_region.region[-2].min A_base1 = A_region.region[-1].min @@ -248,7 +248,7 @@ class TensorCoreIntrinEmitter: mma_load_layout = mma_load_b_32x4_to_shared_16x4_layout_trans if b_transposed else mma_load_b_32x4_to_shared_4x16_layout # legalize shared buffer to region - B_region = to_buffer_region(B_shared_buf) + B_region = self._legalize_to_buffer_region(B_shared_buf) B_buf = B_region.buffer B_base0 = B_region.region[-2].min B_base1 = B_region.region[-1].min diff --git a/tilelang/language/atomic.py b/tilelang/language/atomic.py index 6e5fa88c..56f87473 100644 --- a/tilelang/language/atomic.py +++ b/tilelang/language/atomic.py @@ -4,10 +4,9 @@ from __future__ import annotations import tilelang.language as T -from tvm import ir, tir +from tvm import ir from tvm.tir import PrimExpr, Buffer, BufferRegion, Var, op -from tilelang.language.utils import buffer_region_to_tile_region, buffer_load_to_tile_region -from tilelang.utils.language import get_buffer_region_from_load, legalize_pairwise_extents +from tilelang.utils.language import to_buffer_region, legalize_pairwise_extents _MEMORY_ORDER_ID_MAP = { "relaxed": 0, @@ -203,24 +202,8 @@ def atomic_add(dst: Buffer, dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent) src_extent, dst_extent = legalize_pairwise_extents(src_extent, dst_extent) - def _to_region(data, access_type, extent): - if isinstance(data, tir.Var) and T.has_let_value(data): - data = T.get_let_value(data) - if isinstance(data, tir.Buffer): - zeros = [tir.IntImm("int32", 0) for _ in extent] - return buffer_load_to_tile_region(tir.BufferLoad(data, zeros), access_type, extent) - elif isinstance(data, tir.BufferRegion): - return buffer_region_to_tile_region(data, access_type, extent) - elif isinstance(data, tir.BufferLoad): - region = get_buffer_region_from_load(data) - if region is None: - return buffer_load_to_tile_region(data, access_type, extent) - return buffer_region_to_tile_region(region, access_type, extent) - else: - return buffer_load_to_tile_region(data, access_type, extent) - - value = _to_region(value, "r", src_extent) - dst = _to_region(dst, "w", dst_extent) + value = to_buffer_region(value, access_type="r", extents=src_extent) + dst = to_buffer_region(dst, access_type="w", extents=dst_extent) # Note: tile-region-based atomic operations don't support return_prev yet # This would need to be implemented in the tile runtime diff --git a/tilelang/language/copy.py b/tilelang/language/copy.py index 62de13d0..d59d73e8 100644 --- a/tilelang/language/copy.py +++ b/tilelang/language/copy.py @@ -3,11 +3,11 @@ from __future__ import annotations from typing import Literal from tilelang import language as T from tilelang.utils.language import ( + to_buffer_region, get_buffer_region_from_load, legalize_pairwise_extents, ) from tvm import ir, tir -from tilelang.language.utils import buffer_region_to_tile_region, buffer_load_to_tile_region def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, @@ -69,27 +69,9 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, # - otherwise -> error src_extent, dst_extent = legalize_pairwise_extents(src_extent, dst_extent) - def _to_region(data, access_type, extent): - if isinstance(data, tir.Var) and T.has_let_value(data): - data = T.get_let_value(data) - if isinstance(data, tir.Buffer): - # Restrict a raw buffer to the computed copy extent by creating - # a BufferLoad at origin and passing the extents explicitly. - zeros = [tir.IntImm("int32", 0) for _ in extent] - return buffer_load_to_tile_region(tir.BufferLoad(data, zeros), access_type, extent) - elif isinstance(data, tir.BufferRegion): - return buffer_region_to_tile_region(data, access_type, extent) - elif isinstance(data, tir.BufferLoad): - region = get_buffer_region_from_load(data) - if region is None: - return buffer_load_to_tile_region(data, access_type, extent) - return buffer_region_to_tile_region(region, access_type, extent) - else: - return buffer_load_to_tile_region(data, access_type, extent) - # Use legalized extents for src and dst respectively. - src = _to_region(src, "r", src_extent) - dst = _to_region(dst, "w", dst_extent) + src = to_buffer_region(src, access_type="r", extents=src_extent) + dst = to_buffer_region(dst, access_type="w", extents=dst_extent) if coalesced_width is None: coalesced_width = -1 # PrimExpr can not be None @@ -129,6 +111,7 @@ def c2d_im2col(img: tir.Buffer, eviction_policy = 0 else: eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy] - return tir.call_intrin("handle", tir.op.Op.get("tl.c2d_im2col"), img.access_ptr("r"), - col.access_ptr("w"), nhw_step, c_step, kernel, stride, dilation, pad, - eviction_policy) + img_region = to_buffer_region(img, access_type="r") + col_region = to_buffer_region(col, access_type="w") + return tir.call_intrin("handle", tir.op.Op.get("tl.c2d_im2col"), img_region, col_region, + nhw_step, c_step, kernel, stride, dilation, pad, eviction_policy) diff --git a/tilelang/language/experimental/gemm_sp.py b/tilelang/language/experimental/gemm_sp.py index e966e7d6..7cc3d736 100644 --- a/tilelang/language/experimental/gemm_sp.py +++ b/tilelang/language/experimental/gemm_sp.py @@ -3,6 +3,7 @@ from __future__ import annotations from tilelang.primitives.gemm.base import GemmWarpPolicy import tilelang.language as T from tvm import tir +from tilelang.utils.language import to_buffer_region def gemm_sp( @@ -62,17 +63,18 @@ def gemm_sp( K_A = A_sparse.shape[0] if transpose_A else A_sparse.shape[1] K_B = B.shape[1] if transpose_B else B.shape[0] assert K_A * 2 == K_B, f"T.gemm_sp K shape check failed: K_A = {K_A}, K_B = {K_B}" - Aptr = A_sparse.access_ptr("r") - Bptr = B.access_ptr("r") - Cptr = C.access_ptr("rw") - Eptr = E.access_ptr("r") + # Build tl.region descriptors for operands + A_arg = to_buffer_region(A_sparse, access_type="r") + E_arg = to_buffer_region(E, access_type="r") + B_arg = to_buffer_region(B, access_type="r") + C_arg = to_buffer_region(C, access_type="rw") return tir.call_intrin( "handle", tir.op.Op.get("tl.gemm_sp"), - Aptr, - Eptr, - Bptr, - Cptr, + A_arg, + E_arg, + B_arg, + C_arg, transpose_A, transpose_B, M, diff --git a/tilelang/language/fill.py b/tilelang/language/fill.py index ad74720f..fbbcf1b6 100644 --- a/tilelang/language/fill.py +++ b/tilelang/language/fill.py @@ -2,12 +2,7 @@ from __future__ import annotations from tvm import tir from tilelang.language import has_let_value, get_let_value -from tilelang.utils.language import get_buffer_region_from_load -from tilelang.language.utils import ( - buffer_to_tile_region, - buffer_region_to_tile_region, - buffer_load_to_tile_region, -) +from tilelang.utils.language import get_buffer_region_from_load, to_buffer_region def fill(buffer: tir.Buffer | tir.BufferRegion | tir.BufferLoad, value: tir.PrimExpr): @@ -24,26 +19,21 @@ def fill(buffer: tir.Buffer | tir.BufferRegion | tir.BufferLoad, value: tir.Prim if isinstance(buffer, tir.Var) and has_let_value(buffer): buffer = get_let_value(buffer) - # Convert to a tl.region descriptor (PrimExpr) with write access - region_call = None + # Build tl.region as argument if isinstance(buffer, tir.Buffer): - region_call = buffer_to_tile_region(buffer, "w") + extents = list(buffer.shape) elif isinstance(buffer, tir.BufferRegion): extents = [r.extent for r in buffer.region] - region_call = buffer_region_to_tile_region(buffer, "w", extents) elif isinstance(buffer, tir.BufferLoad): region = get_buffer_region_from_load(buffer) if region is not None: extents = [r.extent for r in region.region] - region_call = buffer_region_to_tile_region(region, "w", extents) else: - # Fallback: treat element access as 1-extent per dim - region_call = buffer_load_to_tile_region(buffer, "w", [1] * len(buffer.indices)) + extents = [tir.IntImm("int32", 1) for _ in buffer.indices] else: - # As-is fallback (rare): pass through for downstream handling - region_call = buffer - - return tir.call_intrin("handle", tir.op.Op.get("tl.fill"), region_call, value) + extents = [] + return tir.call_intrin("handle", tir.op.Op.get("tl.fill"), + to_buffer_region(buffer, access_type="w", extents=extents), value) def clear(buffer: tir.Buffer | tir.Var): diff --git a/tilelang/language/gemm.py b/tilelang/language/gemm.py index 0f2e82d7..2bfd3a0c 100644 --- a/tilelang/language/gemm.py +++ b/tilelang/language/gemm.py @@ -7,10 +7,11 @@ from tilelang.utils.language import ( to_buffer_region, retrieve_shape, retrieve_stride, - retrieve_ptr, retrieve_offset, prim_expr_equal, ) +from tilelang.language.utils import ( + buffer_region_to_tile_region,) from tilelang.env import env as _env @@ -50,17 +51,17 @@ def _gemm_impl( C = legalize_arguments(C) mbar = legalize_arguments(mbar) if mbar is not None else None - # Normalize A/B/C to BufferRegion to pass into tl.gemm - A = to_buffer_region(A) - B = to_buffer_region(B) - C = to_buffer_region(C) + # Normalize A/B/C to BufferRegion for shape/stride/offset analysis + A_region = to_buffer_region(A) + B_region = to_buffer_region(B) + C_region = to_buffer_region(C) - A_shape = retrieve_shape(A) - B_shape = retrieve_shape(B) - C_shape = retrieve_shape(C) + A_shape = retrieve_shape(A_region) + B_shape = retrieve_shape(B_region) + C_shape = retrieve_shape(C_region) - A_stride = retrieve_stride(A) - B_stride = retrieve_stride(B) + A_stride = retrieve_stride(A_region) + B_stride = retrieve_stride(B_region) assert len(C_shape) == 2, "current only support C as a 2D tensor" assert len(A_shape) >= 2, "current only support A as a 2D or higher-order tensor" @@ -82,18 +83,22 @@ def _gemm_impl( stride_a = A_stride[-2] stride_b = B_stride[-2] - A_offset = retrieve_offset(A) - B_offset = retrieve_offset(B) + A_offset = retrieve_offset(A_region) + B_offset = retrieve_offset(B_region) assert A_offset[-2] == 0, "The offset of the first dimension of A must be 0" assert B_offset[-2] == 0, "The offset of the first dimension of B must be 0" offset_a = A_offset[-1] offset_b = B_offset[-1] - mbarptr = retrieve_ptr(mbar, "rw") if mbar is not None else tir.const(0, "uint32") - C_coords = [r.min for r in C.region] - return tir.call_intrin("handle", tir.op.Op.get(op_key), A, B, C, transpose_A, transpose_B, M, N, - K, policy, clear_accum, stride_a, stride_b, offset_a, offset_b, k_pack, - wg_wait, mbarptr, C_coords[0], C_coords[1]) + mbar = to_buffer_region(mbar, access_type="rw") if mbar is not None else tir.const(0, "uint32") + C_coords = [r.min for r in C_region.region] + # Convert BufferRegion to tl.region calls for arguments + A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape]) + B_arg = buffer_region_to_tile_region(B_region, "r", [r for r in B_shape]) + C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape]) + return tir.call_intrin("handle", tir.op.Op.get(op_key), A_arg, B_arg, C_arg, transpose_A, + transpose_B, M, N, K, policy, clear_accum, stride_a, stride_b, offset_a, + offset_b, k_pack, wg_wait, mbar, C_coords[0], C_coords[1]) # Public wrappers diff --git a/tilelang/language/reduce.py b/tilelang/language/reduce.py index 9d84e0b2..3c4d8187 100644 --- a/tilelang/language/reduce.py +++ b/tilelang/language/reduce.py @@ -2,7 +2,7 @@ from __future__ import annotations from tvm import tir from tilelang.language import copy, macro, alloc_shared, alloc_fragment -from tilelang.language.utils import buffer_to_tile_region +from tilelang.utils.language import to_buffer_region from tilelang.utils.language import is_shared, is_fragment from tvm.script.ir_builder import IRBuilder @@ -51,8 +51,8 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea tir.call_intrin( "handle", tir.op.Op.get("tl.reduce"), - buffer_to_tile_region(red_frag_in, "r"), - buffer_to_tile_region(red_frag_out, "w"), + to_buffer_region(red_frag_in, access_type="r"), + to_buffer_region(red_frag_out, access_type="w"), reduce_type, dim, clear, @@ -66,8 +66,8 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea tir.call_intrin( "handle", tir.op.Op.get("tl.reduce"), - buffer_to_tile_region(red_frag_in, "r"), - buffer_to_tile_region(out, "w"), + to_buffer_region(red_frag_in, access_type="r"), + to_buffer_region(out, access_type="w"), reduce_type, dim, clear, @@ -79,8 +79,8 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea tir.call_intrin( "handle", tir.op.Op.get("tl.reduce"), - buffer_to_tile_region(buffer, "r"), - buffer_to_tile_region(red_frag_out, "w"), + to_buffer_region(buffer, access_type="r"), + to_buffer_region(red_frag_out, access_type="w"), reduce_type, dim, clear, @@ -90,8 +90,8 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea tir.call_intrin( "handle", tir.op.Op.get("tl.reduce"), - buffer_to_tile_region(buffer, "r"), - buffer_to_tile_region(out, "w"), + to_buffer_region(buffer, access_type="r"), + to_buffer_region(out, access_type="w"), reduce_type, dim, clear, @@ -246,8 +246,8 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) - tir.call_intrin( "handle", tir.op.Op.get("tl.cumsum"), - buffer_to_tile_region(cumsum_smem, "r"), - buffer_to_tile_region(cumsum_smem, "w"), + to_buffer_region(cumsum_smem, access_type="r"), + to_buffer_region(cumsum_smem, access_type="w"), dim, reverse, ) @@ -300,8 +300,8 @@ def cumsum(src: tir.Buffer, dst: tir.Buffer | None = None, dim: int = 0, reverse return tir.call_intrin( "handle", tir.op.Op.get("tl.cumsum"), - buffer_to_tile_region(src, "r"), - buffer_to_tile_region(dst, "w"), + to_buffer_region(src, access_type="r"), + to_buffer_region(dst, access_type="w"), dim, reverse, ) @@ -323,7 +323,7 @@ def finalize_reducer(reducer: tir.Buffer): return tir.call_intrin( "handle", tir.op.Op.get("tl.finalize_reducer"), - reducer.access_ptr("w"), + to_buffer_region(reducer, access_type="w"), ) diff --git a/tilelang/language/utils.py b/tilelang/language/utils.py index ad8b83dd..75fea4c0 100644 --- a/tilelang/language/utils.py +++ b/tilelang/language/utils.py @@ -1,97 +1,38 @@ from tilelang import tvm as tvm from tvm import tir -from tvm.tir import PrimExpr, Buffer, BufferLoad, op +from tvm.tir import PrimExpr, BufferLoad, op from tilelang import language as T def region(buffer: BufferLoad, access_type: str, *args: PrimExpr): - """ - Create a tile memory-region descriptor for a BufferLoad. - - Maps access_type ('r', 'w', 'rw') to the numeric codes expected by the `tl.region` intrinsic - (1, 2, 3 respectively) and returns a tir.Call representing the region with the provided extents. - - Parameters: - buffer (tir.BufferLoad): The BufferLoad that identifies the underlying buffer and indices. - access_type (str): One of 'r', 'w', or 'rw' indicating read, write, or read-write access. - *args (tir.PrimExpr): Extent expressions for each region dimension. - - Returns: - tir.Call: A call to the `tl.region` intrinsic describing the memory region. - - Raises: - KeyError: If access_type is not one of 'r', 'w', or 'rw'. - """ + """Create a tl.region call for a BufferLoad and extents.""" access_type = {"r": 1, "w": 2, "rw": 3}[access_type] return T.call_intrin("handle", op.Op.get("tl.region"), buffer, access_type, *args) -def buffer_to_tile_region(buffer: Buffer, access_type: str): - """Convert a TVM buffer to a tile region descriptor. - - Args: - buffer (tir.Buffer): The buffer to convert - access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write - - Returns: - tir.Call: A region descriptor covering the entire buffer - """ - mins = [0 for _ in buffer.shape] - extents = [x for x in buffer.shape] - return region(T.BufferLoad(buffer, mins), access_type, *extents) - - def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: list[PrimExpr]): - """Convert a buffer load operation to a tile region descriptor. - - Args: - load (tir.BufferLoad): The buffer load operation - access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write - extents (List[tir.PrimExpr]): List of expressions defining the region size - - Returns: - tir.Call: A region descriptor for the loaded area - """ - indices = load.indices - + """Convert a BufferLoad to a tl.region call with explicit extents.""" + indices = list(load.indices) if len(indices) > len(extents): - # (f"mismatch between indices and extents for buffer load {load}: indices = {indices}, extents = {extents}, " - # f"region will be expanded in the last 2 dimensions") - new_extents = [] - for _ in range(len(indices) - len(extents)): - new_extents.append(1) - for extent in extents: - new_extents.append(extent) - extents = new_extents + extents = [tir.IntImm("int32", 1) for _ in range(len(indices) - len(extents)) + ] + list(extents) assert len(indices) == len(extents), f"indices = {indices}, extents = {extents}" return region(load, access_type, *extents) def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str, extents: list[tir.PrimExpr]): - """Convert a buffer region to a tile region descriptor. - - Args: - buffer_region (tir.BufferRegion): The buffer region to convert - access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write - - Returns: - tir.Call: A region descriptor for the specified buffer region - """ - mins = [x.min for x in buffer_region.region] - region_extents = [x.extent for x in buffer_region.region] - assert len(region_extents) >= len( - extents - ), f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}" - - # Clamp extents element-wise so that the produced region respects the - # requested copy/fill extent, supporting dynamic PrimExpr via tir.min. + """Clamp extents and return a tl.region call.""" + mins = [r.min for r in buffer_region.region] + region_extents = [r.extent for r in buffer_region.region] + assert len(region_extents) >= len(extents), ( + f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}" + ) clamped_extents = [ tir.min(region_extents[i], extents[i]) if i < len(extents) else region_extents[i] for i in range(len(region_extents)) ] - - return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *clamped_extents) + return region(tir.BufferLoad(buffer_region.buffer, mins), access_type, *clamped_extents) def index_to_coordinates(index, shape) -> list[PrimExpr]: diff --git a/tilelang/tileop/gemm/gemm_base.py b/tilelang/tileop/gemm/gemm_base.py index 021f59a4..581272cf 100644 --- a/tilelang/tileop/gemm/gemm_base.py +++ b/tilelang/tileop/gemm/gemm_base.py @@ -123,6 +123,10 @@ class GemmBase: def mbarptr(self) -> PrimExpr: return getattr(self.gemm_node, "mbarPtr", tvm.tir.const(0, "uint32")) + @property + def mbar(self) -> tir.Buffer: + return getattr(self.gemm_node, "mbar", None) + @property def C_coords(self): coords = getattr(self.gemm_node, "cCoords", None) diff --git a/tilelang/tileop/gemm/gemm_tcgen05.py b/tilelang/tileop/gemm/gemm_tcgen05.py index 52c192e5..c2c8c1c8 100644 --- a/tilelang/tileop/gemm/gemm_tcgen05.py +++ b/tilelang/tileop/gemm/gemm_tcgen05.py @@ -94,9 +94,11 @@ class GemmTCGEN5(GemmBase): if self.wg_wait != -1: raise ValueError("TCGEN5MMA currently requires wg_wait == -1") - mbarptr = self.mbarptr - if mbarptr == 0: - raise ValueError("TCGEN5MMA requires a valid mbarrier pointer") + mbar = self.mbar + if mbar == 0: + raise ValueError("TCGEN5MMA requires a valid mbarrier") + + mbarptr = mbar.access_ptr("rw") C_coords = self.C_coords if len(C_coords) != 2: @@ -110,11 +112,10 @@ class GemmTCGEN5(GemmBase): B_shared = self.BRegion C_local = self.C clear_accum = self.clear_accum - mbar = self.mbarptr @T.prim_func def _gemm_ss() -> None: if thread_var // 32 == 0: - mma_emitter.tcgen05mma(A_shared, B_shared, C_local, mbar, clear_accum) + mma_emitter.tcgen05mma(A_shared, B_shared, C_local, mbarptr, clear_accum) return _Simplify(_gemm_ss, inline_let=True) diff --git a/tilelang/utils/__init__.py b/tilelang/utils/__init__.py index e13905f8..a713df8e 100644 --- a/tilelang/utils/__init__.py +++ b/tilelang/utils/__init__.py @@ -15,5 +15,6 @@ from .language import ( retrive_ptr_from_buffer_region, # noqa: F401 is_full_region, # noqa: F401 to_buffer_region, # noqa: F401 + get_buffer_region_from_load, # noqa: F401 ) from .deprecated import deprecated # noqa: F401 diff --git a/tilelang/utils/language.py b/tilelang/utils/language.py index e9fe13da..41da8ab0 100644 --- a/tilelang/utils/language.py +++ b/tilelang/utils/language.py @@ -1,10 +1,10 @@ from __future__ import annotations from tvm.tir import Buffer, BufferLoad, BufferRegion, PrimExpr +from tilelang.language.utils import region as _make_region_call from functools import reduce from tvm import IRModule, DataType from tvm.tir import PrimFunc from tvm import ir, tir - # Scope Checkers for TVM Buffers # These utility functions check the memory scope of a given TVM buffer. @@ -159,7 +159,8 @@ def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc: return func -def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> tir.BufferRegion | None: +def get_buffer_region_from_load(buffer_load: tir.BufferLoad, + extents: list[PrimExpr] | None = None) -> tir.BufferRegion | None: """ Get the buffer region from a buffer load. @@ -170,45 +171,71 @@ def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> tir.BufferRegion buffer, indices = buffer_load.buffer, buffer_load.indices regions = [] found_ramp: bool = False - for indice in indices: + + if extents is not None: + assert len(extents) == len(indices), "extents should have the same length as indices" + for i, indice in enumerate(indices): if isinstance(indice, tir.Ramp): + assert extents is None, "extents should be provided for BufferLoad with Ramp indices" regions.append(ir.Range.from_min_extent(indice.base, indice.lanes)) found_ramp = True elif isinstance(indice, tir.PrimExpr): - regions.append(ir.Range.from_min_extent(indice, 1)) + if extents is not None: + regions.append(ir.Range.from_min_extent(indice, extents[i])) + found_ramp = True + else: + regions.append(ir.Range.from_min_extent(indice, 1)) else: - raise ValueError("Unsupported type: ", type(indice)) + raise ValueError(f"Unsupported type: {type(indice)} for index {i}") if found_ramp: return tir.BufferRegion(buffer, regions) else: return None -def to_buffer_region(obj: Buffer | BufferLoad | BufferRegion) -> BufferRegion: +def to_buffer_region(obj: Buffer | BufferLoad | BufferRegion | tir.Var, + access_type: str = "rw", + extents: list[PrimExpr] | None = None) -> PrimExpr | BufferRegion: """ - Convert Buffer/BufferRegion/BufferLoad to a BufferRegion. + Convert to/from the tl.region representation. - - Buffer -> full-region BufferRegion covering entire shape - - BufferRegion -> returned as-is - - BufferLoad -> best-effort convert via get_buffer_region_from_load; - if scalar, fall back to 1-sized ranges at given indices + - Buffer/BufferLoad/BufferRegion -> returns a tl.region call (PrimExpr) + - tl.region Call -> returns the decoded BufferRegion for analysis """ + from tilelang.language.frame import has_let_value, get_let_value + if isinstance(obj, tir.Var) and has_let_value(obj): + obj = get_let_value(obj) + # Encode into tl.region call (when extents is provided), otherwise return BufferRegion for analysis if isinstance(obj, tir.BufferRegion): - return obj + if extents is None: + return obj + mins = [r.min for r in obj.region] + exts = [r.extent for r in obj.region] + assert len(extents) == len(exts) + exts = [tir.min(exts[i], extents[i]) for i in range(len(exts))] + return _make_region_call(tir.BufferLoad(obj.buffer, mins), access_type, *exts) if isinstance(obj, tir.Buffer): mins = [tir.IntImm("int32", 0) for _ in obj.shape] - ranges = [ir.Range.from_min_extent(m, e) for m, e in zip(mins, obj.shape)] - return tir.BufferRegion(obj, ranges) + if extents is None: + ranges = [ir.Range.from_min_extent(m, e) for m, e in zip(mins, obj.shape)] + return tir.BufferRegion(obj, ranges) + exts = list(extents) + return _make_region_call(tir.BufferLoad(obj, mins), access_type, *exts) if isinstance(obj, tir.BufferLoad): - region = get_buffer_region_from_load(obj) - if region is not None: - return region - # Fallback: scalar load -> 1-sized ranges at indices - mins = [idx for idx in obj.indices] - ones = [tir.IntImm("int32", 1) for _ in obj.indices] - ranges = [ir.Range.from_min_extent(m, e) for m, e in zip(mins, ones)] - return tir.BufferRegion(obj.buffer, ranges) - raise ValueError(f"Unsupported argument type for BufferRegion: {type(obj)}") + if extents is None: + region = get_buffer_region_from_load(obj) + if region is not None: + return region + mins = [idx for idx in obj.indices] + ones = [tir.IntImm("int32", 1) for _ in obj.indices] + ranges = [ir.Range.from_min_extent(m, e) for m, e in zip(mins, ones)] + return tir.BufferRegion(obj.buffer, ranges) + exts = list(extents) + if len(obj.indices) > len(exts): + exts = [tir.IntImm("int32", 1) for _ in range(len(obj.indices) - len(exts))] + exts + assert len(obj.indices) == len(exts) + return _make_region_call(obj, access_type, *exts) + raise ValueError(f"Unsupported argument type for to_buffer_region: {type(obj)}") def retrieve_shape(obj: Buffer | BufferRegion | BufferLoad) -> list: -- GitLab From f0c721a467ed0e535b160e3f7e76709faa77cf57 Mon Sep 17 00:00:00 2001 From: Yunqian Fan Date: Wed, 26 Nov 2025 15:44:00 +0800 Subject: [PATCH 051/139] [Enhancement] add more dtype and fix mma.ws for fp16 for tcgen05 (#1327) * feat: add fp8 variants; add placeholder for fp6/fp4 in meta support ld with pack for fp32 dtype add dump add tempalte expand remove unused dtype and change to rebased apis * fix: when atom-m!=128, enable_ws * fix: typo in tcgen05 meta; dispatch in gemm sm100 --- .../example_tilelang_gemm_fp8_sm100.py | 126 +++ src/op/copy.cc | 14 +- src/op/gemm_py.cc | 2 + src/op/tcgen5_meta.h | 38 +- src/tl_templates/cuda/copy_sm100.h | 35 +- src/tl_templates/cuda/gemm_sm100.h | 82 +- src/tl_templates/cuda/tcgen_05_ld.h | 755 +++++++++++++++++- tilelang/intrinsics/mma_macro_generator.py | 3 + .../intrinsics/tcgen05_macro_generator.py | 9 +- tilelang/jit/adapter/wrapper.py | 1 + tilelang/tileop/gemm/gemm_tcgen05.py | 5 +- 11 files changed, 980 insertions(+), 90 deletions(-) create mode 100644 examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py new file mode 100644 index 00000000..4628a997 --- /dev/null +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py @@ -0,0 +1,126 @@ +import torch +import tilelang +import tilelang.language as T +from tilelang.utils.tensor import map_torch_type + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + mbar = T.alloc_barrier(1) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm_v2( + A_shared, + B_shared, + C_tmem, + trans_A, + trans_B, + mbar=mbar, + wg_wait=-1, + clear_accum=(k == 0), + ) + T.mbarrier_wait_parity(mbar, k % 2) + + T.copy(C_tmem, C_local) + T.copy(C_local, C_shared) + + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +def calc_diff(x, y): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +M, N, K = 4096, 4096, 8192 +block_M, block_N, block_K = 64, 256, 32 +trans_A, trans_B = False, True +num_stages = 2 +threads = 256 +for tvm_fp8_dtype in ["float8_e4m3", "float8_e5m2"]: + for tvm_acc_dtype in ["float16", "float32"]: # , torch.float16]: + torch_fp8_dtype = map_torch_type(tvm_fp8_dtype) + torch_acc_dtype = map_torch_type(tvm_acc_dtype) + print(f"running {tvm_fp8_dtype} -> {tvm_acc_dtype}") + in_dtype, out_dtype, accum_dtype = tvm_fp8_dtype, tvm_acc_dtype, tvm_acc_dtype + + func = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, + ) + jit_kernel = tilelang.compile( + func, + out_idx=[2], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT: True, + }, + ) + # jit_kernel.export_ptx("./dump.ptx") + # jit_kernel.export_sources("./dump.cu") + + a = torch.randn(M, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype) + b = torch.randn(N, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype) + + c = jit_kernel(a, b) + ref_c = (a.to(torch.half) @ b.T.to(torch.half)).float() + c = c.float() + diff = calc_diff(c, ref_c) + # assert diff < 1e-3, f"{diff}" + print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] diff = {diff}") + + profiler = jit_kernel.get_profiler() + latency = profiler.do_bench() + print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Latency: {latency} ms") + print( + f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS" + ) diff --git a/src/op/copy.cc b/src/op/copy.cc index 9b93fea1..b0cac131 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -1118,6 +1118,11 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, bool is_ld = false; // tcgen05.ld (tensor memory -> register) bool is_st = false; // tcgen05.st (register -> tensor memory) bool is_cp = false; // tcgen05.cp (shared memory -> tensor memory) + bool src_needs_pack = + 16 == src->dtype.bits(); // if needs .pack::16b when is_ld + bool dst_needs_unpack = + 16 == dst->dtype.bits(); // if needs .unpack::16b when is_st + if (src.scope() == "shared.tmem" && dst.scope() == "local.fragment") { is_ld = true; } else if (src.scope() == "local.fragment" && dst.scope() == "shared.tmem") { @@ -1125,9 +1130,8 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, } else if (src.scope() == "shared.dyn" && dst.scope() == "shared.tmem") { is_cp = true; } else { - ICHECK(0) << "Unsupported tensor memory copy: " - << "src scope = " << src.scope() - << ", dst scope = " << dst.scope(); + ICHECK(0) << "Unsupported tensor memory copy: " << "src scope = " + << src.scope() << ", dst scope = " << dst.scope(); } // Currently tcgen05.cp is not supported // TODO (mzw) Support tcgen05.cp @@ -1247,8 +1251,10 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, : relative_wg_idx * (num_chunks_each_wg * meta.width); have_succeeded = true; Array args; + const char *bool_str = src_needs_pack ? "true" : "false"; args.push_back(StringImm(meta.intrinsics_name + "<" + - std::to_string(num_chunks_each_wg) + ">")); + std::to_string(num_chunks_each_wg) + ", " + + bool_str + ">")); args.push_back( BufferLoad(src, {(int)logical_row_min, (int)logical_col_min})); // Will be translated later diff --git a/src/op/gemm_py.cc b/src/op/gemm_py.cc index 511a4283..aa6c0282 100644 --- a/src/op/gemm_py.cc +++ b/src/op/gemm_py.cc @@ -344,6 +344,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { result.push_back(Integer(meta.atom_m)); result.push_back(Integer(meta.atom_n)); result.push_back(Integer(meta.atom_k)); + result.push_back(Integer(meta.enable_ws)); + result.push_back(Integer(meta.enable_2cta)); } return result; }); diff --git a/src/op/tcgen5_meta.h b/src/op/tcgen5_meta.h index bb63c8dc..3d994bf5 100644 --- a/src/op/tcgen5_meta.h +++ b/src/op/tcgen5_meta.h @@ -15,16 +15,19 @@ using runtime::DataType; struct TCGEN5MMAMeta { int atom_m, atom_n, atom_k; + bool enable_ws, enable_2cta; }; inline std::pair GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { // TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA. #define FAIL \ - return { false, TCGEN5MMAMeta{0, 0, 0} } -#define SUCCESS(atom_m, atom_n, atom_k) \ return { \ - true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \ + false, TCGEN5MMAMeta { 0, 0, 0, false, false } \ + } +#define SUCCESS(atom_m, atom_n, atom_k, use_ws, use_2cta) \ + return { \ + true, TCGEN5MMAMeta { atom_m, atom_n, atom_k, use_ws, use_2cta } \ } std::vector ws_valid_atom_ns = {256, 128, 64}; if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) && @@ -34,39 +37,52 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { if (M % 128 == 0) { for (int atom_n = 256; atom_n >= 16; atom_n -= 16) if (N % atom_n == 0) - SUCCESS(128, atom_n, 16); + SUCCESS(128, atom_n, 16, false, false); FAIL; } else if (M % 64 == 0) { for (int atom_n : ws_valid_atom_ns) if (N % atom_n == 0) - SUCCESS(64, atom_n, 16); + SUCCESS(64, atom_n, 16, true, false); FAIL; } else if (M % 32 == 0) { for (int atom_n : ws_valid_atom_ns) if (N % atom_n == 0) - SUCCESS(32, atom_n, 16); + SUCCESS(32, atom_n, 16, true, false); FAIL; } else { FAIL; } - } else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) && - (c_dtype.is_float() && c_dtype.bits() == 32)) { + } else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e4m3() || + ab_dtype.is_float8_e5m2() || ab_dtype.is_float8_e5m2fnuz() || + ab_dtype.is_float6_e2m3fn() || ab_dtype.is_float6_e3m2fn() || + ab_dtype.is_float4_e2m1fn()) && + ((c_dtype.is_float() && c_dtype.bits() == 32) || + (c_dtype.is_float16() && c_dtype.bits() == 16))) { if (K % 32 != 0) FAIL; if (M % 128 == 0) { + for (int atom_n : ws_valid_atom_ns) + if (N % atom_n == 0) + SUCCESS(128, atom_n, 32, true, false); for (int atom_n = 256; atom_n >= 16; atom_n -= 16) if (N % atom_n == 0) - SUCCESS(128, atom_n, 32); + SUCCESS(128, atom_n, 32, false, true); + for (int atom_n = 256; atom_n >= 8; atom_n -= 8) + if (N % atom_n == 0) + SUCCESS(128, atom_n, 32, false, false); FAIL; } else if (M % 64 == 0) { for (int atom_n : ws_valid_atom_ns) if (N % atom_n == 0) - SUCCESS(64, atom_n, 32); + SUCCESS(64, atom_n, 32, true, false); + for (int atom_n = 256; atom_n >= 8; atom_n -= 8) + if (N % atom_n == 0) + SUCCESS(64, atom_n, 32, false, false); FAIL; } else if (M % 32 == 0) { for (int atom_n : ws_valid_atom_ns) if (N % atom_n == 0) - SUCCESS(32, atom_n, 32); + SUCCESS(32, atom_n, 32, true, false); FAIL; } else { FAIL; diff --git a/src/tl_templates/cuda/copy_sm100.h b/src/tl_templates/cuda/copy_sm100.h index c4047c34..aa898bcc 100644 --- a/src/tl_templates/cuda/copy_sm100.h +++ b/src/tl_templates/cuda/copy_sm100.h @@ -51,6 +51,21 @@ __device__ __forceinline__ void st_global_256(fp8_e4_32_t *ptr, : : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w)); } +__device__ __forceinline__ ulonglong4 ld_global_256(const fp8_e5_32_t *ptr) { + ulonglong4 ret; + asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];" + : "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w) + : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ void st_global_256(fp8_e5_32_t *ptr, + fp8_e5_32_t &val8) { + ulonglong4 &val = *((ulonglong4 *)&val8); + asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};" + : + : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w)); +} __device__ __forceinline__ unsigned long long pack_bfloat16x4(const bfloat16_t x, const bfloat16_t y, const bfloat16_t z, @@ -95,38 +110,38 @@ __device__ __forceinline__ void tcgen05_ld_core(uint32_t const &tmem_start_col, } } -template +template __device__ __forceinline__ void tcgen05_ld_32dp32bNx(uint32_t const &tmem_start_col, uint32_t const &tmem_col_offset, dst_t *dst_ptr) { - tcgen05_ld_core(tmem_start_col + tmem_col_offset, - dst_ptr); + tcgen05_ld_core, 7, N>( + tmem_start_col + tmem_col_offset, dst_ptr); tl::fence_view_async_tmem_load(); } -template +template __device__ __forceinline__ void tcgen05_ld_32dp64bNx(uint32_t const &tmem_start_col, uint32_t const &tmem_col_offset, dst_t *dst_ptr) { - tcgen05_ld_core(tmem_start_col + tmem_col_offset, - dst_ptr); + tcgen05_ld_core, 7, N>( + tmem_start_col + tmem_col_offset, dst_ptr); tl::fence_view_async_tmem_load(); } -template +template __device__ __forceinline__ void tcgen05_ld_32dp128bNx(uint32_t const &tmem_start_col, uint32_t const &tmem_col_offset, dst_t *dst_ptr) { - tcgen05_ld_core( + tcgen05_ld_core, 6, N>( tmem_start_col + tmem_col_offset, dst_ptr); tl::fence_view_async_tmem_load(); } -template +template __device__ __forceinline__ void tcgen05_ld_32dp256bNx(uint32_t const &tmem_start_col, uint32_t const &tmem_col_offset, dst_t *dst_ptr) { - tcgen05_ld_core( + tcgen05_ld_core, 5, N>( tmem_start_col + tmem_col_offset, dst_ptr); tl::fence_view_async_tmem_load(); } diff --git a/src/tl_templates/cuda/gemm_sm100.h b/src/tl_templates/cuda/gemm_sm100.h index 856d37dd..84e22f24 100644 --- a/src/tl_templates/cuda/gemm_sm100.h +++ b/src/tl_templates/cuda/gemm_sm100.h @@ -243,47 +243,99 @@ struct DispatchInstruction -struct DispatchInstruction> { - using MMA = MMA_Traits, - Int, integral_constant, - integral_constant, - integral_constant, - integral_constant>; + using MMA = + MMA_Traits, Int, integral_constant, + integral_constant, + integral_constant, + integral_constant>; }; template -struct DispatchInstruction> { using MMA = - MMA_Traits, - Int, integral_constant, + MMA_Traits, Int, integral_constant, integral_constant, integral_constant, integral_constant>; }; template -struct DispatchInstruction> { - using MMA = MMA_Traits, - Int, integral_constant, + using MMA = MMA_Traits, Int, + integral_constant, integral_constant, integral_constant, integral_constant>; }; +template +struct DispatchInstruction> { + using MMA = MMA_Traits, Int, + integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; + +template +struct DispatchInstruction> { + using MMA = + MMA_Traits, Int, integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; template -struct DispatchInstruction> { using MMA = - MMA_Traits, - Int, integral_constant, + MMA_Traits, Int, integral_constant, integral_constant, integral_constant, integral_constant>; }; +template +struct DispatchInstruction> { + using MMA = MMA_Traits, Int, + integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; +template +struct DispatchInstruction> { + using MMA = MMA_Traits, Int, + integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; + template diff --git a/src/tl_templates/cuda/tcgen_05_ld.h b/src/tl_templates/cuda/tcgen_05_ld.h index b2eb2f81..9e5e3420 100644 --- a/src/tl_templates/cuda/tcgen_05_ld.h +++ b/src/tl_templates/cuda/tcgen_05_ld.h @@ -10,7 +10,9 @@ namespace tl { // 32 data path lanes, 32-bit pattern, repeated N times -class tmem_ld_32dp32bNx { +template class tmem_ld_32dp32bNx; + +template <> class tmem_ld_32dp32bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { @@ -180,9 +182,180 @@ public: } } }; +template <> class tmem_ld_32dp32bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, + "N must be a power of 2 and lies between 1 ~ 128"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.pack::16b.x1.b32" + "{%0}," + "[%1];\n" + : "=r"(dst_ptr[0]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.pack::16b.x2.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.pack::16b.x4.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.pack::16b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.pack::16b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.pack::16b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.pack::16b.x64.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 128) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.pack::16b.x128.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; // 16 data path lanes, 64-bit pattern, repeated N times -class tmem_ld_16dp64bNx { +template class tmem_ld_16dp64bNx; +template <> class tmem_ld_16dp64bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { @@ -352,39 +525,43 @@ public: } } }; - -// 16 data path lanes, 128-bit pattern, repeated N times -class tmem_ld_16dp128bNx { +template <> class tmem_ld_16dp64bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { - static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 64, - "N must be a power of 2 and lies between 1 ~ 64"); + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, + "N must be a power of 2 and lies between 1 ~ 128"); if constexpr (N == 1) { - asm volatile("tcgen05.ld.sync.aligned.16x128b.x1.b32" + asm volatile("tcgen05.ld.sync.aligned.16x64b.pack::16b.x1.b32" + "{%0}," + "[%1];\n" + : "=r"(dst_ptr[0]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x64b.pack::16b.x2.b32" "{%0, %1}," "[%2];\n" : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) : "r"(src_addr)); - } else if constexpr (N == 2) { - asm volatile("tcgen05.ld.sync.aligned.16x128b.x2.b32" + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.16x64b.pack::16b.x4.b32" "{%0, %1, %2, %3}," "[%4];\n" : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), "=r"(dst_ptr[3]) : "r"(src_addr)); - } else if constexpr (N == 4) { - asm volatile("tcgen05.ld.sync.aligned.16x128b.x4.b32" + } else if constexpr (N == 8) { + asm volatile("tcgen05.ld.sync.aligned.16x64b.pack::16b.x8.b32" "{%0, %1, %2, %3, %4, %5, %6, %7}," "[%8];\n" : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) : "r"(src_addr)); - } else if constexpr (N == 8) { + } else if constexpr (N == 16) { asm volatile( - "tcgen05.ld.sync.aligned.16x128b.x8.b32" + "tcgen05.ld.sync.aligned.16x64b.pack::16b.x16.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " "%14, %15}," "[%16];\n" @@ -395,9 +572,9 @@ public: "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), "=r"(dst_ptr[15]) : "r"(src_addr)); - } else if constexpr (N == 16) { + } else if constexpr (N == 32) { asm volatile( - "tcgen05.ld.sync.aligned.16x128b.x16.b32" + "tcgen05.ld.sync.aligned.16x64b.pack::16b.x32.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " "%26, %27, %28, %29, %30, %31}," @@ -414,9 +591,9 @@ public: "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) : "r"(src_addr)); - } else if constexpr (N == 32) { + } else if constexpr (N == 64) { asm volatile( - "tcgen05.ld.sync.aligned.16x128b.x32.b32" + "tcgen05.ld.sync.aligned.16x64b.pack::16b.x64.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " "%28, " @@ -449,9 +626,9 @@ public: "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), "=r"(dst_ptr[63]) : "r"(src_addr)); - } else if constexpr (N == 64) { + } else if constexpr (N == 128) { asm volatile( - "tcgen05.ld.sync.aligned.16x128b.x64.b32" + "tcgen05.ld.sync.aligned.16x64b.pack::16b.x128.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " "%28, " @@ -519,32 +696,39 @@ public: } }; -// 16 data path lanes, 256-bit pattern, repeated N times -class tmem_ld_16dp256bNx { +// 16 data path lanes, 128-bit pattern, repeated N times +template class tmem_ld_16dp128bNx; +template <> class tmem_ld_16dp128bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { - static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 32, - "N must be a power of 2 and lies between 1 ~ 32"); + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 64, + "N must be a power of 2 and lies between 1 ~ 64"); if constexpr (N == 1) { - asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32" + asm volatile("tcgen05.ld.sync.aligned.16x128b.x1.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.x2.b32" "{%0, %1, %2, %3}," "[%4];\n" : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), "=r"(dst_ptr[3]) : "r"(src_addr)); - } else if constexpr (N == 2) { - asm volatile("tcgen05.ld.sync.aligned.16x256b.x2.b32" + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.x4.b32" "{%0, %1, %2, %3, %4, %5, %6, %7}," "[%8];\n" : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) : "r"(src_addr)); - } else if constexpr (N == 4) { + } else if constexpr (N == 8) { asm volatile( - "tcgen05.ld.sync.aligned.16x256b.x4.b32" + "tcgen05.ld.sync.aligned.16x128b.x8.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " "%14, %15}," "[%16];\n" @@ -555,9 +739,9 @@ public: "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), "=r"(dst_ptr[15]) : "r"(src_addr)); - } else if constexpr (N == 8) { + } else if constexpr (N == 16) { asm volatile( - "tcgen05.ld.sync.aligned.16x256b.x8.b32" + "tcgen05.ld.sync.aligned.16x128b.x16.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " "%26, %27, %28, %29, %30, %31}," @@ -574,9 +758,9 @@ public: "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) : "r"(src_addr)); - } else if constexpr (N == 16) { + } else if constexpr (N == 32) { asm volatile( - "tcgen05.ld.sync.aligned.16x256b.x16.b32" + "tcgen05.ld.sync.aligned.16x128b.x32.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " "%28, " @@ -609,9 +793,492 @@ public: "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), "=r"(dst_ptr[63]) : "r"(src_addr)); - } else if constexpr (N == 32) { + } else if constexpr (N == 64) { asm volatile( - "tcgen05.ld.sync.aligned.16x256b.x32.b32" + "tcgen05.ld.sync.aligned.16x128b.x64.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; +template <> class tmem_ld_16dp128bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 64, + "N must be a power of 2 and lies between 1 ~ 64"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.pack::16b.x1.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.pack::16b.x2.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.pack::16b.x4.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.pack::16b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.pack::16b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.pack::16b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.pack::16b.x64.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; + +// 16 data path lanes, 256-bit pattern, repeated N times +template class tmem_ld_16dp256bNx; +template <> class tmem_ld_16dp256bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 32, + "N must be a power of 2 and lies between 1 ~ 32"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.x2.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x4.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; +template <> class tmem_ld_16dp256bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 32, + "N must be a power of 2 and lies between 1 ~ 32"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.pack::16b.x1.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.pack::16b.x2.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.pack::16b.x4.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.pack::16b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.pack::16b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.pack::16b.x32.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " "%28, " @@ -681,32 +1348,32 @@ public: // 32 data path lanes, 64-bit pattern, repeated N times // (conducted with 2x16dp64bNx) -class tmem_ld_32dp64bNx { +template class tmem_ld_32dp64bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { - tmem_ld_16dp64bNx::copy(src_addr, dst_ptr); - tmem_ld_16dp64bNx::copy(src_addr + (16 << 16), dst_ptr + N); + tmem_ld_16dp64bNx::copy(src_addr, dst_ptr); + tmem_ld_16dp64bNx::copy(src_addr + (16 << 16), dst_ptr + N); } }; // 32 data path lanes, 128-bit pattern, repeated N times -class tmem_ld_32dp128bNx { +template class tmem_ld_32dp128bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { - tmem_ld_16dp128bNx::copy(src_addr, dst_ptr); - tmem_ld_16dp128bNx::copy(src_addr + (16 << 16), dst_ptr + N * 2); + tmem_ld_16dp128bNx::copy(src_addr, dst_ptr); + tmem_ld_16dp128bNx::copy(src_addr + (16 << 16), dst_ptr + N * 2); } }; // 32 data path lanes, 256-bit pattern, repeated N times -class tmem_ld_32dp256bNx { +template class tmem_ld_32dp256bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { - tmem_ld_16dp256bNx::copy(src_addr, dst_ptr); - tmem_ld_16dp256bNx::copy(src_addr + (16 << 16), dst_ptr + N * 4); + tmem_ld_16dp256bNx::copy(src_addr, dst_ptr); + tmem_ld_16dp256bNx::copy(src_addr + (16 << 16), dst_ptr + N * 4); } }; diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index aab2a49e..6e49b058 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -47,7 +47,10 @@ class TensorCoreIntrinEmitter: "int8": "int8", "int32": "int32", "float8_e4m3": "e4m3", + "float8_e4m3fn": "e4m3", + "float8_e4m3fnuz": "e4m3", "float8_e5m2": "e5m2", + "float8_e5m2fnuz": "e5m2", } # Represent the thread binding in the form of (tx, warp_n, warp_m) diff --git a/tilelang/intrinsics/tcgen05_macro_generator.py b/tilelang/intrinsics/tcgen05_macro_generator.py index e53ff7cb..966f4dc4 100644 --- a/tilelang/intrinsics/tcgen05_macro_generator.py +++ b/tilelang/intrinsics/tcgen05_macro_generator.py @@ -169,12 +169,11 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): accum_dtype_in_bits = DataType(accum_dtype).bits meta = self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim) - if len(meta) != 3: + if len(meta) != 5: raise ValueError( f"Unsupported TCGEN5MMA configuration for desc generation: M={m_dim}, N={n_dim}, " f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}") - atom_m, atom_n, atom_k = (int(x) for x in meta) - enable_ws = atom_m != 128 + atom_m, atom_n, atom_k, enable_ws, enable_2cta = (int(x) for x in meta) # by default, we utilize non-swizzle layout offset a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * @@ -382,10 +381,10 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): k = int(self.chunk) meta = self.get_tcgen5_mma_meta(m, n, k) - if len(meta) != 3: + if len(meta) != 5: raise ValueError(f"Unsupported TCGEN5MMA configuration: M={m}, N={n}, K={k}, " f"A dtype={self.a_dtype}, accum dtype={self.accum_dtype}") - atom_m, atom_n, _ = (int(x) for x in meta) + atom_m, atom_n, _, _, _ = (int(x) for x in meta) if m % atom_m != 0 or n % atom_n != 0: raise ValueError( diff --git a/tilelang/jit/adapter/wrapper.py b/tilelang/jit/adapter/wrapper.py index 48b8e908..75607976 100644 --- a/tilelang/jit/adapter/wrapper.py +++ b/tilelang/jit/adapter/wrapper.py @@ -144,6 +144,7 @@ class TLCUDASourceWrapper: "float16": "half_t", "bfloat16": "bfloat16_t", "float8_e4m3": "fp8_e4_t", + "float8_e4m3fn": "fp8_e4_t", "float8_e5m2": "fp8_e5_t", "float64": "double", "int64": "int64_t", diff --git a/tilelang/tileop/gemm/gemm_tcgen05.py b/tilelang/tileop/gemm/gemm_tcgen05.py index c2c8c1c8..76f919e0 100644 --- a/tilelang/tileop/gemm/gemm_tcgen05.py +++ b/tilelang/tileop/gemm/gemm_tcgen05.py @@ -85,6 +85,9 @@ class GemmTCGEN5(GemmBase): raise ValueError(f"TCGEN5MMA currently only supports gemm_ss, got " f"A scope {self.A.scope()}, B scope {self.B.scope()}") + atom_m, atom_n, atom_k, enable_ws, enable_2cta = mma_emitter.get_tcgen5_mma_meta( + self.M, self.N, self.K) + if self.A.scope() not in {"shared", "shared.dyn", "shared.tmem"}: raise ValueError(f"Unsupported A scope for TCGEN5MMA: {self.A.scope()}") if self.B.scope() not in {"shared", "shared.dyn"}: @@ -105,7 +108,7 @@ class GemmTCGEN5(GemmBase): raise ValueError("TCGEN5MMA expects 2D coordinates for C buffer access") accum_dtype = str(self.C.dtype) - if accum_dtype != "float32": + if accum_dtype not in ["float32", 'float16']: raise ValueError(f"Unsupported accumulator dtype for TCGEN5MMA: {accum_dtype}") A_shared = self.ARegion -- GitLab From 17718bec9d4b10cc7360ee333c560ede675de66a Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Wed, 26 Nov 2025 19:16:29 +0800 Subject: [PATCH 052/139] [Refactor] Enhance CopyNode's IterVar Creation and Range Handling (#1346) * [Refactor] Enhance CopyNode's IterVar Creation and Range Handling This commit refines the `MakeIterVars` method in `CopyNode` to select base ranges based on memory scope levels, ensuring that the chosen ranges are not smaller than the original source ranges. Additionally, it updates the Python `copy` function to clarify range handling, including broadcasting logic and extent alignment. These changes improve the robustness and clarity of the copy operation's implementation. * test fix --- src/op/copy.cc | 88 ++++++++++++++++++++++++++++++++-- tilelang/language/copy.py | 25 +++++++--- tilelang/language/customize.py | 5 +- 3 files changed, 105 insertions(+), 13 deletions(-) diff --git a/src/op/copy.cc b/src/op/copy.cc index b0cac131..1bd548bc 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -179,15 +179,95 @@ TileOperator CopyNode::Clone() const { * copy operation. */ Array CopyNode::MakeIterVars() const { + // Choose the range set from the lowest-level memory scope between src and + // dst. Scope levels: global < shared/shared.dyn/shared.tmem < local.fragment + // (fragment) + auto scope_level = [](const Buffer &b) -> int { + String s = b.scope(); + if (s == "local.fragment" || s == "local") + return 2; + if (s == "shared" || s == "shared.dyn" || s == "shared.tmem") + return 1; + // default to global level for unknown scopes + return 0; + }; + + int src_level = scope_level(src); + int dst_level = scope_level(dst); + bool base_is_src = (src_level >= dst_level); + const Array &base_ranges = base_is_src ? src_range : dst_range; + + // Sanity check: when switching away from the original (src_range), + // ensure the chosen base ranges are not provably smaller than the original + // per dimension. This guards against generating undersized loop domains. + // Improved logic: use two pointers to traverse both base_ranges and + // src_range, skipping dimensions with extent == 1. The number of non-1 + // extents must match. + arith::Analyzer analyzer; + + size_t base_dim = 0, src_dim = 0; + while (base_dim < base_ranges.size() && src_dim < src_range.size()) { + // Skip base extents that are 1 + while (base_dim < base_ranges.size() && + is_one(base_ranges[base_dim]->extent)) { + ++base_dim; + } + // Skip src extents that are 1 + while (src_dim < src_range.size() && is_one(src_range[src_dim]->extent)) { + ++src_dim; + } + // Both indices now at non-1, or at end + if (base_dim < base_ranges.size() && src_dim < src_range.size()) { + PrimExpr base_ext = base_ranges[base_dim]->extent; + PrimExpr src_ext = src_range[src_dim]->extent; + // Only fail if base extent is provably smaller than src extent + if (analyzer.CanProve(base_ext < src_ext)) { + std::ostringstream oss; + oss << "Selected loop range is smaller than original src range at " + "matched non-1 dimension: " + << "base(extent=" << base_ext + << ", scope=" << (base_is_src ? src.scope() : dst.scope()) + << ", min=" << base_ranges[base_dim]->min + << ", base_dim=" << base_dim << ") < src(extent=" << src_ext + << ", min=" << src_range[src_dim]->min << ", src_dim=" << src_dim + << ", scope=" << src.scope() << ") for src=" << src->name + << ", dst=" << dst->name << "\n"; + oss << "src buffer: " << src->name << ", scope=" << src.scope() << "\n"; + oss << "dst buffer: " << dst->name << ", scope=" << dst.scope() << "\n"; + oss << "base_ranges[" << base_dim + << "]: min=" << base_ranges[base_dim]->min + << ", extent=" << base_ext << "\n"; + oss << "src_ranges[" << src_dim << "]: min=" << src_range[src_dim]->min + << ", extent=" << src_ext << "\n"; + LOG(FATAL) << oss.str(); + } + ++base_dim; + ++src_dim; + } + } + + // Any remaining unmatched dimensions in either range must all have extent == + // 1 + while (base_dim < base_ranges.size()) { + ICHECK(is_one(base_ranges[base_dim]->extent)) + << "base_ranges has extra non-1 extent at dim " << base_dim; + ++base_dim; + } + while (src_dim < src_range.size()) { + ICHECK(is_one(src_range[src_dim]->extent)) + << "src_range has extra non-1 extent at dim " << src_dim; + ++src_dim; + } + Array loop_vars; size_t idx = 0; - for (size_t i = 0; i < src_range.size(); i++) { - if (is_one(src_range[i]->extent)) + for (size_t i = 0; i < base_ranges.size(); i++) { + if (is_one(base_ranges[i]->extent)) continue; - Var var = Var(std::string{char('i' + idx)}, src_range[i]->extent->dtype); + Var var = Var(std::string{char('i' + idx)}, base_ranges[i]->extent->dtype); idx++; loop_vars.push_back( - {Range(0, src_range[i]->extent), var, IterVarType::kDataPar}); + {Range(0, base_ranges[i]->extent), var, IterVarType::kDataPar}); } return loop_vars; } diff --git a/tilelang/language/copy.py b/tilelang/language/copy.py index d59d73e8..965919fd 100644 --- a/tilelang/language/copy.py +++ b/tilelang/language/copy.py @@ -27,6 +27,22 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, Returns: tir.Call: A handle to the copy operation + + Range handling notes: + - Accepts `Buffer`/`BufferRegion`/`BufferLoad` on either side. Extents are + derived as follows: `Buffer -> shape`, `BufferRegion -> [r.extent]`, + `BufferLoad -> extents from its inferred/encoded region`. + - If both `src` and `dst` are scalar `BufferLoad` without region extents, + lowers to a direct store: `dst[...] = src`. + - If one side is missing extents, it is treated as all-ones with the other + side's rank to enable broadcasting. + - Extents are right-aligned and legalized via `legalize_pairwise_extents`: + per tail-dimension, equal keeps as-is, a `1` broadcasts to the other, + otherwise a conservative `tir.max` is used to remain safe for dynamic + shapes. + - The finalized extents are encoded with `tl.region` via `to_buffer_region` + and passed through to the backend; low-level loop construction and any + scope-specific decisions happen during lowering. """ if isinstance(src, tir.Buffer) and isinstance(dst, tir.Buffer): ir.assert_structural_equal(src.shape, dst.shape) @@ -57,16 +73,11 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, return tir.BufferStore(dst.buffer, src, dst.indices) assert src_extent or dst_extent, "Can't deduce copy extents from args" - # Treat missing extent as length-matched ones to enable broadcasting logic. + # Treat missing extent as length-matched ones to enable broadcasting. src_extent = list(src_extent) if src_extent else [1] * len(dst_extent) dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent) - # Align and broadcast extents from the right (tail) side independently - # for src and dst, so we can pass them unchanged into _to_region. - # Rules per-dim from the right: - # - equal -> keep both - # - one is 1 -> set that side to the other side's dim - # - otherwise -> error + # Align and broadcast extents from the right (tail) side. src_extent, dst_extent = legalize_pairwise_extents(src_extent, dst_extent) # Use legalized extents for src and dst respectively. diff --git a/tilelang/language/customize.py b/tilelang/language/customize.py index 3d40ce47..720c9e99 100644 --- a/tilelang/language/customize.py +++ b/tilelang/language/customize.py @@ -46,8 +46,9 @@ def reshape(src: Buffer, shape: list[PrimExpr]) -> Buffer: Returns: Buffer: A new buffer view with the specified shape """ - assert prim_expr_equal(bits_product(shape, src.dtype), - bits_product(src.shape, src.dtype)), "T.reshape/view shape check failed." + assert prim_expr_equal( + bits_product(shape, src.dtype), bits_product(src.shape, src.dtype) + ), f"T.reshape/view shape check failed. src {src} src.shape: {src.shape}, src.dtype: {src.dtype}, target shape: {shape}, target dtype: {src.dtype}" return T.Tensor(shape, src.dtype, src.data) -- GitLab From 4f844000e3d36b9ff2c7bc4f44bbcea8c92bd152 Mon Sep 17 00:00:00 2001 From: Kuris <227995639+kurisu6912@users.noreply.github.com> Date: Wed, 26 Nov 2025 19:27:43 +0800 Subject: [PATCH 053/139] [Fix] Fix missing `not` rewrite in frontend (#1348) --- .../language/test_tilelang_language_frontend_v2.py | 13 +++++++++++++ tilelang/language/v2/ast.py | 12 ++++++++++-- tilelang/language/v2/builder.py | 9 +++++---- 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/testing/python/language/test_tilelang_language_frontend_v2.py b/testing/python/language/test_tilelang_language_frontend_v2.py index 299a4127..ee694104 100644 --- a/testing/python/language/test_tilelang_language_frontend_v2.py +++ b/testing/python/language/test_tilelang_language_frontend_v2.py @@ -466,5 +466,18 @@ def test_buffer_slice_step(): pass +def test_boolop(): + a = Var('a', 'int32') + b = Var('b', 'int32') + c = Var('c', 'int32') + d = Var('d', 'int32') + + @T.macro + def cond(): + return not (a < b and b < c and a * d < b * d) or b * d < c * d + + cond() + + if __name__ == '__main__': tilelang.testing.main() diff --git a/tilelang/language/v2/ast.py b/tilelang/language/v2/ast.py index 307efdac..c6dfecf1 100644 --- a/tilelang/language/v2/ast.py +++ b/tilelang/language/v2/ast.py @@ -78,7 +78,7 @@ def quote_expr(expr: str, **kws) -> ast.expr: Operator = Literal['Add', 'Sub', 'Mult', 'MatMult', 'Div', 'Mod', 'Pow', 'LShift', 'RShift', 'BitOr', 'BitXor', 'BitAnd', 'FloorDiv'] -BoolOp = Literal['And', 'Or'] +BoolOp = Literal['And', 'Or', 'Not'] def get_operator_name(operator: ast.operator) -> Operator: @@ -217,11 +217,13 @@ class BaseBuilder: def aug_assign_slice(self, op: Operator, target: Any, sl: slice, aug_value: Any): eval_aug_assign(op, target, sl, aug_value) - def boolop(self, op: BoolOp, left: Any, right: Callable[[], Any]) -> Any: + def boolop(self, op: BoolOp, left: Any, right: Callable[[], Any] | None = None) -> Any: if op == 'And': return left and right() if op == 'Or': return left or right() + if op == 'Not': + return not left raise ValueError(f'Unknown boolop: {op}') def ifexp(self, cond: Any, then: Callable[[], Any], otherwise: Callable[[], Any]) -> Any: @@ -517,6 +519,12 @@ class DSLMutator(ast.NodeTransformer): ) return last + def visit_UnaryOp(self, node: ast.UnaryOp): + node = self.generic_visit(node) + if isinstance(node.op, ast.Not): + return quote_expr("__tb.boolop('Not', operand)", operand=node.operand, span=node) + return node + def visit_Compare(self, node: ast.Compare) -> ast.expr: node = self.generic_visit(node) left = node.left diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index c54b0701..aea425ad 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -148,8 +148,7 @@ class Builder(BaseBuilder): @classmethod def current(cls) -> Self: - builder = thread_local_storage.builder - assert builder is not None, "No active Builder found in the current thread." + builder = getattr(thread_local_storage, 'builder', None) return builder @contextmanager @@ -424,7 +423,7 @@ class Builder(BaseBuilder): else: return super().aug_assign_slice(op, target, sl, aug_value) - def boolop(self, op, left, right): + def boolop(self, op, left, right=None): left = unwrap_cond(left) if isinstance(left, PrimExpr): with self.with_frame(BoolOpFrame()): @@ -432,6 +431,8 @@ class Builder(BaseBuilder): return tir.And(left, right()) if op == 'Or': return tir.Or(left, right()) + if op == 'Not': + return tir.Not(left) raise RuntimeError(f"Unsupported boolean operator: {op}") else: return super().boolop(op, left, right) @@ -562,7 +563,7 @@ class Macro(Generic[_P, _T]): return self.ir_gen.source def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: - builder = Builder.current() + builder = Builder.current() or Builder() with builder.macro(self.name, self.annotations): res = self.ir_gen.gen(builder)(*args, **kwargs) return res -- GitLab From 6bae64f6ebf5737bb8648b81584cd1b644e003d2 Mon Sep 17 00:00:00 2001 From: Gongen-Ali Date: Wed, 26 Nov 2025 19:48:57 +0800 Subject: [PATCH 054/139] [Enhancement] Add support for k_pack in gemm_mfma (#1344) * add support for k_pack * support benchmark on ROCm * fix format --- benchmark/matmul_fp8/benchmark_matmul.py | 6 +++- src/tl_templates/hip/hip_fp8.h | 38 +++++++++++++++++++++ tilelang/intrinsics/mfma_macro_generator.py | 9 ++--- tilelang/tileop/gemm/gemm_mfma.py | 18 +++++----- 4 files changed, 58 insertions(+), 13 deletions(-) diff --git a/benchmark/matmul_fp8/benchmark_matmul.py b/benchmark/matmul_fp8/benchmark_matmul.py index 36b91035..796f7b90 100644 --- a/benchmark/matmul_fp8/benchmark_matmul.py +++ b/benchmark/matmul_fp8/benchmark_matmul.py @@ -1,5 +1,6 @@ import argparse import itertools +import torch import logging import tilelang import tilelang.language as T @@ -99,6 +100,7 @@ def get_configs(args, kwargs): block_K=[64, 128], num_stages=[0, 1, 2, 3], thread_num=[128, 256], + k_pack=[1, 2], policy=[T.GemmWarpPolicy.Square], enable_rasteration=[True, False], ) @@ -125,6 +127,7 @@ def matmul( block_K=None, num_stages=None, thread_num=None, + k_pack=None, policy=None, enable_rasteration=None, ): @@ -156,7 +159,7 @@ def matmul( # Use half-precision for input data to reduce memory bandwidth, # accumulate in float for better numerical accuracy - dtype = "float8_e4m3" + dtype = "float8_e4m3fnuz" if torch.version.hip is not None else "float8_e4m3" accum_dtype = "float" @T.prim_func @@ -210,6 +213,7 @@ def matmul( C_local, transpose_B=True, policy=policy, + k_pack=k_pack, ) # Write back the results from C_local to the global memory C T.copy(C_local, C_shared) diff --git a/src/tl_templates/hip/hip_fp8.h b/src/tl_templates/hip/hip_fp8.h index 0000745b..b32f84dc 100644 --- a/src/tl_templates/hip/hip_fp8.h +++ b/src/tl_templates/hip/hip_fp8.h @@ -127,3 +127,41 @@ __device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z, res.y = *reinterpret_cast(&b); return res; } + +__device__ fp8_e4_16_t make_fp8_e4_16_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, + fp8_e4_t x3, fp8_e4_t x4, fp8_e4_t x5, + fp8_e4_t x6, fp8_e4_t x7, fp8_e4_t y0, + fp8_e4_t y1, fp8_e4_t y2, fp8_e4_t y3, + fp8_e4_t y4, fp8_e4_t y5, fp8_e4_t y6, + fp8_e4_t y7) { + signed char x0_char = *reinterpret_cast(&x0); + signed char x1_char = *reinterpret_cast(&x1); + signed char x2_char = *reinterpret_cast(&x2); + signed char x3_char = *reinterpret_cast(&x3); + signed char x4_char = *reinterpret_cast(&x4); + signed char x5_char = *reinterpret_cast(&x5); + signed char x6_char = *reinterpret_cast(&x6); + signed char x7_char = *reinterpret_cast(&x7); + signed char y0_char = *reinterpret_cast(&y0); + signed char y1_char = *reinterpret_cast(&y1); + signed char y2_char = *reinterpret_cast(&y2); + signed char y3_char = *reinterpret_cast(&y3); + signed char y4_char = *reinterpret_cast(&y4); + signed char y5_char = *reinterpret_cast(&y5); + signed char y6_char = *reinterpret_cast(&y6); + signed char y7_char = *reinterpret_cast(&y7); + int a = (x3_char << 24) | (x2_char << 16) | (x1_char << 8) | x0_char; + int b = (x7_char << 24) | (x6_char << 16) | (x5_char << 8) | x4_char; + int c = (y3_char << 24) | (y2_char << 16) | (y1_char << 8) | y0_char; + int d = (y7_char << 24) | (y6_char << 16) | (y5_char << 8) | y4_char; + fp8_e4_8_t res_x; + res_x.x = *reinterpret_cast(&a); + res_x.y = *reinterpret_cast(&b); + fp8_e4_8_t res_y; + res_y.x = *reinterpret_cast(&c); + res_y.y = *reinterpret_cast(&d); + fp8_e4_16_t res; + res.x = res_x; + res.y = res_y; + return res; +} \ No newline at end of file diff --git a/tilelang/intrinsics/mfma_macro_generator.py b/tilelang/intrinsics/mfma_macro_generator.py index 02c0b039..618a9981 100644 --- a/tilelang/intrinsics/mfma_macro_generator.py +++ b/tilelang/intrinsics/mfma_macro_generator.py @@ -372,8 +372,8 @@ class MatrixCoreIntrinEmitter: a_is_fragment = is_fragment(A_local_buf) b_is_fragment = is_fragment(B_local_buf) - a_local_stride: PrimExpr = k_inner * warp_rows * local_size_a if a_is_fragment else 0 - b_local_stride: PrimExpr = k_inner * warp_cols * local_size_b if b_is_fragment else 0 + a_local_stride: PrimExpr = k_inner * warp_rows * k_pack * local_size_a if a_is_fragment else 0 + b_local_stride: PrimExpr = k_inner * warp_cols * k_pack * local_size_b if b_is_fragment else 0 @T.macro def _warp_mfma(A_local_buf, B_local_buf, C_local_buf): @@ -543,7 +543,8 @@ class MatrixCoreIntrinEmitter: return local_id base_fragment = T.Fragment( - [micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s], + [micro_size_s, micro_size_r * + self.k_pack] if is_sr_axis_order else [micro_size_r * self.k_pack, micro_size_s], forward_thread_fn=forward_thread, forward_index_fn=forward_index, ) @@ -552,7 +553,7 @@ class MatrixCoreIntrinEmitter: chunk = self.chunk warp_s = warp_rows if matrix_is_a else warp_cols - warp_r = chunk // micro_size_r + warp_r = chunk // (micro_size_r * self.k_pack) block_s = block_row_warps if matrix_is_a else block_col_warps replicate = block_col_warps if matrix_is_a else block_row_warps diff --git a/tilelang/tileop/gemm/gemm_mfma.py b/tilelang/tileop/gemm/gemm_mfma.py index 45a53d3c..862ec725 100644 --- a/tilelang/tileop/gemm/gemm_mfma.py +++ b/tilelang/tileop/gemm/gemm_mfma.py @@ -28,6 +28,7 @@ class GemmMFMA(GemmBase): warp_row_tiles=warp_row_tiles, warp_col_tiles=warp_col_tiles, chunk=self.chunk, + k_pack=self.k_pack, ) if self.is_gemm_ss(): @@ -75,6 +76,7 @@ class GemmMFMA(GemmBase): warp_col_tiles=warp_col_tiles, chunk=self.chunk, thread_var=thread_var, + k_pack=self.k_pack, ) in_dtype = self.in_dtype @@ -110,11 +112,11 @@ class GemmMFMA(GemmBase): B_shared into local fragments, then issues Matrix Core mfma ops, accumulating into C_local. """ - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + A_local = T.alloc_local((warp_rows * local_size_a * self.k_pack), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b * self.k_pack), in_dtype) if clear_accum: T.clear(C_buf) - for ki in T.serial(0, (block_K // micro_size_k)): + for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))): # Load A into fragment mfma_emitter.ldmatrix_a( A_local, @@ -145,12 +147,12 @@ class GemmMFMA(GemmBase): B_shared into local fragments, then issues Matrix Core mfma ops, accumulating into C_local. """ - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + A_local = T.alloc_local((warp_rows * local_size_a * self.k_pack), in_dtype) if clear_accum: T.clear(C_buf) - for ki in T.serial(0, (block_K // micro_size_k)): + for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))): # Load A into fragment mfma_emitter.ldmatrix_a( @@ -177,10 +179,10 @@ class GemmMFMA(GemmBase): B_shared into local fragments, then issues Matrix Core mfma ops, accumulating into C_local. """ - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b * self.k_pack), in_dtype) if clear_accum: T.clear(C_buf) - for ki in T.serial(0, (block_K // micro_size_k)): + for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))): # Load B into fragment mfma_emitter.ldmatrix_b( @@ -207,7 +209,7 @@ class GemmMFMA(GemmBase): accumulating into C_local. """ - for ki in T.serial(0, (block_K // micro_size_k)): + for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))): # Perform Matrix Multiplication mfma_emitter.mfma(A_buf, B_buf, C_buf, ki) -- GitLab From b8240b7ae9387ba7143e6243b59069c3a04a12e9 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Thu, 27 Nov 2025 14:28:14 +0800 Subject: [PATCH 055/139] Add sparse fine-tuning kernel for deepseek sparse attention to example (#1296) * [EXAMPLE] add example for dsa sparse finetuning * [Refactor] --- examples/dsa_sparse_finetune/dsa.py | 252 +++++++++++ examples/dsa_sparse_finetune/index.py | 79 ++++ examples/dsa_sparse_finetune/indexer_bwd.py | 265 +++++++++++ .../indexer_topk_reducesum.py | 277 ++++++++++++ .../dsa_sparse_finetune/sparse_mla_bwd.py | 420 ++++++++++++++++++ .../dsa_sparse_finetune/sparse_mla_fwd.py | 332 ++++++++++++++ .../sparse_mla_topk_reducesum.py | 241 ++++++++++ examples/dsa_sparse_finetune/utils.py | 75 ++++ 8 files changed, 1941 insertions(+) create mode 100644 examples/dsa_sparse_finetune/dsa.py create mode 100644 examples/dsa_sparse_finetune/index.py create mode 100644 examples/dsa_sparse_finetune/indexer_bwd.py create mode 100644 examples/dsa_sparse_finetune/indexer_topk_reducesum.py create mode 100644 examples/dsa_sparse_finetune/sparse_mla_bwd.py create mode 100644 examples/dsa_sparse_finetune/sparse_mla_fwd.py create mode 100644 examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py create mode 100644 examples/dsa_sparse_finetune/utils.py diff --git a/examples/dsa_sparse_finetune/dsa.py b/examples/dsa_sparse_finetune/dsa.py new file mode 100644 index 00000000..1ca28241 --- /dev/null +++ b/examples/dsa_sparse_finetune/dsa.py @@ -0,0 +1,252 @@ +from typing import Optional +import torch +import torch.nn.functional as F +from indexer_topk_reducesum import indexer_topk_reducesum_interface +from indexer_bwd import indexer_bwd_interface +from sparse_mla_fwd import sparse_mla_fwd_interface +from sparse_mla_bwd import sparse_mla_bwd +from sparse_mla_topk_reducesum import sparse_mla_topk_reducesum_interface +from einops import einsum, repeat +from utils import get_abs_err, get_err_ratio + + +class RegsiterLossFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, loss): + ctx.save_for_backward(loss) + return x + + @staticmethod + def backward(ctx, grad): + loss = ctx.saved_tensors + return grad, torch.ones(1, dtype=loss[0].dtype, device=loss[0].device) + + +register_loss = RegsiterLossFunction.apply + + +def ref_deepseek_sparse_attention_innner( + q: torch.Tensor, + kv: torch.Tensor, + index_q: torch.Tensor, + index_k: torch.Tensor, + weights: torch.Tensor, + topk: int, + dim_v: int, + sm_scale: Optional[float] = None, + index_sm_scale: Optional[float] = None, +): + dtype = q.dtype + q, kv, index_q, index_k, weights = map(lambda x: x.to(torch.float32), + (q, kv, index_q, index_k, weights)) + + index_sm_scale = index_q.shape[-1]**-0.5 + b, s = index_q.shape[:2] + + # tl_topk_indices = tl_topk_indices.to(torch.int64) + # tl_topk_indices[tl_topk_indices == -1] = s + + casual_mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device) + index_logits = einsum(index_q, index_k, 'b s1 h k, b s2 k -> b s1 h s2') + index_logits = F.relu(index_logits) + index_logits = (index_logits * weights.unsqueeze(-1)).sum( + dim=-2, dtype=torch.float32) * index_sm_scale + index_logits = torch.where(casual_mask, index_logits, float('-inf')) + topk_indices = torch.topk(index_logits, k=topk, dim=-1).indices + topk_logits = torch.gather( + F.pad(index_logits, (0, 1), value=float('-inf')), dim=-1, index=topk_indices) + topk_score = F.log_softmax(topk_logits, dim=-1, dtype=torch.float32) + index_topk_score = topk_score + + if sm_scale is None: + sm_scale = kv.shape[-1]**-0.5 + + h = q.shape[-2] + index_mask = torch.zeros((b, s, s + 1), dtype=torch.bool, device="cuda")\ + .scatter_(dim=-1, index=topk_indices, src=torch.ones_like(topk_indices, dtype=torch.bool))[:, :, :-1] + mask = repeat(casual_mask & index_mask, 'b s1 s2 -> b s1 h s2', h=h) + k, v = kv, kv[..., :dim_v] + logits = einsum(q, k, 'b s1 h d, b s2 d -> b s1 h s2') * sm_scale + logits = torch.where(mask, logits, float('-inf')) + attn_score = F.softmax(logits, dim=-1, dtype=torch.float32) + o = einsum(attn_score, v, 'b s1 h s2, b s2 d -> b s1 h d') + + attn_score = attn_score.sum(dim=-2) # [b, s1, s2] + attn_topk_score = torch.gather(F.pad(attn_score, (0, 1)), dim=-1, index=topk_indices) + attn_topk_score = attn_topk_score / attn_topk_score.sum(dim=-1, keepdim=True) + + loss = F.kl_div( + index_topk_score.clip(-100, 0), + attn_topk_score.detach().log().clip(-100, 0), + log_target=True, + reduction="sum") + o = register_loss(o, loss) + + return o.to(dtype), topk_indices + + +def ref_deepseek_sparse_attention( + q: torch.Tensor, + kv: torch.Tensor, + index_q: torch.Tensor, + index_k: torch.Tensor, + weights: torch.Tensor, + offsets: torch.Tensor, + topk: int, + dim_v: int, + sm_scale: Optional[float] = None, + index_sm_scale: Optional[float] = None, +): + all_o, all_topk_indices = [], [] + for i in range(offsets.shape[0] - 1): + o, topk_indices = ref_deepseek_sparse_attention_innner( + q[None, offsets[i]:offsets[i + 1]], + kv[None, offsets[i]:offsets[i + 1]], + index_q[None, offsets[i]:offsets[i + 1]], + index_k[None, offsets[i]:offsets[i + 1]], + weights[None, offsets[i]:offsets[i + 1]], + topk, + dim_v, + sm_scale, + index_sm_scale, + ) + all_o.append(o.squeeze(0)) + all_topk_indices.append(topk_indices.squeeze(0)) + o = torch.cat(all_o, dim=0) + topk_indices = torch.cat(all_topk_indices, dim=0) + return o, topk_indices + + +class DSAFunction(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q: torch.Tensor, + kv: torch.Tensor, + index_q: torch.Tensor, + index_k: torch.Tensor, + weights: torch.Tensor, + offsets: torch.Tensor, + topk: int, + dim_v: int, + sm_scale: Optional[float] = None, + ): + # topk_indices, index_score = ref_index_score(index_q, weights, index_k, topk) + topk_indices, index_score = indexer_topk_reducesum_interface(index_q, weights, index_k, + topk, offsets) + o, lse = sparse_mla_fwd_interface( + q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), offsets, sm_scale=sm_scale, d_v=dim_v) + ctx.save_for_backward(q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, + offsets) + ctx.topk = topk + ctx.dim_v = dim_v + ctx.sm_scale = sm_scale + return o, topk_indices + + @staticmethod + def backward( + ctx, + do: torch.Tensor, + _1: torch.Tensor, + ): + q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets = ctx.saved_tensors + attn_score = sparse_mla_topk_reducesum_interface( + q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), lse, offsets, + dim_v=ctx.dim_v).squeeze(-2) + dq, dkv = sparse_mla_bwd( + q, + kv.unsqueeze(-2), + o, + do, + topk_indices.unsqueeze(-2), + lse, + offsets, + sm_scale=ctx.sm_scale) + dindex_q, dweights, dindex_k = indexer_bwd_interface(index_q, weights, index_k, attn_score, + index_score, topk_indices, offsets) + return dq, dkv.squeeze(-2), dindex_q, dindex_k, dweights, None, None, None, None + + +def deepseek_sparse_attention( + q: torch.Tensor, + kv: torch.Tensor, + index_q: torch.Tensor, + index_k: torch.Tensor, + weights: torch.Tensor, + offsets: torch.Tensor, + topk: int, + dim_v: int, + sm_scale: Optional[float] = None, +): + return DSAFunction.apply(q, kv, index_q, index_k, weights, offsets, topk, dim_v, sm_scale) + + +def test_kernel( + B=1, + S=2048, + H=16, + D=512, + tail_D=64, + index_D=128, + topk=64, +): + torch.manual_seed(42) + q = torch.randn((S, H, D + tail_D)).cuda().bfloat16().requires_grad_() + kv = torch.randn((S, D + tail_D)).cuda().bfloat16().requires_grad_() + index_q = torch.randn((S, H, index_D)).cuda().bfloat16().requires_grad_() + weights = torch.randn((S, H)).cuda().bfloat16().requires_grad_() + index_k = torch.randn((S, index_D)).cuda().bfloat16().requires_grad_() + do = torch.randn((S, H, D)).cuda().bfloat16().requires_grad_() + offsets = torch.tensor([0, S // 2, S], dtype=torch.int32).cuda() + + o, topk_indices = deepseek_sparse_attention(q, kv, index_q, index_k, weights, offsets, topk, D) + o.backward(do) + q_grad, q.grad = q.grad, None + kv_grad, kv.grad = kv.grad, None + index_q_grad, index_q.grad = index_q.grad, None + index_k_grad, index_k.grad = index_k.grad, None + weights_grad, weights.grad = weights.grad, None + + ref_o, ref_topk_indices = ref_deepseek_sparse_attention(q, kv, index_q, index_k, weights, + offsets, topk, D) + ref_o.backward(do) + ref_q_grad, q.grad = q.grad, None + ref_kv_grad, kv.grad = kv.grad, None + ref_index_q_grad, index_q.grad = index_q.grad, None + ref_index_k_grad, index_k.grad = index_k.grad, None + ref_weights_grad, weights.grad = weights.grad, None + + print(f"o err: {get_abs_err(o, ref_o):.6f} ratio: {get_err_ratio(o, ref_o):.6f}") + print( + f"q.grad err: {get_abs_err(q_grad, ref_q_grad):.6f} ratio: {get_err_ratio(q_grad, ref_q_grad):.6f}" + ) + print( + f"kv.grad err: {get_abs_err(kv_grad, ref_kv_grad):.6f} ratio: {get_err_ratio(kv_grad, ref_kv_grad):.6f}" + ) + print( + f"index_q.grad err: {get_abs_err(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f} ratio: {get_err_ratio(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f}" + ) + print( + f"index_k.grad err: {get_abs_err(index_k_grad, ref_index_k_grad):.6f} ratio: {get_err_ratio(index_k_grad, ref_index_k_grad):.6f}" + ) + print( + f"weights.grad err: {get_abs_err(weights_grad, ref_weights_grad):.6f} ratio: {get_err_ratio(weights_grad, ref_weights_grad):.6f}" + ) + + intersections = [] + for j in range(S): + ref_np = ref_topk_indices[j].cpu().to(torch.int32).numpy() + trt_np = topk_indices[j].cpu().to(torch.int32).numpy() + + mask = (trt_np != -1) + + set_ref = set(ref_np[mask]) + set_trt = set(trt_np[mask]) + intersection = set_ref & set_trt + intersections.append(len(intersection) / len(set_ref)) + print("average intersections: {:.4f}".format(sum(intersections) / len(intersections))) + + +test_kernel() diff --git a/examples/dsa_sparse_finetune/index.py b/examples/dsa_sparse_finetune/index.py new file mode 100644 index 00000000..92ce687f --- /dev/null +++ b/examples/dsa_sparse_finetune/index.py @@ -0,0 +1,79 @@ +# Modified from: https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/index.py +import torch +import torch.nn.functional as F +import functools +from typing import Callable, Any + + +def tensor_cache(fn: Callable[..., torch.Tensor],) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent result of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + If the function is called again with the same input tensors, it will return the cached result. + + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + last_args: tuple | None = None + last_kwargs: dict | None = None + last_result: Any = None + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal last_args, last_kwargs, last_result + + if (last_args is not None and last_kwargs is not None) and \ + (len(args) == len(last_args) and len(kwargs) == len(last_kwargs)) and \ + all(a is b for a, b in zip(args, last_args, strict=False)) and \ + all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()): + return last_result + + result = fn(*args, **kwargs) + last_args, last_kwargs, last_result = args, kwargs, result + return result + + return wrapper + + +@tensor_cache +def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return torch.diff(cu_seqlens) + + +@tensor_cache +def prepare_cu_seqlens_from_lens( + lens: torch.LongTensor, + dtype: torch.dtype | None = torch.int32, +) -> torch.LongTensor: + return F.pad(lens.cumsum(dim=0, dtype=dtype), (1, 0)) + + +@tensor_cache +def prepare_lens_from_cu_seqlens(cu_seqlens: torch.LongTensor,) -> torch.LongTensor: + return torch.diff(cu_seqlens) + + +@tensor_cache +def prepare_position_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return torch.cat([ + torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) + for n in prepare_lens(cu_seqlens).unbind() + ]) + + +@tensor_cache +def prepare_sequence_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return prepare_position_ids(cu_seqlens).eq(0).cumsum(0) - 1 + + +@tensor_cache +def prepare_token_indices(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + position_ids = prepare_position_ids(cu_seqlens) + return torch.stack([prepare_sequence_ids(cu_seqlens), position_ids], 1).to(cu_seqlens) diff --git a/examples/dsa_sparse_finetune/indexer_bwd.py b/examples/dsa_sparse_finetune/indexer_bwd.py new file mode 100644 index 00000000..5430c1c0 --- /dev/null +++ b/examples/dsa_sparse_finetune/indexer_bwd.py @@ -0,0 +1,265 @@ +import torch +import torch.nn.functional as F +from einops import einsum, repeat + +import tilelang as tl +import tilelang.language as T +from typing import Optional +from index import prepare_token_indices + +from utils import get_abs_err, get_err_ratio + +BF16 = "bfloat16" +FP32 = "float32" +INT32 = "int32" + +pass_configs = { + tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, +} + + +@tl.jit(pass_configs=pass_configs) +def tl_indexer_bwd_impl( + heads: int, + dim: int, + topk: int, + sm_scale: Optional[float] = None, + block_I: int = 32, + num_stages: int = 0, + num_threads: int = 128, +): + assert num_stages == 0 + assert topk == tl.math.next_power_of_2(topk) + assert topk % block_I == 0 + assert heads <= 64 and heads % 8 == 0 + batch_plus_one = T.symbolic("batch_plus_one") + seq_len = T.symbolic("seq_len") + dtype: str = BF16 + accum_dtype: str = FP32 + index_q_shape = [seq_len, heads, dim] + weights_shape = [seq_len, heads] + index_k_shape = [seq_len, dim] + shape_p = [seq_len, topk] + topk_indices_shape = [seq_len, topk] + offsets_shape = [batch_plus_one] + token_indices_shape = [seq_len, 2] + if sm_scale is None: + sm_scale = dim**-0.5 + + @T.prim_func + def tl_indexer_bwd_kernel( + IndexQ: T.Tensor(index_q_shape, dtype), + Weights: T.Tensor(weights_shape, dtype), + IndexK: T.Tensor(index_k_shape, dtype), + dIndexQ: T.Tensor(index_q_shape, dtype), + dWeights: T.Tensor(weights_shape, dtype), + dIndexK: T.Tensor(index_k_shape, dtype), + AttnScore: T.Tensor(shape_p, FP32), + IndexScore: T.Tensor(shape_p, FP32), + TopkIndices: T.Tensor(topk_indices_shape, INT32), + Offsets: T.Tensor(offsets_shape, INT32), + TokenIndices: T.Tensor(token_indices_shape, INT32), + ): + with T.Kernel(seq_len, threads=num_threads) as (bx): + i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1] + bos = Offsets[i_b] + num_blocks = T.ceildiv(topk, block_I) + + index_q_shared = T.alloc_shared([heads, dim], dtype=dtype) + weights_shared = T.alloc_shared([heads], dtype=dtype) + + d_index_q_frag = T.alloc_fragment([heads, dim], dtype=accum_dtype) + d_weights_frag = T.alloc_fragment([heads], dtype=accum_dtype) + + T.copy(IndexQ[bos + i_t, :, :], index_q_shared) + T.copy(Weights[bos + i_t, :], weights_shared) + T.fill(d_index_q_frag, 0) + T.fill(d_weights_frag, 0) + + for i, j in T.Parallel(heads, dim): + index_q_shared[i, j] = index_q_shared[i, j] * sm_scale + + for bi_i in T.Pipelined(num_blocks, num_stages=num_stages): + + i_st = bi_i * block_I + i_ed = (bi_i + 1) * block_I + + indices_shared = T.alloc_shared([block_I], dtype=INT32) + T.copy(TopkIndices[bos + i_t, i_st:i_ed], indices_shared) + + index_k_shared = T.alloc_shared([block_I, dim], dtype=dtype) + for i, j in T.Parallel(block_I, dim): + pos = indices_shared[i] + index_k_shared[i, j] = T.if_then_else((pos > -1) & (pos <= i_t), + IndexK[bos + pos, j], 0) + + attn_score_shared = T.alloc_shared([block_I], dtype=accum_dtype) + index_score_shared = T.alloc_shared([block_I], dtype=accum_dtype) + for i in T.Parallel(block_I): + attn_score_shared[i] = AttnScore[bos + i_t, i_st + i] + index_score_shared[i] = IndexScore[bos + i_t, i_st + i] + + logits = T.alloc_fragment((block_I, heads), accum_dtype) + T.gemm( + index_k_shared, + index_q_shared, + logits, + transpose_A=False, + transpose_B=True, + clear_accum=True, + ) + for i, j in T.Parallel(block_I, heads): + logits[i, j] = T.max(logits[i, j], 0) + + # dw + d_weights_i = T.alloc_fragment((block_I, heads), accum_dtype) + for i, j in T.Parallel(block_I, heads): + d_weights_i[i, + j] = (index_score_shared[i] - attn_score_shared[i]) * logits[i, j] + T.reduce_sum(d_weights_i, d_weights_frag, dim=0, clear=False) + + d_logits_qk = T.alloc_shared((block_I, heads), accum_dtype) + d_logits_qk_cast1 = T.alloc_fragment((block_I, heads), dtype) + d_logits_qk_cast2 = T.alloc_fragment((block_I, heads), dtype) + + for i, j in T.Parallel(block_I, heads): + d_relu = T.alloc_var(accum_dtype) + if logits[i, j] > 0: + d_relu = 1.0 + else: + d_relu = 0.0 + d_logits_qk[i, j] = (index_score_shared[i] - + attn_score_shared[i]) * d_relu * weights_shared[j] + + # dq + T.copy(d_logits_qk, d_logits_qk_cast1) + T.gemm( + d_logits_qk_cast1, # [BS, HQ] + index_k_shared, # [BS, K] + d_index_q_frag, # [HQ, K] + transpose_A=True, + transpose_B=False, + clear_accum=False, + ) + + # dk + T.copy(d_logits_qk, d_logits_qk_cast2) + d_index_k_frag = T.alloc_fragment([block_I, dim], dtype=accum_dtype) + T.gemm( + d_logits_qk_cast2, # [BS, HQ] + index_q_shared, # [HQ, K] + d_index_k_frag, # [BS, K] + transpose_A=False, + transpose_B=False, + clear_accum=True, + ) + + for i, j in T.Parallel(block_I, dim): + pos = indices_shared[i] + if ((pos > -1) & (pos <= i_t)): + T.atomic_add(dIndexK[bos + pos, j], d_index_k_frag[i, j]) + + for i, j in T.Parallel(heads, dim): + d_index_q_frag[i, j] = d_index_q_frag[i, j] * sm_scale + + T.copy(d_index_q_frag, dIndexQ[bos + i_t, :, :]) + T.copy(d_weights_frag, dWeights[bos + i_t, :]) + + return tl_indexer_bwd_kernel + + +def indexer_bwd_interface( + q: torch.Tensor, + weights: torch.Tensor, + k: torch.Tensor, + attn_score: torch.Tensor, + index_score: torch.Tensor, + topk_indices: torch.Tensor, + offsets: torch.Tensor, +): + _, heads, dim, topk = *q.shape, topk_indices.shape[-1] + token_indices = prepare_token_indices(offsets) + dq = torch.zeros_like(q) + dweights = torch.zeros_like(weights) + dk = torch.zeros_like(k) + kernel = tl_indexer_bwd_impl(heads, dim, topk) + kernel(q, weights, k, dq, dweights, dk, attn_score, index_score, topk_indices, offsets, + token_indices) + return dq, dweights, dk + + +def ref_indexer_bwd(Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, + TopkIndices: torch.Tensor, AttnScore: torch.Tensor, + offsets: torch.Tensor) -> torch.Tensor: + Q.requires_grad_(True) + Weights.requires_grad_(True) + K.requires_grad_(True) + softmax_scale = Q.shape[-1]**-0.5 + all_loss = [] + all_log_topk_prob = [] + for i in range(offsets.shape[0] - 1): + assert (offsets[i + 1] - offsets[i]).item() >= TopkIndices.shape[-1] + q = Q[offsets[i]:offsets[i + 1]] + weights = Weights[offsets[i]:offsets[i + 1]] + k = K[offsets[i]:offsets[i + 1]] + topk_indices = TopkIndices[offsets[i]:offsets[i + 1]] + attn_score = AttnScore[offsets[i]:offsets[i + 1]] + s = q.shape[0] + mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device) + logits = einsum(q, k, 's1 h k, s2 k -> s1 h s2') * softmax_scale + logits = F.relu(logits) + score = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) + score = torch.where(mask, score, float('-inf')) + topk_value = torch.gather(score, dim=-1, index=topk_indices.to(torch.int64)) + log_topk_prob = F.log_softmax(topk_value, dim=-1, dtype=torch.float32) + loss = F.kl_div( + log_topk_prob.clip(-100, 0), + attn_score.log().clip(-100, 0), + log_target=True, + reduction="sum") + all_loss.append(loss) + all_log_topk_prob.append(log_topk_prob) + loss = torch.stack(all_loss).sum() + loss.backward() + log_topk_prob = torch.cat(all_log_topk_prob, dim=0) + return log_topk_prob.exp(), Q.grad, Weights.grad, K.grad + + +def test_kernel( + B=1, + S=2048, + H=16, + D=128, + topk=64, +): + torch.manual_seed(42) + q = torch.randn((S, H, D)).cuda().bfloat16() + w = torch.randn((S, H)).cuda().bfloat16() + k = torch.randn((S, D)).cuda().bfloat16() + offsets = torch.tensor([0, 1023, S], dtype=torch.int32).cuda() + + all_attn_score = [] + for i in range(offsets.shape[0] - 1): + seq_len = (offsets[i + 1] - offsets[i]).item() + mask = (torch.arange(seq_len)[:, None] >= torch.arange(topk)[None, :]).to(q.device) + logits = torch.ones(seq_len, topk).cuda() + logits = torch.where(mask, logits, float('-inf')) + attn_score = F.softmax(logits, dim=-1, dtype=torch.float32) + all_attn_score.append(attn_score) + attn_score = torch.cat(all_attn_score, dim=0) + + topk_indices = repeat( + torch.arange(topk, dtype=torch.int32).cuda(), 'k -> s k', s=S).contiguous() + index_score, ref_dq, ref_dw, ref_dk = ref_indexer_bwd(q, w, k, topk_indices, attn_score, + offsets) + + dq, dw, dk = indexer_bwd_interface(q, w, k, attn_score, index_score, topk_indices, offsets) + + print(f"dq err: {get_abs_err(dq, ref_dq):.6f} ratio: {get_err_ratio(dq, ref_dq):.6f}") + print(f"dq err: {get_abs_err(dw, ref_dw):.6f} ratio: {get_err_ratio(dw, ref_dw):.6f}") + print(f"dq err: {get_abs_err(dk, ref_dk):.6f} ratio: {get_err_ratio(dk, ref_dk):.6f}") + + +if __name__ == '__main__': + test_kernel() diff --git a/examples/dsa_sparse_finetune/indexer_topk_reducesum.py b/examples/dsa_sparse_finetune/indexer_topk_reducesum.py new file mode 100644 index 00000000..b7fa6627 --- /dev/null +++ b/examples/dsa_sparse_finetune/indexer_topk_reducesum.py @@ -0,0 +1,277 @@ +import math +import torch +import torch.nn.functional as F +from einops import einsum + +import tilelang as tl +import tilelang.language as T +from typing import Optional +from index import prepare_token_indices + +from utils import get_abs_err, get_err_ratio + +BF16 = "bfloat16" +FP32 = "float32" +INT32 = "int32" + +pass_configs = { + tl.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True, + tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, +} + + +@tl.jit(pass_configs=pass_configs) +def tl_indexer_topk_reducesum_impl( + heads: int, + dim: int, + topk: int, + sm_scale: Optional[float] = None, + block_K: int = 32, + dtype: str = FP32, + num_stages: int = 0, + num_threads: int = 128, +): + assert topk == tl.math.next_power_of_2(topk) + assert topk % block_K == 0 + assert heads <= 64 and heads % 8 == 0 + assert num_stages == 0 + batch_plus_one = T.symbolic("batch_plus_one") + seq_len = T.symbolic("seq_len") + + index_q_shape = [seq_len, heads, dim] + weights_shape = [seq_len, heads] + index_k_shape = [seq_len, dim] + topk_indices_shape = [seq_len, topk] + offsets_shape = [batch_plus_one] + token_indices_shape = [seq_len, 2] + + N = 2 * topk + num_iters = int(round(math.log2(N))) + if sm_scale is None: + sm_scale = dim**-0.5 + + @T.macro + def bitonic_sort( + topk_index_shared: T.SharedBuffer([N], dtype=INT32), + topk_value_shared: T.SharedBuffer([N], dtype=FP32), + ): + T.sync_threads() + for i1 in T.serial(num_iters): + for i2 in T.serial(i1 + 1): + for i in T.Parallel(N): + ascending = (i & (1 << (i1 + 1))) != 0 + j = i ^ (1 << (i1 - i2)) + if i < j and \ + ((ascending and topk_value_shared[i] > topk_value_shared[j]) or ( + not ascending and topk_value_shared[i] < topk_value_shared[j])): + val = topk_value_shared[i] + topk_value_shared[i] = topk_value_shared[j] + topk_value_shared[j] = val + idx = topk_index_shared[i] + topk_index_shared[i] = topk_index_shared[j] + topk_index_shared[j] = idx + T.sync_threads() + + @T.prim_func + def tl_indexer_topk_reducesum_kernel( + IndexQ: T.Tensor(index_q_shape, dtype), + Weights: T.Tensor(weights_shape, dtype), + IndexK: T.Tensor(index_k_shape, dtype), + TopkIndices: T.Tensor(topk_indices_shape, INT32), + ReduceSum: T.Tensor(topk_indices_shape, FP32), + Offsets: T.Tensor(offsets_shape, INT32), + TokenIndices: T.Tensor(token_indices_shape, INT32), + ): + with T.Kernel(seq_len, threads=num_threads) as (bx): + i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1] + bos, eos = Offsets[i_b], Offsets[i_b + 1] + num_blocks = T.ceildiv(i_t + 1, block_K) + + topk_index_shared = T.alloc_shared([N], dtype=INT32) + topk_value_shared = T.alloc_shared([N], dtype=FP32) + + T.fill(topk_index_shared, -1) + T.fill(topk_value_shared, float('-inf')) + T.sync_threads() + + index_q_shared = T.alloc_shared([heads, dim], dtype=dtype) + T.copy(IndexQ[bos + i_t, :, :], index_q_shared) + T.sync_threads() + + weights_frag = T.alloc_shared([heads], dtype=dtype) + T.copy(Weights[bos + i_t, :], weights_frag) + T.sync_threads() + + for i, j in T.Parallel(heads, dim): + index_q_shared[i, j] = index_q_shared[i, j] * sm_scale + T.sync_threads() + + for bk_i in T.Pipelined(num_blocks, num_stages=num_stages): + k_st = bk_i * block_K + k_ed = T.min((bk_i + 1) * block_K, eos - bos) + + index_k_shared = T.alloc_shared([block_K, dim], dtype=dtype) + for i, j in T.Parallel(block_K, dim): + index_k_shared[i, j] = T.if_then_else(k_st + i < k_ed, IndexK[bos + k_st + i, + j], 0) + T.sync_threads() + + logits = T.alloc_fragment((block_K, heads), FP32) + T.gemm( + index_k_shared, + index_q_shared, + logits, + transpose_A=False, + transpose_B=True, + clear_accum=True, + ) + T.sync_threads() + + for i, j in T.Parallel(block_K, heads): + logits[i, j] = T.max(logits[i, j], 0) * weights_frag[j] + T.sync_threads() + + logits_sum = T.alloc_fragment(block_K, FP32) + T.reduce_sum(logits, logits_sum, dim=1) + T.sync_threads() + + offset = T.alloc_var(INT32) + if k_st >= topk: + offset = topk + (k_st % topk) + else: + offset = k_st + T.sync_threads() + for i in T.Parallel(block_K): + if k_st + i > i_t: + logits_sum[i] = float('-inf') + j = offset + i + topk_index_shared[j] = k_st + i + topk_value_shared[j] = logits_sum[i] + T.sync_threads() + + if k_ed > topk and k_ed % topk == 0: + bitonic_sort(topk_index_shared, topk_value_shared) + + bitonic_sort(topk_index_shared, topk_value_shared) + + logits_max_frag = T.alloc_fragment([1], dtype=FP32) + logits_frag = T.alloc_fragment([topk], dtype=FP32) + reducesum_shared = T.alloc_shared([topk], dtype=FP32) + + T.copy(topk_value_shared[:topk], logits_frag) + T.sync_threads() + + T.reduce_max(logits_frag, logits_max_frag, dim=-1) + T.sync_threads() + + for i in T.Parallel(topk): + logits_frag[i] = T.exp(logits_frag[i] - logits_max_frag[0]) + T.sync_threads() + + lse_frag = T.alloc_fragment([1], dtype=FP32) + T.reduce_sum(logits_frag, lse_frag) + T.sync_threads() + + for i in T.Parallel(topk): + reducesum_shared[i] = logits_frag[i] / lse_frag[0] + T.sync_threads() + + # for i in T.Parallel(topk): + # reducesum_shared[i] = logits_frag[i] + # T.sync_threads() + + for i in T.Parallel(topk): + if topk_index_shared[i] > i_t: + topk_index_shared[i] = -1 + T.sync_threads() + + T.copy(topk_index_shared[:topk], TopkIndices[bos + i_t, :]) + T.copy(reducesum_shared[:topk], ReduceSum[bos + i_t, :]) + + return tl_indexer_topk_reducesum_kernel + + +def indexer_topk_reducesum_interface( + q: torch.Tensor, + weights: torch.Tensor, + k: torch.Tensor, + topk: int, + offsets: torch.Tensor, + dtype: str = BF16, +): + seq_len, heads, dim = q.shape + kernel = tl_indexer_topk_reducesum_impl(heads=heads, dim=dim, topk=topk, dtype=dtype) + token_indices = prepare_token_indices(offsets) + topk_indices = torch.zeros((seq_len, topk), device=q.device, dtype=torch.int32) + topk_score = torch.zeros((seq_len, topk), device=q.device, dtype=torch.float32) + kernel(q, weights, k, topk_indices, topk_score, offsets, token_indices) + return topk_indices, topk_score + + +def ref_index_score(Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, topk: int, + offsets: torch.Tensor) -> torch.Tensor: + all_topk_indices = [] + all_topk_score = [] + for i in range(offsets.shape[0] - 1): + assert (offsets[i + 1] - offsets[i]).item() >= topk + q = Q[offsets[i]:offsets[i + 1]] + weights = Weights[offsets[i]:offsets[i + 1]] + k = K[offsets[i]:offsets[i + 1]] + softmax_scale = q.shape[-1]**-0.5 + s = q.shape[0] + mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device) + logits = einsum(q, k, 's1 h k, s2 k -> s1 h s2') + logits = F.relu(logits) + logits = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) * softmax_scale + logits = torch.where(mask, logits, float('-inf')) + topk_logits, topk_indices = torch.topk(logits, k=topk, dim=-1) + topk_score = F.softmax(topk_logits, dim=-1, dtype=torch.float32) + all_topk_indices.append(topk_indices) + all_topk_score.append(topk_score) + topk_indices = torch.cat(all_topk_indices, dim=0) + topk_score = torch.cat(all_topk_score, dim=0) + return topk_indices, topk_score + + +def test_kernel( + B=1, + S=2048, + H=64, + D=128, + topk=64, +): + torch.manual_seed(42) + + q = torch.randn((S, H, D)).cuda().bfloat16() + weights = torch.randn((S, H)).cuda().bfloat16() + k = torch.randn((S, D)).cuda().bfloat16() + offsets = torch.tensor([0, S], dtype=torch.int32).cuda() + + ref_topk_indices, ref_topk_score = ref_index_score(q, weights, k, topk, offsets) + + topk_indices, topk_score = indexer_topk_reducesum_interface(q, weights, k, topk, offsets) + + for j in range(S): + ref_np = ref_topk_indices[j].cpu().to(torch.int32).numpy() + trt_np = topk_indices[j].cpu().to(torch.int32).numpy() + + ref_np_val = ref_topk_score[j] + trt_np_val = topk_score[j] + + mask = (ref_np_val > 0).cpu().numpy() + + set_ref = set(ref_np[mask]) + set_trt = set(trt_np[mask]) + intersection = set_ref & set_trt + + print("idx:", j, "selected/all:", len(intersection), "/", len(set_ref), "=", + len(intersection) / len(set_ref)) + + print( + f"err: {get_abs_err(ref_np_val, trt_np_val):.6f} ratio: {get_err_ratio(ref_np_val, trt_np_val):.6f}" + ) + + +if __name__ == '__main__': + test_kernel() diff --git a/examples/dsa_sparse_finetune/sparse_mla_bwd.py b/examples/dsa_sparse_finetune/sparse_mla_bwd.py new file mode 100644 index 00000000..33c21cb4 --- /dev/null +++ b/examples/dsa_sparse_finetune/sparse_mla_bwd.py @@ -0,0 +1,420 @@ +# ruff: noqa +import tilelang +from tilelang import language as T +import torch +from index import prepare_token_indices + +from utils import assert_tensors_similar + + +@tilelang.jit(out_idx=[-1]) +def preprocess( + H, + D, + block_ND=32, + num_stages=5, + dtype="bfloat16", + accum_dtype="float", +): + assert dtype == "bfloat16" + assert accum_dtype == "float" + + S = T.symbolic('S') + + shape = [S, H, D] + + @T.prim_func + def preprocess_kernel( + O: T.Tensor(shape, dtype), + dO: T.Tensor(shape, dtype), + Delta: T.Tensor([S, H], accum_dtype), + ): + with T.Kernel(H, T.ceildiv(S, block_ND)) as (bx, by): + o = T.alloc_fragment([block_ND, block_ND], accum_dtype) + do = T.alloc_fragment([block_ND, block_ND], accum_dtype) + delta = T.alloc_fragment([block_ND], accum_dtype) + acc = T.alloc_fragment([block_ND, block_ND], accum_dtype) + T.clear(acc) + for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages): + T.copy(O[by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], o) + T.copy(dO[by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], + do) + for i, j in T.Parallel(block_ND, block_ND): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[by * block_ND:(by + 1) * block_ND, bx]) + + return preprocess_kernel + + +@tilelang.jit(out_idx=[-1]) +def postprocess( + D, + D_tail, + kv_group=1, + block_N=64, + threads=128, + dtype="bfloat16", + accum_dtype="float", +): + assert dtype == "bfloat16" + assert accum_dtype == "float" + S_kv = T.symbolic('S_kv') + + dkv_shape = [S_kv, kv_group, D + D_tail] + + @T.prim_func + def postprocess_kernel( + dKV: T.Tensor(dkv_shape, accum_dtype), + dKV_out: T.Tensor(dkv_shape, dtype), + ): + with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, threads=threads) as (bx, by): + T.copy( + dKV[bx * block_N:(bx + 1) * block_N, by, :], + dKV_out[bx * block_N:(bx + 1) * block_N, by, :], + ) + + return postprocess_kernel + + +@tilelang.jit( + out_idx=[-2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) +def bwd( + H, + D, + D_tail, + topk, + kv_group=1, + sm_scale=None, + is_causal=True, + block_size=32, + num_stages=0, + threads=128, + indices_dtype="int32", + dtype="bfloat16", + accum_dtype="float", +): + assert is_causal == True, 'non-casual is not supported now' + assert topk % block_size == 0, 'otherwise will load some index=0 thus causing wrong kv to be loaded' + assert dtype == "bfloat16" + assert accum_dtype == "float" + assert indices_dtype == "int32" + + if sm_scale is None: + sm_scale = (D + D_tail)**(-0.5) + + B_plus_one = T.symbolic('B_plus_one') + S = T.symbolic('S') + + H_kv = H // kv_group + q_shape = [S, H, D + D_tail] + k_shape = [S, kv_group, D + D_tail] + o_shape = [S, H, D] + indices_shape = [S, kv_group, topk] + delta_shape = [S, H] + lse_shape = [S, H] + offsets_shape = [B_plus_one] + token_indices_shape = [S, 2] + assert indices_dtype == "int32" + assert dtype == "bfloat16" + assert accum_dtype == "float" + + H = H_kv + padded_H = max(tilelang.math.next_power_of_2(H_kv), 16) + BS = block_size + NS = tilelang.cdiv(topk, block_size) + + split_store = 2 + + @T.prim_func + def sparse_mla_bwd_kernel( + Q: T.Tensor(q_shape, dtype), + KV: T.Tensor(k_shape, dtype), + dO: T.Tensor(o_shape, dtype), + Indices: T.Tensor(indices_shape, indices_dtype), + Lse: T.Tensor(lse_shape, accum_dtype), + Delta: T.Tensor(delta_shape, accum_dtype), + Offsets: T.Tensor(offsets_shape, indices_dtype), + TokenIndices: T.Tensor(token_indices_shape, indices_dtype), + dQ: T.Tensor(q_shape, dtype), + dKV: T.Tensor(k_shape, accum_dtype), + ): + with T.Kernel(S, kv_group, threads=threads) as (b_s_i, bz): + Q_shared = T.alloc_shared([padded_H, D], dtype) + Q_tail_shared = T.alloc_shared([padded_H, D_tail], dtype) + KV_shared = T.alloc_shared([BS, D], dtype) + KV_tail_shared = T.alloc_shared([BS, D_tail], dtype) + dO_shared = T.alloc_shared([padded_H, D], dtype) + mask = T.alloc_fragment([BS], "bool") + + P_shared_cast = T.alloc_shared([padded_H, BS], dtype) + dP_shared_cast = T.alloc_shared([padded_H, BS], dtype) + dQ_shared = T.alloc_shared([padded_H, D], dtype) + dQ_tail_shared = T.alloc_shared([padded_H, D_tail], dtype) + + acc_p = T.alloc_fragment([padded_H, BS], accum_dtype) + acc_dp = T.alloc_fragment([padded_H, BS], accum_dtype) + acc_dq = T.alloc_fragment([padded_H, D], accum_dtype) + acc_dq_tail = T.alloc_fragment([padded_H, D_tail], accum_dtype) + acc_dkv = T.alloc_fragment([BS, D], accum_dtype) + acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype) + acc_dkv_shared = T.view(KV_shared, shape=[BS // split_store, D], dtype=accum_dtype) + acc_dkv_tail_shared = T.view( + KV_tail_shared, shape=[BS // split_store, D_tail], dtype=accum_dtype) + + b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1] + bos, eos = Offsets[b_i], Offsets[b_i + 1] + + max_kv_i = s_i + + T.copy(Q[bos + s_i, bz * padded_H:(bz + 1) * padded_H, :D], Q_shared) + T.copy(Q[bos + s_i, bz * padded_H:(bz + 1) * padded_H, D:], Q_tail_shared) + T.copy(dO[bos + s_i, bz * padded_H:(bz + 1) * padded_H, :D], dO_shared) + + T.clear(acc_dq) + T.clear(acc_dq_tail) + + T.annotate_layout({ + dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared), + dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared), + }) + + # Process each block of indices + for i_i in T.Pipelined(NS, num_stages=num_stages): + # Check which indices are valid + for bi_i in T.Parallel(BS): + mask[bi_i] = (Indices[bos + s_i, bz, i_i * BS + bi_i] <= max_kv_i) & ( + Indices[bos + s_i, bz, i_i * BS + bi_i] != -1) + + # Compute attention scores + for h_i, bi_i in T.Parallel(padded_H, BS): + acc_p[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_p.dtype)) + + # Load KV, V for this block of indices + for bi_i, d_i in T.Parallel(BS, D): + KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, + d_i] + + T.gemm( + Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + + for bi_i, d_i in T.Parallel(BS, D_tail): + KV_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], + bz, D + d_i] + T.gemm( + Q_tail_shared, + KV_tail_shared, + acc_p, + transpose_B=True, + policy=T.GemmWarpPolicy.FullCol) + + for h_i, bi_i in T.Parallel(padded_H, BS): + acc_p[h_i, bi_i] = T.exp(acc_p[h_i, bi_i] * sm_scale - + Lse[bos + s_i, bz * padded_H + h_i]) + + T.copy(acc_p, P_shared_cast) + + T.gemm( + dO_shared, + KV_shared, + acc_dp, + transpose_B=True, + policy=T.GemmWarpPolicy.FullCol, + clear_accum=True) + + for h_i, bi_i in T.Parallel(padded_H, BS): + acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * ( + acc_dp[h_i, bi_i] - Delta[bos + s_i, bz * padded_H + h_i]) * sm_scale + + T.copy(acc_dp, dP_shared_cast) + T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol) + T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol) + + T.gemm( + dP_shared_cast, + Q_shared, + acc_dkv, + transpose_A=True, + policy=T.GemmWarpPolicy.FullCol, + clear_accum=True) + T.gemm( + P_shared_cast, + dO_shared, + acc_dkv, + transpose_A=True, + policy=T.GemmWarpPolicy.FullCol) + + T.clear(acc_dkv_tail) + T.gemm( + dP_shared_cast, + Q_tail_shared, + acc_dkv_tail, + transpose_A=True, + policy=T.GemmWarpPolicy.FullCol) + + for s in range(split_store): + for bi_i, d_i in T.Parallel(BS, D): + if bi_i < BS // split_store: + acc_dkv_shared[bi_i, d_i] = acc_dkv[bi_i + s * (BS // split_store), d_i] + + for bi_i, d_i in T.Parallel(BS, D_tail): + if bi_i < BS // split_store: + acc_dkv_tail_shared[bi_i, + d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), + d_i] + + for bi_i, d_i in T.Parallel(BS // split_store, D // 4): + T.atomic_addx4( + dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * + (BS // split_store)], bz, d_i * 4], + acc_dkv_shared[bi_i, d_i * 4]) + + # Atomically update dKV, dKV_tail tensors + for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4): + T.atomic_addx4( + dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * + (BS // split_store)], bz, D + d_i * 4], + acc_dkv_tail_shared[bi_i, d_i * 4]) + + # Store the accumulated dQ + T.copy(acc_dq, dQ_shared) + T.copy(acc_dq_tail, dQ_tail_shared) + + T.copy(dQ_shared, dQ[bos + s_i, bz * padded_H:(bz + 1) * padded_H, :D]) + T.copy(dQ_tail_shared, dQ[bos + s_i, bz * padded_H:(bz + 1) * padded_H, D:]) + + return sparse_mla_bwd_kernel + + +def sparse_mla_bwd(q, + kv, + o, + do, + indices, + lse, + offsets, + sm_scale=None, + is_casual=True, + return_kernel=False, + delta=None): + assert q.is_contiguous() + assert kv.is_contiguous() + assert indices.is_contiguous() + assert lse.is_contiguous() + S, H, dim_plus_tail_dim = q.shape + S_kv, kv_group, _ = kv.shape + assert kv.shape[-1] == dim_plus_tail_dim + assert S == S_kv + # dim should be assigned + D = 512 + + D_tail = dim_plus_tail_dim - D + topk = indices.shape[-1] + assert indices.shape == (S, kv_group, topk) + assert lse.shape == (S, H) + + token_indices = prepare_token_indices(offsets) + + # Get kernels + preprocess_kernel = preprocess(H, D) + bwd_kernel = bwd(H, D, D_tail, topk, kv_group, sm_scale, is_casual) + postprocess_kernel = postprocess(D, D_tail, kv_group) + + if delta is None: + delta = preprocess_kernel(o, do) + dkv = torch.zeros_like(kv, dtype=torch.float32) + dq = bwd_kernel(q, kv, do, indices, lse, delta, offsets, token_indices, dkv) + dkv = postprocess_kernel(dkv) + + return dq, dkv + + +def ref_sparse_mla_bwd_interface(q, + kv, + o, + do, + indices, + lse, + offsets, + sm_scale=None, + is_casual=True): + from sparse_mla_fwd import ref_sparse_mla_fwd_interface + q = q.detach().clone() + kv = kv.detach().clone() + q.requires_grad = True + kv.requires_grad = True + o = ref_sparse_mla_fwd_interface(q, kv, indices, offsets, sm_scale, is_casual) + o.backward(do) + return q.grad, kv.grad + + +def test_sparse_mla_bwd(B=1, + S=2048, + H=64, + HKV=1, + DQKV=576, + DV=512, + topk=512, + dtype=torch.bfloat16, + check_correctness=True): + # Prepare data + q = torch.randn((S, H, DQKV), dtype=dtype, device='cuda').requires_grad_(True) + kv = torch.randn((S, HKV, DQKV), dtype=dtype, device='cuda').requires_grad_(True) + do = torch.randn((S, H, DV), dtype=dtype, device='cuda') + offsets = torch.tensor([0, S], dtype=torch.int32, device="cuda") + + indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device='cuda') + for i in range(offsets.shape[0] - 1): + seq_len = (offsets[i + 1] - offsets[i]).item() + assert seq_len >= topk + for t in range(seq_len): + for h in range(HKV): + i_i = torch.randperm(max(1, t))[:topk] + indices[offsets[i] + t, h, :len(i_i)] = i_i + + # Forward + from sparse_mla_fwd import sparse_mla_fwd_interface + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, offsets) + + tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse, offsets) + ref_dq, ref_dkv = ref_sparse_mla_bwd_interface(q, kv, None, do, indices, None, offsets) + + if check_correctness: + assert_tensors_similar(tl_dq, ref_dq, eps=1e-4, name="dq") + assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv") + print("assert_tensors_similar passed") + + per_token_flop = 2 * sum([ + H * DV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DV * topk, + ]) + from tilelang.profiler import do_bench + + def fn(): + return sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse, offsets) + + ms = do_bench(fn, rep=100, warmup=250) + print(f"Average time: {ms:.3f} ms") + print(f'bwd io bandwidth = ', + (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12) + print(f'bwd tflops = ', per_token_flop * S / (ms * 1e-3) / 1e12) + + +if __name__ == "__main__": + test_sparse_mla_bwd( + B=1, + S=2048, + H=64, + HKV=1, + DQKV=576, + DV=512, + topk=512, + dtype=torch.bfloat16, + check_correctness=True) diff --git a/examples/dsa_sparse_finetune/sparse_mla_fwd.py b/examples/dsa_sparse_finetune/sparse_mla_fwd.py new file mode 100644 index 00000000..5f03dfbb --- /dev/null +++ b/examples/dsa_sparse_finetune/sparse_mla_fwd.py @@ -0,0 +1,332 @@ +# ruff: noqa +import torch +import tilelang +from tilelang import language as T +from index import prepare_token_indices + +from utils import assert_tensors_similar + + +@tilelang.jit( + out_idx=[-2, -1], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +def sparse_mla_fwd( + heads, + dim, + tail_dim, + topk, + kv_group=1, + sm_scale=None, + is_causal=True, + CP0=True, + block_I=32, + num_stages=2, + threads=128, +): + assert dim == tilelang.math.next_power_of_2( + dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2( + tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert is_causal == True, "non-casual is not supported" + assert (topk % + block_I == 0), "otherwise will load some index=0 thus causing wrong kv to be loaded" + if sm_scale is None: + sm_scale = (1.0 / (dim + tail_dim))**0.5 + else: + sm_scale = sm_scale + + batch_plus_one = T.symbolic("batch_plus_one") + seq_len = T.symbolic("seq_len") + + head_kv = heads // kv_group + q_shape = [seq_len, heads, dim + tail_dim] + kv_shape = [seq_len, kv_group, dim + tail_dim] + o_shape = [seq_len, heads, dim] + indices_shape = [seq_len, kv_group, topk] + lse_shape = [seq_len, heads] + offsets_shape = [batch_plus_one] + token_indices_shape = [seq_len, 2] + indices_dtype = "int32" + dtype = "bfloat16" + accum_dtype = "float" + + G = kv_group + H = head_kv + padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) + if padded_H != H: + assert ( + kv_group == 1 + ), "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + BI = block_I + NI = tilelang.cdiv(topk, block_I) + D = dim + D_tail = tail_dim + + if head_kv > 64: + assert head_kv % 64 == 0, "head_kv should be a multiple of 64" + REPLICATE_H = head_kv // 64 + else: + REPLICATE_H = 1 + + H_per_block = padded_H if REPLICATE_H == 1 else 64 + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore + TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + ): + with T.Kernel( + seq_len * REPLICATE_H, kv_group, threads=threads) as ( + bx, + by, + ): + Q_shared = T.alloc_shared([H_per_block, D], dtype) + Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) + KV_shared = T.alloc_shared([BI, D], dtype) + K_tail_shared = T.alloc_shared([BI, D_tail], dtype) + mask = T.alloc_fragment([BI], "bool") + + acc_o = T.alloc_fragment([H_per_block, D], accum_dtype) + acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype) + S_shared = T.alloc_shared([H_per_block, BI], dtype) + sumexp = T.alloc_fragment([H_per_block], accum_dtype) + sumexp_i = T.alloc_fragment([H_per_block], accum_dtype) + alpha = T.alloc_fragment([H_per_block], accum_dtype) + m_i = T.alloc_fragment([H_per_block], accum_dtype) + m_i_prev = T.alloc_fragment([H_per_block], accum_dtype) + + T.fill(acc_o, 0) + T.fill(sumexp, 0) + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan + + b_s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H) + b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1] + bos, eos = Offsets[b_i], Offsets[b_i + 1] + g_i = by + q_i = s_i + max_kv_i = q_i + + H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) + H1 = H0 + H_per_block + + T.copy(Q[bos + s_i, H0:H1, :D], Q_shared) + T.copy(Q[bos + s_i, H0:H1, D:], Q_tail_shared) + + for i_i in T.Pipelined(NI, num_stages=num_stages): + + for bi_i in T.Parallel(BI): + mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & ( + Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1) + + for bi_i, d_i in T.Parallel(BI, D): + KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, + d_i] + for bi_i, d_i in T.Parallel(BI, D_tail): + K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], + g_i, D + d_i] + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) + T.gemm( + Q_shared, + KV_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.gemm( + Q_tail_shared, + K_tail_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + alpha[h_i] = T.exp((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.exp(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) + T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(H_per_block, D): + acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i] + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + # Rescale + for h_i, d_i in T.Parallel(H_per_block, D): + acc_o[h_i, d_i] /= sumexp[h_i] + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = T.log(sumexp[h_i]) + m_i[h_i] * sm_scale + + T.copy(acc_o, Output[bos + s_i, H0:H1, :]) + T.copy(sumexp, Lse[bos + s_i, H0:H1]) + + return main + + +def sparse_mla_fwd_interface(q, + kv, + indices, + offsets, + sm_scale=None, + return_p_sum: bool = False, + d_v=512, + block_I=32, + num_stages=2, + threads=128): + is_casual = True + assert return_p_sum == False, "This kernel file is for fwd only" + assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() + seq_len, heads, dim_plus_tail_dim = q.shape + seq_len_kv, kv_group, _ = kv.shape + assert seq_len == seq_len_kv + + assert dim_plus_tail_dim == 576, "you should assign dim otherwise" + dim = d_v + + assert kv.shape[-1] == dim_plus_tail_dim + tail_dim = dim_plus_tail_dim - dim + _, _, topk = indices.shape + assert indices.shape == (seq_len, kv_group, topk) + + token_indices = prepare_token_indices(offsets) + + kernel = sparse_mla_fwd( + heads, + dim, + tail_dim, + topk, + kv_group, + sm_scale, + is_casual, + block_I=block_I, + num_stages=num_stages, + threads=threads) + out, lse = kernel(q, kv, indices, offsets, token_indices) + return out, lse + + +def ref_sparse_mla_fwd_interface(Q, KV, Indices, offsets, sm_scale=None, is_casual=True): + Q = Q.float() + KV = KV.float() + all_o = [] + for i in range(offsets.shape[0] - 1): + q = Q[None, offsets[i]:offsets[i + 1]] + kv = KV[None, offsets[i]:offsets[i + 1]] + indices = Indices[None, offsets[i]:offsets[i + 1]].clone() + + indices = indices.transpose(1, 2) + b, sq, h, dim_q = q.shape + b, sk, g, _ = kv.shape + + assert kv.shape[-1] == 576, "you should assign dim otherwise" + dim = 512 + k = kv + v = kv[..., :dim] + + b, _, _, dim_v = v.shape + g_index = g + h_index = h // g + compressed_casual_mask = torch.arange( + 0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange( + 1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda").view(1, -1) + + indices[indices > sk] = sk + mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) + mask = mask[..., :-1] + mask = mask & compressed_casual_mask.view(1, 1, sq, sk) + mask[:, :, :1 - 1, 0] = True + mask = mask.view(b, g_index, 1, sq, sk) + + q = q.view(b, sq, g, -1, dim_q) + score = torch.einsum("bmghd,bngd->bghmn", q, k) + sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale + score = score.masked_fill(~mask, float("-inf")).mul(sm_scale) + p = score.softmax(dim=-1) + p = p.view(b, g_index, h_index, -1, sq, sk) + p = p.view(b, g, -1, sq, sk) + o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v) + o = o.reshape(b, sq, h, dim_v) + all_o.append(o.squeeze(0)) + o = torch.cat(all_o, dim=0) + return o.to(torch.bfloat16) + + +def test_sparse_mla_fwd(B=1, + S=4096, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=2048, + dtype=torch.bfloat16, + check_correctness=True, + block_I=64, + num_stages=2, + threads=256): + torch.random.manual_seed(0) + q = torch.randn((S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((S, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) + offsets = torch.tensor([0, S // 2 - 1, S], dtype=torch.int32, device="cuda") + + indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device="cuda") + for i in range(offsets.shape[0] - 1): + seq_len = (offsets[i + 1] - offsets[i]).item() + assert seq_len >= topk + for t in range(seq_len): + for h in range(HKV): + i_i = torch.randperm(max(1, t))[:topk] + indices[offsets[i] + t, h, :len(i_i)] = i_i + + tl_out, tl_lse = sparse_mla_fwd_interface( + q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads) + + if check_correctness: + # otherwise may cause out of memory + ref_out = ref_sparse_mla_fwd_interface(q, kv, indices, offsets) + assert_tensors_similar(tl_out, ref_out, eps=1e-2, name="out") + print("assert_tensors_similar passed") + + def fn(): + return sparse_mla_fwd_interface( + q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads) + + from tilelang.profiler import do_bench + + ms = do_bench( + fn, + rep=100, + warmup=250, + ) + print(f"Average time: {ms:.3f} ms") + print("fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) + print("fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) + + +if __name__ == "__main__": + test_sparse_mla_fwd( + B=1, + S=4096, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=1024, + dtype=torch.bfloat16, + check_correctness=True, + block_I=64, + num_stages=2, + threads=256) diff --git a/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py b/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py new file mode 100644 index 00000000..94bdb8fb --- /dev/null +++ b/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py @@ -0,0 +1,241 @@ +# ruff: noqa +import torch +import torch.nn as nn +import torch.nn.functional as F +import tilelang +from tilelang import language as T +from einops import repeat, rearrange, einsum +from index import prepare_token_indices +from utils import get_abs_err, get_err_ratio + +BF16 = "bfloat16" +FP32 = "float32" +INT32 = "int32" + +pass_configs = { + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, +} + + +@tilelang.jit(pass_configs=pass_configs) +def tl_sparse_mla_topk_reducesum_impl( + heads, + dim, + tail_dim, + topk, + kv_group=1, + sm_scale=None, + block_I=32, + num_stages=2, + threads=128, +): + assert dim == tilelang.math.next_power_of_2( + dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2( + tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert (topk % + block_I == 0), "otherwise will load some index=0 thus causing wrong kv to be loaded" + if sm_scale is None: + sm_scale = (1.0 / (dim + tail_dim))**0.5 + + batch_plus_one = T.symbolic("batch_plus_one") + seq_len = T.symbolic("seq_len") + seq_len_kv = T.symbolic("seq_len_kv") + + head_kv = heads // kv_group + indices_dtype = "int32" + dtype = "bfloat16" + accum_dtype = "float" + + G = kv_group + H = head_kv + padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) + if padded_H != H: + assert ( + kv_group == 1 + ), "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + BI = block_I + NI = tilelang.cdiv(topk, block_I) + D = dim + D_tail = tail_dim + + if head_kv > 64: + assert head_kv % 64 == 0, "head_kv should be a multiple of 64" + REPLICATE_H = head_kv // 64 + else: + REPLICATE_H = 1 + + H_per_block = padded_H if REPLICATE_H == 1 else 64 + + q_shape = [seq_len, heads, dim + tail_dim] + kv_shape = [seq_len_kv, kv_group, dim + tail_dim] + indices_shape = [seq_len, kv_group, topk] + lse_shape = [seq_len, heads] + reducesum_shape = [seq_len, kv_group, REPLICATE_H, topk] + offsets_shape = [batch_plus_one] + token_indices_shape = [seq_len, 2] + + @T.prim_func + def tl_sparse_mla_topk_reducesum_kernel( + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore + TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore + ReduceSum: T.Tensor(reducesum_shape, accum_dtype), # type: ignore + ): + with T.Kernel( + seq_len * REPLICATE_H, kv_group, threads=threads) as ( + bx, + by, + ): + Q_shared = T.alloc_shared([H_per_block, D], dtype) + Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) + KV_shared = T.alloc_shared([BI, D], dtype) + K_tail_shared = T.alloc_shared([BI, D_tail], dtype) + mask = T.alloc_fragment([BI], "bool") + + acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype) + reducesum = T.alloc_fragment([BI], accum_dtype) + lse = T.alloc_fragment([H_per_block], accum_dtype) + + T.fill(lse, 0) + + b_s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H) + b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1] + bos, eos = Offsets[b_i], Offsets[b_i + 1] + r_i = bx % REPLICATE_H + g_i = by + q_i = s_i + max_kv_i = q_i + + H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) + H1 = H0 + H_per_block + + T.copy(Q[bos + s_i, H0:H1, :D], Q_shared) + T.copy(Q[bos + s_i, H0:H1, D:], Q_tail_shared) + T.copy(Lse[bos + s_i, H0:H1], lse) + + for i_i in T.Pipelined(NI, num_stages=num_stages): + + for bi_i in T.Parallel(BI): + mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & ( + Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1) + + for bi_i, d_i in T.Parallel(BI, D): + KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, + d_i] + for bi_i, d_i in T.Parallel(BI, D_tail): + K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], + g_i, D + d_i] + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) + T.gemm( + Q_shared, + KV_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.gemm( + Q_tail_shared, + K_tail_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.exp(acc_s[h_i, bi_i] * sm_scale - lse[h_i]) + T.reduce_sum(acc_s, reducesum, dim=0) + T.copy(reducesum, ReduceSum[bos + s_i, g_i, r_i, i_i * BI:i_i * BI + BI]) + + return tl_sparse_mla_topk_reducesum_kernel + + +def sparse_mla_topk_reducesum_interface( + q: torch.Tensor, + kv: torch.Tensor, + topk_indices: torch.Tensor, + lse: torch.Tensor, + offsets: torch.Tensor, + dim_v: int, +): + assert kv.shape[-2] == 1 + seq_len, heads, dim_plus_tail_dim, topk = *q.shape, topk_indices.shape[-1] + REPLICATE_H = max(heads // 64, 1) + tail_dim = dim_plus_tail_dim - dim_v + token_indices = prepare_token_indices(offsets) + + reducesum = torch.zeros([seq_len, 1, REPLICATE_H, topk], dtype=torch.float32, device=q.device) + kernel = tl_sparse_mla_topk_reducesum_impl(heads=heads, dim=dim_v, tail_dim=tail_dim, topk=topk) + kernel(q, kv, topk_indices, lse, offsets, token_indices, reducesum) + reducesum = reducesum.sum(dim=-2) # [batch, seq_len, 1, RH, topk] -> [batch, seq_len, 1, topk] + attn_score = reducesum / reducesum.sum(dim=-1, keepdim=True) + + return attn_score + + +def ref_mla_topk_softmax(Q: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, + offsets: torch.Tensor): + # q: [batch, seq_len, heads, dim] + # k: [batch, seq_len, dim] + sm_scale = Q.shape[-1]**-0.5 + all_lse = [] + all_topk_score = [] + for i in range(offsets.shape[0] - 1): + q = Q[offsets[i]:offsets[i + 1]] + k = K[offsets[i]:offsets[i + 1]] + topk_indices = TopkIndices[offsets[i]:offsets[i + 1]] + seq_len = q.shape[0] + mask = (torch.arange(seq_len)[:, None] + >= torch.arange(seq_len)[None, :]).unsqueeze(-2).cuda() + logits = einsum(q, k, 's1 h d, s2 d -> s1 h s2') * sm_scale + logits = torch.where(mask, logits, float('-inf')) + score = F.softmax(logits, dim=-1, dtype=torch.float32) + score_sum = score.sum(dim=-2) + topk_score = torch.gather(score_sum, dim=-1, index=topk_indices.to(torch.int64)) + topk_score = topk_score / topk_score.sum(dim=-1, keepdim=True) + max_logits = logits.amax(dim=-1).to(torch.float32) + lse = torch.log( + (logits - max_logits.unsqueeze(-1).to(torch.float32)).exp().sum(dim=-1)) + max_logits + all_lse.append(lse) + all_topk_score.append(topk_score) + lse = torch.cat(all_lse, dim=0) + topk_score = torch.cat(all_topk_score, dim=0) + return lse, topk_score + + +def test_kernel( + B=1, + S=2048, + H=16, + D=512, + tail_D=64, + topk=128, +): + torch.manual_seed(42) + + q = torch.randn((S, H, D + tail_D)).cuda().bfloat16() + kv = torch.randn((S, D + tail_D)).cuda().bfloat16() + offsets = torch.tensor([0, 1023, S], dtype=torch.int32).cuda() + + topk_indices = repeat( + torch.arange(topk, dtype=torch.int32).cuda(), 'k -> s k', s=S).contiguous() + + lse, ref_attn_score = ref_mla_topk_softmax(q, kv, topk_indices, offsets) + + kv = kv.unsqueeze(-2) + topk_indices = topk_indices.unsqueeze(-2) + + attn_score = sparse_mla_topk_reducesum_interface( + q, kv, topk_indices, lse, offsets, dim_v=D).squeeze(-2) + print( + f"attn_score err: {get_abs_err(attn_score, ref_attn_score):.6f} ratio: {get_err_ratio(attn_score, ref_attn_score):.6f}" + ) + + +if __name__ == '__main__': + test_kernel() diff --git a/examples/dsa_sparse_finetune/utils.py b/examples/dsa_sparse_finetune/utils.py new file mode 100644 index 00000000..691af64d --- /dev/null +++ b/examples/dsa_sparse_finetune/utils.py @@ -0,0 +1,75 @@ +import torch + + +def get_abs_err(y, x): + x = x.to(torch.float32) + y = y.to(torch.float32) + return (x - y).flatten().abs().max().item() + + +def get_err_ratio(y, x): + x = x.to(torch.float32) + y = y.to(torch.float32) + err = (x - y).flatten().square().mean().sqrt().item() + base = (x).flatten().square().mean().sqrt().item() + return err / base + + +def calculate_tensor_similarity(x, y, name="tensor"): + """ + Calculate similarity between two tensors using a normalized dot product metric. + + Unlike torch.testing.assert_close which uses absolute/relative tolerance based on + element-wise differences, this function computes a global similarity score: + sim = 2 * / (||x||^2 + ||y||^2) + + This metric is scale-invariant and measures the cosine-like similarity normalized + by the magnitude of both tensors. It returns 1 for identical tensors and values + closer to 0 for dissimilar ones. This is particularly useful for comparing tensors + with varying magnitudes where relative errors matter more than absolute differences. + + Args: + x: First tensor to compare + y: Second tensor to compare + name: Name of the tensor for logging purposes + + Returns: + Similarity score in range [0, 1] where 1 means identical + """ + x, y = x.data.double(), y.data.double() + denominator = (x * x + y * y).sum() + if denominator == 0: + print(f"\033[33mWARNING: {name} all zero\033[0m") + return 1 + sim = 2 * (x * y).sum() / denominator + return sim + + +def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True): + """ + Assert that two tensors are similar using a global similarity metric. + + Key differences from torch.testing.assert_close: + - torch.testing.assert_close: Uses element-wise comparison with rtol/atol, checking + that |x - y| <= atol + rtol * |y| for each element. It's sensitive to outliers + and requires all elements to satisfy the tolerance. + - assert_tensors_similar: Uses a single global similarity score (1 - sim) where sim is the + normalized dot product. It's more robust to outliers and focuses on overall + tensor similarity rather than element-wise precision. This is better suited for + comparing large tensors where a few outlier elements shouldn't fail the test. + + Args: + x: First tensor to compare + y: Second tensor to compare + eps: Maximum allowed difference (1 - similarity), default 1e-8 + name: Name of the tensor for error messages + raise_assert: Whether to raise assertion error on failure + """ + sim = calculate_tensor_similarity(x, y, name) + diff = 1. - sim + if not (0 <= diff <= eps): + print( + f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m" + ) + if raise_assert: + assert False # noqa: B011 -- GitLab From 1e92d11cd252e014c44a1c0dc94deaade14c7d2f Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Fri, 28 Nov 2025 03:28:14 +0800 Subject: [PATCH 056/139] [Refactor] Improve assertion handling in CodeGenCHost and ArgBinder (#1352) * [Refactor] Improve assertion handling in CodeGenCHost and ArgBinder This commit refines the assertion message generation in CodeGenCHost by optimizing the handling of equality checks and reducing buffer size for error messages. Additionally, it enhances the ArgBinder by introducing a nullable guard mechanism for assertions, allowing for more precise error handling when binding arguments. The changes improve the clarity and efficiency of assertion handling across the codebase. * [Enhancement] Update matmul kernel and optimize argument binding This commit enhances the matmul kernel by introducing additional tensor parameters and refining the pipeline stages for improved performance. It also updates the argument binding mechanism to include a flag indicating whether buffers are used, enhancing the efficiency of buffer management. Furthermore, the optimization phase in the engine is improved by adding a simplification step, ensuring better performance and clarity in the generated code. * lint fix * [Enhancement] Add tensor checks documentation and improve argument binding assertions This commit introduces a new documentation page for host-side tensor checks, detailing the automatic validations performed by TileLang on kernel arguments. It enhances the ArgBinder by adding assertions for non-null pointers when arguments are used, improving error handling. Additionally, the optimization phase in the engine is updated to include a simplification step, ensuring better performance and clarity in the generated code. * [Enhancement] Update .gitignore and refine matmul kernel for improved performance This commit adds host checks logs to the .gitignore file to prevent unnecessary log files from being tracked. Additionally, it refines the matmul kernel by adjusting pipeline stages, updating tensor parameters, and enhancing argument handling for better performance. The changes also include improved error messages in the argument binding process, ensuring clearer diagnostics for users. * lint fix * lint fix * [Refactor] Simplify tensor_null_test function and remove ptr_null_test This commit refactors the tensor_null_test function by adding a with_bias parameter and removing the ptr_null_test function, which was previously unused. The run_test function is updated to reflect these changes, streamlining the testing process for tensor operations. * lint fix * fix --- .gitignore | 3 + docs/compiler_internals/tensor_checks.md | 387 ++++++++++++++++++ docs/index.md | 1 + examples/quickstart.py | 2 +- maint/host_checks/01_num_args_mismatch.py | 21 + maint/host_checks/02_pointer_type_error.py | 22 + maint/host_checks/03_ndim_mismatch.py | 19 + maint/host_checks/04_dtype_mismatch.py | 19 + maint/host_checks/05_shape_mismatch.py | 19 + maint/host_checks/06_strides_mismatch.py | 19 + maint/host_checks/07_device_type_mismatch.py | 18 + maint/host_checks/08_device_id_mismatch.py | 25 ++ maint/host_checks/09_null_data_pointer.py | 25 ++ maint/host_checks/10_scalar_type_mismatch.py | 15 + maint/host_checks/README.md | 21 + maint/host_checks/common.py | 50 +++ maint/host_checks/run_all.py | 71 ++++ src/runtime/error_helpers.cc | 60 +++ src/target/codegen_c_host.cc | 81 +--- src/transform/arg_binder.cc | 205 ++++------ src/transform/arg_binder.h | 2 +- src/transform/make_packed_api.cc | 109 ++++- src/transform/merge_if_stmt.cc | 45 +- src/transform/merge_if_stmt.h | 52 +++ .../python/jit/test_tilelang_jit_nullptr.py | 74 +--- tilelang/engine/phase.py | 1 + tilelang/jit/adapter/tvm_ffi.py | 17 - 27 files changed, 1100 insertions(+), 283 deletions(-) create mode 100644 docs/compiler_internals/tensor_checks.md create mode 100644 maint/host_checks/01_num_args_mismatch.py create mode 100644 maint/host_checks/02_pointer_type_error.py create mode 100644 maint/host_checks/03_ndim_mismatch.py create mode 100644 maint/host_checks/04_dtype_mismatch.py create mode 100644 maint/host_checks/05_shape_mismatch.py create mode 100644 maint/host_checks/06_strides_mismatch.py create mode 100644 maint/host_checks/07_device_type_mismatch.py create mode 100644 maint/host_checks/08_device_id_mismatch.py create mode 100644 maint/host_checks/09_null_data_pointer.py create mode 100644 maint/host_checks/10_scalar_type_mismatch.py create mode 100644 maint/host_checks/README.md create mode 100644 maint/host_checks/common.py create mode 100644 maint/host_checks/run_all.py create mode 100644 src/runtime/error_helpers.cc create mode 100644 src/transform/merge_if_stmt.h diff --git a/.gitignore b/.gitignore index 752f6cb7..730398df 100644 --- a/.gitignore +++ b/.gitignore @@ -108,3 +108,6 @@ cmake-build-*/ # pre-commit cache .pre-commit-cache/* + +# host checks logs +maint/host_checks/logs/* diff --git a/docs/compiler_internals/tensor_checks.md b/docs/compiler_internals/tensor_checks.md new file mode 100644 index 00000000..b4d2a0b3 --- /dev/null +++ b/docs/compiler_internals/tensor_checks.md @@ -0,0 +1,387 @@ +# Tensor Checks (Host-Side Auto-Validation) + +This page explains the host-side checks that TileLang automatically inserts into the generated host stub for kernels. When you pass `torch.Tensor` or any DLPack-compatible object to a TileLang kernel, the host stub validates argument count, pointer kinds, dtype, shape, strides, device, and more — so you don’t need to handwrite Python checks. This keeps the ABI stable and significantly reduces Python overhead compared to doing equivalent checks in Python or via pybind. + +## Why Host-Side Checks +- ABI stability: the entry is based on TVM FFI + DLPack, consistently accepting tensors and scalars. +- Lower overhead: shifting checks from Python into C reduces interpreter/property-access costs; the call overhead is lower than pybind-based approaches. +- Focused error reporting: assertions are raised close to the call site with precise “which field failed” messages. + +## How To Inspect Host Source +You can inspect the auto-generated host source (with all checks and the final device-kernel call) for debugging: + +```python +print(matmul_relu_kernel.get_host_source()) +``` + +--- + +## What The Host Checks + +### 1) Argument count and pointer kind +- `num_args` must match the number of formal parameters; otherwise the kernel returns `-1` with an error message. +- Each argument’s FFI type must be a pointer kind (for DLTensor/handle) or a valid scalar type; otherwise you’ll see errors like `Expect arg[i] to be pointer` or a scalar type error. + +### 2) Tensor checks (per tensor, after nullability decision) +- Nullability + - If the tensor is “statically reachable/used” by the function body, the handle must be non-NULL; otherwise: `xxx is expected to have non-NULL pointer`. + - If an input tensor is not used by the function (statically unreachable), NULL is allowed; other field checks are executed only when `handle != NULL`. +- Rank (`ndim`) + - Runtime `ndim` must equal the compile-time rank. +- Data type (`dtype`) + - Match the triple `(code, bits, lanes)` with tolerance: + - `float8_e4m3`: accept `e4m3`, `e4m3fn`, `e4m3fnuz`. + - `float8_e5m2`: accept `e5m2`, `e5m2fnuz`. + - `bool`: accept `int8/uint8` with `bits=8` (same lanes), `kDLBool(code=6, bits=1 or 8)`, and any `bitwidth=1` (lanes must match). + - For packed-bit dtypes (e.g., `Int(1)`, `Int(4)`, `UInt(4)`), strict dtype checking is skipped. +- Shape + - Each runtime dimension is bound to the compile-time shape (constants or symbols) and checked for consistency. + - Linear equations among symbolic dims can be solved on the fly (when there’s only one unknown at a given check point), enabling cross-tensor constraints. +- Strides + - If `buffer_type = AutoBroadcast`: allow `strides == NULL` and derive strides from `shape`. If explicit `strides` is present, bind to compile-time constraints and check for equality. + - Otherwise: check per-dimension; if `strides == NULL`, derive from `shape` and compare (e.g., contiguous: `strides[-1] == 1`, `strides[-2] == shape[-1]`). +- `byte_offset` + - Must be 0 (non-zero raises an error) to keep addressing simple and aligned. +- Device info + - Assert `device_type == target backend` (CUDA/ROCM/Metal/OneAPI/WebGPU/CPU, etc.). Error messages include a DLPack code legend. + - When multiple tensors participate, assert that `device_id` matches across them. +- Data pointer + - Must be non-NULL when the tensor is required to be non-null by the nullability rule. + +### 3) Scalar checks +- `T.int*` family: require integer; error: `Expect arg[i] to be int`. +- `T.bool`: require boolean; error: `Expect arg[i] to be boolean`. + +--- + +## Shapes and Symbolic Equations: Linear Solving +When shapes are symbolic, the host binds and (when possible) solves linear relations at runtime (only one unknown per check point). Example: + +```python +@T.prim_func +def main( + A: T.Tensor((m,), dtype), + B: T.Tensor((m + n,), dtype), + C: T.Tensor((n * k,), dtype), +): + ... +``` + +This enables enforcing cross-tensor relationships like `len(B) == m + n` and `len(C) == n * k` at runtime. + +--- + +## Nullability Rules and Examples +Which tensors may be NULL? + +- Rule: If an input tensor is not used by the function under static analysis (i.e., the access is statically unreachable), it is considered Nullable; otherwise it must be non-NULL. +- Examples: + +1) Must be non-NULL (used) +```python +@T.prim_func +def main(A: T.Tensor((M, K), dtype)): + A[0] = 1 +``` +Passing `None` raises: `main.A_handle is expected to have non-NULL pointer`. + +2) Still must be non-NULL (constant-true branch) +```python +some_cond: bool = True +@T.prim_func +def main(A: T.Tensor((M, K), dtype)): + if some_cond: + A[0] = 1 +``` + +3) Nullable (constant-false branch, statically unreachable) +```python +some_cond: bool = False +@T.prim_func +def main(A: T.Tensor((M, K), dtype)): + if some_cond: + A[0] = 1 +``` + +4) Must be non-NULL (runtime condition) +```python +@T.prim_func +def main(A: T.Tensor((M, K), dtype), some_cond: T.bool): + if some_cond: + A[0] = 1 +``` +Since `some_cond` is only known at runtime, static analysis cannot prove `A` is unused; `A` is thus non-nullable. + +--- + +## Device Type Codes (DLPack) +Supported and referenced device codes in error messages: `1=CPU, 2=CUDA, 7=Vulkan, 8=Metal, 10=ROCM, 14=OneAPI, 15=WebGPU`. +Kernels assert that `device_type` matches the target backend, and require `device_id` consistency across tensors. + +--- + +## Common Error Examples (What you’ll see) +- Argument count mismatch (num_args) + - Trigger: missing/extra argument + - Error: `: num_args should be N; expected: , got: N` + +- Pointer-typed argument expected + - Trigger: scalar passed where a tensor is expected + - Error: `: Expect arg[i] to be pointer` + +- Rank (ndim) mismatch + - Trigger: runtime rank differs from compile-time rank + - Error: `..ndim is expected to equal R, but got mismatched ndim` + +- Dtype mismatch + - Trigger: dtype not equal to the compiled dtype and not within the tolerance set + - Error: `..dtype is expected to be , but got incompatible dtype` + +- Shape constraint violation + - Trigger: a dimension doesn’t match a constant/symbol binding + - Error: `Argument ..shape[i] has an unsatisfied constraint: ... == ` + +- Strides check failed (e.g., non-contiguous layout) + - Trigger: transposed/sliced tensors that violate expected strides + - Error: `Argument ..strides[j] has an unsatisfied constraint: ... == ` + +- Device type mismatch + - Trigger: calling a CUDA kernel with CPU tensors, etc. + - Error: `..device_type mismatch [expected: ()] ...` + +- Device id mismatch + - Trigger: mixing tensors from different GPUs + - Error: `Argument ..device_id has an unsatisfied constraint: ... == ...` + +- NULL data pointer + - Trigger: tensor required to be non-null has a NULL data pointer + - Error: `. is expected to have non-NULL data pointer, but got NULL` + +- Scalar type mismatch + - Trigger: passing float to `T.int32`, or non-boolean to `T.bool` + - Error: `: Expect arg[i] to be int/boolean` + +--- + +## Troubleshooting Tips +- Print the host source: `print(fn.get_host_source())` to see the exact assertion and expected vs. actual fields. +- Fix strides: call `.contiguous()` for non-contiguous tensors, or avoid generating transposed/sliced layouts that break assumptions. +- Align devices: ensure all participating tensors share the same `device_type` and `device_id`. +- Align dtype: use `.to()` or construct tensors with the correct dtype; pay attention to `float8` and `bool` tolerance. +- Dynamic shapes: ensure cross-tensor linear relations can be uniquely determined at the check point (only one unknown at a time). + +--- + +## FAQ +- Can I disable the checks? + - Not recommended and usually not supported. Checks are done on the host to preserve ABI stability and fail early close to the device call. +- Is the overhead noticeable? + - The checks are lightweight (branches and field reads). Compared to Python-side checks, it’s faster; the dominating cost remains the Python→C boundary. Overall it’s cheaper than equivalent checks in Python. + +--- + +## Reference Example (Matmul + ReLU) + +```python +@T.prim_func +def matmul_relu_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), +): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[ko * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + T.copy(C_local, C[by * block_M, bx * block_N]) + +# For debugging, print the host source +print(matmul_relu_kernel.get_host_source()) +``` + +The host will insert all checks described above for this example. + +--- + +## Quick Error Reference (Short List) +- Argument count + - Trigger: missing/extra args; Error: `num_args should be N; expected: , got: N`. +- Pointer kind + - Trigger: scalar passed to tensor arg; Error: `Expect arg[i] to be pointer`. +- Rank (ndim) + - Trigger: runtime rank != compile-time; Error: `ndim ... expected to equal R`. +- Dtype + - Trigger: mismatch and not tolerated; Error: `dtype ... expected to be `. +- Shape + - Trigger: constant/symbol binding violated; Error: `shape[i] ... == `. +- Strides + - Trigger: layout mismatch; Error: `strides[j] ... == `. +- Device type + - Trigger: wrong backend device; Error: `device_type mismatch [expected: ...]`. +- Device id + - Trigger: tensors on different GPUs; Error: `device_id ... == ...`. +- Data pointer + - Trigger: required non-NULL but NULL; Error: `non-NULL data pointer`. +- Scalar types + - Trigger: wrong scalar type; Error: `Expect arg[i] to be int/boolean`. + +--- + +## Host Error Troubleshooting (Minimal Repros) + +Below are minimal repro snippets for common host-side errors, assuming a CUDA-targeted kernel like `matmul_relu_kernel` with: + +```python +# Convention: +# A: float16 [M, K] +# B: float16 [K, N] +# C: float16 [M, N] +# Target: CUDA (device_type=2) +fn = matmul_relu_kernel # your compiled function +M = N = K = 1024 +``` + +Adjust dtype/device if your kernel differs. + +### 0. Tip: print the host source +```python +print(fn.get_host_source()) +``` + +### 1. num_args mismatch +```python +import torch + +A = torch.empty((M, K), device='cuda', dtype=torch.float16) +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +# Missing C +fn(A, B) +``` +Expected: `: num_args should be 3; expected: , got: 3`. + +Fix: pass all arguments per the signature. + +### 2. Expect pointer (tensor) but got scalar +```python +import torch + +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(1, B, C) +``` +Expected: `: Expect arg[0] to be pointer`. + +Fix: pass a DLPack-compatible tensor (e.g., torch.Tensor). + +### 3. ndim mismatch +```python +import torch + +A = torch.empty((M, K, 1), device='cuda', dtype=torch.float16) # rank=3 +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `.A_handle.ndim is expected to equal 2, but got mismatched ndim`. + +Fix: ensure runtime rank equals compiled rank. + +### 4. dtype mismatch +```python +import torch + +A = torch.empty((M, K), device='cuda', dtype=torch.float32) # should be float16 +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `.A_handle.dtype is expected to be float16, but got incompatible dtype`. + +Fix: `A = A.to(torch.float16)` or create with the correct dtype. + +### 5. Shape constant/symbol mismatch +```python +import torch + +A = torch.empty((M, K + 1), device='cuda', dtype=torch.float16) # K mismatched +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `Argument .A_handle.shape[i] has an unsatisfied constraint: ... == `. + +Fix: satisfy linear constraints and constants across tensors. + +### 6. Strides check failure (non-contiguous) +```python +import torch + +A = torch.empty((M, K), device='cuda', dtype=torch.float16) +A_nc = A.t() # transpose -> non-contiguous +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A_nc, B, C) +``` +Expected: `Argument .A_handle.strides[1] has an unsatisfied constraint: ... == 1`. + +Fix: pass `A_nc.contiguous()` or align the layout expectation in the kernel. + +### 7. device_type mismatch +```python +import torch + +A = torch.empty((M, K), device='cpu', dtype=torch.float16) +B = torch.empty((K, N), device='cpu', dtype=torch.float16) +C = torch.empty((M, N), device='cpu', dtype=torch.float16) +fn(A, B, C) # CUDA-targeted kernel +``` +Expected: `.A_handle.device_type mismatch [expected: 2 (cuda)] ...`. + +Fix: move tensors to the CUDA device. + +### 8. device_id mismatch (multi-GPU) +```python +import torch + +A = torch.empty((M, K), device='cuda:0', dtype=torch.float16) +B = torch.empty((K, N), device='cuda:1', dtype=torch.float16) +C = torch.empty((M, N), device='cuda:0', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `Argument .B_handle.device_id has an unsatisfied constraint: ... == ...`. + +Fix: place all tensors on the same GPU (e.g., `cuda:0`). + +### 9. NULL data pointer (advanced) +This usually comes from hand-constructed DLTensor/NDArray, or external frameworks passing unallocated/freed storage. Regular `torch.Tensor` allocations rarely hit this. + +Expected: `. is expected to have non-NULL data pointer, but got NULL`. + +Fix: ensure valid underlying storage; in PyTorch scenarios, avoid constructing tensors from invalid external handles. + +### 10. Scalar type mismatch (int / bool) +```python +import tilelang.language as T + +@T.prim_func +def scalar_check(x: T.int32, flag: T.bool()): + T.evaluate(0) + +scalar_check(1.0, True) # x is float -> Expect arg[0] to be int +scalar_check(1, 2.5) # flag is float -> Expect arg[1] to be boolean +``` + +Fix: pass correct scalar types, e.g., `scalar_check(1, True)`. + +--- + +## Closing Notes +- Cross-check “shape / strides / device / dtype” against the kernel signature to localize issues efficiently. +- For complex symbolic relations, print the host source to confirm binding/solving order, then adjust runtime shapes/layouts accordingly. + diff --git a/docs/index.md b/docs/index.md index 5d9a158f..9f794776 100644 --- a/docs/index.md +++ b/docs/index.md @@ -42,6 +42,7 @@ deeplearning_operators/deepseek_mla compiler_internals/letstmt_inline compiler_internals/inject_fence_proxy +compiler_internals/tensor_checks ::: :::{toctree} diff --git a/examples/quickstart.py b/examples/quickstart.py index 46a39e0d..39ad348b 100644 --- a/examples/quickstart.py +++ b/examples/quickstart.py @@ -77,7 +77,7 @@ torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) print("Kernel output matches PyTorch reference.") # 4. Retrieve and inspect the generated CUDA source (optional) -# cuda_source = jit_kernel.get_kernel_source() +# cuda_source = matmul_relu_kernel.get_kernel_source() # print("Generated CUDA kernel:\n", cuda_source) # 5.Profile latency with kernel diff --git a/maint/host_checks/01_num_args_mismatch.py b/maint/host_checks/01_num_args_mismatch.py new file mode 100644 index 00000000..8ba36646 --- /dev/null +++ b/maint/host_checks/01_num_args_mismatch.py @@ -0,0 +1,21 @@ +"""Reproduce: Argument count mismatch. + +Note: The adapter-level wrapper expects only inputs (A, B) because C is marked as output. +Calling with the wrong number of inputs raises a ValueError before host entry. +""" +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 256 + fn = build_matmul_kernel(M, N, K, target="cuda") + + a = torch.empty((M, K), device="cuda", dtype=torch.float16) + # Missing b + # Expected: ValueError with message about expected vs. actual inputs + fn(a) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/02_pointer_type_error.py b/maint/host_checks/02_pointer_type_error.py new file mode 100644 index 00000000..fd358540 --- /dev/null +++ b/maint/host_checks/02_pointer_type_error.py @@ -0,0 +1,22 @@ +"""Reproduce: Pointer-type argument expected but scalar provided. + +We pass an integer for A; wrapper forwards it to the host where a pointer is expected. +Expected: error like "Expect buffer A_handle to be pointer or tensor" (exact name depends on kernel param). +""" +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 256 + fn = build_matmul_kernel(M, N, K, target="cuda") + + # Wrong type for A (int instead of tensor) + a = 1 + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/03_ndim_mismatch.py b/maint/host_checks/03_ndim_mismatch.py new file mode 100644 index 00000000..994ce23e --- /dev/null +++ b/maint/host_checks/03_ndim_mismatch.py @@ -0,0 +1,19 @@ +"""Reproduce: ndim (rank) mismatch for A. +""" +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 128 + fn = build_matmul_kernel(M, N, K, target="cuda") + + # A has rank 3 instead of 2 + a = torch.empty((M, K, 1), device="cuda", dtype=torch.float16) + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/04_dtype_mismatch.py b/maint/host_checks/04_dtype_mismatch.py new file mode 100644 index 00000000..6e6a0503 --- /dev/null +++ b/maint/host_checks/04_dtype_mismatch.py @@ -0,0 +1,19 @@ +"""Reproduce: dtype mismatch for A (float32 vs expected float16). +""" +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 128 + fn = build_matmul_kernel(M, N, K, target="cuda") + print(fn.get_host_source()) + + a = torch.empty((M, K), device="cuda", dtype=torch.float32) # should be float16 + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/05_shape_mismatch.py b/maint/host_checks/05_shape_mismatch.py new file mode 100644 index 00000000..8b41ae36 --- /dev/null +++ b/maint/host_checks/05_shape_mismatch.py @@ -0,0 +1,19 @@ +"""Reproduce: shape constant/symbol mismatch on A. +""" +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 128 + fn = build_matmul_kernel(M, N, K, target="cuda") + + # A's second dimension is wrong (K+1 instead of K) + a = torch.empty((M, K + 1), device="cuda", dtype=torch.float16) + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/06_strides_mismatch.py b/maint/host_checks/06_strides_mismatch.py new file mode 100644 index 00000000..477d200b --- /dev/null +++ b/maint/host_checks/06_strides_mismatch.py @@ -0,0 +1,19 @@ +"""Reproduce: strides check failure (non-contiguous A via transpose). +""" +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 128 + fn = build_matmul_kernel(M, N, K, target="cuda") + + a = torch.empty((M, K), device="cuda", dtype=torch.float16) + a_nc = a.t() # non-contiguous after transpose + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a_nc, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/07_device_type_mismatch.py b/maint/host_checks/07_device_type_mismatch.py new file mode 100644 index 00000000..67cb7718 --- /dev/null +++ b/maint/host_checks/07_device_type_mismatch.py @@ -0,0 +1,18 @@ +"""Reproduce: device_type mismatch by passing CPU tensors to a CUDA kernel. +""" +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 64 + fn = build_matmul_kernel(M, N, K, target="cuda") + + a = torch.empty((M, K), device="cpu", dtype=torch.float16) + b = torch.empty((K, N), device="cpu", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/08_device_id_mismatch.py b/maint/host_checks/08_device_id_mismatch.py new file mode 100644 index 00000000..64910966 --- /dev/null +++ b/maint/host_checks/08_device_id_mismatch.py @@ -0,0 +1,25 @@ +"""Reproduce: device_id mismatch (requires >=2 CUDA devices). +""" +import torch +from common import build_matmul_kernel + + +def main(): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available") + if torch.cuda.device_count() < 2: + print("[SKIP] Need at least 2 CUDA devices to reproduce device_id mismatch.") + return + + M = N = K = 64 + fn = build_matmul_kernel(M, N, K, target="cuda") + + a = torch.empty((M, K), device="cuda:0", dtype=torch.float16) + b = torch.empty((K, N), device="cuda:1", dtype=torch.float16) + # Output device is derived by the adapter; mismatch occurs in host checks + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/09_null_data_pointer.py b/maint/host_checks/09_null_data_pointer.py new file mode 100644 index 00000000..00bac67d --- /dev/null +++ b/maint/host_checks/09_null_data_pointer.py @@ -0,0 +1,25 @@ +"""Reproduce: NULL data pointer (advanced). + +Passing None for a tensor argument will be forwarded through the adapter. Depending on +FFI handling, this commonly triggers a pointer-type assertion (e.g., "Expect buffer to be pointer or tensor") +or a host-side non-NULL pointer check. + +Note: Constructing a true DLTensor with NULL data in PyTorch is not typical; this script +demonstrates passing None, which still reproduces the intended class of failure. +""" +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 64 + fn = build_matmul_kernel(M, N, K, target="cuda") + + a = None # attempt to pass a null-like pointer + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/10_scalar_type_mismatch.py b/maint/host_checks/10_scalar_type_mismatch.py new file mode 100644 index 00000000..f1fcba27 --- /dev/null +++ b/maint/host_checks/10_scalar_type_mismatch.py @@ -0,0 +1,15 @@ +"""Reproduce: scalar parameter type mismatch (int/bool). +""" +from common import build_scalar_check_kernel + + +def main(): + fn = build_scalar_check_kernel(target="cuda") + + # Wrong types + fn(1.0, True) # x should be int -> Expect arg[0] to be int + fn(1, 2.5) # flag should be bool -> Expect arg[1] to be boolean + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/README.md b/maint/host_checks/README.md new file mode 100644 index 00000000..ac23d6fd --- /dev/null +++ b/maint/host_checks/README.md @@ -0,0 +1,21 @@ +# Host-Side Check Repro Scripts + +This folder contains standalone scripts that deliberately trigger host-side (and adapter-side) validation errors described in `docs/compiler_internals/tensor_checks.md`. Each script can be run directly and will reproduce the corresponding error with a minimal example. + +Prerequisites +- CUDA-capable environment (most scripts compile a CUDA-targeted kernel) +- Python packages: torch, tilelang + +Usage +- Run any script, e.g.: + - `python 01_num_args_mismatch.py` + - `python 02_pointer_type_error.py` + - ... up to `10_scalar_type_mismatch.py` + +- Or run all at once with a summary: + - `python run_all.py` + - Logs per test are saved under `logs/` as `