Unverified Commit 60567ba3 authored by Jiaxing Ding's avatar Jiaxing Ding Committed by GitHub
Browse files

[AMD] Supoort T.gemm_v2 for AMD Backend (#1136)

parent 7d389a43
import tilelang.language as T
from typing import Literal, Callable
from tvm.tir import IndexMap
from tilelang.intrinsics.utils import get_mma_micro_size
from tilelang.intrinsics.mfma_layout import (
shared_16x4_to_local_64x1_layout_A,
shared_16x16_to_local_64x4_layout_A,
shared_16x32_to_local_64x8_layout_A,
shared_16x64_to_local_64x16_layout_A,
)
def make_mfma_load_base_layout(dtype: str = "float16",
matrix: Literal["A", "B"] = "A",
k_dim: int = 16,
transposed: bool = False) -> T.Fragment:
"""
Create a layout function for storing MFMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mfma_store_layout` to
map fragment indices to threads and local indices.
Parameters
----------
dtype : str
The data type of the matrix.
matrix : Literal["A", "B"]
The mfma operand to be loaded.
k_dim : int
The k dimension of the mfma.
transposed : bool
Whether the matrix is transposed, by default False.
Returns
-------
T.Fragment
Describes how threads and indices in fragment are laid out.
"""
assert matrix in ["A", "B"], "matrix should be either A or B"
# s represents spatial axis
# r represents reduction axis
# sr represents the two dims are spatial + reduction
# rs represents the two dims are reduction + spatial
transform_func_sr_a: Callable = None
transform_func_sr_b: Callable = None
if k_dim == 4:
transform_func_sr_a = shared_16x4_to_local_64x1_layout_A
transform_func_sr_b = shared_16x4_to_local_64x1_layout_A
elif k_dim == 16:
transform_func_sr_a = shared_16x16_to_local_64x4_layout_A
transform_func_sr_b = shared_16x16_to_local_64x4_layout_A
elif k_dim == 32:
transform_func_sr_a = shared_16x32_to_local_64x8_layout_A
transform_func_sr_b = shared_16x32_to_local_64x8_layout_A
elif k_dim == 64:
transform_func_sr_a = shared_16x64_to_local_64x16_layout_A
transform_func_sr_b = shared_16x64_to_local_64x16_layout_A
else:
raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently")
is_sr_conditions = [False]
is_sr_conditions.append(matrix == "A" and not transposed)
is_sr_conditions.append(matrix == "B" and transposed)
is_sr_axis_order = any(is_sr_conditions)
micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtype)
# the layout of mma.sync is row.col.
# so the b matrix expected a transposed basic layout
transform_func: Callable = None
if matrix == "A":
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
j, i)
micro_size_s, micro_size_r = micro_size_x, micro_size_k
elif matrix == "B":
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(
j, i)
micro_size_s, micro_size_r = micro_size_k, micro_size_y
else:
raise ValueError(f"Unsupported matrix {matrix}")
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
def forward_thread(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
"""
lane_id, _ = inverse_mma_load_layout.map_indices([i, j])
return lane_id
def forward_index(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
"""
_, local_id = inverse_mma_load_layout.map_indices([i, j])
return local_id
base_fragment = T.Fragment(
[micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s],
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
)
return base_fragment
block_rows = 2
block_cols = 2
warp_rows = 2
warp_cols = 2
chunk = 2
from tilelang.tools import plot_layout
# ldmatrix layout 16x16
base_layout = make_mfma_load_base_layout(dtype="float16", matrix="A", transposed=False)
print(base_layout)
plot_layout(base_layout, name="base_layout")
# warp layout 32x32
warp_layout = base_layout.repeat([warp_rows, warp_cols],
repeat_on_thread=False,
lower_dim_first=False)
print(warp_layout)
plot_layout(warp_layout, name="warp_layout")
# block layout 64x32
block_layout = warp_layout.repeat([block_rows, 1], repeat_on_thread=True,
lower_dim_first=True).replicate(block_cols)
print(block_layout)
plot_layout(block_layout, name="block_layout")
from tilelang import tvm as tvm
import tilelang.testing
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm_v2(A_shared, B_shared, C_local, trans_A, trans_B)
# T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_ss(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=256,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
latency = profiler.do_bench(profiler.func, warmup=100)
print(f"GEMM SS latency: {latency} ms")
def test_gemm_ss():
# GEMM tests for float16
run_gemm_ss(1024, 1024, 1024, False, False, "float16", "float16", "float32", 128, 128, 32)
run_gemm_ss(1024, 1024, 1024, False, True, "float16", "float16", "float32", 128, 128, 32)
run_gemm_ss(1024, 1024, 1024, True, False, "float16", "float16", "float32", 128, 128, 32)
run_gemm_ss(1024, 1024, 1024, True, True, "float16", "float16", "float32", 128, 128, 32)
# GEMM tests for int8 tests
run_gemm_ss(1024, 1024, 1024, False, True, "int8", "int8", "int32", 128, 128, 32)
run_gemm_ss(1024, 1024, 1024, False, False, "int8", "int8", "int32", 128, 128, 32)
run_gemm_ss(1024, 1024, 1024, True, False, "int8", "int8", "int32", 128, 128, 32)
run_gemm_ss(1024, 1024, 1024, True, True, "int8", "int8", "int32", 128, 128, 32)
def matmul_rs(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
A_frag_shape = A_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
A_frag = T.alloc_fragment(A_frag_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
T.annotate_layout({
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
})
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.copy(A_shared, A_frag)
T.gemm_v2(A_frag, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_rs(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=256,
):
program = matmul_rs(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_gemm_rs():
# GEMM tests for float16
run_gemm_rs(1024, 1024, 1024, False, False, "float16", "float16", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, False, True, "float16", "float16", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, False, "float16", "float16", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, True, "float16", "float16", "float32", 128, 128, 32)
# GEMM tests for int8 tests
run_gemm_rs(1024, 1024, 1024, False, True, "int8", "int8", "int32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, False, False, "int8", "int8", "int32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, False, "int8", "int8", "int32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, True, "int8", "int8", "int32", 128, 128, 32)
def matmul_sr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
B_frag_shape = B_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
B_frag = T.alloc_fragment(B_frag_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
})
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.copy(B_shared, B_frag)
T.gemm_v2(A_shared, B_frag, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_sr(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=256,
):
program = matmul_sr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_gemm_sr():
# GEMM tests for float16
run_gemm_sr(1024, 1024, 1024, False, False, "float16", "float16", "float32", 128, 128, 32)
run_gemm_sr(1024, 1024, 1024, False, True, "float16", "float16", "float32", 128, 128, 32)
run_gemm_sr(1024, 1024, 1024, True, False, "float16", "float16", "float32", 128, 128, 32)
run_gemm_sr(1024, 1024, 1024, True, True, "float16", "float16", "float32", 128, 128, 32)
# GEMM tests for int8 tests
run_gemm_sr(1024, 1024, 1024, False, True, "int8", "int8", "int32", 128, 128, 32)
run_gemm_sr(1024, 1024, 1024, False, False, "int8", "int8", "int32", 128, 128, 32)
run_gemm_sr(1024, 1024, 1024, True, False, "int8", "int8", "int32", 128, 128, 32)
run_gemm_sr(1024, 1024, 1024, True, True, "int8", "int8", "int32", 128, 128, 32)
def matmul_rr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
A_frag_shape = A_shared_shape
B_frag_shape = B_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
A_frag = T.alloc_fragment(A_frag_shape, in_dtype)
B_frag = T.alloc_fragment(B_frag_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
T.annotate_layout({
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
})
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.copy(A_shared, A_frag)
T.copy(B_shared, B_frag)
T.gemm_v2(A_frag, B_frag, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_rr(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=256,
):
program = matmul_rr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
print(program)
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_gemm_rr():
# GEMM tests for float16
run_gemm_rr(1024, 1024, 1024, False, False, "float16", "float16", "float32", 128, 128, 32)
run_gemm_rr(1024, 1024, 1024, False, True, "float16", "float16", "float32", 128, 128, 32)
run_gemm_rr(1024, 1024, 1024, True, False, "float16", "float16", "float32", 128, 128, 32)
run_gemm_rr(1024, 1024, 1024, True, True, "float16", "float16", "float32", 128, 128, 32)
# GEMM tests for int8 tests
run_gemm_rr(1024, 1024, 1024, False, True, "int8", "int8", "int32", 128, 128, 32)
run_gemm_rr(1024, 1024, 1024, False, False, "int8", "int8", "int32", 128, 128, 32)
run_gemm_rr(1024, 1024, 1024, True, False, "int8", "int8", "int32", 128, 128, 32)
run_gemm_rr(1024, 1024, 1024, True, True, "int8", "int8", "int32", 128, 128, 32)
if __name__ == "__main__":
tilelang.testing.main()
...@@ -2,10 +2,32 @@ from __future__ import annotations ...@@ -2,10 +2,32 @@ from __future__ import annotations
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.language as T import tilelang.language as T
from tvm import DataType from tvm import DataType
from tvm.tir import PrimExpr from tvm.tir import PrimExpr, IndexMap, Buffer, Var
from tvm.runtime import convert from tvm.runtime import convert
from .utils import ( from .utils import (
mfma_store_index_map,) mfma_store_index_map,)
from typing import Literal, Callable
from tilelang.utils import is_fragment
from .mfma_layout import (
shared_16x4_to_local_64x1_layout_A,
shared_4x16_to_local_64x1_layout_B,
shared_16x16_to_local_64x4_layout_A,
shared_16x16_to_local_64x4_layout_B,
shared_16x32_to_local_64x8_layout_A,
shared_16x32_to_local_64x8_layout_B,
shared_16x64_to_local_64x16_layout_A,
shared_16x64_to_local_64x16_layout_B,
thread_id_shared_access_64x1_to_16x4_layout_A,
thread_id_shared_access_64x1_to_4x16_layout_B,
thread_id_shared_access_64x4_to_16x16_layout_A,
thread_id_shared_access_64x4_to_16x16_layout_B,
thread_id_shared_access_64x8_to_16x32_layout_A,
thread_id_shared_access_64x8_to_16x32_layout_B,
thread_id_shared_access_64x16_to_16x64_layout_A,
thread_id_shared_access_64x16_to_16x64_layout_B,
)
lift = convert lift = convert
...@@ -53,6 +75,7 @@ class MatrixCoreIntrinEmitter: ...@@ -53,6 +75,7 @@ class MatrixCoreIntrinEmitter:
k_pack: int | None = None, k_pack: int | None = None,
is_m_first: bool | None = False, is_m_first: bool | None = False,
b_preshuffle: bool | None = False, b_preshuffle: bool | None = False,
thread_var: Var | None = None,
): ):
self.a_dtype = a_dtype self.a_dtype = a_dtype
self.b_dtype = b_dtype self.b_dtype = b_dtype
...@@ -79,6 +102,7 @@ class MatrixCoreIntrinEmitter: ...@@ -79,6 +102,7 @@ class MatrixCoreIntrinEmitter:
self.reduce_k = reduce_k self.reduce_k = reduce_k
self.threads = (self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k) self.threads = (self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k)
self.num_elems_per_byte = num_elems_per_byte self.num_elems_per_byte = num_elems_per_byte
self.thread_var = thread_var
def _initialize_k_dim(self, a_dtype="float16"): def _initialize_k_dim(self, a_dtype="float16"):
if isinstance(a_dtype, str): if isinstance(a_dtype, str):
...@@ -147,24 +171,6 @@ class MatrixCoreIntrinEmitter: ...@@ -147,24 +171,6 @@ class MatrixCoreIntrinEmitter:
self.b_preshuffle = b_preshuffle self.b_preshuffle = b_preshuffle
def get_ldmatrix_index_map(self, is_b=False): def get_ldmatrix_index_map(self, is_b=False):
from .mfma_layout import (
shared_16x4_to_local_64x1_layout_A,
shared_4x16_to_local_64x1_layout_B,
shared_16x16_to_local_64x4_layout_A,
shared_16x16_to_local_64x4_layout_B,
shared_16x32_to_local_64x8_layout_A,
shared_16x32_to_local_64x8_layout_B,
shared_16x64_to_local_64x16_layout_A,
shared_16x64_to_local_64x16_layout_B,
thread_id_shared_access_64x1_to_16x4_layout_A,
thread_id_shared_access_64x1_to_4x16_layout_B,
thread_id_shared_access_64x4_to_16x16_layout_A,
thread_id_shared_access_64x4_to_16x16_layout_B,
thread_id_shared_access_64x8_to_16x32_layout_A,
thread_id_shared_access_64x8_to_16x32_layout_B,
thread_id_shared_access_64x16_to_16x64_layout_A,
thread_id_shared_access_64x16_to_16x64_layout_B,
)
k_dim = self.k_dim * self.k_pack k_dim = self.k_dim * self.k_pack
transposed = self.a_transposed if not is_b else self.b_transposed transposed = self.a_transposed if not is_b else self.b_transposed
...@@ -200,6 +206,22 @@ class MatrixCoreIntrinEmitter: ...@@ -200,6 +206,22 @@ class MatrixCoreIntrinEmitter:
return index_map, reverse_index_map return index_map, reverse_index_map
def get_store_index_map(self, inverse: bool = False) -> IndexMap:
warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
index_map = IndexMap.from_func(mfma_store_index_map, index_dtype="int32")
if not inverse:
return index_map
inverse_index_map = index_map.inverse([warp_size, local_size_c])
return inverse_index_map
def get_thread_binding(self):
if self.thread_var is None:
current_frame = T.KernelLaunchFrame.Current()
assert current_frame is not None, "Must be called in a T.Kernel Frame"
return current_frame.get_thread_binding()
else:
return self.thread_var
def extract_thread_binding(self, def extract_thread_binding(self,
thread_id, thread_id,
is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
...@@ -238,8 +260,7 @@ class MatrixCoreIntrinEmitter: ...@@ -238,8 +260,7 @@ class MatrixCoreIntrinEmitter:
local_size_a = self.local_size_a local_size_a = self.local_size_a
k_pack = self.k_pack k_pack = self.k_pack
is_transposed = self.a_transposed is_transposed = self.a_transposed
current_frame = T.KernelLaunchFrame.Current() thread_binding = self.get_thread_binding()
thread_binding = current_frame.get_thread_binding()
_, reverse_index_map = self.get_ldmatrix_index_map(is_b=False) _, reverse_index_map = self.get_ldmatrix_index_map(is_b=False)
@T.macro @T.macro
...@@ -279,8 +300,7 @@ class MatrixCoreIntrinEmitter: ...@@ -279,8 +300,7 @@ class MatrixCoreIntrinEmitter:
local_size_b = self.local_size_b local_size_b = self.local_size_b
k_pack = self.k_pack k_pack = self.k_pack
is_transposed = self.b_transposed is_transposed = self.b_transposed
current_frame = T.KernelLaunchFrame.Current() thread_binding = self.get_thread_binding()
thread_binding = current_frame.get_thread_binding()
_, reverse_index_map = self.get_ldmatrix_index_map(is_b=True) _, reverse_index_map = self.get_ldmatrix_index_map(is_b=True)
@T.macro @T.macro
...@@ -316,7 +336,11 @@ class MatrixCoreIntrinEmitter: ...@@ -316,7 +336,11 @@ class MatrixCoreIntrinEmitter:
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
def mfma(self, A_local_buf, B_local_buf, C_local_buf): def mfma(self,
A_local_buf: Buffer,
B_local_buf: Buffer,
C_local_buf: Buffer,
k_inner: PrimExpr | None = 0):
warp_rows = self.warp_rows warp_rows = self.warp_rows
warp_cols = self.warp_cols warp_cols = self.warp_cols
local_size_a = self.local_size_a local_size_a = self.local_size_a
...@@ -329,8 +353,15 @@ class MatrixCoreIntrinEmitter: ...@@ -329,8 +353,15 @@ class MatrixCoreIntrinEmitter:
compute_b_dtype = b_dtype if local_size_b == 1 else f"{b_dtype}x{local_size_b}" compute_b_dtype = b_dtype if local_size_b == 1 else f"{b_dtype}x{local_size_b}"
compute_out_dtype = out_dtype if local_size_out == 1 else f"{out_dtype}x{local_size_out}" compute_out_dtype = out_dtype if local_size_out == 1 else f"{out_dtype}x{local_size_out}"
a_is_fragment = is_fragment(A_local_buf)
b_is_fragment = is_fragment(B_local_buf)
a_local_stride: PrimExpr = k_inner * warp_rows * local_size_a if a_is_fragment else 0
b_local_stride: PrimExpr = k_inner * warp_cols * local_size_b if b_is_fragment else 0
print(a_local_stride, b_local_stride)
@T.macro @T.macro
def _warp_mma(A_local_buf, B_local_buf, C_local_buf): def _warp_mfma(A_local_buf, B_local_buf, C_local_buf):
for kp, i, j in T.grid(k_pack, warp_rows, warp_cols): for kp, i, j in T.grid(k_pack, warp_rows, warp_cols):
T.tvm_mfma( T.tvm_mfma(
mfma_suffix, mfma_suffix,
...@@ -340,15 +371,15 @@ class MatrixCoreIntrinEmitter: ...@@ -340,15 +371,15 @@ class MatrixCoreIntrinEmitter:
compute_b_dtype, compute_b_dtype,
compute_out_dtype, compute_out_dtype,
B_local_buf.data, B_local_buf.data,
((j * k_pack + kp) * local_size_b) // local_size_b, (b_local_stride + (j * k_pack + kp) * local_size_b) // local_size_b,
A_local_buf.data, A_local_buf.data,
((i * k_pack + kp) * local_size_a) // local_size_a, (a_local_stride + (i * k_pack + kp) * local_size_a) // local_size_a,
C_local_buf.data, C_local_buf.data,
(i * warp_cols * local_size_out + j * local_size_out) // local_size_out, (i * warp_cols * local_size_out + j * local_size_out) // local_size_out,
dtype=compute_out_dtype, dtype=compute_out_dtype,
) )
return _warp_mma(A_local_buf, B_local_buf, C_local_buf) return _warp_mfma(A_local_buf, B_local_buf, C_local_buf)
def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None): def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None):
block_row_warps = self.block_row_warps block_row_warps = self.block_row_warps
...@@ -356,8 +387,7 @@ class MatrixCoreIntrinEmitter: ...@@ -356,8 +387,7 @@ class MatrixCoreIntrinEmitter:
warp_rows = self.warp_rows warp_rows = self.warp_rows
warp_cols = self.warp_cols warp_cols = self.warp_cols
local_size_out = self.local_size_out local_size_out = self.local_size_out
current_frame = T.KernelLaunchFrame.Current() thread_binding = self.get_thread_binding()
thread_binding = current_frame.get_thread_binding()
is_global = pid_m is not None and pid_n is not None is_global = pid_m is not None and pid_n is not None
BLOCK_M = block_row_warps * warp_rows BLOCK_M = block_row_warps * warp_rows
BLOCK_N = block_col_warps * warp_cols BLOCK_N = block_col_warps * warp_cols
...@@ -366,7 +396,7 @@ class MatrixCoreIntrinEmitter: ...@@ -366,7 +396,7 @@ class MatrixCoreIntrinEmitter:
assert C_buf_dims in {2, 4}, "C_buf should be 2D or 4D" assert C_buf_dims in {2, 4}, "C_buf should be 2D or 4D"
# STS # STS
# MMA Store must be in simulated instead of TVM Intrins # MFMA Store must be in simulated instead of TVM Intrins
# As TVM Intrins is like a hack that the threadIdx.x should be always # As TVM Intrins is like a hack that the threadIdx.x should be always
# equal to the warp_size # equal to the warp_size
@T.macro @T.macro
...@@ -400,6 +430,217 @@ class MatrixCoreIntrinEmitter: ...@@ -400,6 +430,217 @@ class MatrixCoreIntrinEmitter:
thread_binding) if is_global else _warp_stmatrix_shared( thread_binding) if is_global else _warp_stmatrix_shared(
C_local_buf, C_buf, thread_binding) C_local_buf, C_buf, thread_binding)
def make_mfma_load_layout(self,
local_buf: Buffer,
matrix: Literal["A", "B"] = "A") -> T.Fragment:
"""
Create a layout function for storing MFMA results into a fragment buffer.
Parameters
----------
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.utils import is_fragment
assert matrix in ["A", "B"], "matrix should be either A or B"
matrix_is_a: bool = matrix == "A"
matrix_is_b: bool = matrix == "B"
transposed = self.a_transposed if matrix_is_a else self.b_transposed
# s represents spatial axis
# r represents reduction axis
# sr represents the two dims are spatial + reduction
# rs represents the two dims are reduction + spatial
# sr also can represent a non-transposed basic layout
# then rs also can represent a transposed basic layout
transform_func_sr_a: Callable = None
transform_func_sr_b: Callable = None
k_dim = self.k_dim * self.k_pack
if k_dim == 4:
transform_func_sr_a = shared_16x4_to_local_64x1_layout_A
transform_func_sr_b = shared_16x4_to_local_64x1_layout_A
elif k_dim == 16:
transform_func_sr_a = shared_16x16_to_local_64x4_layout_A
transform_func_sr_b = shared_16x16_to_local_64x4_layout_A
elif k_dim == 32:
transform_func_sr_a = shared_16x32_to_local_64x8_layout_A
transform_func_sr_b = shared_16x32_to_local_64x8_layout_A
elif k_dim == 64:
transform_func_sr_a = shared_16x64_to_local_64x16_layout_A
transform_func_sr_b = shared_16x64_to_local_64x16_layout_A
else:
raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently")
is_sr_conditions = [False]
is_sr_conditions.append(matrix_is_a and not transposed)
is_sr_conditions.append(matrix_is_b and transposed)
is_sr_axis_order = any(is_sr_conditions)
transform_func: Callable = None
if matrix_is_a:
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
j, i)
elif matrix_is_b:
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(
j, i)
else:
raise ValueError(f"Unsupported matrix {matrix}")
assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}"
if matrix_is_a:
micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k
else:
micro_size_r, micro_size_s = self.micro_size_k, self.micro_size_y
block_row_warps, block_col_warps = (
self.block_row_warps,
self.block_col_warps,
)
inverse_mfma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
def forward_thread(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
"""
lane_id, _ = inverse_mfma_load_layout.map_indices([i, j])
return lane_id
def forward_index(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
"""
_, local_id = inverse_mfma_load_layout.map_indices([i, j])
return local_id
base_fragment = T.Fragment(
[micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s],
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
)
warp_rows, warp_cols = self.warp_rows, self.warp_cols
chunk = self.chunk
warp_s = warp_rows if matrix_is_a else warp_cols
warp_r = chunk // micro_size_r
block_s = block_row_warps if matrix_is_a else block_col_warps
replicate = block_col_warps if matrix_is_a else block_row_warps
if is_sr_axis_order:
warp_fragment = base_fragment.repeat([warp_s, warp_r],
repeat_on_thread=False,
lower_dim_first=False)
if matrix_is_a:
block_fragment = warp_fragment.repeat([block_s, 1],
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1],
repeat_on_thread=True,
lower_dim_first=True)
else:
raise ValueError(f"Unsupported matrix type {matrix}")
else:
warp_fragment = base_fragment.repeat([warp_r, warp_s],
repeat_on_thread=False,
lower_dim_first=True)
if matrix_is_a:
block_fragment = warp_fragment.repeat([1, block_s],
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s],
repeat_on_thread=True,
lower_dim_first=True)
else:
raise ValueError(f"Unsupported matrix type {matrix}")
return block_fragment
def make_mfma_store_layout(self, local_buf: Buffer) -> T.Fragment:
"""
Create a layout function for storing MFMA results into a fragment buffer.
Parameters
----------
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.utils import is_fragment
shape = local_buf.shape
inverse_mfma_store_layout = self.get_store_index_map(inverse=True)
assert is_fragment(local_buf), "local_buf must be a fragment"
micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y
local_size_out = self.local_size_out
block_row_warps, block_col_warps = self.block_row_warps, self.block_col_warps
warp_rows, warp_cols = self.warp_rows, self.warp_cols
warp_size = self.WARP_SIZE
is_m_first = self.is_m_first
def forward_thread(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
map them to a thread index according to `inverse_mfma_store_layout`.
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i, block_j = (i // micro_size_x) // warp_rows, (j // micro_size_y) // warp_cols
# upper bounds of mfma_i and mfma_j are micro_size_x and micro_size_y
mfma_i, mfma_j = i % micro_size_x, j % micro_size_y
lane_id, _ = inverse_mfma_store_layout.map_indices([mfma_i, mfma_j])
if is_m_first:
thread_id = block_i * (block_col_warps * warp_cols) + block_j * warp_size + lane_id
else:
thread_id = block_j * (block_row_warps * warp_size) + block_i * warp_size + lane_id
return thread_id
def forward_index(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
map them to a local index in a single thread according
to `inverse_mfma_store_layout`.
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i, warp_j = (i // micro_size_x) % warp_rows, (j // micro_size_y) % warp_cols
# upper bounds of mfma_i and mfma_j are micro_size_x and micro_size_y
mfma_i, mfma_j = i % micro_size_x, j % micro_size_y
_, local_id = inverse_mfma_store_layout.map_indices([mfma_i, mfma_j])
return warp_i * (warp_cols * local_size_out) + warp_j * local_size_out + local_id
return T.Fragment(
shape,
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
)
class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
......
...@@ -8,6 +8,7 @@ import tvm.ffi ...@@ -8,6 +8,7 @@ import tvm.ffi
from tilelang.ir import GemmWarpPolicy from tilelang.ir import GemmWarpPolicy
from .gemm_mma import GemmMMA from .gemm_mma import GemmMMA
from .gemm_wgmma import GemmWGMMA from .gemm_wgmma import GemmWGMMA
from .gemm_mfma import GemmMFMA
from tilelang import _ffi_api from tilelang import _ffi_api
...@@ -28,14 +29,18 @@ def gemm_py_lower(gemm_py, layout_map, target, thread_bounds, thread_var): ...@@ -28,14 +29,18 @@ def gemm_py_lower(gemm_py, layout_map, target, thread_bounds, thread_var):
# same definition with src/op/gemm_py.h # same definition with src/op/gemm_py.h
class GemmInst(IntEnum): class GemmInst(IntEnum):
MMA = 0 MMA = 0
WGMMMA = 1 WGMMA = 1
MFMA = 2 TCGEN5MMA = 2
MFMA = 3
def is_mma(self) -> bool: def is_mma(self) -> bool:
return self == GemmInst.MMA return self == GemmInst.MMA
def is_wgmma(self) -> bool: def is_wgmma(self) -> bool:
return self == GemmInst.WGMMMA return self == GemmInst.WGMMA
def is_tcgen5mma(self) -> bool:
return self == GemmInst.TCGEN5MMA
def is_mfma(self) -> bool: def is_mfma(self) -> bool:
return self == GemmInst.MFMA return self == GemmInst.MFMA
...@@ -115,6 +120,8 @@ class GemmPy(Node, Scriptable): ...@@ -115,6 +120,8 @@ class GemmPy(Node, Scriptable):
elif gemm_inst.is_wgmma(): elif gemm_inst.is_wgmma():
return GemmWGMMA return GemmWGMMA
elif gemm_inst.is_mfma(): elif gemm_inst.is_mfma():
raise NotImplementedError("MFMA is not implemented") return GemmMFMA
elif gemm_inst.is_tcgen5mma():
raise NotImplementedError("TCGEN5MMA is not implemented")
else: else:
raise ValueError(f"Unsupported GEMM instruction: {gemm_inst}") raise ValueError(f"Unsupported GEMM instruction: {gemm_inst}")
from .gemm_base import GemmBase
from tilelang.layout import make_swizzled_layout
from tilelang.intrinsics.mfma_macro_generator import (
MatrixCoreIntrinEmitter,)
from tilelang.utils.language import is_shared, is_fragment
from tilelang import tvm as tvm
from tvm.target import Target
from tvm import tir
from tilelang import language as T
from tilelang.transform.simplify import _Simplify
class GemmMFMA(GemmBase):
def infer_layout(self, target: Target, thread_nums: int):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mfma_emitter = MatrixCoreIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
)
if self.is_gemm_ss():
return {
self.A: make_swizzled_layout(self.A),
self.B: make_swizzled_layout(self.B),
self.C: mfma_emitter.make_mfma_store_layout(self.C),
}
elif self.is_gemm_sr():
return {
self.A: make_swizzled_layout(self.A),
self.B: mfma_emitter.make_mfma_load_layout(self.B, matrix="B"),
self.C: mfma_emitter.make_mfma_store_layout(self.C),
}
elif self.is_gemm_rs():
return {
self.A: mfma_emitter.make_mfma_load_layout(self.A, matrix="A"),
self.B: make_swizzled_layout(self.B),
self.C: mfma_emitter.make_mfma_store_layout(self.C),
}
elif self.is_gemm_rr():
return {
self.A: mfma_emitter.make_mfma_load_layout(self.A, matrix="A"),
self.B: mfma_emitter.make_mfma_load_layout(self.B, matrix="B"),
self.C: mfma_emitter.make_mfma_store_layout(self.C),
}
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mfma_emitter = MatrixCoreIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
thread_var=thread_var,
)
in_dtype = self.in_dtype
warp_rows = mfma_emitter.warp_rows
warp_cols = mfma_emitter.warp_cols
local_size_a = mfma_emitter.local_size_a
local_size_b = mfma_emitter.local_size_b
block_K = mfma_emitter.chunk
micro_size_k = mfma_emitter.micro_size_k
A_shared = self.A
B_shared = self.B
C_local = self.C
assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
if self.is_gemm_ss():
@T.prim_func
def _gemm_ssr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Matrix Core mfma ops,
accumulating into C_local.
"""
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mfma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)
# Load B into fragment
mfma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
# Perform Matrix Multiplication
mfma_emitter.mfma(A_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_sr():
B_local = self.B
@T.prim_func
def _gemm_srr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Matrix Core mfma ops,
accumulating into C_local.
"""
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mfma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)
# Perform Matrix Multiplication
mfma_emitter.mfma(A_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
# alloc_buffers body
# insert into parent block
return _Simplify(_gemm_srr, inline_let=True)
elif self.is_gemm_rs():
A_local = self.A
@T.prim_func
def _gemm_rsr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Matrix Core mfma ops,
accumulating into C_local.
"""
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load B into fragment
mfma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
# Perform Matrix Multiplication
mfma_emitter.mfma(A_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
elif self.is_gemm_rr():
A_local = self.A
B_local = self.B
@T.prim_func
def _gemm_rsr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Matrix Core mfma ops,
accumulating into C_local.
"""
for ki in T.serial(0, (block_K // micro_size_k)):
# Perform Matrix Multiplication
mfma_emitter.mfma(A_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def is_gemm_ss(self) -> bool:
return is_shared(self.A) and is_shared(self.B)
def is_gemm_sr(self) -> bool:
return is_shared(self.A) and is_fragment(self.B)
def is_gemm_rs(self) -> bool:
return is_fragment(self.A) and is_shared(self.B)
def is_gemm_rr(self) -> bool:
return is_fragment(self.A) and is_fragment(self.B)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment