Unverified Commit 91a7bb2b authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[TileOp] Introduce a experimental python defined `T.gemm_v2` (#793)

* Refactor GEMM and GEMM-SP operations to enhance clarity and maintainability

- Removed deprecated prime factorization functions from `gemm.cc` and `gemm_sp.cc`.
- Introduced a new `GemmWarpPolicy` class to manage warp policy attributes and methods, improving encapsulation.
- Updated reflection methods to include the new policy structure, ensuring proper registration and introspection capabilities.
- Enhanced `GetArchInt` function in `utils.cc` for better readability and type safety.
- Added new `gemm_v2` function in `gemm.py` for improved GEMM operation with additional parameters and checks.

* Refactor GEMM and frontend legalize operations for improved clarity and functionality

- Updated `gemm_py.h` to include the correct header for GEMM operations.
- Renamed `FrontendLegalizer` class to `LetInliner` and updated related methods to reflect this change, enhancing code clarity.
- Modified the pass function from `FrontendLegalize` to `LetInline` for better alignment with its purpose.
- Updated test cases to utilize the new `gemm_v2` function and adjusted the testing framework for improved output and clarity.
- Removed obsolete test file `test_tilelang_transform_frontend_legalize.py` to streamline the test suite.
- Enhanced the `LowerAndLegalize` function to utilize the new `LetInline` pass, improving the overall transformation process.

* Enhance CUDA code generation and testing for GEMM operations

- Added indentation printing in `codegen_cuda.cc` for improved assembly code formatting.
- Updated `test_tilelang_tilelibrary_gemm.py` to include additional GEMM test cases and shared memory allocation with specified scope.
- Introduced new `matmul_sr` and `run_gemm_sr` functions for GEMM operations with shared and fragment memory layouts.
- Refactored layout inference in `mma_macro_generator.py` to improve clarity and correctness in shared memory handling.
- Enhanced `gemm/__init__.py` to support new GEMM operation combinations and layout inference logic.

These changes improve the clarity, functionality, and testing coverage of GEMM operations in the TileLang framework.

* Refactor GEMM layout and testing for improved clarity and functionality

- Updated `gemm_layouts.cc` to enhance the layout generation logic for transposed and non-transposed GEMM operations.
- Renamed and modified functions in `test_tilelang_tilelibrary_gemm.py` to reflect changes in GEMM function signatures and improve test coverage.
- Introduced new GEMM operation combinations in `gemm/__init__.py` to support additional layouts and configurations.
- Enhanced layout inference in `mma_layout.py` and `mma_macro_generator.py` for better handling of shared memory layouts.

These changes improve the clarity, functionality, and testing coverage of GEMM operations in the TileLang framework.

* Refactor GEMM layout and Python integration for improved functionality

- Updated `gemm_layouts.cc` to correct the order of layout replication and repetition for transposed and non-transposed GEMM operations.
- Enhanced `gemm_py.cc` to handle block realization more robustly, ensuring correct assignment of global symbols and block attributes.
- Refactored `inject_pipeline.cc` to streamline buffer read/write region handling, improving clarity and maintainability.
- Cleaned up test cases in `test_tilelang_tilelibrary_gemm.py` by removing unnecessary print statements and adjusting function calls for better test execution flow.

These changes enhance the clarity, functionality, and robustness of GEMM operations and their testing in the TileLang framework.

* Refactor GEMM layout and testing for improved clarity and functionality

- Updated `gemm_layouts.cc` to enhance layout generation logic for transposed and non-transposed GEMM operations.
- Improved block realization handling in `gemm_py.cc` for better assignment of global symbols.
- Streamlined buffer read/write region handling in `inject_pipeline.cc` for clarity.
- Enhanced test cases in `test_tilelang_tilelibrary_gemm.py` by adjusting function calls and adding new GEMM operation combinations.

These changes improve the clarity, functionality, and robustness of GEMM operations and their testing in the TileLang framework.

* tfloat32 support.

* lint fix

* lint fix

* Refactor shared memory allocation in GEMM tests

- Removed unnecessary scope specification in shared memory allocation for matrices A and B in `test_tilelang_tilelibrary_gemm.py`.
- This change simplifies the allocation process and aligns with the updated GEMM function signatures.
parent 9fd6bb30
This diff is collapsed.
from tvm import DataType
from typing import Literal
from .mma_layout import (
ldmatrix_32x4_to_shared_16x8_layout_a,
ldmatrix_32x4_to_shared_16x8_layout_b,
ldmatrix_32x8_to_shared_16x16_layout,
ldmatrix_trans_32x8_to_shared_16x16_layout,
ldmatrix_16x32_to_shared_16x32_layout_a,
ldmatrix_16x32_to_shared_16x32_layout_b,
ldmatrix_32x16_to_shared_16x32_layout_a,
ldmatrix_32x16_to_shared_16x32_layout_b,
mma_store_32x8_to_shared_16x16_layout,
)
from .mfma_layout import (thread_id_shared_access_64x4_to_16x16_layout_C_n_m)
......@@ -26,7 +28,18 @@ def get_ldmatrix_offset(
):
assert matrix in ["A", "B"], "matrix should be either A or B"
dtype_bits = DataType(dtype).bits
if dtype_bits == 16:
if dtype_bits == 32:
if matrix == "B" and transposed:
transform_func = ldmatrix_32x4_to_shared_16x8_layout_b
new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
elif matrix == "A" and not transposed:
transform_func = ldmatrix_32x4_to_shared_16x8_layout_a
new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
else:
raise ValueError("ldmatrix only supports B transposed and A non-transposed for int8")
elif dtype_bits == 16:
transform_func = ldmatrix_32x8_to_shared_16x16_layout
transform_func_trans = ldmatrix_trans_32x8_to_shared_16x16_layout
if transposed:
......@@ -37,11 +50,11 @@ def get_ldmatrix_offset(
return new_row_idx * stride + new_col_idx
elif dtype_bits == 8:
if matrix == "B" and transposed:
transform_func = ldmatrix_16x32_to_shared_16x32_layout_b
transform_func = ldmatrix_32x16_to_shared_16x32_layout_b
new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
elif matrix == "A" and not transposed:
transform_func = ldmatrix_16x32_to_shared_16x32_layout_a
transform_func = ldmatrix_32x16_to_shared_16x32_layout_a
new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
else:
......
......@@ -2,6 +2,8 @@ from tilelang import tvm as tvm
from tvm.ir.base import Node
from tvm.runtime import Scriptable
import tvm.ffi
from tvm.target import Target
from tilelang import _ffi_api
@tvm.ffi.register_object("tl.Fill")
......@@ -26,7 +28,15 @@ class Conv2DIm2ColOp(Node, Scriptable):
@tvm.ffi.register_object("tl.GemmWarpPolicy")
class GemmWarpPolicy(Node, Scriptable):
...
policy_type: int
m_warp: int
n_warp: int
def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target,
is_wgmma: bool):
_ffi_api.GemmWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target,
is_wgmma)
return self.m_warp, self.n_warp
@tvm.ffi.register_object("tl.Gemm")
......
......@@ -43,7 +43,7 @@ from .allocate import (
alloc_barrier, # noqa: F401
)
from .copy import copy, c2d_im2col # noqa: F401
from .gemm import GemmWarpPolicy, gemm # noqa: F401
from .gemm import GemmWarpPolicy, gemm, gemm_v2 # noqa: F401
from .experimental.gemm_sp import gemm_sp # noqa: F401
from .fill import fill, clear # noqa: F401
from .reduce import (
......
......@@ -180,3 +180,180 @@ def gemm(
k_pack,
wg_wait,
)
# experimental currently, for fast compilation
def gemm_v2(
A: Union[tir.Buffer, tir.Var],
B: Union[tir.Buffer, tir.Var],
C: Union[tir.Buffer, tir.Var],
transpose_A: bool = False,
transpose_B: bool = False,
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
clear_accum: bool = False,
k_pack: int = 1,
wg_wait: int = 0,
):
"""Perform a General Matrix Multiplication (GEMM) operation.
This function computes C = A @ B where A and B can optionally be transposed.
The operation supports various warp policies and accumulation modes.
Args:
A (Union[tir.Buffer, tir.Var]): First input matrix
B (Union[tir.Buffer, tir.Var]): Second input matrix
C (Union[tir.Buffer, tir.Var]): Output matrix for results
transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False.
transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False.
policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square.
clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False.
k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1.
wg_wait (int, optional): Warp group wait count. Defaults to 0.
Returns:
tir.Call: A handle to the GEMM operation
Raises:
AssertionError: If the K dimensions of matrices A and B don't match
"""
def legalize_arguments(arg: Union[tir.Buffer, tir.Var]):
"""Convert let-bound variables to their corresponding buffers.
Args:
arg (Union[tir.Buffer, tir.Var]): Input argument to legalize
Returns:
Union[tir.Buffer, tir.Var]: The legalized argument
"""
if isinstance(arg, tir.Var) and T.has_let_value(arg):
return T.get_let_value(arg).buffer
return arg
A = legalize_arguments(A)
B = legalize_arguments(B)
C = legalize_arguments(C)
def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
if isinstance(object, tir.Buffer):
return object.shape
elif isinstance(object, tir.BufferRegion):
region = object.region
shape = []
for r in region:
shape.append(r.extent)
return shape
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
if isinstance(object, tir.Buffer):
strides = []
stride = 1
for s in reversed(object.shape):
strides.insert(0, stride)
stride *= s
return strides
elif isinstance(object, tir.BufferRegion):
buffer, _ = object.buffer, object.region
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
return strides
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
A_shape = retrieve_shape(A)
B_shape = retrieve_shape(B)
C_shape = retrieve_shape(C)
A_stride = retrieve_stride(A)
B_stride = retrieve_stride(B)
assert len(C_shape) == 2, "current only support C as a 2D tensor"
assert len(A_shape) >= 2, "current only support A as a 2D or higher-order tensor"
assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor"
if len(A_shape) > 2:
for i in range(len(A_shape) - 2):
assert A_shape[i] == 1, \
"current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
if len(B_shape) > 2:
for i in range(len(B_shape) - 2):
assert B_shape[i] == 1, \
"current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
M, N = C_shape
K = A_shape[-2] if transpose_A else A_shape[-1]
K_B = B_shape[-1] if transpose_B else B_shape[-2]
assert K == K_B, f"T.gemm K shape check failed: K_A = {K}, K_B = {K_B}"
stride_a = A_stride[-2]
stride_b = B_stride[-2]
def retrieve_ptr(object: Union[tir.Buffer, tir.BufferRegion],
access_type: str = "r") -> tir.PrimExpr:
if isinstance(object, tir.Buffer):
return object.access_ptr(access_type)
elif isinstance(object, tir.BufferRegion):
buffer, region = object.buffer, object.region
indices = []
for r in region:
indices.append(r.min)
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
offset = 0
# not offset the last two dimension
for i in range(len(indices) - 2):
offset += indices[i] * strides[i]
return buffer.access_ptr(access_mask=access_type, offset=offset)
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr:
"""Retrieve the offset of the buffer or buffer region."""
if isinstance(object, tir.Buffer):
return [0] * len(object.shape)
elif isinstance(object, tir.BufferRegion):
_, region = object.buffer, object.region
indices = []
for r in region:
indices.append(r.min)
return indices
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
A_offset = retrieve_offset(A)
B_offset = retrieve_offset(B)
assert A_offset[-2] == 0, "The offset of the first dimension of A must be 0"
assert B_offset[-2] == 0, "The offset of the first dimension of B must be 0"
offset_a = A_offset[-1]
offset_b = B_offset[-1]
Aptr = retrieve_ptr(A, "r")
Bptr = retrieve_ptr(B, "r")
Cptr = retrieve_ptr(C, "rw")
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.gemm_py"),
Aptr,
Bptr,
Cptr,
transpose_A,
transpose_B,
M,
N,
K,
policy,
clear_accum,
stride_a,
stride_b,
offset_a,
offset_b,
k_pack,
wg_wait,
)
......@@ -261,46 +261,54 @@ def Kernel(
def get_thread_binding(dim: int = 0) -> Var:
"""Returns the thread binding for the given dimension.
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_thread_binding(dim)
def get_thread_bindings() -> List[Var]:
"""Returns all three thread bindings.
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_thread_bindings()
def get_block_binding(dim: int = 0) -> Var:
"""Returns the block binding for the given dimension.
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_block_binding(dim)
def get_block_bindings() -> List[Var]:
"""Returns all three block bindings.
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_block_bindings()
def get_thread_extent(dim: int = 0) -> int:
"""Returns the thread extent for the given dimension.
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_thread_extent(dim)
def get_thread_extents() -> List[int]:
"""Returns all three thread extents.
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_thread_extents()
def get_block_extent(dim: int = 0) -> int:
"""Returns the block extent for the given dimension.
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_block_extent(dim)
def get_block_extents() -> List[int]:
"""Returns all three block extents.
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_block_extents()
......@@ -5,6 +5,8 @@ import tvm
from tilelang import _ffi_api
# Use a stable swizzled layout to ensure consistent memory access patterns.
# Swizzling should be enabled or disabled based on whether TMA (Tensor Memory Access) is applied.
def make_swizzled_layout(buffer: tvm.tir.Buffer):
assert len(buffer.shape) == 2
return _ffi_api.make_swizzled_layout(
......
......@@ -126,9 +126,17 @@ class Profiler:
if lhs is not None and rhs is not None:
# in case of numsplit template, the ref output may be None
# which means the value is invalid, so we skip the comparison
def is_float8(tensor: torch.Tensor) -> bool:
return tensor.dtype in {
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
}
torch_assert_close(
lhs,
rhs,
lhs if not is_float8(lhs) else lhs.to(torch.float32),
rhs if not is_float8(rhs) else rhs.to(torch.float32),
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio,
......
from .gemm import GemmPy # noqa: F401
from tilelang import tvm as tvm
from tvm import tir
from tilelang.utils.target import (
target_is_cuda,)
from tvm.target import Target
from tvm.ir.base import Node
from tvm.runtime import Scriptable
import tvm.ffi
from tilelang.ir import GemmWarpPolicy
from .gemm_mma import GemmMMA
@tvm.ffi.register_func("tl.gemm_py.infer_layout")
def gemm_py_infer_layout(gemm_py, target, thread_bounds):
thread_nums = thread_bounds.extent
return gemm_py.infer_layout(target, thread_nums)
@tvm.ffi.register_func("tl.gemm_py.lower")
def gemm_py_lower(gemm_py, target, thread_bounds, thread_var):
thread_nums = thread_bounds.extent
stmt = gemm_py.lower(target, thread_nums, thread_var)
return stmt
@tvm.ffi.register_object("tl.GemmPy")
class GemmPy(Node, Scriptable):
A: tir.Buffer
B: tir.Buffer
C: tir.Buffer
APtr: tir.PrimExpr
BPtr: tir.PrimExpr
CPtr: tir.PrimExpr
M: int
N: int
K: int
trans_A: bool
trans_B: bool
stride_A: int
stride_B: int
offset_A: int
offset_B: int
clear_accum: bool
k_pack: int
wg_wait: int
policy: GemmWarpPolicy
def infer_layout(self, target: Target, thread_nums: int):
if target_is_cuda(target):
# TODO(lei): Support more cuda architectures, now mma only
return GemmMMA(self).infer_layout(target, thread_nums)
else:
raise ValueError(f"Unsupported target: {target}")
def lower(self, target: Target, thread_nums: int, thread_var: tir.Var):
if target_is_cuda(target):
# TODO(lei): Support more cuda architectures, now mma only
# Now only implement ssr layout
return GemmMMA(self).lower(target, thread_nums, thread_var)
else:
raise ValueError(f"Unsupported target: {target}")
from dataclasses import dataclass
from tilelang import tvm as tvm
from tvm.target import Target
from tvm import tir
from tilelang.utils.language import is_shared, is_fragment
from tilelang.ir import GemmWarpPolicy
from tvm.ir.base import Node
@dataclass
class GemmBase(object):
gemm_node: Node
def infer_layout(self, target: Target, thread_nums: int):
raise NotImplementedError("infer_layout is not implemented")
def lower(self, target: Target, thread_nums: int, thread_var: tir.Var):
raise NotImplementedError("lower is not implemented")
def is_gemm_ss(self) -> bool:
return is_shared(self.A) and is_shared(self.B)
def is_gemm_sr(self) -> bool:
return is_shared(self.A) and is_fragment(self.B)
def is_gemm_rs(self) -> bool:
return is_fragment(self.A) and is_shared(self.B)
def is_gemm_rr(self) -> bool:
return is_fragment(self.A) and is_fragment(self.B)
@property
def M(self) -> int:
return self.gemm_node.M
@property
def N(self) -> int:
return self.gemm_node.N
@property
def K(self) -> int:
return self.gemm_node.K
@property
def trans_A(self) -> bool:
return self.gemm_node.trans_A
@property
def trans_B(self) -> bool:
return self.gemm_node.trans_B
@property
def in_dtype(self) -> str:
assert self.A.dtype == self.B.dtype, "A and B must have the same dtype"
return self.A.dtype
@property
def accum_dtype(self) -> str:
return self.C.dtype
@property
def chunk(self) -> int:
return self.A.shape[-2] if self.trans_A else self.A.shape[-1]
@property
def A(self) -> tir.Buffer:
return self.gemm_node.A
@property
def B(self) -> tir.Buffer:
return self.gemm_node.B
@property
def C(self) -> tir.Buffer:
return self.gemm_node.C
@property
def APtr(self) -> tir.PrimExpr:
return self.gemm_node.APtr
@property
def BPtr(self) -> tir.PrimExpr:
return self.gemm_node.BPtr
@property
def CPtr(self) -> tir.PrimExpr:
return self.gemm_node.CPtr
@property
def stride_A(self) -> int:
return self.gemm_node.stride_A
@property
def stride_B(self) -> int:
return self.gemm_node.stride_B
@property
def offset_A(self) -> int:
return self.gemm_node.offset_A
@property
def offset_B(self) -> int:
return self.gemm_node.offset_B
@property
def clear_accum(self) -> bool:
return self.gemm_node.clear_accum
@property
def k_pack(self) -> int:
return self.gemm_node.k_pack
@property
def wg_wait(self) -> int:
return self.gemm_node.wg_wait
@property
def policy(self) -> GemmWarpPolicy:
return self.gemm_node.policy
from .gemm_base import GemmBase
from tilelang.layout import make_swizzled_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
from tilelang.utils.language import is_shared, is_fragment
from tilelang import tvm as tvm
from tvm.target import Target
from tvm import tir
from tilelang import language as T
from tilelang.transform.simplify import _Simplify
class GemmMMA(GemmBase):
def infer_layout(self, target: Target, thread_nums: int):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
)
if self.is_gemm_ss():
return {
self.A: make_swizzled_layout(self.A),
self.B: make_swizzled_layout(self.B),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
elif self.is_gemm_sr():
return {
self.A: make_swizzled_layout(self.A),
self.B: mma_emitter.make_mma_load_layout(self.B, matrix="B"),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
elif self.is_gemm_rs():
return {
self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"),
self.B: make_swizzled_layout(self.B),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
elif self.is_gemm_rr():
return {
self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"),
self.B: mma_emitter.make_mma_load_layout(self.B, matrix="B"),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def lower(self, target: Target, thread_nums: int, thread_var: tir.Var):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
thread_var=thread_var,
)
in_dtype = self.in_dtype
warp_rows = mma_emitter.warp_rows
warp_cols = mma_emitter.warp_cols
local_size_a = mma_emitter.local_size_a
local_size_b = mma_emitter.local_size_b
block_K = mma_emitter.chunk
micro_size_k = mma_emitter.micro_size_k
A_shared = self.A
B_shared = self.B
C_local = self.C
if self.is_gemm_ss():
@T.prim_func
def _gemm_ssr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_sr():
B_local = self.B
@T.prim_func
def _gemm_srr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
# alloc_buffers body
# insert into parent block
return _Simplify(_gemm_srr, inline_let=True)
elif self.is_gemm_rs():
A_local = self.A
@T.prim_func
def _gemm_rsr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
elif self.is_gemm_rr():
A_local = self.A
B_local = self.B
@T.prim_func
def _gemm_rsr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
for ki in T.serial(0, (block_K // micro_size_k)):
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def is_gemm_ss(self) -> bool:
return is_shared(self.A) and is_shared(self.B)
def is_gemm_sr(self) -> bool:
return is_shared(self.A) and is_fragment(self.B)
def is_gemm_rs(self) -> bool:
return is_fragment(self.A) and is_shared(self.B)
def is_gemm_rr(self) -> bool:
return is_fragment(self.A) and is_fragment(self.B)
......@@ -2,7 +2,7 @@
# pylint: disable=invalid-name, unsupported-binary-operation
from . import _ffi_api
from .simplify import Simplify, simplify_prim_func # noqa: F401
from .simplify import Simplify, simplify_prim_func, LetInline # noqa: F401
from .pass_config import PassConfigKey # noqa: F401
from tilelang import tvm as tvm # noqa: F401
from tvm.ir.transform import PassContext # noqa: F401
......@@ -68,17 +68,6 @@ def InjectSoftwarePipeline():
return _ffi_api.InjectSoftwarePipeline() # type: ignore
def FrontendLegalize():
"""FrontendLegalize
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.FrontendLegalize() # type: ignore
def InjectAssumes():
"""Inject Assumes
......
......@@ -5,6 +5,17 @@ from typing import Union, Callable
from . import _ffi_api
def LetInline():
"""LetInline
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LetInline() # type: ignore
def Simplify(simplify_arguments: bool = False):
"""Simplify
......@@ -16,13 +27,24 @@ def Simplify(simplify_arguments: bool = False):
return _ffi_api.Simplify(simplify_arguments) # type: ignore
def _Simplify(stmt: Union[PrimFunc, IRModule]) -> Union[PrimFunc, IRModule]:
def _Simplify(stmt: Union[PrimFunc, IRModule],
inline_let: bool = False) -> Union[PrimFunc, IRModule]:
if isinstance(stmt, PrimFunc):
mod = Simplify(simplify_arguments=True)(IRModule.from_expr(stmt))
if inline_let:
mod = LetInline()(IRModule.from_expr(stmt))
mod = Simplify(simplify_arguments=True)(mod)
else:
mod = Simplify(simplify_arguments=True)(IRModule.from_expr(stmt))
assert len(mod.functions) == 1, "Simplify should return a single function"
return list(mod.functions.values()).pop()
elif isinstance(stmt, IRModule):
return Simplify(simplify_arguments=True)(stmt)
if inline_let:
mod = LetInline()(stmt)
mod = Simplify(simplify_arguments=True)(mod)
else:
mod = Simplify(simplify_arguments=True)(stmt)
assert len(mod.functions) == 1, "Simplify should return a single function"
return list(mod.functions.values()).pop()
else:
raise ValueError(f"Unsupported type: {type(stmt)}")
......@@ -37,6 +59,7 @@ def simplify_prim_func(func: Callable) -> Callable:
return wrapper
def apply_simplify(stmt: Union[PrimFunc, IRModule]) -> Union[PrimFunc, IRModule]:
def apply_simplify(stmt: Union[PrimFunc, IRModule],
inline_let: bool = False) -> Union[PrimFunc, IRModule]:
"""Apply Simplify pass to a PrimFunc or IRModule."""
return _Simplify(stmt)
return _Simplify(stmt, inline_let)
from typing import Literal, Union
from tilelang import tvm as tvm
from tilelang import _ffi_api
from tvm.target import Target
from tvm.contrib import rocm
from tilelang.contrib import nvcc
......@@ -81,3 +82,55 @@ def determine_target(target: Union[str, Target, Literal["auto"]] = "auto",
if return_object:
return Target(return_var)
return return_var
def target_is_cuda(target: Target) -> bool:
return _ffi_api.TargetIsCuda(target)
def target_is_hip(target: Target) -> bool:
return _ffi_api.TargetIsRocm(target)
def target_is_volta(target: Target) -> bool:
return _ffi_api.TargetIsVolta(target)
def target_is_turing(target: Target) -> bool:
return _ffi_api.TargetIsTuring(target)
def target_is_ampere(target: Target) -> bool:
return _ffi_api.TargetIsAmpere(target)
def target_is_hopper(target: Target) -> bool:
return _ffi_api.TargetIsHopper(target)
def target_is_sm120(target: Target) -> bool:
return _ffi_api.TargetIsSM120(target)
def target_is_cdna(target: Target) -> bool:
return _ffi_api.TargetIsCDNA(target)
def target_has_async_copy(target: Target) -> bool:
return _ffi_api.TargetHasAsyncCopy(target)
def target_has_ldmatrix(target: Target) -> bool:
return _ffi_api.TargetHasLdmatrix(target)
def target_has_stmatrix(target: Target) -> bool:
return _ffi_api.TargetHasStmatrix(target)
def target_has_bulk_copy(target: Target) -> bool:
return _ffi_api.TargetHasBulkCopy(target)
def target_get_warp_size(target: Target) -> int:
return _ffi_api.TargetGetWarpSize(target)
......@@ -113,9 +113,11 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer):
else:
return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype)
elif supply_type == TensorSupplyType.Uniform:
return torch.empty(*shape, device=device, dtype=dtype).uniform_(-1.0, 1.0)
return torch.empty(
*shape, device=device, dtype=torch.float32).uniform_(-1.0, 1.0).to(dtype)
elif supply_type == TensorSupplyType.Normal:
return torch.empty(*shape, device=device, dtype=dtype).normal_(-1.0, 1.0)
return torch.empty(
*shape, device=device, dtype=torch.float32).normal_(-1.0, 1.0).to(dtype)
elif supply_type == TensorSupplyType.Randn:
return torch.randn(*shape, device=device).to(dtype)
elif supply_type == TensorSupplyType.Zero:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment