Commit de1ba1e4 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Replace `T.thread_binding` with `T.get_thread_binding` in examples and test cases (#163)

* [Refactor] Update BitBLAS Benchmark with TileLang Carver Imports and Roller Hints Generation

- Replace BitBLAS imports with TileLang Carver imports in benchmark_matmul.py
- Modify roller hints generation using new TileLang Carver template and utility functions
- Update get_roller_hints_from_func to handle None cases and improve return logic
- Adjust DefaultPolicy to handle different codegen dictionary formats

* [Refactor] Update Thread Binding and Import Statements in TileLang Kernels

- Replace T.thread_binding() with T.get_thread_binding() across multiple kernel test files
- Update import statements for MMA layout and macro generator in dequantize GEMM and FP8 examples
- Move map_torch_type utility function to tilelang.utils.tensor
- Remove unnecessary imports and improve code organization
parent 901deae1
...@@ -39,8 +39,6 @@ def matmul( ...@@ -39,8 +39,6 @@ def matmul(
local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits
local_size_compressed = local_size // num_elems_per_byte local_size_compressed = local_size // num_elems_per_byte
import tvm.tl.language as T
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Buffer(A_shape, in_dtype),
...@@ -55,7 +53,7 @@ def matmul( ...@@ -55,7 +53,7 @@ def matmul(
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
tx = T.thread_binding(0, threads, thread="threadIdx.x") tx = T.get_thread_binding()
T.clear(C_local) T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
...@@ -149,8 +147,8 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ...@@ -149,8 +147,8 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
accum_dtype, accum_dtype,
transform_b, transform_b,
): ):
from bitblas.tl.utils import make_mma_swizzle_layout as make_swizzle_layout from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout as make_swizzle_layout
from bitblas.tl.mma_macro_generator import ( from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitterWithLadderTransform,) TensorCoreIntrinEmitterWithLadderTransform,)
from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 from bitblas.gpu.intrin.lop3 import decode_i4_to_f16
...@@ -257,8 +255,8 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ...@@ -257,8 +255,8 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
B_dequantize_local = T.alloc_local((warp_cols * local_size), in_dtype) B_dequantize_local = T.alloc_local((warp_cols * local_size), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
reduced_accum_res = T.alloc_local(0, accum_dtype) reduced_accum_res = T.alloc_local(0, accum_dtype)
thread_binding = T.thread_binding(0, threads, "threadIdx.x") thread_binding = T.get_thread_binding(0)
rk = T.thread_binding(0, reduce_k, "threadIdx.y") rk = T.get_thread_binding(1)
T.annotate_layout({ T.annotate_layout({
A_shared: make_swizzle_layout(A_shared), A_shared: make_swizzle_layout(A_shared),
......
...@@ -12,6 +12,7 @@ from tilelang.intrinsics import get_swizzle_layout ...@@ -12,6 +12,7 @@ from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import ( from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,) TensorCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func from tilelang.transform import simplify_prim_func
from tilelang.utils.tensor import map_torch_type
tilelang.testing.set_random_seed(0) tilelang.testing.set_random_seed(0)
...@@ -186,16 +187,6 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -186,16 +187,6 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
# src_code is the generated cuda source # src_code is the generated cuda source
assert src_code is not None assert src_code is not None
def map_torch_type(intype):
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
}
if intype in typemap:
return typemap[intype]
else:
return getattr(torch, intype)
in_dtype = map_torch_type(in_dtype) in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype) out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype) accum_dtype = map_torch_type(accum_dtype)
......
...@@ -256,7 +256,7 @@ def matmul( ...@@ -256,7 +256,7 @@ def matmul(
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
tx = T.thread_binding(0, threads, thread="threadIdx.x") tx = T.get_thread_binding()
T.clear(C_local) T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
...@@ -458,8 +458,8 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ...@@ -458,8 +458,8 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
B_dequantize_local = T.alloc_local((warp_cols * local_size), in_dtype) B_dequantize_local = T.alloc_local((warp_cols * local_size), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
reduced_accum_res = T.alloc_local(0, accum_dtype) reduced_accum_res = T.alloc_local(0, accum_dtype)
thread_binding = T.thread_binding(0, threads, "threadIdx.x") thread_binding = T.get_thread_binding(0)
rk = T.thread_binding(0, reduce_k, "threadIdx.y") rk = T.get_thread_binding(1)
T.annotate_layout({ T.annotate_layout({
A_shared: make_swizzle_layout(A_shared), A_shared: make_swizzle_layout(A_shared),
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch import torch
import torch.backends import torch.backends
import tilelang.testing import tilelang.testing
...@@ -67,8 +65,8 @@ def gemv_simt( ...@@ -67,8 +65,8 @@ def gemv_simt(
accum_res = T.alloc_local((1,), accum_dtype) accum_res = T.alloc_local((1,), accum_dtype)
reduced_accum_res = T.alloc_local((1,), accum_dtype) reduced_accum_res = T.alloc_local((1,), accum_dtype)
kr = T.thread_binding(0, reduce_thread, thread="threadIdx.x") kr = T.get_thread_binding(0)
ni = T.thread_binding(0, n_partition, thread="threadIdx.y") ni = T.get_thread_binding(1)
T.clear(accum_res) T.clear(accum_res)
for ko in T.serial(T.ceildiv(K, block_K)): for ko in T.serial(T.ceildiv(K, block_K)):
......
...@@ -93,10 +93,10 @@ def tl_matmul_simt( ...@@ -93,10 +93,10 @@ def tl_matmul_simt(
B_local = T.alloc_local((local_size_b, micro_size_k), in_dtype) B_local = T.alloc_local((local_size_b, micro_size_k), in_dtype)
C_local = T.alloc_local((local_size_c,), accum_dtype) C_local = T.alloc_local((local_size_c,), accum_dtype)
thread_binding = T.thread_binding(threads, "threadIdx.x") tid = T.get_thread_binding()
warp_m = thread_binding % thread_row_tiles warp_m = tid % thread_row_tiles
warp_n = thread_binding // thread_row_tiles warp_n = tid // thread_row_tiles
T.clear(C_local) T.clear(C_local)
......
...@@ -67,8 +67,8 @@ def gemv_simt( ...@@ -67,8 +67,8 @@ def gemv_simt(
accum_res = T.alloc_local((1,), accum_dtype) accum_res = T.alloc_local((1,), accum_dtype)
reduced_accum_res = T.alloc_local((1,), accum_dtype) reduced_accum_res = T.alloc_local((1,), accum_dtype)
kr = T.thread_binding(0, reduce_thread, thread="threadIdx.x") kr = T.get_thread_binding(0)
ni = T.thread_binding(0, n_partition, thread="threadIdx.y") ni = T.get_thread_binding(1)
T.clear(accum_res) T.clear(accum_res)
for ko in T.serial(T.ceildiv(K, block_K)): for ko in T.serial(T.ceildiv(K, block_K)):
......
...@@ -5,7 +5,6 @@ import torch ...@@ -5,7 +5,6 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import tilelang import tilelang
from tilelang.profiler import cached from tilelang.profiler import cached
from tilelang.autotuner import *
import tilelang.language as T import tilelang.language as T
import tilelang.testing import tilelang.testing
...@@ -303,7 +302,6 @@ def assert_mha_equal(batch, h, n_ctx, d_head, causal): ...@@ -303,7 +302,6 @@ def assert_mha_equal(batch, h, n_ctx, d_head, causal):
dQ_ref, Q.grad = Q.grad.clone(), None dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None dV_ref, V.grad = V.grad.clone(), None
torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment