Unverified Commit 283a9a00 authored by botbw's avatar botbw Committed by GitHub
Browse files

[Language] support `T.gemm_sp_v2` on sm80 and sm89 (#1056)

* [misc] add a cpp side wrapper for gemm_sp_py

* [misc] typing

* [IR] bind GemmSPWarpPolicy

* [chore] add wrapper code

* [IR] fix GemmSPWarpPolicy

* [codegen] apply ptxas instructions

* [intrinsic] add typical (unused) mma layout

* [template] add uint16 debug func

* [intrinsic] add b matrix layout

* [gemm_sp] enable fp16/bf16 on sm8x

* [layout] refactor fp16/bf16 layout

* [gemm_sp] enable int8

* [chore] update test case dtype

* [gemm_sp] enable fp32

* [layout] refactor layouts

* [intrinsic] enable ldmatrix for mat A

* [layout] enable ldsm for matrix b

* [layout] add ldmatrix for fp32 and fp8

* [chore] refine

* [chore] refactor

* [chore] add fp8 efactor

* [chore] refactor

* [chore] add remove negative zero util

* [example] add a custom compress kernel

* [chore] minor update

* [test] refactor gemm_sp test

* [refactor] make metadata layout func

* [example] add option for using cutlass layout

* [doc] add a gemm_sp doc

* [doc] minor polish

* [chore] remove unused

* [bugfix] fix non replicate b case

* [test] refactor

* [chore] add a check

* [bugfix] fix util bug

* [wip] init a new test case for v2

* [chore] minor refactor

* [chore] minor update

* [bugfix] enable 16bit rs

* [language] enable rs

* [language] enable gemm_sp_sr

* [language] enable gemm_sp_rr

* [test] enable more tests

* [tvm] update ffi binding

* [chore] remove print

* [chore] fix benchmark script

* [lint] precommit lint

* [chore] apply feedback

* [test] use arch 8.0

* [chore] rollback ::ordered_metadata for backward compatibility

* [bugfix] fix captialized

* [example] keep gemm_sp on hopper

* [test] fix no fp8 normal kernel

* [test] reduce matmul size to satisfy accum error

* [test] use cal_diff for assertion

* [bugfix] expand float8 type

* [lib] add make_int4 for short type

* [language] add transpose E

* [bugfix] fix wrong var

* [format] format

* [chore] refactor binding

* [chore] fix wrong passing var
parent b10ef75f
...@@ -39,6 +39,19 @@ class GemmWarpPolicy(Node, Scriptable): ...@@ -39,6 +39,19 @@ class GemmWarpPolicy(Node, Scriptable):
return self.m_warp, self.n_warp return self.m_warp, self.n_warp
@tvm_ffi.register_object("tl.GemmSPWarpPolicy")
class GemmSPWarpPolicy(Node, Scriptable):
policy_type: int
m_warp: int
n_warp: int
def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target,
is_wgmma: bool, bits: int):
_ffi_api.GemmSPWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target,
is_wgmma, bits)
return self.m_warp, self.n_warp
@tvm_ffi.register_object("tl.Gemm") @tvm_ffi.register_object("tl.Gemm")
class Gemm(Node, Scriptable): class Gemm(Node, Scriptable):
... ...
......
...@@ -51,7 +51,7 @@ from .allocate import ( ...@@ -51,7 +51,7 @@ from .allocate import (
) )
from .copy import copy, c2d_im2col # noqa: F401 from .copy import copy, c2d_im2col # noqa: F401
from .gemm import GemmWarpPolicy, gemm, gemm_v1, gemm_v2 # noqa: F401 from .gemm import GemmWarpPolicy, gemm, gemm_v1, gemm_v2 # noqa: F401
from .experimental.gemm_sp import gemm_sp # noqa: F401 from .experimental.gemm_sp import gemm_sp, gemm_sp_v2 # noqa: F401
from .fill import fill, clear # noqa: F401 from .fill import fill, clear # noqa: F401
from .reduce import ( from .reduce import (
reduce, # noqa: F401 reduce, # noqa: F401
......
...@@ -3,7 +3,15 @@ from __future__ import annotations ...@@ -3,7 +3,15 @@ from __future__ import annotations
from tilelang.primitives.gemm.base import GemmWarpPolicy from tilelang.primitives.gemm.base import GemmWarpPolicy
import tilelang.language as T import tilelang.language as T
from tvm import tir from tvm import tir
from tilelang.utils.language import to_buffer_region from tilelang.utils.language import (
to_buffer_region,
retrieve_shape,
retrieve_stride,
retrieve_offset,
prim_expr_equal,
)
from tilelang.language.utils import (
buffer_region_to_tile_region,)
def gemm_sp( def gemm_sp(
...@@ -85,3 +93,128 @@ def gemm_sp( ...@@ -85,3 +93,128 @@ def gemm_sp(
k_pack, k_pack,
wg_wait, wg_wait,
) )
# experimental currently, for fast compilation
def gemm_sp_v2(
A_sparse: tir.Buffer | tir.Var,
E: tir.Buffer | tir.Var,
B: tir.Buffer | tir.Var,
C: tir.Buffer | tir.Var,
transpose_A: bool = False,
transpose_B: bool = False,
transpose_E: bool = False,
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
clear_accum: bool = False,
k_pack: int = 1,
wg_wait: int = 0,
):
"""Perform a General Matrix Multiplication (GEMM) operation.
This function computes C = A @ B where A and B can optionally be transposed.
The operation supports various warp policies and accumulation modes.
Args:
A_sparse (Union[tir.Buffer, tir.Var]): First input matrix, contains only non-zero elements
E (Union[tir.Buffer, tir.Var]): The metadata of A_sparse, noted as E
B (Union[tir.Buffer, tir.Var]): Second input matrix
C (Union[tir.Buffer, tir.Var]): Output matrix for results
transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False.
transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False.
policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square.
clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False.
k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1.
wg_wait (int, optional): Warp group wait count. Defaults to 0.
Returns:
tir.Call: A handle to the GEMM operation
Raises:
AssertionError: If the K dimensions of matrices A and B don't match
"""
def legalize_arguments(arg: tir.Buffer | tir.Var):
"""Convert let-bound variables to their corresponding buffers.
Args:
arg (Union[tir.Buffer, tir.Var]): Input argument to legalize
Returns:
Union[tir.Buffer, tir.Var]: The legalized argument
"""
if isinstance(arg, tir.Var) and T.has_let_value(arg):
return T.get_let_value(arg).buffer
return arg
A_sparse = legalize_arguments(A_sparse)
E = legalize_arguments(E)
B = legalize_arguments(B)
C = legalize_arguments(C)
A_region = to_buffer_region(A_sparse)
E_region = to_buffer_region(E)
B_region = to_buffer_region(B)
C_region = to_buffer_region(C)
A_shape = retrieve_shape(A_sparse)
E_shape = retrieve_shape(E) # nolint: F841
B_shape = retrieve_shape(B)
C_shape = retrieve_shape(C)
A_stride = retrieve_stride(A_sparse)
B_stride = retrieve_stride(B)
assert len(C_shape) == 2, "current only support C as a 2D tensor"
assert len(A_shape) >= 2, "current only support A as a 2D or higher-order tensor"
assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor"
if len(A_shape) > 2:
for i in range(len(A_shape) - 2):
assert A_shape[i] == 1, \
"current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
if len(B_shape) > 2:
for i in range(len(B_shape) - 2):
assert B_shape[i] == 1, \
"current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
M, N = C_shape
K = 2 * (A_shape[-2] if transpose_A else A_shape[-1])
K_B = B_shape[-1] if transpose_B else B_shape[-2]
assert prim_expr_equal(
K, K_B), f"T.gemm_sp K shape check failed: K_A (wo sparse) = {K}, K_B = {K_B}"
stride_a = A_stride[-2]
stride_b = B_stride[-2]
A_offset = retrieve_offset(A_sparse)
B_offset = retrieve_offset(B)
assert A_offset[-2] == 0, "The offset of the first dimension of A must be 0"
assert B_offset[-2] == 0, "The offset of the first dimension of B must be 0"
offset_a = A_offset[-1]
offset_b = B_offset[-1]
A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape])
E_arg = buffer_region_to_tile_region(E_region, "r", [r for r in E_shape])
B_arg = buffer_region_to_tile_region(B_region, "r", [r for r in B_shape])
C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape])
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.gemm_sp_py"),
A_arg,
E_arg,
B_arg,
C_arg,
transpose_A,
transpose_B,
transpose_E,
M,
N,
K,
policy,
clear_accum,
stride_a,
stride_b,
offset_a,
offset_b,
k_pack,
wg_wait,
)
...@@ -13,4 +13,4 @@ from .swizzle import ( ...@@ -13,4 +13,4 @@ from .swizzle import (
make_quarter_bank_swizzled_layout, # noqa: F401 make_quarter_bank_swizzled_layout, # noqa: F401
make_linear_layout, # noqa: F401 make_linear_layout, # noqa: F401
) )
from .gemm_sp import make_metadata_layout # noqa: F401 from .gemm_sp import make_cutlass_metadata_layout # noqa: F401
...@@ -17,7 +17,7 @@ def decompose_col_major(index_1d: int, basis: list[int]) -> list[int]: ...@@ -17,7 +17,7 @@ def decompose_col_major(index_1d: int, basis: list[int]) -> list[int]:
return res return res
def _make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str, block_k: int): def make_cutlass_metadata_layout_sm90(buffer: tvm.tir.Buffer, mma_dtype: str, block_k: int):
"""Make a layout of metadata that is compatible with cutlass sm90 compression kernel. Note that layout atom is the same for smem and gmem. """Make a layout of metadata that is compatible with cutlass sm90 compression kernel. Note that layout atom is the same for smem and gmem.
Args: Args:
...@@ -30,7 +30,7 @@ def _make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str, b ...@@ -30,7 +30,7 @@ def _make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str, b
block_k = 128 block_k = 128
# Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146 # Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146
warnings.warn(f"block_k {block_k} is too large, set to 128 for {mma_dtype}.", stacklevel=2) warnings.warn(f"block_k {block_k} is too large, set to 128 for {mma_dtype}.", stacklevel=2)
if mma_dtype not in ["float16", "bfloat16", "float32", "int8", "float8"]: if mma_dtype not in ["float16", "bfloat16", "float32", "int8", "float8_e4m3", "float8_e5m2"]:
raise NotImplementedError(f"Unsupported dtype: {mma_dtype}") raise NotImplementedError(f"Unsupported dtype: {mma_dtype}")
if buffer.dtype not in ["uint8", "int8"]: if buffer.dtype not in ["uint8", "int8"]:
...@@ -41,7 +41,8 @@ def _make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str, b ...@@ -41,7 +41,8 @@ def _make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str, b
"bfloat16": 16, "bfloat16": 16,
"float32": 32, "float32": 32,
"int8": 8, "int8": 8,
"float8": 8, "float8_e4m3": 8,
"float8_e5m2": 8,
} }
# ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl#L108-L117 # ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl#L108-L117
...@@ -75,8 +76,8 @@ def _make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str, b ...@@ -75,8 +76,8 @@ def _make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str, b
shape_i, shape_k = shape_ik[:3], shape_ik[3:] shape_i, shape_k = shape_ik[:3], shape_ik[3:]
stride_i, stride_k = stride_ik[:3], stride_ik[3:] stride_i, stride_k = stride_ik[:3], stride_ik[3:]
elif bits_map[mma_dtype] == 8: elif bits_map[mma_dtype] == 8:
shape_i, shape_k = [64], [BlockK] shape_i, shape_k = [64], [block_k // 8]
stride_i, stride_k = [BlockK], [1] stride_i, stride_k = [block_k // 8], [1]
else: else:
raise NotImplementedError(f"Unknown mma type {mma_dtype}") raise NotImplementedError(f"Unknown mma type {mma_dtype}")
...@@ -103,54 +104,48 @@ def _make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str, b ...@@ -103,54 +104,48 @@ def _make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str, b
return T.Layout(shape, transform) return T.Layout(shape, transform)
def _make_metadata_layout_sm8x_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str): def make_cutlass_metadata_layout_sm8x(buffer: tvm.tir.Buffer, mma_dtype: str):
"""Make a layout of metadata that is compatible with cutlass sm8x compression kernel. Note that layout atom is the same for smem and gmem. """Make a layout of metadata that is compatible with cutlass sm8x compression kernel. Note that layout atom is the same for smem and gmem.
ref: https://github.com/pytorch/pytorch/blob/d0c24b392cbb7b213d22e42c52c6c2d1ac2da1bd/torch/sparse/_semi_structured_conversions.py#L5
Args: Args:
buffer: metadata buffer shape, for sm80 it should be a 16bit type buffer: metadata buffer shape, for sm80 it should be a 16bit type
""" """
# ref: https://github.com/nvidia/cutlass/blob/ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h#L651
# https://github.com/nvidia/cutlass/blob/ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/layout/matrix.h#L405
# https://github.com/nvidia/cutlass/blob/ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/gemm/warp/mma_sparse_tensor_op.h#L172
if mma_dtype in ["float16", "bfloat16"] and buffer.dtype not in ["uint16", "int16"]: if mma_dtype in ["float16", "bfloat16"] and buffer.dtype not in ["uint16", "int16"]:
raise ValueError(f"metadata should be 16 bit, got {buffer.dtype}") raise ValueError(f"metadata should be 16 bit, got {buffer.dtype}")
if mma_dtype in ["float8", "int8", "uint8"] and buffer.dtype not in ["uint32", "int32"]: if mma_dtype in ["float8_e4m3", "float8_e5m2", "int8", "uint8"
] and buffer.dtype not in ["uint32", "int32"]:
raise ValueError(f"metadata should be 32 bit, got {buffer.dtype}") raise ValueError(f"metadata should be 32 bit, got {buffer.dtype}")
kInterleaved = 2 m, k = buffer.shape
stride = buffer.shape[0] * kInterleaved group = 32 if buffer.dtype.bits == 16 else 16
interweave = 4 if buffer.dtype.bits == 16 else 2
def ColumnMajorInterleaved(i: int, j: int) -> int: def ColumnMajorInterleaved(i: int, j: int) -> int:
column_major = j // kInterleaved i = i // group * group + (i % 8) * interweave + (i % group) // 8
column_minor = j % kInterleaved topright = (1 - (i % 2)) & (j % 2)
return column_major * stride + i * kInterleaved + column_minor bottomleft = (i % 2) & (1 - (j % 2))
i += topright - bottomleft
j -= topright - bottomleft
offset = (j // 2) * m * 2 + i * 2 + (j % 2)
return offset // k, offset % k
return T.Layout(buffer.shape, ColumnMajorInterleaved) return T.Layout(buffer.shape, ColumnMajorInterleaved)
def make_metadata_layout(buffer: tvm.tir.Buffer, def make_cutlass_metadata_layout(buffer: tvm.tir.Buffer,
mma_dtype: str = "float16", mma_dtype: str = "float16",
backend: str = 'cutlass', arch: str | None = None,
arch: str | None = None, **extra_args):
**extra_args):
if arch is None: if arch is None:
arch = nvcc.get_target_compute_version() arch = nvcc.get_target_compute_version()
compute_version = nvcc.parse_compute_version(arch) compute_version = nvcc.parse_compute_version(arch)
if compute_version >= (9, 0): if compute_version >= (9, 0):
if backend == 'cutlass': return make_cutlass_metadata_layout_sm90(buffer=buffer, mma_dtype=mma_dtype, **extra_args)
return _make_metadata_layout_sm90_cutlass(
buffer=buffer, mma_dtype=mma_dtype, **extra_args)
else:
raise NotImplementedError(f"Arch {arch}, Unsupported backend: {backend}")
elif compute_version >= (8, 0): elif compute_version >= (8, 0):
if backend == 'cutlass': return make_cutlass_metadata_layout_sm8x(buffer=buffer, mma_dtype=mma_dtype)
return _make_metadata_layout_sm8x_cutlass(buffer=buffer, mma_dtype=mma_dtype)
else:
raise NotImplementedError(f"Arch {arch}, Unsupported backend: {backend}")
else: else:
raise NotImplementedError(f"Unsupported architecture: {arch}") raise NotImplementedError(f"Unsupported architecture: {arch}")
...@@ -10,6 +10,7 @@ from tilelang.utils.tensor import ( ...@@ -10,6 +10,7 @@ from tilelang.utils.tensor import (
get_tensor_supply, get_tensor_supply,
TensorSupplyType, TensorSupplyType,
torch_assert_close, torch_assert_close,
is_float8_dtype,
) )
from tilelang.engine.param import KernelParam from tilelang.engine.param import KernelParam
from tilelang.jit.adapter import BaseKernelAdapter from tilelang.jit.adapter import BaseKernelAdapter
...@@ -125,17 +126,9 @@ class Profiler: ...@@ -125,17 +126,9 @@ class Profiler:
if lhs is not None and rhs is not None: if lhs is not None and rhs is not None:
# in case of numsplit template, the ref output may be None # in case of numsplit template, the ref output may be None
# which means the value is invalid, so we skip the comparison # which means the value is invalid, so we skip the comparison
def is_float8(tensor: torch.Tensor) -> bool:
return tensor.dtype in {
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
}
torch_assert_close( torch_assert_close(
lhs if not is_float8(lhs) else lhs.to(torch.float32), lhs if not is_float8_dtype(lhs.dtype) else lhs.to(torch.float32),
rhs if not is_float8(rhs) else rhs.to(torch.float32), rhs if not is_float8_dtype(rhs.dtype) else rhs.to(torch.float32),
rtol=rtol, rtol=rtol,
atol=atol, atol=atol,
max_mismatched_ratio=max_mismatched_ratio, max_mismatched_ratio=max_mismatched_ratio,
......
from .gemm import GemmPy # noqa: F401 from .gemm import GemmPy # noqa: F401
from .gemm_sp import GemmSPPy # noqa: F401
...@@ -3,6 +3,7 @@ from tilelang import tvm as tvm ...@@ -3,6 +3,7 @@ from tilelang import tvm as tvm
from tvm import tir from tvm import tir
from tvm.target import Target from tvm.target import Target
from tvm.ir.base import Node from tvm.ir.base import Node
from tvm.ir import Range
from tvm.runtime import Scriptable from tvm.runtime import Scriptable
import tvm_ffi import tvm_ffi
from tilelang.ir import GemmWarpPolicy as GemmWarpPolicy from tilelang.ir import GemmWarpPolicy as GemmWarpPolicy
...@@ -16,13 +17,14 @@ from tilelang.utils.target import target_is_volta ...@@ -16,13 +17,14 @@ from tilelang.utils.target import target_is_volta
@tvm_ffi.register_global_func("tl.gemm_py.infer_layout") @tvm_ffi.register_global_func("tl.gemm_py.infer_layout")
def gemm_py_infer_layout(gemm_py, target, thread_bounds): def gemm_py_infer_layout(gemm_py: GemmMMA, target: Target, thread_bounds: Range):
thread_nums = thread_bounds.extent thread_nums = thread_bounds.extent
return gemm_py.infer_layout(target, thread_nums) return gemm_py.infer_layout(target, thread_nums)
@tvm_ffi.register_global_func("tl.gemm_py.lower") @tvm_ffi.register_global_func("tl.gemm_py.lower")
def gemm_py_lower(gemm_py, layout_map, target, thread_bounds, thread_var): def gemm_py_lower(gemm_py: GemmMMA, layout_map, target: Target, thread_bounds: Range,
thread_var: tir.Var):
thread_nums = thread_bounds.extent thread_nums = thread_bounds.extent
stmt = gemm_py.lower(layout_map, target, thread_nums, thread_var) stmt = gemm_py.lower(layout_map, target, thread_nums, thread_var)
return stmt return stmt
......
from tilelang import tvm as tvm
from tvm import tir
from tilelang.utils.target import (
target_is_cuda,)
from tvm.target import Target
from tvm.ir.base import Node
from tvm.ir import Range
from tvm.runtime import Scriptable
import tvm_ffi
from tilelang.ir import GemmWarpPolicy
from .gemm_sp_mma import GemmSPMMA
@tvm_ffi.register_global_func("tl.gemm_sp_py.infer_layout")
def gemm_sp_py_infer_layout(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range):
thread_nums = thread_bounds.extent
return gemm_sp_py.infer_layout(target, thread_nums)
@tvm_ffi.register_global_func("tl.gemm_sp_py.lower")
def gemm_sp_py_lower(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range,
thread_var: tir.Var):
thread_nums = thread_bounds.extent
stmt = gemm_sp_py.lower(target, thread_nums, thread_var)
return stmt
@tvm_ffi.register_object("tl.GemmSPPy")
class GemmSPPy(Node, Scriptable):
A: tir.Buffer
E: tir.Buffer
B: tir.Buffer
C: tir.Buffer
APtr: tir.PrimExpr
EPtr: tir.PrimExpr
BPtr: tir.PrimExpr
CPtr: tir.PrimExpr
M: int
N: int
K: int
trans_A: bool
trans_B: bool
stride_A: int
stride_B: int
offset_A: int
offset_B: int
clear_accum: bool
k_pack: int
wg_wait: int
policy: GemmWarpPolicy
def infer_layout(self, target: Target, thread_nums: int):
if target_is_cuda(target):
# TODO(lei): Support more cuda architectures, now mma only
return GemmSPMMA(self).infer_layout(target, thread_nums)
else:
raise ValueError(f"Unsupported target: {target}")
def lower(self, target: Target, thread_nums: int, thread_var: tir.Var):
if target_is_cuda(target):
# TODO(lei): Support more cuda architectures, now mma only
# Now only implement ssr layout
return GemmSPMMA(self).lower(target, thread_nums, thread_var)
else:
raise ValueError(f"Unsupported target: {target}")
from dataclasses import dataclass
from tilelang import tvm as tvm
from tvm.target import Target
from tvm import tir
from tilelang.utils.language import is_shared, is_fragment
from tilelang.ir import GemmWarpPolicy
from tvm.ir.base import Node
@dataclass
class GemmSPBase:
gemm_sp_node: Node
def infer_layout(self, target: Target, thread_nums: int):
raise NotImplementedError("infer_layout is not implemented")
def lower(self, target: Target, thread_nums: int, thread_var: tir.Var):
raise NotImplementedError("lower is not implemented")
def is_gemm_ss(self) -> bool:
return is_shared(self.A) and is_shared(self.B)
def is_gemm_sr(self) -> bool:
return is_shared(self.A) and is_fragment(self.B)
def is_gemm_rs(self) -> bool:
return is_fragment(self.A) and is_shared(self.B)
def is_gemm_rr(self) -> bool:
return is_fragment(self.A) and is_fragment(self.B)
@property
def M(self) -> int:
return self.gemm_sp_node.M
@property
def N(self) -> int:
return self.gemm_sp_node.N
@property
def K(self) -> int:
return self.gemm_sp_node.K
@property
def trans_A(self) -> bool:
return self.gemm_sp_node.trans_A
@property
def trans_B(self) -> bool:
return self.gemm_sp_node.trans_B
@property
def trans_E(self) -> bool:
return self.gemm_sp_node.trans_E
@property
def e_dtype(self) -> str:
return self.E.dtype
@property
def in_dtype(self) -> str:
assert self.A.dtype == self.B.dtype, "A and B must have the same dtype"
return self.A.dtype
@property
def accum_dtype(self) -> str:
return self.C.dtype
@property
def A(self) -> tir.Buffer:
return self.gemm_sp_node.A
@property
def E(self) -> tir.Buffer:
return self.gemm_sp_node.E
@property
def B(self) -> tir.Buffer:
return self.gemm_sp_node.B
@property
def C(self) -> tir.Buffer:
return self.gemm_sp_node.C
@property
def ARegion(self) -> tir.PrimExpr:
return self.gemm_sp_node.ARegion
@property
def ERegion(self) -> tir.PrimExpr:
return self.gemm_sp_node.ERegion
@property
def BRegion(self) -> tir.PrimExpr:
return self.gemm_sp_node.BRegion
@property
def CRegion(self) -> tir.PrimExpr:
return self.gemm_sp_node.CRegion
@property
def stride_A(self) -> int:
return self.gemm_sp_node.stride_A
@property
def stride_B(self) -> int:
return self.gemm_sp_node.stride_B
@property
def offset_A(self) -> int:
return self.gemm_sp_node.offset_A
@property
def offset_B(self) -> int:
return self.gemm_sp_node.offset_B
@property
def clear_accum(self) -> bool:
return self.gemm_sp_node.clear_accum
@property
def k_pack(self) -> int:
return self.gemm_sp_node.k_pack
@property
def wg_wait(self) -> int:
return self.gemm_sp_node.wg_wait
@property
def policy(self) -> GemmWarpPolicy:
return self.gemm_sp_node.policy
from .gemm_sp_base import GemmSPBase
from tilelang.layout import make_swizzled_layout
from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter
from tilelang.utils.language import is_shared, is_fragment
from tilelang import tvm as tvm
from tvm.target import Target
from tvm import tir
from tilelang import language as T
from tilelang.transform.simplify import _Simplify
class GemmSPMMA(GemmSPBase):
def infer_layout(self, target: Target, thread_nums: int):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = SparseTensorCoreIntrinEmitter(
a_dtype=self.in_dtype,
e_dtype=self.e_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
e_transposed=self.trans_E,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
warp_k=self.K,
)
if self.is_gemm_ss():
return {
self.A: make_swizzled_layout(self.A),
self.B: make_swizzled_layout(self.B),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
elif self.is_gemm_sr():
return {
self.A: make_swizzled_layout(self.A),
self.B: mma_emitter.make_mma_load_layout(self.B, matrix="B"),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
elif self.is_gemm_rs():
return {
self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"),
self.B: make_swizzled_layout(self.B),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
elif self.is_gemm_rr():
return {
self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"),
self.B: mma_emitter.make_mma_load_layout(self.B, matrix="B"),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def lower(self, target: Target, thread_nums: int, thread_var: tir.Var):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = SparseTensorCoreIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
e_dtype=self.e_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
e_transposed=self.trans_E,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
warp_k=self.K,
thread_var=thread_var,
)
in_dtype = self.in_dtype
warp_rows = mma_emitter.warp_rows
warp_cols = mma_emitter.warp_cols
local_size_a = mma_emitter.local_size_a
local_size_e = mma_emitter.local_size_e
local_size_b = mma_emitter.local_size_b
micro_size_k = mma_emitter.micro_size_k
A_shared = self.A
E_shared = self.E
B_shared = self.B
C_local = self.C
assert micro_size_k <= self.K, f"K dimension {self.K} should be >= micro size k {micro_size_k}"
if self.is_gemm_ss():
@T.prim_func
def _gemm_ssr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
E_local = T.alloc_local((warp_rows * local_size_e), self.e_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
for ki in T.serial(0, (self.K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)
# Load E into fragment
mma_emitter.ldmatrix_e(
E_local,
E_shared,
ki,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma_sp(A_local, E_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_sr():
B_local = self.B
@T.prim_func
def _gemm_srr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
E_local = T.alloc_local((warp_rows * local_size_e), self.e_dtype)
for ki in T.serial(0, (self.K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)
# Load E into fragment
mma_emitter.ldmatrix_e(
E_local,
E_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma_sp(A_local, E_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
# alloc_buffers body
# insert into parent block
return _Simplify(_gemm_srr, inline_let=True)
elif self.is_gemm_rs():
A_local = self.A
@T.prim_func
def _gemm_rsr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
E_local = T.alloc_local((warp_rows * local_size_e), self.e_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
for ki in T.serial(0, (self.K // micro_size_k)):
# Load E into fragment
mma_emitter.ldmatrix_e(
E_local,
E_shared,
ki,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma_sp(A_local, E_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
elif self.is_gemm_rr():
A_local = self.A
B_local = self.B
@T.prim_func
def _gemm_rrr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
E_local = T.alloc_local((warp_rows * local_size_e), self.e_dtype)
for ki in T.serial(0, (self.K // micro_size_k)):
# Load E into fragment
mma_emitter.ldmatrix_e(
E_local,
E_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma_sp(A_local, E_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rrr, inline_let=True)
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def is_gemm_ss(self) -> bool:
return is_shared(self.A) and is_shared(self.B)
def is_gemm_sr(self) -> bool:
return is_shared(self.A) and is_fragment(self.B)
def is_gemm_rs(self) -> bool:
return is_fragment(self.A) and is_shared(self.B)
def is_gemm_rr(self) -> bool:
return is_fragment(self.A) and is_fragment(self.B)
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
import torch import torch
import warnings import warnings
from tilelang.contrib import nvcc from tilelang.contrib import nvcc
from tilelang.utils.tensor import is_float8_dtype, fp8_remove_negative_zeros_
from torch.utils.cpp_extension import load, _import_module_from_library from torch.utils.cpp_extension import load, _import_module_from_library
from tilelang import env from tilelang import env
...@@ -88,7 +89,18 @@ def compress(A: torch.Tensor, ...@@ -88,7 +89,18 @@ def compress(A: torch.Tensor,
if compute_version >= (9, 0): if compute_version >= (9, 0):
return compress_sm90(A, transposed=transposed, **kwargs) return compress_sm90(A, transposed=transposed, **kwargs)
elif compute_version >= (8, 0): elif compute_version >= (8, 0):
return compress_sm80(A, transposed=transposed) if transposed:
A = A.t().contiguous()
origin_dtype = A.dtype
if is_float8_dtype(origin_dtype):
fp8_remove_negative_zeros_(A)
A = A.view(torch.int8)
A_sp, E = compress_sm80(A, transposed=False)
if is_float8_dtype(origin_dtype):
A_sp = A_sp.view(origin_dtype)
if transposed:
A_sp = A_sp.t().contiguous()
return A_sp, E
else: else:
raise ValueError(f"Unsupported CUDA compute version: {compute_version}. " raise ValueError(f"Unsupported CUDA compute version: {compute_version}. "
"Supported versions are sm_80 and sm_90.") "Supported versions are sm_80 and sm_90.")
...@@ -105,6 +117,8 @@ def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device='cuda', transp ...@@ -105,6 +117,8 @@ def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device='cuda', transp
transposed (bool): If True, returns a transposed tensor of shape (K, M) transposed (bool): If True, returns a transposed tensor of shape (K, M)
""" """
elem, group = 2, 4 elem, group = 2, 4
if dtype == torch.float32:
elem, group = 1, 2
tensor = torch.randn((M, K), dtype=torch.float, device=device).view(M, -1, group) tensor = torch.randn((M, K), dtype=torch.float, device=device).view(M, -1, group)
indice = tensor.topk(elem, dim=-1).indices indice = tensor.topk(elem, dim=-1).indices
tensor.scatter_(-1, indice, 0) tensor.scatter_(-1, indice, 0)
...@@ -114,6 +128,36 @@ def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device='cuda', transp ...@@ -114,6 +128,36 @@ def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device='cuda', transp
return tensor.to(dtype) # dtype like float8 might not have randn kernel return tensor.to(dtype) # dtype like float8 might not have randn kernel
def randint_semi_sparse(M: int,
K: int,
low: int,
high: int,
dtype=torch.int32,
device='cuda',
transposed: bool = False):
"""
Generate a random semi-sparse integer tensor. The generated tensor will have 2:4 sparsity along the K dimension.
Args:
M (int): Number of rows
K (int): Number of columns
low (int): Lower bound of the random integers
high (int): Upper bound of the random integers
dtype: Data type of the tensor
device: Device to create the tensor on
transposed (bool): If True, returns a transposed tensor of shape (K, M)
"""
elem, group = 2, 4
if dtype == torch.float32:
elem, group = 1, 2
tensor = torch.randint(low, high, (M, K), dtype=dtype, device=device).view(M, -1, group)
indice = tensor.topk(elem, dim=-1).indices
tensor.scatter_(-1, indice, 0)
tensor = tensor.view(M, K)
if transposed:
tensor = tensor.t().contiguous()
return tensor
def arange_semi_sparse(M: int, def arange_semi_sparse(M: int,
K: int, K: int,
dtype=torch.float16, dtype=torch.float16,
...@@ -129,6 +173,8 @@ def arange_semi_sparse(M: int, ...@@ -129,6 +173,8 @@ def arange_semi_sparse(M: int,
transposed (bool): If True, returns a transposed tensor of shape (K, M) transposed (bool): If True, returns a transposed tensor of shape (K, M)
""" """
elem, group = 2, 4 elem, group = 2, 4
if dtype == torch.float32:
elem, group = 1, 2
tensor = torch.arange(M * K, dtype=dtype, device=device).view(M, -1, group) tensor = torch.arange(M * K, dtype=dtype, device=device).view(M, -1, group)
indice = tensor.topk(elem, dim=-1).indices indice = tensor.topk(elem, dim=-1).indices
tensor.scatter_(-1, indice, 0) tensor.scatter_(-1, indice, 0)
......
...@@ -5,6 +5,22 @@ from tvm import tir ...@@ -5,6 +5,22 @@ from tvm import tir
import numpy as np import numpy as np
def is_float8_dtype(dtype: torch.dtype) -> bool:
return dtype in {
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
}
def fp8_remove_negative_zeros_(tensor: torch.Tensor):
assert is_float8_dtype(tensor.dtype), "Input tensor must be of float8 dtype"
bits = tensor.view(torch.uint8)
zeros_mask = (tensor == 0)
bits[zeros_mask] = 0x00
class TensorSupplyType(Enum): class TensorSupplyType(Enum):
Integer = 1 Integer = 1
Uniform = 2 Uniform = 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment