Commit 362b3520 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Simplify interface via replacing argument thread binding of...

[Refactor] Simplify interface via replacing argument thread binding of intrinsics with `KernelFrame.Current` (#34)

* installation script fix

* readme typo fix

* doc fix for dequantize gemm

* [Doc] remove CODE_OF_CONDUCT.md and SECURITY.md; update references in CONTRIBUTING.md

* [Doc] add unit tests for AnnotateDeviceRegions transform; remove SUPPORT.md

* update license

* [Enhancement] add tensor supply handling for unsigned integers; improve error message for execution backend assertion

* [Refactor] improve code readability by reformatting function signatures and assertions

* [Refactor] replace torch.manual_seed with tilelang.testing.set_random_seed for consistency in random seed handling

* [Refactor] unify thread binding variable naming across kernel and example files

* [Refactor] remove unused thread binding parameter from matrix multiplication functions

* [Refactor] remove unused thread binding parameter from matrix multiplication functions

* [Refactor] enable main testing function in tilelang kernel gemm test

* bug fix
parent 1b63d3a2
...@@ -257,7 +257,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ...@@ -257,7 +257,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
B_dequantize_local = T.alloc_local((warp_cols * local_size), in_dtype) B_dequantize_local = T.alloc_local((warp_cols * local_size), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
reduced_accum_res = T.alloc_local(0, accum_dtype) reduced_accum_res = T.alloc_local(0, accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x") thread_binding = T.thread_binding(0, threads, "threadIdx.x")
rk = T.thread_binding(0, reduce_k, "threadIdx.y") rk = T.thread_binding(0, reduce_k, "threadIdx.y")
T.annotate_layout({ T.annotate_layout({
...@@ -279,7 +279,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ...@@ -279,7 +279,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte // for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte //
(threads * vec_load_qb)): (threads * vec_load_qb)):
for v in T.vectorized(0, vec_load_qb): for v in T.vectorized(0, vec_load_qb):
t = thread_bindings t = thread_binding
idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v
vkk = idx % (micro_size_k // num_elems_per_byte) vkk = idx % (micro_size_k // num_elems_per_byte)
vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y
...@@ -299,7 +299,6 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ...@@ -299,7 +299,6 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
A_local, A_local,
A_shared, A_shared,
ki, ki,
thread_bindings=thread_bindings,
rk=rk, rk=rk,
) )
...@@ -308,7 +307,6 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ...@@ -308,7 +307,6 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
B_local, B_local,
B_shared, B_shared,
ki, ki,
thread_bindings=thread_bindings,
rk=rk, rk=rk,
) )
...@@ -343,7 +341,6 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ...@@ -343,7 +341,6 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
mma_emitter.stmatrix( mma_emitter.stmatrix(
C_local, C_local,
C_shared, C_shared,
thread_bindings=thread_bindings,
) )
for i, j in T.Parallel(block_M, (block_N // reduce_k)): for i, j in T.Parallel(block_M, (block_N // reduce_k)):
......
...@@ -339,7 +339,7 @@ def tl_matmul( ...@@ -339,7 +339,7 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x") thread_binding = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({ T.annotate_layout({
A_shared: make_swizzle_layout(A_shared), A_shared: make_swizzle_layout(A_shared),
...@@ -367,16 +367,14 @@ def tl_matmul( ...@@ -367,16 +367,14 @@ def tl_matmul(
mma_emitter.ldmatrix_a( mma_emitter.ldmatrix_a(
A_local, A_local,
A_shared, A_shared,
ki, ki
thread_bindings=thread_bindings,
) )
# Load B into fragment # Load B into fragment
mma_emitter.ldmatrix_b( mma_emitter.ldmatrix_b(
B_local, B_local,
B_shared, B_shared,
ki, ki
thread_bindings=thread_bindings,
) )
# Perform Matrix Multiplication # Perform Matrix Multiplication
...@@ -386,7 +384,6 @@ def tl_matmul( ...@@ -386,7 +384,6 @@ def tl_matmul(
mma_emitter.stmatrix( mma_emitter.stmatrix(
C_local, C_local,
C_shared, C_shared,
thread_bindings=thread_bindings,
) )
# Store shared into global # Store shared into global
...@@ -416,10 +413,10 @@ def tl_matmul( ...@@ -416,10 +413,10 @@ def tl_matmul(
```python ```python
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Warp-synchronous load for A # Warp-synchronous load for A
mma_emitter.ldmatrix_a(A_local, A_shared, ki, thread_bindings=thread_bindings) mma_emitter.ldmatrix_a(A_local, A_shared, ki)
# Warp-synchronous load for B # Warp-synchronous load for B
mma_emitter.ldmatrix_b(B_local, B_shared, ki, thread_bindings=thread_bindings) mma_emitter.ldmatrix_b(B_local, B_shared, ki)
``` ```
Internally, these calls orchestrate how each thread in the warp issues the correct load instructions, performs address calculations, and stores the data into registers. Internally, these calls orchestrate how each thread in the warp issues the correct load instructions, performs address calculations, and stores the data into registers.
...@@ -437,7 +434,7 @@ def tl_matmul( ...@@ -437,7 +434,7 @@ def tl_matmul(
5. **Store Results via `stmatrix`** 5. **Store Results via `stmatrix`**
Finally, you write the results from the warp-level fragments back to shared memory or global memory. This step might happen multiple times in a loop or just once at the end. The code snippet: Finally, you write the results from the warp-level fragments back to shared memory or global memory. This step might happen multiple times in a loop or just once at the end. The code snippet:
```python ```python
mma_emitter.stmatrix(C_local, C_shared, thread_bindings=thread_bindings) mma_emitter.stmatrix(C_local, C_shared)
``` ```
orchestrates the warp-synchronous stores, ensuring each thread places the correct fragment element into the correct location of the shared or global buffer. orchestrates the warp-synchronous stores, ensuring each thread places the correct fragment element into the correct location of the shared or global buffer.
......
...@@ -116,8 +116,6 @@ def tl_matmul( ...@@ -116,8 +116,6 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({ T.annotate_layout({
A_shared: make_swizzle_layout(A_shared), A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared), B_shared: make_swizzle_layout(B_shared),
...@@ -141,30 +139,16 @@ def tl_matmul( ...@@ -141,30 +139,16 @@ def tl_matmul(
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment # Load A into fragment
mma_emitter.ldmatrix_a( mma_emitter.ldmatrix_a(A_local, A_shared, ki)
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)
# Load B into fragment # Load B into fragment
mma_emitter.ldmatrix_b( mma_emitter.ldmatrix_b(B_local, B_shared, ki)
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)
# Perform Matrix Multiplication # Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local) mma_emitter.mma(A_local, B_local, C_local)
# Perform STMatrix # Perform STMatrix
mma_emitter.stmatrix( mma_emitter.stmatrix(C_local, C_shared)
C_local,
C_shared,
thread_bindings=thread_bindings,
)
# Store shared into global # Store shared into global
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
......
...@@ -99,8 +99,6 @@ def tl_matmul( ...@@ -99,8 +99,6 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({ T.annotate_layout({
A_shared: make_swizzle_layout(A_shared), A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared), B_shared: make_swizzle_layout(B_shared),
...@@ -128,7 +126,6 @@ def tl_matmul( ...@@ -128,7 +126,6 @@ def tl_matmul(
A_local, A_local,
A_shared, A_shared,
ki, ki,
thread_bindings=thread_bindings,
) )
# Load B into fragment # Load B into fragment
...@@ -136,7 +133,6 @@ def tl_matmul( ...@@ -136,7 +133,6 @@ def tl_matmul(
B_local, B_local,
B_shared, B_shared,
ki, ki,
thread_bindings=thread_bindings,
) )
# Perform Matrix Multiplication # Perform Matrix Multiplication
...@@ -147,7 +143,6 @@ def tl_matmul( ...@@ -147,7 +143,6 @@ def tl_matmul(
mfma_emitter.stmatrix( mfma_emitter.stmatrix(
C_local, C_local,
C_shared, C_shared,
thread_bindings=thread_bindings,
) )
# Store shared into global # Store shared into global
...@@ -162,7 +157,6 @@ def tl_matmul( ...@@ -162,7 +157,6 @@ def tl_matmul(
mfma_emitter.stmatrix( mfma_emitter.stmatrix(
C_local, C_local,
C, C,
thread_bindings=thread_bindings,
pid_m=by, pid_m=by,
pid_n=bx, pid_n=bx,
) )
......
...@@ -113,8 +113,6 @@ def tl_matmul_macro( ...@@ -113,8 +113,6 @@ def tl_matmul_macro(
B_local = T.alloc_local((warp_cols * local_size), in_dtype) B_local = T.alloc_local((warp_cols * local_size), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({ T.annotate_layout({
A_shared: make_swizzle_layout(A_shared), A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared), B_shared: make_swizzle_layout(B_shared),
...@@ -142,7 +140,6 @@ def tl_matmul_macro( ...@@ -142,7 +140,6 @@ def tl_matmul_macro(
A_local, A_local,
A_shared, A_shared,
ki, ki,
thread_bindings=thread_bindings,
) )
# Load B into fragment # Load B into fragment
...@@ -150,7 +147,6 @@ def tl_matmul_macro( ...@@ -150,7 +147,6 @@ def tl_matmul_macro(
B_local, B_local,
B_shared, B_shared,
ki, ki,
thread_bindings=thread_bindings,
) )
# Perform Matrix Multiplication # Perform Matrix Multiplication
...@@ -160,7 +156,6 @@ def tl_matmul_macro( ...@@ -160,7 +156,6 @@ def tl_matmul_macro(
mma_emitter.stmatrix( mma_emitter.stmatrix(
C_local, C_local,
C_shared, C_shared,
thread_bindings=thread_bindings,
) )
# Store shared into global # Store shared into global
......
...@@ -457,7 +457,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ...@@ -457,7 +457,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
B_dequantize_local = T.alloc_local((warp_cols * local_size), in_dtype) B_dequantize_local = T.alloc_local((warp_cols * local_size), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
reduced_accum_res = T.alloc_local(0, accum_dtype) reduced_accum_res = T.alloc_local(0, accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x") thread_binding = T.thread_binding(0, threads, "threadIdx.x")
rk = T.thread_binding(0, reduce_k, "threadIdx.y") rk = T.thread_binding(0, reduce_k, "threadIdx.y")
T.annotate_layout({ T.annotate_layout({
...@@ -479,7 +479,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ...@@ -479,7 +479,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte // for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte //
(threads * vec_load_qb)): (threads * vec_load_qb)):
for v in T.vectorized(0, vec_load_qb): for v in T.vectorized(0, vec_load_qb):
t = thread_bindings t = thread_binding
idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v
vkk = idx % (micro_size_k // num_elems_per_byte) vkk = idx % (micro_size_k // num_elems_per_byte)
vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y
...@@ -499,7 +499,6 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ...@@ -499,7 +499,6 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
A_local, A_local,
A_shared, A_shared,
ki, ki,
thread_bindings=thread_bindings,
rk=rk, rk=rk,
) )
...@@ -508,7 +507,6 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ...@@ -508,7 +507,6 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
B_local, B_local,
B_shared, B_shared,
ki, ki,
thread_bindings=thread_bindings,
rk=rk, rk=rk,
) )
...@@ -543,7 +541,6 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ...@@ -543,7 +541,6 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
mma_emitter.stmatrix( mma_emitter.stmatrix(
C_local, C_local,
C_shared, C_shared,
thread_bindings=thread_bindings,
) )
for i, j in T.Parallel(block_M, (block_N // reduce_k)): for i, j in T.Parallel(block_M, (block_N // reduce_k)):
......
...@@ -119,7 +119,7 @@ def tl_matmul( ...@@ -119,7 +119,7 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x") thread_binding = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({ T.annotate_layout({
A_shared: make_swizzle_layout(A_shared), A_shared: make_swizzle_layout(A_shared),
...@@ -148,7 +148,6 @@ def tl_matmul( ...@@ -148,7 +148,6 @@ def tl_matmul(
A_local, A_local,
A_shared, A_shared,
ki, ki,
thread_bindings=thread_bindings,
) )
# Load B into fragment # Load B into fragment
...@@ -156,7 +155,6 @@ def tl_matmul( ...@@ -156,7 +155,6 @@ def tl_matmul(
B_local, B_local,
B_shared, B_shared,
ki, ki,
thread_bindings=thread_bindings,
) )
# Perform Matrix Multiplication # Perform Matrix Multiplication
...@@ -166,7 +164,6 @@ def tl_matmul( ...@@ -166,7 +164,6 @@ def tl_matmul(
mma_emitter.stmatrix( mma_emitter.stmatrix(
C_local, C_local,
C_shared, C_shared,
thread_bindings=thread_bindings,
) )
# Store shared into global # Store shared into global
......
...@@ -109,7 +109,7 @@ def tl_matmul( ...@@ -109,7 +109,7 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x") thread_binding = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({ T.annotate_layout({
A_shared: make_swizzle_layout(A_shared), A_shared: make_swizzle_layout(A_shared),
...@@ -138,16 +138,14 @@ def tl_matmul( ...@@ -138,16 +138,14 @@ def tl_matmul(
A_local, A_local,
A_shared, A_shared,
ki, ki,
thread_bindings=thread_bindings, )
)
# Load B into fragment # Load B into fragment
mma_emitter.ldmatrix_b( mma_emitter.ldmatrix_b(
B_local, B_local,
B_shared, B_shared,
ki, ki,
thread_bindings=thread_bindings, )
)
# Perform Matrix Multiplication # Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local) mma_emitter.mma(A_local, B_local, C_local)
...@@ -156,7 +154,6 @@ def tl_matmul( ...@@ -156,7 +154,6 @@ def tl_matmul(
mma_emitter.stmatrix( mma_emitter.stmatrix(
C_local, C_local,
C_shared, C_shared,
thread_bindings=thread_bindings,
) )
# Store shared into global # Store shared into global
...@@ -297,7 +294,7 @@ def tl_matmul_weight_only_transform( ...@@ -297,7 +294,7 @@ def tl_matmul_weight_only_transform(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x") thread_binding = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({ T.annotate_layout({
A_shared: make_swizzle_layout(A_shared), A_shared: make_swizzle_layout(A_shared),
...@@ -328,16 +325,14 @@ def tl_matmul_weight_only_transform( ...@@ -328,16 +325,14 @@ def tl_matmul_weight_only_transform(
A_local, A_local,
A_shared, A_shared,
ki, ki,
thread_bindings=thread_bindings, )
)
# Load B into fragment # Load B into fragment
mma_emitter.ldmatrix_b( mma_emitter.ldmatrix_b(
B_local, B_local,
B_shared, B_shared,
ki, ki,
thread_bindings=thread_bindings, )
)
# Perform Matrix Multiplication # Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local) mma_emitter.mma(A_local, B_local, C_local)
...@@ -346,7 +341,6 @@ def tl_matmul_weight_only_transform( ...@@ -346,7 +341,6 @@ def tl_matmul_weight_only_transform(
mma_emitter.stmatrix( mma_emitter.stmatrix(
C_local, C_local,
C_shared, C_shared,
thread_bindings=thread_bindings,
) )
# Store shared into global # Store shared into global
......
...@@ -205,7 +205,7 @@ class MatrixCoreIntrinEmitter(object): ...@@ -205,7 +205,7 @@ class MatrixCoreIntrinEmitter(object):
(WARP_SIZE * block_row_warps)) % block_col_warps, (WARP_SIZE * block_row_warps)) % block_col_warps,
return lane_id, warp_n, warp_m return lane_id, warp_n, warp_m
def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk=0): def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, rk=0):
warp_row_tiles = self.warp_row_tiles warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows warp_rows = self.warp_rows
chunk = self.chunk chunk = self.chunk
...@@ -214,7 +214,8 @@ class MatrixCoreIntrinEmitter(object): ...@@ -214,7 +214,8 @@ class MatrixCoreIntrinEmitter(object):
local_size_a = self.local_size_a local_size_a = self.local_size_a
k_pack = self.k_pack k_pack = self.k_pack
is_transposed = self.a_transposed is_transposed = self.a_transposed
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
_, reverse_index_map = self.get_ldmatrix_index_map(is_b=False) _, reverse_index_map = self.get_ldmatrix_index_map(is_b=False)
@T.macro @T.macro
...@@ -222,10 +223,10 @@ class MatrixCoreIntrinEmitter(object): ...@@ -222,10 +223,10 @@ class MatrixCoreIntrinEmitter(object):
A_local_buf, A_local_buf,
A_shared_buf, A_shared_buf,
ki, ki,
thread_bindings, thread_binding,
rk=0, rk=0,
): ):
tx, _, warp_m = self.extract_thread_binding(thread_bindings) tx, _, warp_m = self.extract_thread_binding(thread_binding)
if is_transposed: if is_transposed:
for i in T.serial(warp_rows): for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a): for local_id in T.vectorized(k_pack * local_size_a):
...@@ -243,9 +244,9 @@ class MatrixCoreIntrinEmitter(object): ...@@ -243,9 +244,9 @@ class MatrixCoreIntrinEmitter(object):
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row, A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row,
r + col] r + col]
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_bindings, rk) return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0): def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, rk=0):
warp_col_tiles = self.warp_col_tiles warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols warp_cols = self.warp_cols
chunk = self.chunk chunk = self.chunk
...@@ -254,7 +255,8 @@ class MatrixCoreIntrinEmitter(object): ...@@ -254,7 +255,8 @@ class MatrixCoreIntrinEmitter(object):
local_size_b = self.local_size_b local_size_b = self.local_size_b
k_pack = self.k_pack k_pack = self.k_pack
is_transposed = self.b_transposed is_transposed = self.b_transposed
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
_, reverse_index_map = self.get_ldmatrix_index_map(is_b=True) _, reverse_index_map = self.get_ldmatrix_index_map(is_b=True)
@T.macro @T.macro
...@@ -262,10 +264,10 @@ class MatrixCoreIntrinEmitter(object): ...@@ -262,10 +264,10 @@ class MatrixCoreIntrinEmitter(object):
B_local_buf, B_local_buf,
B_shared_buf, B_shared_buf,
ki, ki,
thread_bindings, thread_binding,
rk=0, rk=0,
): ):
tx, warp_n, _ = self.extract_thread_binding(thread_bindings) tx, warp_n, _ = self.extract_thread_binding(thread_binding)
if is_transposed: if is_transposed:
for j in T.serial(warp_cols): for j in T.serial(warp_cols):
...@@ -288,7 +290,7 @@ class MatrixCoreIntrinEmitter(object): ...@@ -288,7 +290,7 @@ class MatrixCoreIntrinEmitter(object):
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row, B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row,
r + col] r + col]
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_bindings, rk) return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
def mfma(self, A_local_buf, B_local_buf, C_local_buf): def mfma(self, A_local_buf, B_local_buf, C_local_buf):
warp_rows = self.warp_rows warp_rows = self.warp_rows
...@@ -324,13 +326,14 @@ class MatrixCoreIntrinEmitter(object): ...@@ -324,13 +326,14 @@ class MatrixCoreIntrinEmitter(object):
return _warp_mma(A_local_buf, B_local_buf, C_local_buf) return _warp_mma(A_local_buf, B_local_buf, C_local_buf)
def stmatrix(self, C_local_buf, C_buf, thread_bindings, pid_m=None, pid_n=None): def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None):
block_row_warps = self.block_row_warps block_row_warps = self.block_row_warps
block_col_warps = self.block_col_warps block_col_warps = self.block_col_warps
warp_rows = self.warp_rows warp_rows = self.warp_rows
warp_cols = self.warp_cols warp_cols = self.warp_cols
local_size_out = self.local_size_out local_size_out = self.local_size_out
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
is_global = pid_m is not None and pid_n is not None is_global = pid_m is not None and pid_n is not None
BLOCK_M = block_row_warps * warp_rows BLOCK_M = block_row_warps * warp_rows
BLOCK_N = block_col_warps * warp_cols BLOCK_N = block_col_warps * warp_cols
...@@ -341,8 +344,8 @@ class MatrixCoreIntrinEmitter(object): ...@@ -341,8 +344,8 @@ class MatrixCoreIntrinEmitter(object):
# As TVM Intrins is like a hack that the threadIdx.x should be always # As TVM Intrins is like a hack that the threadIdx.x should be always
# equal to the warp_size # equal to the warp_size
@T.macro @T.macro
def _warp_stmatrix_shared(C_local_buf, C_buf, thread_bindings): def _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding):
tx, warp_n, warp_m = self.extract_thread_binding(thread_bindings) tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
for i, j in T.grid(warp_rows, warp_cols): for i, j in T.grid(warp_rows, warp_cols):
for local_id in T.serial(local_size_out): for local_id in T.serial(local_size_out):
row, col = T.meta_var(mfma_store_index_map(tx, local_id)) row, col = T.meta_var(mfma_store_index_map(tx, local_id))
...@@ -351,8 +354,8 @@ class MatrixCoreIntrinEmitter(object): ...@@ -351,8 +354,8 @@ class MatrixCoreIntrinEmitter(object):
local_id] local_id]
@T.macro @T.macro
def _warp_stmatrix_global(C_local_buf, C_buf, thread_bindings): def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding):
tx, warp_n, warp_m = self.extract_thread_binding(thread_bindings) tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
for i, j in T.grid(warp_rows, warp_cols): for i, j in T.grid(warp_rows, warp_cols):
for local_id in T.serial(local_size_out): for local_id in T.serial(local_size_out):
row, col = T.meta_var(mfma_store_index_map(tx, local_id)) row, col = T.meta_var(mfma_store_index_map(tx, local_id))
...@@ -362,5 +365,5 @@ class MatrixCoreIntrinEmitter(object): ...@@ -362,5 +365,5 @@ class MatrixCoreIntrinEmitter(object):
local_id] local_id]
return _warp_stmatrix_global(C_local_buf, C_buf, return _warp_stmatrix_global(C_local_buf, C_buf,
thread_bindings) if is_global else _warp_stmatrix_shared( thread_binding) if is_global else _warp_stmatrix_shared(
C_local_buf, C_buf, thread_bindings) C_local_buf, C_buf, thread_binding)
...@@ -159,7 +159,6 @@ class TensorCoreIntrinEmitter(object): ...@@ -159,7 +159,6 @@ class TensorCoreIntrinEmitter(object):
A_local_buf: Buffer, A_local_buf: Buffer,
A_shared_buf: Buffer, A_shared_buf: Buffer,
ki: PrimExpr, ki: PrimExpr,
thread_bindings: PrimExpr,
rk: Optional[PrimExpr] = 0): rk: Optional[PrimExpr] = 0):
warp_row_tiles = self.warp_row_tiles warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows warp_rows = self.warp_rows
...@@ -170,16 +169,19 @@ class TensorCoreIntrinEmitter(object): ...@@ -170,16 +169,19 @@ class TensorCoreIntrinEmitter(object):
a_dtype = self.a_dtype a_dtype = self.a_dtype
a_transposed = self.a_transposed a_transposed = self.a_transposed
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
@T.macro @T.macro
def _warp_ldmatrix_a( def _warp_ldmatrix_a(
A_local_buf, A_local_buf,
A_shared_buf, A_shared_buf,
ki, ki,
thread_bindings, thread_binding,
rk=0, rk=0,
): ):
stride = A_shared_buf.shape[-1] stride = A_shared_buf.shape[-1]
tx, _, warp_m = self.extract_thread_binding(thread_bindings) tx, _, warp_m = self.extract_thread_binding(thread_binding)
for i in T.serial(warp_rows): for i in T.serial(warp_rows):
T.ptx_ldmatrix( T.ptx_ldmatrix(
a_dtype, a_dtype,
...@@ -195,13 +197,12 @@ class TensorCoreIntrinEmitter(object): ...@@ -195,13 +197,12 @@ class TensorCoreIntrinEmitter(object):
get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed), get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed),
) )
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_bindings, rk) return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
def ldmatrix_b(self, def ldmatrix_b(self,
B_local_buf: Buffer, B_local_buf: Buffer,
B_shared_buf: Buffer, B_shared_buf: Buffer,
ki: PrimExpr, ki: PrimExpr,
thread_bindings: PrimExpr,
rk: Optional[PrimExpr] = 0): rk: Optional[PrimExpr] = 0):
warp_col_tiles = self.warp_col_tiles warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols warp_cols = self.warp_cols
...@@ -211,17 +212,19 @@ class TensorCoreIntrinEmitter(object): ...@@ -211,17 +212,19 @@ class TensorCoreIntrinEmitter(object):
local_size_b = self.local_size_b local_size_b = self.local_size_b
b_dtype = self.b_dtype b_dtype = self.b_dtype
b_transposed = self.b_transposed b_transposed = self.b_transposed
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
@T.macro @T.macro
def _warp_ldmatrix_b( def _warp_ldmatrix_b(
B_local_buf, B_local_buf,
B_shared_buf, B_shared_buf,
ki, ki,
thread_bindings, thread_binding,
rk=0, rk=0,
): ):
stride = B_shared_buf.shape[-1] stride = B_shared_buf.shape[-1]
tx, warp_n, _ = self.extract_thread_binding(thread_bindings) tx, warp_n, _ = self.extract_thread_binding(thread_binding)
for j in T.serial(warp_cols): for j in T.serial(warp_cols):
# Assign B_shared_elem # Assign B_shared_elem
...@@ -242,7 +245,7 @@ class TensorCoreIntrinEmitter(object): ...@@ -242,7 +245,7 @@ class TensorCoreIntrinEmitter(object):
get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed), get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed),
) )
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_bindings, rk) return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
def mma(self, def mma(self,
A_local_buf: Buffer, A_local_buf: Buffer,
...@@ -304,7 +307,7 @@ class TensorCoreIntrinEmitter(object): ...@@ -304,7 +307,7 @@ class TensorCoreIntrinEmitter(object):
return _warp_mma(A_local_buf, B_local_buf, C_local_buf) return _warp_mma(A_local_buf, B_local_buf, C_local_buf)
def stmatrix(self, C_local_buf, C_buf, thread_bindings, pid_m=None, pid_n=None): def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None):
block_row_warps = self.block_row_warps block_row_warps = self.block_row_warps
block_col_warps = self.block_col_warps block_col_warps = self.block_col_warps
warp_rows = self.warp_rows warp_rows = self.warp_rows
...@@ -316,13 +319,16 @@ class TensorCoreIntrinEmitter(object): ...@@ -316,13 +319,16 @@ class TensorCoreIntrinEmitter(object):
BLOCK_N = block_col_warps * warp_cols BLOCK_N = block_col_warps * warp_cols
M_DIM, N_DIM = self.M_DIM, self.N_DIM M_DIM, N_DIM = self.M_DIM, self.N_DIM
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
# STS # STS
# MMA Store must be in simulated instead of TVM Intrins # MMA Store must be in simulated instead of TVM Intrins
# As TVM Intrins is like a hack that the threadIdx.x should be always # As TVM Intrins is like a hack that the threadIdx.x should be always
# equal to the warp_size # equal to the warp_size
@T.macro @T.macro
def _warp_stmatrix_shared(C_local_buf, C_buf, thread_bindings): def _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding):
tx, warp_n, warp_m = self.extract_thread_binding(thread_bindings) tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
for i, j in T.grid(warp_rows, warp_cols): for i, j in T.grid(warp_rows, warp_cols):
for local_id_o in T.serial(local_size_out // 2): for local_id_o in T.serial(local_size_out // 2):
for local_id_i in T.vectorized(2): for local_id_i in T.vectorized(2):
...@@ -333,8 +339,8 @@ class TensorCoreIntrinEmitter(object): ...@@ -333,8 +339,8 @@ class TensorCoreIntrinEmitter(object):
j * local_size_out + local_id] j * local_size_out + local_id]
@T.macro @T.macro
def _warp_stmatrix_global(C_local_buf, C_buf, thread_bindings): def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding):
tx, warp_n, warp_m = self.extract_thread_binding(thread_bindings) tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
for i, j in T.grid(warp_rows, warp_cols): for i, j in T.grid(warp_rows, warp_cols):
for local_id_o in T.serial(local_size_out // 2): for local_id_o in T.serial(local_size_out // 2):
for local_id_i in T.vectorized(2): for local_id_i in T.vectorized(2):
...@@ -346,8 +352,8 @@ class TensorCoreIntrinEmitter(object): ...@@ -346,8 +352,8 @@ class TensorCoreIntrinEmitter(object):
] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + ] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out +
local_id] local_id]
return (_warp_stmatrix_global(C_local_buf, C_buf, thread_bindings) return (_warp_stmatrix_global(C_local_buf, C_buf, thread_binding)
if is_global else _warp_stmatrix_shared(C_local_buf, C_buf, thread_bindings)) if is_global else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding))
def make_mma_load_layout(self, def make_mma_load_layout(self,
local_buf: Buffer, local_buf: Buffer,
...@@ -610,7 +616,7 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): ...@@ -610,7 +616,7 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
assert transform_kind_a in [0, 1, 2, 3], "Input transform stage should be 0, 1, 2, or 3" assert transform_kind_a in [0, 1, 2, 3], "Input transform stage should be 0, 1, 2, or 3"
assert transform_kind_b in [0, 1, 2, 3], "Weight transform stage should be 0, 1, 2, or 3" assert transform_kind_b in [0, 1, 2, 3], "Weight transform stage should be 0, 1, 2, or 3"
def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk=0): def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, rk=0):
warp_row_tiles = self.warp_row_tiles warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows warp_rows = self.warp_rows
chunk = self.chunk chunk = self.chunk
...@@ -621,16 +627,19 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): ...@@ -621,16 +627,19 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
a_transposed = self.a_transposed a_transposed = self.a_transposed
transform_kind_a = self.transform_kind_a transform_kind_a = self.transform_kind_a
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
@T.macro @T.macro
def _warp_ldmatrix_a( def _warp_ldmatrix_a(
A_local_buf, A_local_buf,
A_shared_buf, A_shared_buf,
ki, ki,
thread_bindings, thread_binding,
rk=0, rk=0,
): ):
stride = A_shared_buf.shape[-1] stride = A_shared_buf.shape[-1]
tx, _, warp_m = self.extract_thread_binding(thread_bindings) tx, _, warp_m = self.extract_thread_binding(thread_binding)
if transform_kind_a == TransformKind.NonTransform: if transform_kind_a == TransformKind.NonTransform:
for i in T.serial(warp_rows): for i in T.serial(warp_rows):
T.ptx_ldmatrix( T.ptx_ldmatrix(
...@@ -712,9 +721,9 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): ...@@ -712,9 +721,9 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
else: else:
raise ValueError("Unsupported TransformKind for Input A") raise ValueError("Unsupported TransformKind for Input A")
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_bindings, rk) return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0): def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, rk=0):
warp_col_tiles = self.warp_col_tiles warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols warp_cols = self.warp_cols
chunk = self.chunk chunk = self.chunk
...@@ -726,16 +735,19 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): ...@@ -726,16 +735,19 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
b_transposed = self.b_transposed b_transposed = self.b_transposed
num_elems_per_byte = self.num_elems_per_byte num_elems_per_byte = self.num_elems_per_byte
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
@T.macro @T.macro
def _warp_ldmatrix_b( def _warp_ldmatrix_b(
B_local_buf, B_local_buf,
B_shared_buf, B_shared_buf,
ki, ki,
thread_bindings, thread_binding,
rk=0, rk=0,
): ):
stride = B_shared_buf.shape[-1] stride = B_shared_buf.shape[-1]
tx, warp_n, _ = self.extract_thread_binding(thread_bindings) tx, warp_n, _ = self.extract_thread_binding(thread_binding)
if transform_kind_b == TransformKind.NonTransform: if transform_kind_b == TransformKind.NonTransform:
for j in T.serial(warp_cols): for j in T.serial(warp_cols):
...@@ -824,7 +836,7 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): ...@@ -824,7 +836,7 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
else: else:
raise ValueError("Unsupported TransformKind for Input B") raise ValueError("Unsupported TransformKind for Input B")
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_bindings, rk) return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
def mma(self, A_local_buf, B_local_buf, C_local_buf): def mma(self, A_local_buf, B_local_buf, C_local_buf):
warp_rows = self.warp_rows warp_rows = self.warp_rows
......
...@@ -125,6 +125,13 @@ class KernelLaunchFrame(TIRFrame): ...@@ -125,6 +125,13 @@ class KernelLaunchFrame(TIRFrame):
iter_var = self.frames[-4 + dim].iter_var iter_var = self.frames[-4 + dim].iter_var
return int(iter_var.dom.extent) return int(iter_var.dom.extent)
def get_thread_binding(self, dim: int = 0) -> Var:
"""
Returns the thread binding for the given dimension.
dim=0 corresponds to threadIdx.x, dim=1 to threadIdx.y, and dim=2 to threadIdx.z.
"""
return self.frames[-4 + dim].iter_var.var
def get_num_threads(self) -> int: def get_num_threads(self) -> int:
""" """
Returns the thread indices from the topmost frame. Returns the thread indices from the topmost frame.
......
...@@ -40,7 +40,7 @@ class GemmPrimitiveMMA(GemmBaseParams): ...@@ -40,7 +40,7 @@ class GemmPrimitiveMMA(GemmBaseParams):
local_size_b = mma_emitter.local_size_b local_size_b = mma_emitter.local_size_b
block_K = mma_emitter.chunk block_K = mma_emitter.chunk
micro_size_k = mma_emitter.micro_size_k micro_size_k = mma_emitter.micro_size_k
threads = mma_emitter.threads
# Check if C is a fragment for applying custom layout # Check if C is a fragment for applying custom layout
a_is_fragment = is_fragment(A) a_is_fragment = is_fragment(A)
c_is_fragment = is_fragment(C) c_is_fragment = is_fragment(C)
...@@ -54,7 +54,6 @@ class GemmPrimitiveMMA(GemmBaseParams): ...@@ -54,7 +54,6 @@ class GemmPrimitiveMMA(GemmBaseParams):
""" """
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
if a_is_fragment: if a_is_fragment:
# Annotate layout for A_local if it is a fragment. # Annotate layout for A_local if it is a fragment.
T.annotate_layout({ T.annotate_layout({
...@@ -77,7 +76,6 @@ class GemmPrimitiveMMA(GemmBaseParams): ...@@ -77,7 +76,6 @@ class GemmPrimitiveMMA(GemmBaseParams):
B_local, B_local,
B_shared, B_shared,
ki, ki,
thread_bindings=thread_bindings,
) )
# Perform Matrix Multiplication # Perform Matrix Multiplication
mma_emitter.mma( mma_emitter.mma(
...@@ -135,7 +133,6 @@ class GemmPrimitiveMMA(GemmBaseParams): ...@@ -135,7 +133,6 @@ class GemmPrimitiveMMA(GemmBaseParams):
local_size_b = mma_emitter.local_size_b local_size_b = mma_emitter.local_size_b
block_K = mma_emitter.chunk block_K = mma_emitter.chunk
micro_size_k = mma_emitter.micro_size_k micro_size_k = mma_emitter.micro_size_k
threads = mma_emitter.threads
# Check if C is a fragment for applying custom layout # Check if C is a fragment for applying custom layout
c_is_fragment = is_fragment(C) c_is_fragment = is_fragment(C)
...@@ -150,8 +147,6 @@ class GemmPrimitiveMMA(GemmBaseParams): ...@@ -150,8 +147,6 @@ class GemmPrimitiveMMA(GemmBaseParams):
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
if c_is_fragment: if c_is_fragment:
# Annotate layout for C_local if it is a fragment. # Annotate layout for C_local if it is a fragment.
T.annotate_layout({ T.annotate_layout({
...@@ -164,7 +159,6 @@ class GemmPrimitiveMMA(GemmBaseParams): ...@@ -164,7 +159,6 @@ class GemmPrimitiveMMA(GemmBaseParams):
A_local, A_local,
A_shared, A_shared,
ki, ki,
thread_bindings=thread_bindings,
) )
# Load B into fragment # Load B into fragment
...@@ -172,7 +166,6 @@ class GemmPrimitiveMMA(GemmBaseParams): ...@@ -172,7 +166,6 @@ class GemmPrimitiveMMA(GemmBaseParams):
B_local, B_local,
B_shared, B_shared,
ki, ki,
thread_bindings=thread_bindings,
) )
# Perform Matrix Multiplication # Perform Matrix Multiplication
...@@ -197,7 +190,8 @@ class GemmPrimitiveMMA(GemmBaseParams): ...@@ -197,7 +190,8 @@ class GemmPrimitiveMMA(GemmBaseParams):
# Infer block partition if necessary # Infer block partition if necessary
current_frame = T.KernelLaunchFrame.Current() current_frame = T.KernelLaunchFrame.Current()
threads = current_frame.num_threads threads = current_frame.get_num_threads()
self.infer_block_partition(threads) self.infer_block_partition(threads)
A, B, C = self.A, self.B, self.C A, B, C = self.A, self.B, self.C
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment