Commit 362b3520 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Simplify interface via replacing argument thread binding of...

[Refactor] Simplify interface via replacing argument thread binding of intrinsics with `KernelFrame.Current` (#34)

* installation script fix

* readme typo fix

* doc fix for dequantize gemm

* [Doc] remove CODE_OF_CONDUCT.md and SECURITY.md; update references in CONTRIBUTING.md

* [Doc] add unit tests for AnnotateDeviceRegions transform; remove SUPPORT.md

* update license

* [Enhancement] add tensor supply handling for unsigned integers; improve error message for execution backend assertion

* [Refactor] improve code readability by reformatting function signatures and assertions

* [Refactor] replace torch.manual_seed with tilelang.testing.set_random_seed for consistency in random seed handling

* [Refactor] unify thread binding variable naming across kernel and example files

* [Refactor] remove unused thread binding parameter from matrix multiplication functions

* [Refactor] remove unused thread binding parameter from matrix multiplication functions

* [Refactor] enable main testing function in tilelang kernel gemm test

* bug fix
parent 1b63d3a2
......@@ -257,7 +257,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
B_dequantize_local = T.alloc_local((warp_cols * local_size), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
reduced_accum_res = T.alloc_local(0, accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
thread_binding = T.thread_binding(0, threads, "threadIdx.x")
rk = T.thread_binding(0, reduce_k, "threadIdx.y")
T.annotate_layout({
......@@ -279,7 +279,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte //
(threads * vec_load_qb)):
for v in T.vectorized(0, vec_load_qb):
t = thread_bindings
t = thread_binding
idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v
vkk = idx % (micro_size_k // num_elems_per_byte)
vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y
......@@ -299,7 +299,6 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
rk=rk,
)
......@@ -308,7 +307,6 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
rk=rk,
)
......@@ -343,7 +341,6 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)
for i, j in T.Parallel(block_M, (block_N // reduce_k)):
......
......@@ -339,7 +339,7 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
thread_binding = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
......@@ -367,16 +367,14 @@ def tl_matmul(
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
ki
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
ki
)
# Perform Matrix Multiplication
......@@ -386,7 +384,6 @@ def tl_matmul(
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)
# Store shared into global
......@@ -416,10 +413,10 @@ def tl_matmul(
```python
for ki in T.serial(0, (block_K // micro_size_k)):
# Warp-synchronous load for A
mma_emitter.ldmatrix_a(A_local, A_shared, ki, thread_bindings=thread_bindings)
mma_emitter.ldmatrix_a(A_local, A_shared, ki)
# Warp-synchronous load for B
mma_emitter.ldmatrix_b(B_local, B_shared, ki, thread_bindings=thread_bindings)
mma_emitter.ldmatrix_b(B_local, B_shared, ki)
```
Internally, these calls orchestrate how each thread in the warp issues the correct load instructions, performs address calculations, and stores the data into registers.
......@@ -437,7 +434,7 @@ def tl_matmul(
5. **Store Results via `stmatrix`**
Finally, you write the results from the warp-level fragments back to shared memory or global memory. This step might happen multiple times in a loop or just once at the end. The code snippet:
```python
mma_emitter.stmatrix(C_local, C_shared, thread_bindings=thread_bindings)
mma_emitter.stmatrix(C_local, C_shared)
```
orchestrates the warp-synchronous stores, ensuring each thread places the correct fragment element into the correct location of the shared or global buffer.
......
......@@ -116,8 +116,6 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
......@@ -141,30 +139,16 @@ def tl_matmul(
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)
mma_emitter.ldmatrix_a(A_local, A_shared, ki)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)
mma_emitter.ldmatrix_b(B_local, B_shared, ki)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
# Perform STMatrix
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)
mma_emitter.stmatrix(C_local, C_shared)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
......
......@@ -99,8 +99,6 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
......@@ -128,7 +126,6 @@ def tl_matmul(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)
# Load B into fragment
......@@ -136,7 +133,6 @@ def tl_matmul(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)
# Perform Matrix Multiplication
......@@ -147,7 +143,6 @@ def tl_matmul(
mfma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)
# Store shared into global
......@@ -162,7 +157,6 @@ def tl_matmul(
mfma_emitter.stmatrix(
C_local,
C,
thread_bindings=thread_bindings,
pid_m=by,
pid_n=bx,
)
......
......@@ -113,8 +113,6 @@ def tl_matmul_macro(
B_local = T.alloc_local((warp_cols * local_size), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
......@@ -142,7 +140,6 @@ def tl_matmul_macro(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)
# Load B into fragment
......@@ -150,7 +147,6 @@ def tl_matmul_macro(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)
# Perform Matrix Multiplication
......@@ -160,7 +156,6 @@ def tl_matmul_macro(
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)
# Store shared into global
......
......@@ -457,7 +457,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
B_dequantize_local = T.alloc_local((warp_cols * local_size), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
reduced_accum_res = T.alloc_local(0, accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
thread_binding = T.thread_binding(0, threads, "threadIdx.x")
rk = T.thread_binding(0, reduce_k, "threadIdx.y")
T.annotate_layout({
......@@ -479,7 +479,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte //
(threads * vec_load_qb)):
for v in T.vectorized(0, vec_load_qb):
t = thread_bindings
t = thread_binding
idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v
vkk = idx % (micro_size_k // num_elems_per_byte)
vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y
......@@ -499,7 +499,6 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
rk=rk,
)
......@@ -508,7 +507,6 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
rk=rk,
)
......@@ -543,7 +541,6 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)
for i, j in T.Parallel(block_M, (block_N // reduce_k)):
......
......@@ -119,7 +119,7 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
thread_binding = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
......@@ -148,7 +148,6 @@ def tl_matmul(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)
# Load B into fragment
......@@ -156,7 +155,6 @@ def tl_matmul(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)
# Perform Matrix Multiplication
......@@ -166,7 +164,6 @@ def tl_matmul(
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)
# Store shared into global
......
......@@ -109,7 +109,7 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
thread_binding = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
......@@ -138,7 +138,6 @@ def tl_matmul(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)
# Load B into fragment
......@@ -146,7 +145,6 @@ def tl_matmul(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)
# Perform Matrix Multiplication
......@@ -156,7 +154,6 @@ def tl_matmul(
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)
# Store shared into global
......@@ -297,7 +294,7 @@ def tl_matmul_weight_only_transform(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
thread_binding = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
......@@ -328,7 +325,6 @@ def tl_matmul_weight_only_transform(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)
# Load B into fragment
......@@ -336,7 +332,6 @@ def tl_matmul_weight_only_transform(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)
# Perform Matrix Multiplication
......@@ -346,7 +341,6 @@ def tl_matmul_weight_only_transform(
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)
# Store shared into global
......
......@@ -205,7 +205,7 @@ class MatrixCoreIntrinEmitter(object):
(WARP_SIZE * block_row_warps)) % block_col_warps,
return lane_id, warp_n, warp_m
def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk=0):
def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, rk=0):
warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows
chunk = self.chunk
......@@ -214,7 +214,8 @@ class MatrixCoreIntrinEmitter(object):
local_size_a = self.local_size_a
k_pack = self.k_pack
is_transposed = self.a_transposed
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
_, reverse_index_map = self.get_ldmatrix_index_map(is_b=False)
@T.macro
......@@ -222,10 +223,10 @@ class MatrixCoreIntrinEmitter(object):
A_local_buf,
A_shared_buf,
ki,
thread_bindings,
thread_binding,
rk=0,
):
tx, _, warp_m = self.extract_thread_binding(thread_bindings)
tx, _, warp_m = self.extract_thread_binding(thread_binding)
if is_transposed:
for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a):
......@@ -243,9 +244,9 @@ class MatrixCoreIntrinEmitter(object):
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row,
r + col]
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_bindings, rk)
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0):
def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, rk=0):
warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols
chunk = self.chunk
......@@ -254,7 +255,8 @@ class MatrixCoreIntrinEmitter(object):
local_size_b = self.local_size_b
k_pack = self.k_pack
is_transposed = self.b_transposed
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
_, reverse_index_map = self.get_ldmatrix_index_map(is_b=True)
@T.macro
......@@ -262,10 +264,10 @@ class MatrixCoreIntrinEmitter(object):
B_local_buf,
B_shared_buf,
ki,
thread_bindings,
thread_binding,
rk=0,
):
tx, warp_n, _ = self.extract_thread_binding(thread_bindings)
tx, warp_n, _ = self.extract_thread_binding(thread_binding)
if is_transposed:
for j in T.serial(warp_cols):
......@@ -288,7 +290,7 @@ class MatrixCoreIntrinEmitter(object):
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row,
r + col]
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_bindings, rk)
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
def mfma(self, A_local_buf, B_local_buf, C_local_buf):
warp_rows = self.warp_rows
......@@ -324,13 +326,14 @@ class MatrixCoreIntrinEmitter(object):
return _warp_mma(A_local_buf, B_local_buf, C_local_buf)
def stmatrix(self, C_local_buf, C_buf, thread_bindings, pid_m=None, pid_n=None):
def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None):
block_row_warps = self.block_row_warps
block_col_warps = self.block_col_warps
warp_rows = self.warp_rows
warp_cols = self.warp_cols
local_size_out = self.local_size_out
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
is_global = pid_m is not None and pid_n is not None
BLOCK_M = block_row_warps * warp_rows
BLOCK_N = block_col_warps * warp_cols
......@@ -341,8 +344,8 @@ class MatrixCoreIntrinEmitter(object):
# As TVM Intrins is like a hack that the threadIdx.x should be always
# equal to the warp_size
@T.macro
def _warp_stmatrix_shared(C_local_buf, C_buf, thread_bindings):
tx, warp_n, warp_m = self.extract_thread_binding(thread_bindings)
def _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding):
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
for i, j in T.grid(warp_rows, warp_cols):
for local_id in T.serial(local_size_out):
row, col = T.meta_var(mfma_store_index_map(tx, local_id))
......@@ -351,8 +354,8 @@ class MatrixCoreIntrinEmitter(object):
local_id]
@T.macro
def _warp_stmatrix_global(C_local_buf, C_buf, thread_bindings):
tx, warp_n, warp_m = self.extract_thread_binding(thread_bindings)
def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding):
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
for i, j in T.grid(warp_rows, warp_cols):
for local_id in T.serial(local_size_out):
row, col = T.meta_var(mfma_store_index_map(tx, local_id))
......@@ -362,5 +365,5 @@ class MatrixCoreIntrinEmitter(object):
local_id]
return _warp_stmatrix_global(C_local_buf, C_buf,
thread_bindings) if is_global else _warp_stmatrix_shared(
C_local_buf, C_buf, thread_bindings)
thread_binding) if is_global else _warp_stmatrix_shared(
C_local_buf, C_buf, thread_binding)
......@@ -159,7 +159,6 @@ class TensorCoreIntrinEmitter(object):
A_local_buf: Buffer,
A_shared_buf: Buffer,
ki: PrimExpr,
thread_bindings: PrimExpr,
rk: Optional[PrimExpr] = 0):
warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows
......@@ -170,16 +169,19 @@ class TensorCoreIntrinEmitter(object):
a_dtype = self.a_dtype
a_transposed = self.a_transposed
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
@T.macro
def _warp_ldmatrix_a(
A_local_buf,
A_shared_buf,
ki,
thread_bindings,
thread_binding,
rk=0,
):
stride = A_shared_buf.shape[-1]
tx, _, warp_m = self.extract_thread_binding(thread_bindings)
tx, _, warp_m = self.extract_thread_binding(thread_binding)
for i in T.serial(warp_rows):
T.ptx_ldmatrix(
a_dtype,
......@@ -195,13 +197,12 @@ class TensorCoreIntrinEmitter(object):
get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed),
)
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_bindings, rk)
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
def ldmatrix_b(self,
B_local_buf: Buffer,
B_shared_buf: Buffer,
ki: PrimExpr,
thread_bindings: PrimExpr,
rk: Optional[PrimExpr] = 0):
warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols
......@@ -211,17 +212,19 @@ class TensorCoreIntrinEmitter(object):
local_size_b = self.local_size_b
b_dtype = self.b_dtype
b_transposed = self.b_transposed
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
@T.macro
def _warp_ldmatrix_b(
B_local_buf,
B_shared_buf,
ki,
thread_bindings,
thread_binding,
rk=0,
):
stride = B_shared_buf.shape[-1]
tx, warp_n, _ = self.extract_thread_binding(thread_bindings)
tx, warp_n, _ = self.extract_thread_binding(thread_binding)
for j in T.serial(warp_cols):
# Assign B_shared_elem
......@@ -242,7 +245,7 @@ class TensorCoreIntrinEmitter(object):
get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed),
)
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_bindings, rk)
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
def mma(self,
A_local_buf: Buffer,
......@@ -304,7 +307,7 @@ class TensorCoreIntrinEmitter(object):
return _warp_mma(A_local_buf, B_local_buf, C_local_buf)
def stmatrix(self, C_local_buf, C_buf, thread_bindings, pid_m=None, pid_n=None):
def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None):
block_row_warps = self.block_row_warps
block_col_warps = self.block_col_warps
warp_rows = self.warp_rows
......@@ -316,13 +319,16 @@ class TensorCoreIntrinEmitter(object):
BLOCK_N = block_col_warps * warp_cols
M_DIM, N_DIM = self.M_DIM, self.N_DIM
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
# STS
# MMA Store must be in simulated instead of TVM Intrins
# As TVM Intrins is like a hack that the threadIdx.x should be always
# equal to the warp_size
@T.macro
def _warp_stmatrix_shared(C_local_buf, C_buf, thread_bindings):
tx, warp_n, warp_m = self.extract_thread_binding(thread_bindings)
def _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding):
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
for i, j in T.grid(warp_rows, warp_cols):
for local_id_o in T.serial(local_size_out // 2):
for local_id_i in T.vectorized(2):
......@@ -333,8 +339,8 @@ class TensorCoreIntrinEmitter(object):
j * local_size_out + local_id]
@T.macro
def _warp_stmatrix_global(C_local_buf, C_buf, thread_bindings):
tx, warp_n, warp_m = self.extract_thread_binding(thread_bindings)
def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding):
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
for i, j in T.grid(warp_rows, warp_cols):
for local_id_o in T.serial(local_size_out // 2):
for local_id_i in T.vectorized(2):
......@@ -346,8 +352,8 @@ class TensorCoreIntrinEmitter(object):
] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out +
local_id]
return (_warp_stmatrix_global(C_local_buf, C_buf, thread_bindings)
if is_global else _warp_stmatrix_shared(C_local_buf, C_buf, thread_bindings))
return (_warp_stmatrix_global(C_local_buf, C_buf, thread_binding)
if is_global else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding))
def make_mma_load_layout(self,
local_buf: Buffer,
......@@ -610,7 +616,7 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
assert transform_kind_a in [0, 1, 2, 3], "Input transform stage should be 0, 1, 2, or 3"
assert transform_kind_b in [0, 1, 2, 3], "Weight transform stage should be 0, 1, 2, or 3"
def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk=0):
def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, rk=0):
warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows
chunk = self.chunk
......@@ -621,16 +627,19 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
a_transposed = self.a_transposed
transform_kind_a = self.transform_kind_a
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
@T.macro
def _warp_ldmatrix_a(
A_local_buf,
A_shared_buf,
ki,
thread_bindings,
thread_binding,
rk=0,
):
stride = A_shared_buf.shape[-1]
tx, _, warp_m = self.extract_thread_binding(thread_bindings)
tx, _, warp_m = self.extract_thread_binding(thread_binding)
if transform_kind_a == TransformKind.NonTransform:
for i in T.serial(warp_rows):
T.ptx_ldmatrix(
......@@ -712,9 +721,9 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
else:
raise ValueError("Unsupported TransformKind for Input A")
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_bindings, rk)
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0):
def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, rk=0):
warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols
chunk = self.chunk
......@@ -726,16 +735,19 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
b_transposed = self.b_transposed
num_elems_per_byte = self.num_elems_per_byte
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
@T.macro
def _warp_ldmatrix_b(
B_local_buf,
B_shared_buf,
ki,
thread_bindings,
thread_binding,
rk=0,
):
stride = B_shared_buf.shape[-1]
tx, warp_n, _ = self.extract_thread_binding(thread_bindings)
tx, warp_n, _ = self.extract_thread_binding(thread_binding)
if transform_kind_b == TransformKind.NonTransform:
for j in T.serial(warp_cols):
......@@ -824,7 +836,7 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
else:
raise ValueError("Unsupported TransformKind for Input B")
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_bindings, rk)
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
def mma(self, A_local_buf, B_local_buf, C_local_buf):
warp_rows = self.warp_rows
......
......@@ -125,6 +125,13 @@ class KernelLaunchFrame(TIRFrame):
iter_var = self.frames[-4 + dim].iter_var
return int(iter_var.dom.extent)
def get_thread_binding(self, dim: int = 0) -> Var:
"""
Returns the thread binding for the given dimension.
dim=0 corresponds to threadIdx.x, dim=1 to threadIdx.y, and dim=2 to threadIdx.z.
"""
return self.frames[-4 + dim].iter_var.var
def get_num_threads(self) -> int:
"""
Returns the thread indices from the topmost frame.
......
......@@ -40,7 +40,7 @@ class GemmPrimitiveMMA(GemmBaseParams):
local_size_b = mma_emitter.local_size_b
block_K = mma_emitter.chunk
micro_size_k = mma_emitter.micro_size_k
threads = mma_emitter.threads
# Check if C is a fragment for applying custom layout
a_is_fragment = is_fragment(A)
c_is_fragment = is_fragment(C)
......@@ -54,7 +54,6 @@ class GemmPrimitiveMMA(GemmBaseParams):
"""
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
if a_is_fragment:
# Annotate layout for A_local if it is a fragment.
T.annotate_layout({
......@@ -77,7 +76,6 @@ class GemmPrimitiveMMA(GemmBaseParams):
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)
# Perform Matrix Multiplication
mma_emitter.mma(
......@@ -135,7 +133,6 @@ class GemmPrimitiveMMA(GemmBaseParams):
local_size_b = mma_emitter.local_size_b
block_K = mma_emitter.chunk
micro_size_k = mma_emitter.micro_size_k
threads = mma_emitter.threads
# Check if C is a fragment for applying custom layout
c_is_fragment = is_fragment(C)
......@@ -150,8 +147,6 @@ class GemmPrimitiveMMA(GemmBaseParams):
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
if c_is_fragment:
# Annotate layout for C_local if it is a fragment.
T.annotate_layout({
......@@ -164,7 +159,6 @@ class GemmPrimitiveMMA(GemmBaseParams):
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)
# Load B into fragment
......@@ -172,7 +166,6 @@ class GemmPrimitiveMMA(GemmBaseParams):
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)
# Perform Matrix Multiplication
......@@ -197,7 +190,8 @@ class GemmPrimitiveMMA(GemmBaseParams):
# Infer block partition if necessary
current_frame = T.KernelLaunchFrame.Current()
threads = current_frame.num_threads
threads = current_frame.get_num_threads()
self.infer_block_partition(threads)
A, B, C = self.A, self.B, self.C
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment