Unverified Commit f58bcd43 authored by Zhiwen Mo's avatar Zhiwen Mo Committed by GitHub
Browse files

[SM100] Add sm100 GEMM layouts and tcgen05 support (#887)

* update sm100 related utcmma, tmem, ld/st256 in src
* update sm100 related utcmma, tmem, ld/st256 in tilelang
* Remove deprecated GEMM examples and related README documentation for SM100 architecture support
* Update GEMM implementation to replace UTCMMA with TCGEN5MMA across relevant files
* Remove gemm_umma.py example and update README to reflect TCGEN5MMA terminology changes
* Update README.md for gemm_sm100 example by removing outdated API sections and streamlining documentation
* Update README and source files to reflect TCGEN5.MMA terminology changes
* Refactor CUDA GEMM header for improved readability
parent c382dcbc
...@@ -44,7 +44,7 @@ def assert_gemm_codegen( ...@@ -44,7 +44,7 @@ def assert_gemm_codegen(
): ):
func = matmul(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype) func = matmul(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype)
# Because the current pass context have been polluted by previous testing. # Because the current pass context have been polluted by previous testing.
with tvm.transform.PassContext(): with tvm.transform.PassContext(), tvm.target.Target("webgpu"):
artifact = tilelang.lower(func, target="webgpu") artifact = tilelang.lower(func, target="webgpu")
src_code = artifact.kernel_source src_code = artifact.kernel_source
......
...@@ -449,6 +449,14 @@ def have_tma(target): ...@@ -449,6 +449,14 @@ def have_tma(target):
return any(conditions) return any(conditions)
def is_hopper(target):
if target.kind.name != "cuda":
return False
compute_version = get_target_compute_version(target)
major, minor = parse_compute_version(compute_version)
return major == 9 and minor == 0
def get_nvcc_compiler() -> str: def get_nvcc_compiler() -> str:
"""Get the path to the nvcc compiler""" """Get the path to the nvcc compiler"""
return os.path.join(find_cuda_path(), "bin", "nvcc") return os.path.join(find_cuda_path(), "bin", "nvcc")
...@@ -2,7 +2,7 @@ from tvm import tir, IRModule ...@@ -2,7 +2,7 @@ from tvm import tir, IRModule
from tvm.target import Target from tvm.target import Target
import tilelang import tilelang
from tilelang.transform import PassContext from tilelang.transform import PassContext
from tilelang.contrib.nvcc import have_tma from tilelang.contrib.nvcc import have_tma, is_hopper
from typing import Optional from typing import Optional
...@@ -120,7 +120,8 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -120,7 +120,8 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
pass_ctx = tilelang.transform.get_pass_context() pass_ctx = tilelang.transform.get_pass_context()
# Lower the barrier.arrive into specific initialization slot # Lower the barrier.arrive into specific initialization slot
mod = tilelang.transform.LowerSharedBarrier()(mod) mod = tilelang.transform.LowerSharedBarrier()(mod)
# Lower the shared.tmem into specific initialization slot
mod = tilelang.transform.LowerSharedTmem()(mod)
# which may be introduced by the LegalizeSafeMemoryAccess # which may be introduced by the LegalizeSafeMemoryAccess
if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target): if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target):
mod = tilelang.transform.IfStmtBinding()(mod) mod = tilelang.transform.IfStmtBinding()(mod)
...@@ -136,6 +137,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -136,6 +137,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# so we need to lower the opaque block first # so we need to lower the opaque block first
mod = tilelang.transform.LowerOpaqueBlock()(mod) mod = tilelang.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.MergeIfStmt()(mod) mod = tilelang.transform.MergeIfStmt()(mod)
if is_hopper(target):
mod = tilelang.transform.RewriteWgmmaSync()(mod) mod = tilelang.transform.RewriteWgmmaSync()(mod)
mod = tilelang.transform.InjectFenceProxy()(mod) mod = tilelang.transform.InjectFenceProxy()(mod)
else: else:
......
...@@ -42,6 +42,7 @@ from .allocate import ( ...@@ -42,6 +42,7 @@ from .allocate import (
alloc_shared, # noqa: F401 alloc_shared, # noqa: F401
alloc_fragment, # noqa: F401 alloc_fragment, # noqa: F401
alloc_barrier, # noqa: F401 alloc_barrier, # noqa: F401
alloc_tmem, # noqa: F401
alloc_reducer, # noqa: F401 alloc_reducer, # noqa: F401
) )
from .copy import copy, c2d_im2col # noqa: F401 from .copy import copy, c2d_im2col # noqa: F401
......
...@@ -89,6 +89,35 @@ def alloc_barrier(arrive_count: int): ...@@ -89,6 +89,35 @@ def alloc_barrier(arrive_count: int):
return T.alloc_buffer([arrive_count], "uint64", scope="shared.barrier") return T.alloc_buffer([arrive_count], "uint64", scope="shared.barrier")
def alloc_tmem(shape, dtype):
"""
Allocate a Tensor Memory (TMEM) buffer for use with 5th generation Tensor Core operations (e.g., TCGEN5.MMA).
TMEM is a dedicated on-chip memory introduced in Hopper GPUs, designed to reduce register pressure and enable asynchronous, single-threaded MMA operations. It is organized as a 2D array of 512 columns by 128 rows (lanes), with each cell being 32 bits. Allocation is performed in units of columns, and every lane of a column is allocated together.
Key properties and requirements:
- The number of columns allocated must be a power of 2 and at least 32.
- TMEM allocations are dynamic and must be explicitly deallocated.
- Both allocation and deallocation must be performed by the same warp.
- The base address of the TMEM allocation is stored in shared memory and used as the offset for TCGEN5.MMA accumulator tensors.
- Only TCGEN5.MMA and specific TMEM load/store instructions can access TMEM; all pre-processing must occur before data is loaded into TMEM, and all post-processing after data is retrieved.
- The number of columns allocated should not increase between any two allocations in the execution order within the CTA.
Args:
num_cols (int): Number of columns to allocate in TMEM. Must be a power of 2 and >= 32 but less than or equal to 512.
Returns:
T.Buffer: A TVM buffer object allocated in TMEM scope, suitable for use as an accumulator or operand in TCGEN5.MMA operations.
Note:
- TMEM is only available on supported architectures (e.g., Hopper and later).
- The buffer returned should be used according to TMEM access restrictions and deallocated appropriately.
"""
assert len(shape) == 2, "shape must be a 2D tensor for TMEM allocation"
return T.alloc_buffer(shape, dtype, scope="shared.tmem")
def alloc_reducer(shape, dtype, op="sum", replication=None): def alloc_reducer(shape, dtype, op="sum", replication=None):
""" """
Allocate a reducer buffer. Allocate a reducer buffer.
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from tilelang.primitives.gemm.base import GemmWarpPolicy from tilelang.primitives.gemm.base import GemmWarpPolicy
import tilelang.language as T import tilelang.language as T
from tvm import tir from tvm import tir
from typing import Union, List from typing import Union, List, Optional
from tilelang.utils.language import get_buffer_region_from_load from tilelang.utils.language import get_buffer_region_from_load
...@@ -17,6 +17,7 @@ def gemm( ...@@ -17,6 +17,7 @@ def gemm(
clear_accum: bool = False, clear_accum: bool = False,
k_pack: int = 1, k_pack: int = 1,
wg_wait: int = 0, wg_wait: int = 0,
mbar: Optional[tir.Buffer] = None,
): ):
"""Perform a General Matrix Multiplication (GEMM) operation. """Perform a General Matrix Multiplication (GEMM) operation.
...@@ -33,6 +34,9 @@ def gemm( ...@@ -33,6 +34,9 @@ def gemm(
clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False. clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False.
k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1. k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1.
wg_wait (int, optional): Warp group wait count. Defaults to 0. wg_wait (int, optional): Warp group wait count. Defaults to 0.
On hopper it is equivalent to `wgmma.wait_group.sync.aligned <wg_wait>` if wg_wait is not -1
On sm100, `wg_wait` can only be 0 or -1. `mbarrier_wait(TCGEN5MMA barrier)` will be appended if wg_wait is 0.
mbar (tir.Buffer, optional): mbarrier for TCGEN5MMA synchronization
Returns: Returns:
tir.Call: A handle to the GEMM operation tir.Call: A handle to the GEMM operation
...@@ -57,6 +61,7 @@ def gemm( ...@@ -57,6 +61,7 @@ def gemm(
A = legalize_arguments(A) A = legalize_arguments(A)
B = legalize_arguments(B) B = legalize_arguments(B)
C = legalize_arguments(C) C = legalize_arguments(C)
mbar = legalize_arguments(mbar) if mbar is not None else None
def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]: def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
if isinstance(object, tir.Buffer): if isinstance(object, tir.Buffer):
...@@ -200,26 +205,11 @@ def gemm( ...@@ -200,26 +205,11 @@ def gemm(
Aptr = retrieve_ptr(A, "r") Aptr = retrieve_ptr(A, "r")
Bptr = retrieve_ptr(B, "r") Bptr = retrieve_ptr(B, "r")
Cptr = retrieve_ptr(C, "rw") Cptr = retrieve_ptr(C, "rw")
return tir.call_intrin( mbarptr = retrieve_ptr(mbar, "rw") if mbar is not None else tir.const(0, "uint32")
"handle", C_coords = [r.min for r in C.region] if isinstance(C, tir.BufferRegion) else [0, 0]
tir.op.Op.get("tl.gemm"), return tir.call_intrin("handle", tir.op.Op.get("tl.gemm"), Aptr, Bptr, Cptr, transpose_A,
Aptr, transpose_B, M, N, K, policy, clear_accum, stride_a, stride_b, offset_a,
Bptr, offset_b, k_pack, wg_wait, mbarptr, C_coords[0], C_coords[1])
Cptr,
transpose_A,
transpose_B,
M,
N,
K,
policy,
clear_accum,
stride_a,
stride_b,
offset_a,
offset_b,
k_pack,
wg_wait,
)
# experimental currently, for fast compilation # experimental currently, for fast compilation
......
...@@ -69,6 +69,17 @@ def InjectSoftwarePipeline(): ...@@ -69,6 +69,17 @@ def InjectSoftwarePipeline():
return _ffi_api.InjectSoftwarePipeline() # type: ignore return _ffi_api.InjectSoftwarePipeline() # type: ignore
def FrontendLegalize():
"""FrontendLegalize
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.FrontendLegalize() # type: ignore
def InjectAssumes(): def InjectAssumes():
"""Inject Assumes """Inject Assumes
...@@ -429,6 +440,12 @@ def LowerDeviceKernelLaunch(): ...@@ -429,6 +440,12 @@ def LowerDeviceKernelLaunch():
return _ffi_api.LowerDeviceKernelLaunch() # type: ignore return _ffi_api.LowerDeviceKernelLaunch() # type: ignore
def LowerSharedTmem():
"""LowerSharedTmem
"""
return _ffi_api.LowerSharedTmem() # type: ignore
def LayoutReducer(): def LayoutReducer():
""" """
Return a TVM transform pass that performs layout reduction/normalization. Return a TVM transform pass that performs layout reduction/normalization.
......
...@@ -45,6 +45,8 @@ class PassConfigKey(str, Enum): ...@@ -45,6 +45,8 @@ class PassConfigKey(str, Enum):
TL_DISABLE_SAFE_MEMORY_ACCESS = "tl.disable_safe_memory_legalize" TL_DISABLE_SAFE_MEMORY_ACCESS = "tl.disable_safe_memory_legalize"
"""Disable safe memory access optimization. Default: False""" """Disable safe memory access optimization. Default: False"""
TL_DISABLE_VECTORIZE_256 = "tl.disable_vectorize_256"
"""Disable usage of LDG/STG 256. Default: False"""
TL_DISABLE_WGMMA = "tl.disable_wgmma" TL_DISABLE_WGMMA = "tl.disable_wgmma"
"""Disable usage of Hopper WGMMA. Default: False""" """Disable usage of Hopper WGMMA. Default: False"""
......
...@@ -62,6 +62,9 @@ def determine_target(target: Union[str, Target, Literal["auto"]] = "auto", ...@@ -62,6 +62,9 @@ def determine_target(target: Union[str, Target, Literal["auto"]] = "auto",
return_var: Union[str, Target] = target return_var: Union[str, Target] = target
if target == "auto": if target == "auto":
target = tvm.target.Target.current(allow_none=True)
if target is not None:
return target
# Check for CUDA and HIP availability # Check for CUDA and HIP availability
is_cuda_available = check_cuda_availability() is_cuda_available = check_cuda_availability()
is_hip_available = check_hip_availability() is_hip_available = check_hip_availability()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment