Unverified Commit 91a7bb2b authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[TileOp] Introduce a experimental python defined `T.gemm_v2` (#793)

* Refactor GEMM and GEMM-SP operations to enhance clarity and maintainability

- Removed deprecated prime factorization functions from `gemm.cc` and `gemm_sp.cc`.
- Introduced a new `GemmWarpPolicy` class to manage warp policy attributes and methods, improving encapsulation.
- Updated reflection methods to include the new policy structure, ensuring proper registration and introspection capabilities.
- Enhanced `GetArchInt` function in `utils.cc` for better readability and type safety.
- Added new `gemm_v2` function in `gemm.py` for improved GEMM operation with additional parameters and checks.

* Refactor GEMM and frontend legalize operations for improved clarity and functionality

- Updated `gemm_py.h` to include the correct header for GEMM operations.
- Renamed `FrontendLegalizer` class to `LetInliner` and updated related methods to reflect this change, enhancing code clarity.
- Modified the pass function from `FrontendLegalize` to `LetInline` for better alignment with its purpose.
- Updated test cases to utilize the new `gemm_v2` function and adjusted the testing framework for improved output and clarity.
- Removed obsolete test file `test_tilelang_transform_frontend_legalize.py` to streamline the test suite.
- Enhanced the `LowerAndLegalize` function to utilize the new `LetInline` pass, improving the overall transformation process.

* Enhance CUDA code generation and testing for GEMM operations

- Added indentation printing in `codegen_cuda.cc` for improved assembly code formatting.
- Updated `test_tilelang_tilelibrary_gemm.py` to include additional GEMM test cases and shared memory allocation with specified scope.
- Introduced new `matmul_sr` and `run_gemm_sr` functions for GEMM operations with shared and fragment memory layouts.
- Refactored layout inference in `mma_macro_generator.py` to improve clarity and correctness in shared memory handling.
- Enhanced `gemm/__init__.py` to support new GEMM operation combinations and layout inference logic.

These changes improve the clarity, functionality, and testing coverage of GEMM operations in the TileLang framework.

* Refactor GEMM layout and testing for improved clarity and functionality

- Updated `gemm_layouts.cc` to enhance the layout generation logic for transposed and non-transposed GEMM operations.
- Renamed and modified functions in `test_tilelang_tilelibrary_gemm.py` to reflect changes in GEMM function signatures and improve test coverage.
- Introduced new GEMM operation combinations in `gemm/__init__.py` to support additional layouts and configurations.
- Enhanced layout inference in `mma_layout.py` and `mma_macro_generator.py` for better handling of shared memory layouts.

These changes improve the clarity, functionality, and testing coverage of GEMM operations in the TileLang framework.

* Refactor GEMM layout and Python integration for improved functionality

- Updated `gemm_layouts.cc` to correct the order of layout replication and repetition for transposed and non-transposed GEMM operations.
- Enhanced `gemm_py.cc` to handle block realization more robustly, ensuring correct assignment of global symbols and block attributes.
- Refactored `inject_pipeline.cc` to streamline buffer read/write region handling, improving clarity and maintainability.
- Cleaned up test cases in `test_tilelang_tilelibrary_gemm.py` by removing unnecessary print statements and adjusting function calls for better test execution flow.

These changes enhance the clarity, functionality, and robustness of GEMM operations and their testing in the TileLang framework.

* Refactor GEMM layout and testing for improved clarity and functionality

- Updated `gemm_layouts.cc` to enhance layout generation logic for transposed and non-transposed GEMM operations.
- Improved block realization handling in `gemm_py.cc` for better assignment of global symbols.
- Streamlined buffer read/write region handling in `inject_pipeline.cc` for clarity.
- Enhanced test cases in `test_tilelang_tilelibrary_gemm.py` by adjusting function calls and adding new GEMM operation combinations.

These changes improve the clarity, functionality, and robustness of GEMM operations and their testing in the TileLang framework.

* tfloat32 support.

* lint fix

* lint fix

* Refactor shared memory allocation in GEMM tests

- Removed unnecessary scope specification in shared memory allocation for matrices A and B in `test_tilelang_tilelibrary_gemm.py`.
- This change simplifies the allocation process and aligns with the updated GEMM function signatures.
parent 9fd6bb30
......@@ -2,25 +2,38 @@ import tilelang.language as T
from typing import Union, Tuple, Optional, Literal, Callable
from tilelang.common import TransformKind
from tvm import DataType
from tvm.tir import PrimExpr, IndexMap, Buffer
from tvm.tir import PrimExpr, IndexMap, Buffer, Var
from tvm.runtime import convert
from .utils import (
mma_store_index_map,
get_ldmatrix_offset,
)
from tilelang.utils import is_fragment
from tilelang.intrinsics.mma_layout import (
shared_16x8_to_mma_32x4_layout_sr_a,
shared_16x8_to_mma_32x4_layout_sr_b,
shared_16x16_to_mma_32x8_layout_sr_a,
shared_16x16_to_mma_32x8_layout_sr_b,
shared_16x32_to_mma_32x16_layout_sr_a,
shared_16x32_to_mma_32x16_layout_sr_b,
mma_load_a_32x4_to_shared_16x8_layout,
mma_load_b_32x4_to_shared_16x8_layout,
mma_load_a_32x16_to_shared_16x32_layout,
mma_load_b_32x16_to_shared_16x32_layout,
)
lift = convert
# TODO(lei): Add Typing for this file
class TensorCoreIntrinEmitter(object):
"""
To eliminate Python syntax within TIR Macro.
"""
M_DIM = 16
N_DIM = 16
# use lowercase as n_dim can be dynamic
# the smallest instructions can be m16n8k16, so the n_dim can also be 8
n_dim = 16
WARP_SIZE = 32
dtype_abbrv = {
"float16": "fp16",
......@@ -50,6 +63,7 @@ class TensorCoreIntrinEmitter(object):
reduce_k: int = 1,
num_elems_per_byte: int = 1,
is_m_first: Optional[bool] = False,
thread_var: Optional[Var] = None,
):
self.a_dtype = a_dtype
self.b_dtype = b_dtype
......@@ -64,16 +78,15 @@ class TensorCoreIntrinEmitter(object):
self.chunk = chunk
self._initialize_k_dim(a_dtype)
self._initialize_abbrev(a_dtype, b_dtype, accum_dtype)
self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE)
self._initialize_micro_size(self.M_DIM, self.k_dim)
self._initialize_local_size(self.M_DIM, self.n_dim, self.k_dim, self.WARP_SIZE)
self._initialize_mma_prefix(self.k_dim)
self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim)
self._initialize_is_m_first(is_m_first)
self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y
self.reduce_k = reduce_k
self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k
self.num_elems_per_byte = num_elems_per_byte
self.thread_var = thread_var
if self.warp_rows == 0 or self.warp_cols == 0:
raise ValueError(
......@@ -96,22 +109,53 @@ class TensorCoreIntrinEmitter(object):
self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype]
def _initialize_mma_prefix(self, k_dim: int = 16):
if k_dim == 16:
if k_dim == 8:
# typically used for tfloat32
self.mma_prefix = "m16n8k8"
elif k_dim == 16:
# typically used for float16/bfloat16
self.mma_prefix = "m16n8k16"
elif k_dim == 32:
# typically used for int8/fp8
self.mma_prefix = "m16n8k32"
else:
raise ValueError("Unsupported k_dim")
def _initialize_micro_size(self, m_dim: int = 16, n_dim: int = 16, k_dim: int = 16):
def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16):
warp_row_tiles = self.warp_row_tiles
warp_col_tiles = self.warp_col_tiles
assert warp_row_tiles >= 16, f"warp_row_tiles must be greater than 16, got {warp_row_tiles}"
assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}"
assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}"
assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}"
self.warp_rows = warp_row_tiles // m_dim
if warp_col_tiles % 16 == 0:
self.n_dim = 16
self.micro_size_y = 16
self.warp_cols = warp_col_tiles // 16
else:
# must be divisible by 8
self.n_dim = 8
self.micro_size_y = 8
self.warp_cols = warp_col_tiles // 8
self.micro_size_x = m_dim
self.micro_size_y = n_dim
self.micro_size_k = k_dim
def _initialize_is_m_first(self, is_m_first: Optional[bool] = False):
if is_m_first is not None:
self.is_m_first = is_m_first
def get_thread_binding(self):
if self.thread_var is None:
current_frame = T.KernelLaunchFrame.Current()
assert current_frame is not None, "Must be called in a T.Kernel Frame"
return current_frame.get_thread_binding()
else:
return self.thread_var
def get_store_index_map(self, inverse: bool = False) -> IndexMap:
warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
index_map = IndexMap.from_func(mma_store_index_map, index_dtype="int32")
......@@ -165,9 +209,21 @@ class TensorCoreIntrinEmitter(object):
local_size_a = self.local_size_a
a_dtype = self.a_dtype
a_transposed = self.a_transposed
# ldmatrix cannot be used for int8 + trans case.
ldmatrix_available = not (DataType(a_dtype).bits != 16 and a_transposed)
def mma_load_layout(i, j):
return i, j
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
if not ldmatrix_available:
if DataType(a_dtype).bits == 8:
mma_load_layout = mma_load_a_32x16_to_shared_16x32_layout
elif DataType(a_dtype).bits == 32:
mma_load_layout = mma_load_a_32x4_to_shared_16x8_layout
else:
raise ValueError(f"Unsupported dtype: {a_dtype}")
thread_binding = self.get_thread_binding()
@T.macro
def _warp_ldmatrix_a(
......@@ -179,20 +235,28 @@ class TensorCoreIntrinEmitter(object):
):
stride = A_shared_buf.shape[-1]
tx, _, warp_m = self.extract_thread_binding(thread_binding)
trans = self.a_transposed
for i in T.serial(warp_rows):
T.ptx_ldmatrix(
a_dtype,
T.bool(False),
4,
".b16",
A_local_buf.data,
i * local_size_a,
T.address_of(A_shared_buf[
warp_m * warp_row_tiles + i * micro_size_x,
rk * chunk + ki * micro_size_k,
]),
get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed),
)
# Assign A_shared_buf_elem
wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k
A_shared_buf_elem = A_shared_buf[wk, wi] if a_transposed else A_shared_buf[wi, wk]
if ldmatrix_available:
T.ptx_ldmatrix(
a_dtype,
T.bool(trans),
4,
".b16",
A_local_buf.data,
i * local_size_a,
T.address_of(A_shared_buf_elem),
get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed),
)
else:
for j in T.serial(local_size_a):
mi, mk = mma_load_layout(tx, j)
A_local_buf[i * local_size_a + j] = A_shared_buf[wk + mk, wi + mi]
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
......@@ -209,8 +273,21 @@ class TensorCoreIntrinEmitter(object):
local_size_b = self.local_size_b
b_dtype = self.b_dtype
b_transposed = self.b_transposed
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
thread_binding = self.get_thread_binding()
replicate_b = (self.n_dim == 16)
# ldmatrix cannot be used for int8 + trans case.
ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed)
def mma_load_layout(i, j):
return i, j
if not ldmatrix_available:
if DataType(b_dtype).bits == 8:
mma_load_layout = mma_load_b_32x16_to_shared_16x32_layout
elif DataType(b_dtype).bits == 32:
mma_load_layout = mma_load_b_32x4_to_shared_16x8_layout
else:
raise ValueError(f"Unsupported dtype: {b_dtype}")
@T.macro
def _warp_ldmatrix_b(
......@@ -222,25 +299,36 @@ class TensorCoreIntrinEmitter(object):
):
stride = B_shared_buf.shape[-1]
tx, warp_n, _ = self.extract_thread_binding(thread_binding)
trans = not b_transposed
for j in T.serial(warp_cols):
for i in T.serial(warp_cols):
# Assign B_shared_elem
ri, rj = (
warp_n * warp_col_tiles + j * micro_size_y,
wi, wk = (
warp_n * warp_col_tiles + i * micro_size_y,
rk * chunk + ki * micro_size_k,
)
B_shared_elem = B_shared_buf[ri, rj]
T.ptx_ldmatrix(
b_dtype,
T.bool(False), # TODO(lei): should be optimized
4,
".b16",
B_local_buf.data,
j * local_size_b,
T.address_of(B_shared_elem),
get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed),
)
if ldmatrix_available:
B_shared_buf_elem = B_shared_buf[wi, wk] if b_transposed else B_shared_buf[wk,
wi]
T.ptx_ldmatrix(
b_dtype,
T.bool(trans),
4 if replicate_b else 2,
".b16",
B_local_buf.data,
i * local_size_b,
T.address_of(B_shared_buf_elem),
get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed),
)
else:
# load 16x32 data from shared buffer to local buffer
# must be transposed.
for j in T.serial(local_size_b):
mi, mk = mma_load_layout(tx, j)
B_local_buf[i * local_size_b + j] = B_shared_buf[wk + mk, wi + mi]
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
......@@ -259,6 +347,7 @@ class TensorCoreIntrinEmitter(object):
accum_dtype = self.accum_dtype
accum_dtype_abbrv = self.accum_dtype_abbrv
mma_prefix = self.mma_prefix
replicate_b = (self.n_dim == 16)
a_is_fragment = is_fragment(A_local_buf)
b_is_fragment = is_fragment(B_local_buf)
......@@ -282,25 +371,26 @@ class TensorCoreIntrinEmitter(object):
b_local_stride + j * local_size_b,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out,
T.bool(False),
)
T.ptx_mma(
accum_dtype,
mma_prefix,
"row",
"col",
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
a_local_stride + i * local_size_a,
B_local_buf.data,
b_local_stride + j * local_size_b + lift(local_size_b) // 2,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2,
T.bool(False),
T.bool(False), # saturate
)
if replicate_b:
T.ptx_mma(
accum_dtype,
mma_prefix,
"row",
"col",
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
a_local_stride + i * local_size_a,
B_local_buf.data,
b_local_stride + j * local_size_b + lift(local_size_b) // 2,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out +
lift(local_size_out) // 2,
T.bool(False), # saturate
)
return _warp_mma(A_local_buf, B_local_buf, C_local_buf)
......@@ -314,12 +404,11 @@ class TensorCoreIntrinEmitter(object):
is_global = pid_m is not None and pid_n is not None
BLOCK_M = block_row_warps * warp_rows
BLOCK_N = block_col_warps * warp_cols
M_DIM, N_DIM = self.M_DIM, self.N_DIM
M_DIM, n_dim = self.M_DIM, self.n_dim
C_buf_dims = len(C_buf.shape)
assert C_buf_dims in {2, 4}, "C_buf should be 2D or 4D"
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
thread_binding = self.get_thread_binding()
# STS
# MMA Store must be in simulated instead of TVM Intrins
......@@ -335,7 +424,7 @@ class TensorCoreIntrinEmitter(object):
row, col = T.meta_var(mma_store_index_map(tx, local_id))
if C_buf_dims == 2:
C_buf[(warp_m * warp_rows + i) * M_DIM + row,
(warp_n * warp_cols + j) * N_DIM +
(warp_n * warp_cols + j) * n_dim +
col] = C_local_buf[i * (warp_cols * local_size_out) +
j * local_size_out + local_id]
else:
......@@ -353,7 +442,7 @@ class TensorCoreIntrinEmitter(object):
row, col = T.meta_var(mma_store_index_map(tx, local_id))
C_buf[
(pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row,
(pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM + col,
(pid_n * BLOCK_N + warp_n * warp_cols + j) * n_dim + col,
] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out +
local_id]
......@@ -385,42 +474,55 @@ class TensorCoreIntrinEmitter(object):
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.utils import is_fragment
from tilelang.intrinsics.mma_layout import (
shared_16x16_to_mma_32x8_layout_sr,
shared_16x16_to_mma_32x8_layout_rs,
shared_16x32_to_mma_32x16_layout,
shared_32x16_to_mma_32x16_layout,
)
assert matrix in ["A", "B"], "matrix should be either A or B"
dtype = self.a_dtype if matrix == "A" else self.b_dtype
matrix_is_a: bool = matrix == "A"
matrix_is_b: bool = matrix == "B"
dtype = self.a_dtype if matrix_is_a else self.b_dtype
dtype_bits = DataType(dtype).bits
transposed = self.a_transposed
assert transposed is False, "transposed is not supported yet"
transposed = self.a_transposed if matrix_is_a else self.b_transposed
# s represents spatial axis
# r represents reduction axis
# sr represents the two dims are spatial + reduction
# rs represents the two dims are reduction + spatial
transform_func_sr: Callable = None
transform_func_rs: Callable = None
if dtype_bits == 16:
transform_func_sr = shared_16x16_to_mma_32x8_layout_sr
transform_func_rs = shared_16x16_to_mma_32x8_layout_rs
# sr also can represent a non-transposed basic layout
# then rs also can represent a transposed basic layout
transform_func_sr_a: Callable = None
transform_func_sr_b: Callable = None
if dtype_bits == 32:
...
transform_func_sr_a = shared_16x8_to_mma_32x4_layout_sr_a
transform_func_sr_b = shared_16x8_to_mma_32x4_layout_sr_b
elif dtype_bits == 16:
transform_func_sr_a = shared_16x16_to_mma_32x8_layout_sr_a
transform_func_sr_b = shared_16x16_to_mma_32x8_layout_sr_b
elif dtype_bits == 8:
transform_func_sr = shared_16x32_to_mma_32x16_layout
transform_func_rs = shared_32x16_to_mma_32x16_layout
transform_func_sr_a = shared_16x32_to_mma_32x16_layout_sr_a
transform_func_sr_b = shared_16x32_to_mma_32x16_layout_sr_b
else:
raise ValueError(f"Unsupported dtype {dtype}")
is_sr_conditions = [False]
is_sr_conditions.append(matrix == "A" and not transposed)
is_sr_conditions.append(matrix == "B" and transposed)
is_sr_conditions.append(matrix_is_a and not transposed)
is_sr_conditions.append(matrix_is_b and transposed)
is_sr_axis_order = any(is_sr_conditions)
transform_func: Callable = transform_func_sr if is_sr_axis_order else transform_func_rs
# the layout of mma.sync is row.col.
# so the b matrix expected a transposed basic layout
transform_func: Callable = None
if matrix_is_a:
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
j, i)
elif matrix_is_b:
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(
j, i)
else:
raise ValueError(f"Unsupported matrix {matrix}")
assert is_fragment(local_buf), "local_buf must be a fragment, but got {}".format(
local_buf.scope())
if matrix == "A":
if matrix_is_a:
micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k
else:
micro_size_r, micro_size_s = self.micro_size_k, self.micro_size_y
......@@ -429,10 +531,7 @@ class TensorCoreIntrinEmitter(object):
self.block_row_warps,
self.block_col_warps,
)
warp_rows, warp_cols = self.warp_rows, self.warp_cols
warp_s = warp_rows if matrix == "A" else warp_cols
chunk = self.chunk
transform_func = transform_func
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
def forward_thread(i: int, j: int) -> int:
......@@ -450,18 +549,48 @@ class TensorCoreIntrinEmitter(object):
return local_id
base_fragment = T.Fragment(
[micro_size_r, micro_size_s],
[micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s],
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
)
warp_fragment = base_fragment.repeat([block_row_warps, 1],
repeat_on_thread=True).replicate(block_col_warps)
block_fragment = warp_fragment.repeat([warp_s, chunk // micro_size_r],
repeat_on_thread=False,
lower_dim_first=False)
print(f"base_fragment: {base_fragment}")
print(f"warp_fragment: {warp_fragment}")
print(f"block_fragment: {block_fragment}")
warp_rows, warp_cols = self.warp_rows, self.warp_cols
chunk = self.chunk
warp_s = warp_rows if matrix_is_a else warp_cols
warp_r = chunk // micro_size_r
block_s = block_row_warps if matrix_is_a else block_col_warps
replicate = block_col_warps if matrix_is_a else block_row_warps
if is_sr_axis_order:
warp_fragment = base_fragment.repeat([warp_s, warp_r],
repeat_on_thread=False,
lower_dim_first=False)
if matrix_is_a:
block_fragment = warp_fragment.repeat([block_s, 1],
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1],
repeat_on_thread=True,
lower_dim_first=True)
else:
raise ValueError(f"Unsupported matrix type {matrix}")
else:
warp_fragment = base_fragment.repeat([warp_r, warp_s],
repeat_on_thread=False,
lower_dim_first=True)
if matrix_is_a:
block_fragment = warp_fragment.repeat([1, block_s],
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s],
repeat_on_thread=True,
lower_dim_first=True)
else:
raise ValueError(f"Unsupported matrix type {matrix}")
return block_fragment
def make_mma_store_layout(self, local_buf: Buffer) -> T.Fragment:
......@@ -632,8 +761,7 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
a_transposed = self.a_transposed
transform_kind_a = self.transform_kind_a
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
thread_binding = self.get_thread_binding()
@T.macro
def _warp_ldmatrix_a(
......@@ -740,8 +868,7 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
b_transposed = self.b_transposed
num_elems_per_byte = self.num_elems_per_byte
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
thread_binding = self.get_thread_binding()
@T.macro
def _warp_ldmatrix_b(
......
from tvm import DataType
from typing import Literal
from .mma_layout import (
ldmatrix_32x4_to_shared_16x8_layout_a,
ldmatrix_32x4_to_shared_16x8_layout_b,
ldmatrix_32x8_to_shared_16x16_layout,
ldmatrix_trans_32x8_to_shared_16x16_layout,
ldmatrix_16x32_to_shared_16x32_layout_a,
ldmatrix_16x32_to_shared_16x32_layout_b,
ldmatrix_32x16_to_shared_16x32_layout_a,
ldmatrix_32x16_to_shared_16x32_layout_b,
mma_store_32x8_to_shared_16x16_layout,
)
from .mfma_layout import (thread_id_shared_access_64x4_to_16x16_layout_C_n_m)
......@@ -26,7 +28,18 @@ def get_ldmatrix_offset(
):
assert matrix in ["A", "B"], "matrix should be either A or B"
dtype_bits = DataType(dtype).bits
if dtype_bits == 16:
if dtype_bits == 32:
if matrix == "B" and transposed:
transform_func = ldmatrix_32x4_to_shared_16x8_layout_b
new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
elif matrix == "A" and not transposed:
transform_func = ldmatrix_32x4_to_shared_16x8_layout_a
new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
else:
raise ValueError("ldmatrix only supports B transposed and A non-transposed for int8")
elif dtype_bits == 16:
transform_func = ldmatrix_32x8_to_shared_16x16_layout
transform_func_trans = ldmatrix_trans_32x8_to_shared_16x16_layout
if transposed:
......@@ -37,11 +50,11 @@ def get_ldmatrix_offset(
return new_row_idx * stride + new_col_idx
elif dtype_bits == 8:
if matrix == "B" and transposed:
transform_func = ldmatrix_16x32_to_shared_16x32_layout_b
transform_func = ldmatrix_32x16_to_shared_16x32_layout_b
new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
elif matrix == "A" and not transposed:
transform_func = ldmatrix_16x32_to_shared_16x32_layout_a
transform_func = ldmatrix_32x16_to_shared_16x32_layout_a
new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
else:
......
......@@ -2,6 +2,8 @@ from tilelang import tvm as tvm
from tvm.ir.base import Node
from tvm.runtime import Scriptable
import tvm.ffi
from tvm.target import Target
from tilelang import _ffi_api
@tvm.ffi.register_object("tl.Fill")
......@@ -26,7 +28,15 @@ class Conv2DIm2ColOp(Node, Scriptable):
@tvm.ffi.register_object("tl.GemmWarpPolicy")
class GemmWarpPolicy(Node, Scriptable):
...
policy_type: int
m_warp: int
n_warp: int
def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target,
is_wgmma: bool):
_ffi_api.GemmWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target,
is_wgmma)
return self.m_warp, self.n_warp
@tvm.ffi.register_object("tl.Gemm")
......
......@@ -43,7 +43,7 @@ from .allocate import (
alloc_barrier, # noqa: F401
)
from .copy import copy, c2d_im2col # noqa: F401
from .gemm import GemmWarpPolicy, gemm # noqa: F401
from .gemm import GemmWarpPolicy, gemm, gemm_v2 # noqa: F401
from .experimental.gemm_sp import gemm_sp # noqa: F401
from .fill import fill, clear # noqa: F401
from .reduce import (
......
......@@ -180,3 +180,180 @@ def gemm(
k_pack,
wg_wait,
)
# experimental currently, for fast compilation
def gemm_v2(
A: Union[tir.Buffer, tir.Var],
B: Union[tir.Buffer, tir.Var],
C: Union[tir.Buffer, tir.Var],
transpose_A: bool = False,
transpose_B: bool = False,
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
clear_accum: bool = False,
k_pack: int = 1,
wg_wait: int = 0,
):
"""Perform a General Matrix Multiplication (GEMM) operation.
This function computes C = A @ B where A and B can optionally be transposed.
The operation supports various warp policies and accumulation modes.
Args:
A (Union[tir.Buffer, tir.Var]): First input matrix
B (Union[tir.Buffer, tir.Var]): Second input matrix
C (Union[tir.Buffer, tir.Var]): Output matrix for results
transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False.
transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False.
policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square.
clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False.
k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1.
wg_wait (int, optional): Warp group wait count. Defaults to 0.
Returns:
tir.Call: A handle to the GEMM operation
Raises:
AssertionError: If the K dimensions of matrices A and B don't match
"""
def legalize_arguments(arg: Union[tir.Buffer, tir.Var]):
"""Convert let-bound variables to their corresponding buffers.
Args:
arg (Union[tir.Buffer, tir.Var]): Input argument to legalize
Returns:
Union[tir.Buffer, tir.Var]: The legalized argument
"""
if isinstance(arg, tir.Var) and T.has_let_value(arg):
return T.get_let_value(arg).buffer
return arg
A = legalize_arguments(A)
B = legalize_arguments(B)
C = legalize_arguments(C)
def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
if isinstance(object, tir.Buffer):
return object.shape
elif isinstance(object, tir.BufferRegion):
region = object.region
shape = []
for r in region:
shape.append(r.extent)
return shape
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
if isinstance(object, tir.Buffer):
strides = []
stride = 1
for s in reversed(object.shape):
strides.insert(0, stride)
stride *= s
return strides
elif isinstance(object, tir.BufferRegion):
buffer, _ = object.buffer, object.region
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
return strides
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
A_shape = retrieve_shape(A)
B_shape = retrieve_shape(B)
C_shape = retrieve_shape(C)
A_stride = retrieve_stride(A)
B_stride = retrieve_stride(B)
assert len(C_shape) == 2, "current only support C as a 2D tensor"
assert len(A_shape) >= 2, "current only support A as a 2D or higher-order tensor"
assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor"
if len(A_shape) > 2:
for i in range(len(A_shape) - 2):
assert A_shape[i] == 1, \
"current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
if len(B_shape) > 2:
for i in range(len(B_shape) - 2):
assert B_shape[i] == 1, \
"current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
M, N = C_shape
K = A_shape[-2] if transpose_A else A_shape[-1]
K_B = B_shape[-1] if transpose_B else B_shape[-2]
assert K == K_B, f"T.gemm K shape check failed: K_A = {K}, K_B = {K_B}"
stride_a = A_stride[-2]
stride_b = B_stride[-2]
def retrieve_ptr(object: Union[tir.Buffer, tir.BufferRegion],
access_type: str = "r") -> tir.PrimExpr:
if isinstance(object, tir.Buffer):
return object.access_ptr(access_type)
elif isinstance(object, tir.BufferRegion):
buffer, region = object.buffer, object.region
indices = []
for r in region:
indices.append(r.min)
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
offset = 0
# not offset the last two dimension
for i in range(len(indices) - 2):
offset += indices[i] * strides[i]
return buffer.access_ptr(access_mask=access_type, offset=offset)
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr:
"""Retrieve the offset of the buffer or buffer region."""
if isinstance(object, tir.Buffer):
return [0] * len(object.shape)
elif isinstance(object, tir.BufferRegion):
_, region = object.buffer, object.region
indices = []
for r in region:
indices.append(r.min)
return indices
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
A_offset = retrieve_offset(A)
B_offset = retrieve_offset(B)
assert A_offset[-2] == 0, "The offset of the first dimension of A must be 0"
assert B_offset[-2] == 0, "The offset of the first dimension of B must be 0"
offset_a = A_offset[-1]
offset_b = B_offset[-1]
Aptr = retrieve_ptr(A, "r")
Bptr = retrieve_ptr(B, "r")
Cptr = retrieve_ptr(C, "rw")
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.gemm_py"),
Aptr,
Bptr,
Cptr,
transpose_A,
transpose_B,
M,
N,
K,
policy,
clear_accum,
stride_a,
stride_b,
offset_a,
offset_b,
k_pack,
wg_wait,
)
......@@ -261,46 +261,54 @@ def Kernel(
def get_thread_binding(dim: int = 0) -> Var:
"""Returns the thread binding for the given dimension.
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_thread_binding(dim)
def get_thread_bindings() -> List[Var]:
"""Returns all three thread bindings.
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_thread_bindings()
def get_block_binding(dim: int = 0) -> Var:
"""Returns the block binding for the given dimension.
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_block_binding(dim)
def get_block_bindings() -> List[Var]:
"""Returns all three block bindings.
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_block_bindings()
def get_thread_extent(dim: int = 0) -> int:
"""Returns the thread extent for the given dimension.
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_thread_extent(dim)
def get_thread_extents() -> List[int]:
"""Returns all three thread extents.
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_thread_extents()
def get_block_extent(dim: int = 0) -> int:
"""Returns the block extent for the given dimension.
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_block_extent(dim)
def get_block_extents() -> List[int]:
"""Returns all three block extents.
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_block_extents()
......@@ -5,6 +5,8 @@ import tvm
from tilelang import _ffi_api
# Use a stable swizzled layout to ensure consistent memory access patterns.
# Swizzling should be enabled or disabled based on whether TMA (Tensor Memory Access) is applied.
def make_swizzled_layout(buffer: tvm.tir.Buffer):
assert len(buffer.shape) == 2
return _ffi_api.make_swizzled_layout(
......
......@@ -126,9 +126,17 @@ class Profiler:
if lhs is not None and rhs is not None:
# in case of numsplit template, the ref output may be None
# which means the value is invalid, so we skip the comparison
def is_float8(tensor: torch.Tensor) -> bool:
return tensor.dtype in {
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
}
torch_assert_close(
lhs,
rhs,
lhs if not is_float8(lhs) else lhs.to(torch.float32),
rhs if not is_float8(rhs) else rhs.to(torch.float32),
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio,
......
from .gemm import GemmPy # noqa: F401
from tilelang import tvm as tvm
from tvm import tir
from tilelang.utils.target import (
target_is_cuda,)
from tvm.target import Target
from tvm.ir.base import Node
from tvm.runtime import Scriptable
import tvm.ffi
from tilelang.ir import GemmWarpPolicy
from .gemm_mma import GemmMMA
@tvm.ffi.register_func("tl.gemm_py.infer_layout")
def gemm_py_infer_layout(gemm_py, target, thread_bounds):
thread_nums = thread_bounds.extent
return gemm_py.infer_layout(target, thread_nums)
@tvm.ffi.register_func("tl.gemm_py.lower")
def gemm_py_lower(gemm_py, target, thread_bounds, thread_var):
thread_nums = thread_bounds.extent
stmt = gemm_py.lower(target, thread_nums, thread_var)
return stmt
@tvm.ffi.register_object("tl.GemmPy")
class GemmPy(Node, Scriptable):
A: tir.Buffer
B: tir.Buffer
C: tir.Buffer
APtr: tir.PrimExpr
BPtr: tir.PrimExpr
CPtr: tir.PrimExpr
M: int
N: int
K: int
trans_A: bool
trans_B: bool
stride_A: int
stride_B: int
offset_A: int
offset_B: int
clear_accum: bool
k_pack: int
wg_wait: int
policy: GemmWarpPolicy
def infer_layout(self, target: Target, thread_nums: int):
if target_is_cuda(target):
# TODO(lei): Support more cuda architectures, now mma only
return GemmMMA(self).infer_layout(target, thread_nums)
else:
raise ValueError(f"Unsupported target: {target}")
def lower(self, target: Target, thread_nums: int, thread_var: tir.Var):
if target_is_cuda(target):
# TODO(lei): Support more cuda architectures, now mma only
# Now only implement ssr layout
return GemmMMA(self).lower(target, thread_nums, thread_var)
else:
raise ValueError(f"Unsupported target: {target}")
from dataclasses import dataclass
from tilelang import tvm as tvm
from tvm.target import Target
from tvm import tir
from tilelang.utils.language import is_shared, is_fragment
from tilelang.ir import GemmWarpPolicy
from tvm.ir.base import Node
@dataclass
class GemmBase(object):
gemm_node: Node
def infer_layout(self, target: Target, thread_nums: int):
raise NotImplementedError("infer_layout is not implemented")
def lower(self, target: Target, thread_nums: int, thread_var: tir.Var):
raise NotImplementedError("lower is not implemented")
def is_gemm_ss(self) -> bool:
return is_shared(self.A) and is_shared(self.B)
def is_gemm_sr(self) -> bool:
return is_shared(self.A) and is_fragment(self.B)
def is_gemm_rs(self) -> bool:
return is_fragment(self.A) and is_shared(self.B)
def is_gemm_rr(self) -> bool:
return is_fragment(self.A) and is_fragment(self.B)
@property
def M(self) -> int:
return self.gemm_node.M
@property
def N(self) -> int:
return self.gemm_node.N
@property
def K(self) -> int:
return self.gemm_node.K
@property
def trans_A(self) -> bool:
return self.gemm_node.trans_A
@property
def trans_B(self) -> bool:
return self.gemm_node.trans_B
@property
def in_dtype(self) -> str:
assert self.A.dtype == self.B.dtype, "A and B must have the same dtype"
return self.A.dtype
@property
def accum_dtype(self) -> str:
return self.C.dtype
@property
def chunk(self) -> int:
return self.A.shape[-2] if self.trans_A else self.A.shape[-1]
@property
def A(self) -> tir.Buffer:
return self.gemm_node.A
@property
def B(self) -> tir.Buffer:
return self.gemm_node.B
@property
def C(self) -> tir.Buffer:
return self.gemm_node.C
@property
def APtr(self) -> tir.PrimExpr:
return self.gemm_node.APtr
@property
def BPtr(self) -> tir.PrimExpr:
return self.gemm_node.BPtr
@property
def CPtr(self) -> tir.PrimExpr:
return self.gemm_node.CPtr
@property
def stride_A(self) -> int:
return self.gemm_node.stride_A
@property
def stride_B(self) -> int:
return self.gemm_node.stride_B
@property
def offset_A(self) -> int:
return self.gemm_node.offset_A
@property
def offset_B(self) -> int:
return self.gemm_node.offset_B
@property
def clear_accum(self) -> bool:
return self.gemm_node.clear_accum
@property
def k_pack(self) -> int:
return self.gemm_node.k_pack
@property
def wg_wait(self) -> int:
return self.gemm_node.wg_wait
@property
def policy(self) -> GemmWarpPolicy:
return self.gemm_node.policy
from .gemm_base import GemmBase
from tilelang.layout import make_swizzled_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
from tilelang.utils.language import is_shared, is_fragment
from tilelang import tvm as tvm
from tvm.target import Target
from tvm import tir
from tilelang import language as T
from tilelang.transform.simplify import _Simplify
class GemmMMA(GemmBase):
def infer_layout(self, target: Target, thread_nums: int):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
)
if self.is_gemm_ss():
return {
self.A: make_swizzled_layout(self.A),
self.B: make_swizzled_layout(self.B),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
elif self.is_gemm_sr():
return {
self.A: make_swizzled_layout(self.A),
self.B: mma_emitter.make_mma_load_layout(self.B, matrix="B"),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
elif self.is_gemm_rs():
return {
self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"),
self.B: make_swizzled_layout(self.B),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
elif self.is_gemm_rr():
return {
self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"),
self.B: mma_emitter.make_mma_load_layout(self.B, matrix="B"),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def lower(self, target: Target, thread_nums: int, thread_var: tir.Var):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
thread_var=thread_var,
)
in_dtype = self.in_dtype
warp_rows = mma_emitter.warp_rows
warp_cols = mma_emitter.warp_cols
local_size_a = mma_emitter.local_size_a
local_size_b = mma_emitter.local_size_b
block_K = mma_emitter.chunk
micro_size_k = mma_emitter.micro_size_k
A_shared = self.A
B_shared = self.B
C_local = self.C
if self.is_gemm_ss():
@T.prim_func
def _gemm_ssr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_sr():
B_local = self.B
@T.prim_func
def _gemm_srr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
# alloc_buffers body
# insert into parent block
return _Simplify(_gemm_srr, inline_let=True)
elif self.is_gemm_rs():
A_local = self.A
@T.prim_func
def _gemm_rsr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
elif self.is_gemm_rr():
A_local = self.A
B_local = self.B
@T.prim_func
def _gemm_rsr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
for ki in T.serial(0, (block_K // micro_size_k)):
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def is_gemm_ss(self) -> bool:
return is_shared(self.A) and is_shared(self.B)
def is_gemm_sr(self) -> bool:
return is_shared(self.A) and is_fragment(self.B)
def is_gemm_rs(self) -> bool:
return is_fragment(self.A) and is_shared(self.B)
def is_gemm_rr(self) -> bool:
return is_fragment(self.A) and is_fragment(self.B)
......@@ -2,7 +2,7 @@
# pylint: disable=invalid-name, unsupported-binary-operation
from . import _ffi_api
from .simplify import Simplify, simplify_prim_func # noqa: F401
from .simplify import Simplify, simplify_prim_func, LetInline # noqa: F401
from .pass_config import PassConfigKey # noqa: F401
from tilelang import tvm as tvm # noqa: F401
from tvm.ir.transform import PassContext # noqa: F401
......@@ -68,17 +68,6 @@ def InjectSoftwarePipeline():
return _ffi_api.InjectSoftwarePipeline() # type: ignore
def FrontendLegalize():
"""FrontendLegalize
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.FrontendLegalize() # type: ignore
def InjectAssumes():
"""Inject Assumes
......
......@@ -5,6 +5,17 @@ from typing import Union, Callable
from . import _ffi_api
def LetInline():
"""LetInline
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LetInline() # type: ignore
def Simplify(simplify_arguments: bool = False):
"""Simplify
......@@ -16,13 +27,24 @@ def Simplify(simplify_arguments: bool = False):
return _ffi_api.Simplify(simplify_arguments) # type: ignore
def _Simplify(stmt: Union[PrimFunc, IRModule]) -> Union[PrimFunc, IRModule]:
def _Simplify(stmt: Union[PrimFunc, IRModule],
inline_let: bool = False) -> Union[PrimFunc, IRModule]:
if isinstance(stmt, PrimFunc):
mod = Simplify(simplify_arguments=True)(IRModule.from_expr(stmt))
if inline_let:
mod = LetInline()(IRModule.from_expr(stmt))
mod = Simplify(simplify_arguments=True)(mod)
else:
mod = Simplify(simplify_arguments=True)(IRModule.from_expr(stmt))
assert len(mod.functions) == 1, "Simplify should return a single function"
return list(mod.functions.values()).pop()
elif isinstance(stmt, IRModule):
return Simplify(simplify_arguments=True)(stmt)
if inline_let:
mod = LetInline()(stmt)
mod = Simplify(simplify_arguments=True)(mod)
else:
mod = Simplify(simplify_arguments=True)(stmt)
assert len(mod.functions) == 1, "Simplify should return a single function"
return list(mod.functions.values()).pop()
else:
raise ValueError(f"Unsupported type: {type(stmt)}")
......@@ -37,6 +59,7 @@ def simplify_prim_func(func: Callable) -> Callable:
return wrapper
def apply_simplify(stmt: Union[PrimFunc, IRModule]) -> Union[PrimFunc, IRModule]:
def apply_simplify(stmt: Union[PrimFunc, IRModule],
inline_let: bool = False) -> Union[PrimFunc, IRModule]:
"""Apply Simplify pass to a PrimFunc or IRModule."""
return _Simplify(stmt)
return _Simplify(stmt, inline_let)
from typing import Literal, Union
from tilelang import tvm as tvm
from tilelang import _ffi_api
from tvm.target import Target
from tvm.contrib import rocm
from tilelang.contrib import nvcc
......@@ -81,3 +82,55 @@ def determine_target(target: Union[str, Target, Literal["auto"]] = "auto",
if return_object:
return Target(return_var)
return return_var
def target_is_cuda(target: Target) -> bool:
return _ffi_api.TargetIsCuda(target)
def target_is_hip(target: Target) -> bool:
return _ffi_api.TargetIsRocm(target)
def target_is_volta(target: Target) -> bool:
return _ffi_api.TargetIsVolta(target)
def target_is_turing(target: Target) -> bool:
return _ffi_api.TargetIsTuring(target)
def target_is_ampere(target: Target) -> bool:
return _ffi_api.TargetIsAmpere(target)
def target_is_hopper(target: Target) -> bool:
return _ffi_api.TargetIsHopper(target)
def target_is_sm120(target: Target) -> bool:
return _ffi_api.TargetIsSM120(target)
def target_is_cdna(target: Target) -> bool:
return _ffi_api.TargetIsCDNA(target)
def target_has_async_copy(target: Target) -> bool:
return _ffi_api.TargetHasAsyncCopy(target)
def target_has_ldmatrix(target: Target) -> bool:
return _ffi_api.TargetHasLdmatrix(target)
def target_has_stmatrix(target: Target) -> bool:
return _ffi_api.TargetHasStmatrix(target)
def target_has_bulk_copy(target: Target) -> bool:
return _ffi_api.TargetHasBulkCopy(target)
def target_get_warp_size(target: Target) -> int:
return _ffi_api.TargetGetWarpSize(target)
......@@ -113,9 +113,11 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer):
else:
return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype)
elif supply_type == TensorSupplyType.Uniform:
return torch.empty(*shape, device=device, dtype=dtype).uniform_(-1.0, 1.0)
return torch.empty(
*shape, device=device, dtype=torch.float32).uniform_(-1.0, 1.0).to(dtype)
elif supply_type == TensorSupplyType.Normal:
return torch.empty(*shape, device=device, dtype=dtype).normal_(-1.0, 1.0)
return torch.empty(
*shape, device=device, dtype=torch.float32).normal_(-1.0, 1.0).to(dtype)
elif supply_type == TensorSupplyType.Randn:
return torch.randn(*shape, device=device).to(dtype)
elif supply_type == TensorSupplyType.Zero:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment