Unverified Commit 283a9a00 authored by botbw's avatar botbw Committed by GitHub
Browse files

[Language] support `T.gemm_sp_v2` on sm80 and sm89 (#1056)

* [misc] add a cpp side wrapper for gemm_sp_py

* [misc] typing

* [IR] bind GemmSPWarpPolicy

* [chore] add wrapper code

* [IR] fix GemmSPWarpPolicy

* [codegen] apply ptxas instructions

* [intrinsic] add typical (unused) mma layout

* [template] add uint16 debug func

* [intrinsic] add b matrix layout

* [gemm_sp] enable fp16/bf16 on sm8x

* [layout] refactor fp16/bf16 layout

* [gemm_sp] enable int8

* [chore] update test case dtype

* [gemm_sp] enable fp32

* [layout] refactor layouts

* [intrinsic] enable ldmatrix for mat A

* [layout] enable ldsm for matrix b

* [layout] add ldmatrix for fp32 and fp8

* [chore] refine

* [chore] refactor

* [chore] add fp8 efactor

* [chore] refactor

* [chore] add remove negative zero util

* [example] add a custom compress kernel

* [chore] minor update

* [test] refactor gemm_sp test

* [refactor] make metadata layout func

* [example] add option for using cutlass layout

* [doc] add a gemm_sp doc

* [doc] minor polish

* [chore] remove unused

* [bugfix] fix non replicate b case

* [test] refactor

* [chore] add a check

* [bugfix] fix util bug

* [wip] init a new test case for v2

* [chore] minor refactor

* [chore] minor update

* [bugfix] enable 16bit rs

* [language] enable rs

* [language] enable gemm_sp_sr

* [language] enable gemm_sp_rr

* [test] enable more tests

* [tvm] update ffi binding

* [chore] remove print

* [chore] fix benchmark script

* [lint] precommit lint

* [chore] apply feedback

* [test] use arch 8.0

* [chore] rollback ::ordered_metadata for backward compatibility

* [bugfix] fix captialized

* [example] keep gemm_sp on hopper

* [test] fix no fp8 normal kernel

* [test] reduce matmul size to satisfy accum error

* [test] use cal_diff for assertion

* [bugfix] expand float8 type

* [lib] add make_int4 for short type

* [language] add transpose E

* [bugfix] fix wrong var

* [format] format

* [chore] refactor binding

* [chore] fix wrong passing var
parent b10ef75f
......@@ -39,6 +39,19 @@ class GemmWarpPolicy(Node, Scriptable):
return self.m_warp, self.n_warp
@tvm_ffi.register_object("tl.GemmSPWarpPolicy")
class GemmSPWarpPolicy(Node, Scriptable):
policy_type: int
m_warp: int
n_warp: int
def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target,
is_wgmma: bool, bits: int):
_ffi_api.GemmSPWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target,
is_wgmma, bits)
return self.m_warp, self.n_warp
@tvm_ffi.register_object("tl.Gemm")
class Gemm(Node, Scriptable):
...
......
......@@ -51,7 +51,7 @@ from .allocate import (
)
from .copy import copy, c2d_im2col # noqa: F401
from .gemm import GemmWarpPolicy, gemm, gemm_v1, gemm_v2 # noqa: F401
from .experimental.gemm_sp import gemm_sp # noqa: F401
from .experimental.gemm_sp import gemm_sp, gemm_sp_v2 # noqa: F401
from .fill import fill, clear # noqa: F401
from .reduce import (
reduce, # noqa: F401
......
......@@ -3,7 +3,15 @@ from __future__ import annotations
from tilelang.primitives.gemm.base import GemmWarpPolicy
import tilelang.language as T
from tvm import tir
from tilelang.utils.language import to_buffer_region
from tilelang.utils.language import (
to_buffer_region,
retrieve_shape,
retrieve_stride,
retrieve_offset,
prim_expr_equal,
)
from tilelang.language.utils import (
buffer_region_to_tile_region,)
def gemm_sp(
......@@ -85,3 +93,128 @@ def gemm_sp(
k_pack,
wg_wait,
)
# experimental currently, for fast compilation
def gemm_sp_v2(
A_sparse: tir.Buffer | tir.Var,
E: tir.Buffer | tir.Var,
B: tir.Buffer | tir.Var,
C: tir.Buffer | tir.Var,
transpose_A: bool = False,
transpose_B: bool = False,
transpose_E: bool = False,
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
clear_accum: bool = False,
k_pack: int = 1,
wg_wait: int = 0,
):
"""Perform a General Matrix Multiplication (GEMM) operation.
This function computes C = A @ B where A and B can optionally be transposed.
The operation supports various warp policies and accumulation modes.
Args:
A_sparse (Union[tir.Buffer, tir.Var]): First input matrix, contains only non-zero elements
E (Union[tir.Buffer, tir.Var]): The metadata of A_sparse, noted as E
B (Union[tir.Buffer, tir.Var]): Second input matrix
C (Union[tir.Buffer, tir.Var]): Output matrix for results
transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False.
transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False.
policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square.
clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False.
k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1.
wg_wait (int, optional): Warp group wait count. Defaults to 0.
Returns:
tir.Call: A handle to the GEMM operation
Raises:
AssertionError: If the K dimensions of matrices A and B don't match
"""
def legalize_arguments(arg: tir.Buffer | tir.Var):
"""Convert let-bound variables to their corresponding buffers.
Args:
arg (Union[tir.Buffer, tir.Var]): Input argument to legalize
Returns:
Union[tir.Buffer, tir.Var]: The legalized argument
"""
if isinstance(arg, tir.Var) and T.has_let_value(arg):
return T.get_let_value(arg).buffer
return arg
A_sparse = legalize_arguments(A_sparse)
E = legalize_arguments(E)
B = legalize_arguments(B)
C = legalize_arguments(C)
A_region = to_buffer_region(A_sparse)
E_region = to_buffer_region(E)
B_region = to_buffer_region(B)
C_region = to_buffer_region(C)
A_shape = retrieve_shape(A_sparse)
E_shape = retrieve_shape(E) # nolint: F841
B_shape = retrieve_shape(B)
C_shape = retrieve_shape(C)
A_stride = retrieve_stride(A_sparse)
B_stride = retrieve_stride(B)
assert len(C_shape) == 2, "current only support C as a 2D tensor"
assert len(A_shape) >= 2, "current only support A as a 2D or higher-order tensor"
assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor"
if len(A_shape) > 2:
for i in range(len(A_shape) - 2):
assert A_shape[i] == 1, \
"current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
if len(B_shape) > 2:
for i in range(len(B_shape) - 2):
assert B_shape[i] == 1, \
"current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
M, N = C_shape
K = 2 * (A_shape[-2] if transpose_A else A_shape[-1])
K_B = B_shape[-1] if transpose_B else B_shape[-2]
assert prim_expr_equal(
K, K_B), f"T.gemm_sp K shape check failed: K_A (wo sparse) = {K}, K_B = {K_B}"
stride_a = A_stride[-2]
stride_b = B_stride[-2]
A_offset = retrieve_offset(A_sparse)
B_offset = retrieve_offset(B)
assert A_offset[-2] == 0, "The offset of the first dimension of A must be 0"
assert B_offset[-2] == 0, "The offset of the first dimension of B must be 0"
offset_a = A_offset[-1]
offset_b = B_offset[-1]
A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape])
E_arg = buffer_region_to_tile_region(E_region, "r", [r for r in E_shape])
B_arg = buffer_region_to_tile_region(B_region, "r", [r for r in B_shape])
C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape])
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.gemm_sp_py"),
A_arg,
E_arg,
B_arg,
C_arg,
transpose_A,
transpose_B,
transpose_E,
M,
N,
K,
policy,
clear_accum,
stride_a,
stride_b,
offset_a,
offset_b,
k_pack,
wg_wait,
)
......@@ -13,4 +13,4 @@ from .swizzle import (
make_quarter_bank_swizzled_layout, # noqa: F401
make_linear_layout, # noqa: F401
)
from .gemm_sp import make_metadata_layout # noqa: F401
from .gemm_sp import make_cutlass_metadata_layout # noqa: F401
......@@ -17,7 +17,7 @@ def decompose_col_major(index_1d: int, basis: list[int]) -> list[int]:
return res
def _make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str, block_k: int):
def make_cutlass_metadata_layout_sm90(buffer: tvm.tir.Buffer, mma_dtype: str, block_k: int):
"""Make a layout of metadata that is compatible with cutlass sm90 compression kernel. Note that layout atom is the same for smem and gmem.
Args:
......@@ -30,7 +30,7 @@ def _make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str, b
block_k = 128
# Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146
warnings.warn(f"block_k {block_k} is too large, set to 128 for {mma_dtype}.", stacklevel=2)
if mma_dtype not in ["float16", "bfloat16", "float32", "int8", "float8"]:
if mma_dtype not in ["float16", "bfloat16", "float32", "int8", "float8_e4m3", "float8_e5m2"]:
raise NotImplementedError(f"Unsupported dtype: {mma_dtype}")
if buffer.dtype not in ["uint8", "int8"]:
......@@ -41,7 +41,8 @@ def _make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str, b
"bfloat16": 16,
"float32": 32,
"int8": 8,
"float8": 8,
"float8_e4m3": 8,
"float8_e5m2": 8,
}
# ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl#L108-L117
......@@ -75,8 +76,8 @@ def _make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str, b
shape_i, shape_k = shape_ik[:3], shape_ik[3:]
stride_i, stride_k = stride_ik[:3], stride_ik[3:]
elif bits_map[mma_dtype] == 8:
shape_i, shape_k = [64], [BlockK]
stride_i, stride_k = [BlockK], [1]
shape_i, shape_k = [64], [block_k // 8]
stride_i, stride_k = [block_k // 8], [1]
else:
raise NotImplementedError(f"Unknown mma type {mma_dtype}")
......@@ -103,54 +104,48 @@ def _make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str, b
return T.Layout(shape, transform)
def _make_metadata_layout_sm8x_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str):
def make_cutlass_metadata_layout_sm8x(buffer: tvm.tir.Buffer, mma_dtype: str):
"""Make a layout of metadata that is compatible with cutlass sm8x compression kernel. Note that layout atom is the same for smem and gmem.
ref: https://github.com/pytorch/pytorch/blob/d0c24b392cbb7b213d22e42c52c6c2d1ac2da1bd/torch/sparse/_semi_structured_conversions.py#L5
Args:
buffer: metadata buffer shape, for sm80 it should be a 16bit type
"""
# ref: https://github.com/nvidia/cutlass/blob/ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h#L651
# https://github.com/nvidia/cutlass/blob/ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/layout/matrix.h#L405
# https://github.com/nvidia/cutlass/blob/ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/gemm/warp/mma_sparse_tensor_op.h#L172
if mma_dtype in ["float16", "bfloat16"] and buffer.dtype not in ["uint16", "int16"]:
raise ValueError(f"metadata should be 16 bit, got {buffer.dtype}")
if mma_dtype in ["float8", "int8", "uint8"] and buffer.dtype not in ["uint32", "int32"]:
if mma_dtype in ["float8_e4m3", "float8_e5m2", "int8", "uint8"
] and buffer.dtype not in ["uint32", "int32"]:
raise ValueError(f"metadata should be 32 bit, got {buffer.dtype}")
kInterleaved = 2
stride = buffer.shape[0] * kInterleaved
m, k = buffer.shape
group = 32 if buffer.dtype.bits == 16 else 16
interweave = 4 if buffer.dtype.bits == 16 else 2
def ColumnMajorInterleaved(i: int, j: int) -> int:
column_major = j // kInterleaved
column_minor = j % kInterleaved
return column_major * stride + i * kInterleaved + column_minor
i = i // group * group + (i % 8) * interweave + (i % group) // 8
topright = (1 - (i % 2)) & (j % 2)
bottomleft = (i % 2) & (1 - (j % 2))
i += topright - bottomleft
j -= topright - bottomleft
offset = (j // 2) * m * 2 + i * 2 + (j % 2)
return offset // k, offset % k
return T.Layout(buffer.shape, ColumnMajorInterleaved)
def make_metadata_layout(buffer: tvm.tir.Buffer,
mma_dtype: str = "float16",
backend: str = 'cutlass',
arch: str | None = None,
**extra_args):
def make_cutlass_metadata_layout(buffer: tvm.tir.Buffer,
mma_dtype: str = "float16",
arch: str | None = None,
**extra_args):
if arch is None:
arch = nvcc.get_target_compute_version()
compute_version = nvcc.parse_compute_version(arch)
if compute_version >= (9, 0):
if backend == 'cutlass':
return _make_metadata_layout_sm90_cutlass(
buffer=buffer, mma_dtype=mma_dtype, **extra_args)
else:
raise NotImplementedError(f"Arch {arch}, Unsupported backend: {backend}")
return make_cutlass_metadata_layout_sm90(buffer=buffer, mma_dtype=mma_dtype, **extra_args)
elif compute_version >= (8, 0):
if backend == 'cutlass':
return _make_metadata_layout_sm8x_cutlass(buffer=buffer, mma_dtype=mma_dtype)
else:
raise NotImplementedError(f"Arch {arch}, Unsupported backend: {backend}")
return make_cutlass_metadata_layout_sm8x(buffer=buffer, mma_dtype=mma_dtype)
else:
raise NotImplementedError(f"Unsupported architecture: {arch}")
......@@ -10,6 +10,7 @@ from tilelang.utils.tensor import (
get_tensor_supply,
TensorSupplyType,
torch_assert_close,
is_float8_dtype,
)
from tilelang.engine.param import KernelParam
from tilelang.jit.adapter import BaseKernelAdapter
......@@ -125,17 +126,9 @@ class Profiler:
if lhs is not None and rhs is not None:
# in case of numsplit template, the ref output may be None
# which means the value is invalid, so we skip the comparison
def is_float8(tensor: torch.Tensor) -> bool:
return tensor.dtype in {
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
}
torch_assert_close(
lhs if not is_float8(lhs) else lhs.to(torch.float32),
rhs if not is_float8(rhs) else rhs.to(torch.float32),
lhs if not is_float8_dtype(lhs.dtype) else lhs.to(torch.float32),
rhs if not is_float8_dtype(rhs.dtype) else rhs.to(torch.float32),
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio,
......
from .gemm import GemmPy # noqa: F401
from .gemm_sp import GemmSPPy # noqa: F401
......@@ -3,6 +3,7 @@ from tilelang import tvm as tvm
from tvm import tir
from tvm.target import Target
from tvm.ir.base import Node
from tvm.ir import Range
from tvm.runtime import Scriptable
import tvm_ffi
from tilelang.ir import GemmWarpPolicy as GemmWarpPolicy
......@@ -16,13 +17,14 @@ from tilelang.utils.target import target_is_volta
@tvm_ffi.register_global_func("tl.gemm_py.infer_layout")
def gemm_py_infer_layout(gemm_py, target, thread_bounds):
def gemm_py_infer_layout(gemm_py: GemmMMA, target: Target, thread_bounds: Range):
thread_nums = thread_bounds.extent
return gemm_py.infer_layout(target, thread_nums)
@tvm_ffi.register_global_func("tl.gemm_py.lower")
def gemm_py_lower(gemm_py, layout_map, target, thread_bounds, thread_var):
def gemm_py_lower(gemm_py: GemmMMA, layout_map, target: Target, thread_bounds: Range,
thread_var: tir.Var):
thread_nums = thread_bounds.extent
stmt = gemm_py.lower(layout_map, target, thread_nums, thread_var)
return stmt
......
from tilelang import tvm as tvm
from tvm import tir
from tilelang.utils.target import (
target_is_cuda,)
from tvm.target import Target
from tvm.ir.base import Node
from tvm.ir import Range
from tvm.runtime import Scriptable
import tvm_ffi
from tilelang.ir import GemmWarpPolicy
from .gemm_sp_mma import GemmSPMMA
@tvm_ffi.register_global_func("tl.gemm_sp_py.infer_layout")
def gemm_sp_py_infer_layout(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range):
thread_nums = thread_bounds.extent
return gemm_sp_py.infer_layout(target, thread_nums)
@tvm_ffi.register_global_func("tl.gemm_sp_py.lower")
def gemm_sp_py_lower(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range,
thread_var: tir.Var):
thread_nums = thread_bounds.extent
stmt = gemm_sp_py.lower(target, thread_nums, thread_var)
return stmt
@tvm_ffi.register_object("tl.GemmSPPy")
class GemmSPPy(Node, Scriptable):
A: tir.Buffer
E: tir.Buffer
B: tir.Buffer
C: tir.Buffer
APtr: tir.PrimExpr
EPtr: tir.PrimExpr
BPtr: tir.PrimExpr
CPtr: tir.PrimExpr
M: int
N: int
K: int
trans_A: bool
trans_B: bool
stride_A: int
stride_B: int
offset_A: int
offset_B: int
clear_accum: bool
k_pack: int
wg_wait: int
policy: GemmWarpPolicy
def infer_layout(self, target: Target, thread_nums: int):
if target_is_cuda(target):
# TODO(lei): Support more cuda architectures, now mma only
return GemmSPMMA(self).infer_layout(target, thread_nums)
else:
raise ValueError(f"Unsupported target: {target}")
def lower(self, target: Target, thread_nums: int, thread_var: tir.Var):
if target_is_cuda(target):
# TODO(lei): Support more cuda architectures, now mma only
# Now only implement ssr layout
return GemmSPMMA(self).lower(target, thread_nums, thread_var)
else:
raise ValueError(f"Unsupported target: {target}")
from dataclasses import dataclass
from tilelang import tvm as tvm
from tvm.target import Target
from tvm import tir
from tilelang.utils.language import is_shared, is_fragment
from tilelang.ir import GemmWarpPolicy
from tvm.ir.base import Node
@dataclass
class GemmSPBase:
gemm_sp_node: Node
def infer_layout(self, target: Target, thread_nums: int):
raise NotImplementedError("infer_layout is not implemented")
def lower(self, target: Target, thread_nums: int, thread_var: tir.Var):
raise NotImplementedError("lower is not implemented")
def is_gemm_ss(self) -> bool:
return is_shared(self.A) and is_shared(self.B)
def is_gemm_sr(self) -> bool:
return is_shared(self.A) and is_fragment(self.B)
def is_gemm_rs(self) -> bool:
return is_fragment(self.A) and is_shared(self.B)
def is_gemm_rr(self) -> bool:
return is_fragment(self.A) and is_fragment(self.B)
@property
def M(self) -> int:
return self.gemm_sp_node.M
@property
def N(self) -> int:
return self.gemm_sp_node.N
@property
def K(self) -> int:
return self.gemm_sp_node.K
@property
def trans_A(self) -> bool:
return self.gemm_sp_node.trans_A
@property
def trans_B(self) -> bool:
return self.gemm_sp_node.trans_B
@property
def trans_E(self) -> bool:
return self.gemm_sp_node.trans_E
@property
def e_dtype(self) -> str:
return self.E.dtype
@property
def in_dtype(self) -> str:
assert self.A.dtype == self.B.dtype, "A and B must have the same dtype"
return self.A.dtype
@property
def accum_dtype(self) -> str:
return self.C.dtype
@property
def A(self) -> tir.Buffer:
return self.gemm_sp_node.A
@property
def E(self) -> tir.Buffer:
return self.gemm_sp_node.E
@property
def B(self) -> tir.Buffer:
return self.gemm_sp_node.B
@property
def C(self) -> tir.Buffer:
return self.gemm_sp_node.C
@property
def ARegion(self) -> tir.PrimExpr:
return self.gemm_sp_node.ARegion
@property
def ERegion(self) -> tir.PrimExpr:
return self.gemm_sp_node.ERegion
@property
def BRegion(self) -> tir.PrimExpr:
return self.gemm_sp_node.BRegion
@property
def CRegion(self) -> tir.PrimExpr:
return self.gemm_sp_node.CRegion
@property
def stride_A(self) -> int:
return self.gemm_sp_node.stride_A
@property
def stride_B(self) -> int:
return self.gemm_sp_node.stride_B
@property
def offset_A(self) -> int:
return self.gemm_sp_node.offset_A
@property
def offset_B(self) -> int:
return self.gemm_sp_node.offset_B
@property
def clear_accum(self) -> bool:
return self.gemm_sp_node.clear_accum
@property
def k_pack(self) -> int:
return self.gemm_sp_node.k_pack
@property
def wg_wait(self) -> int:
return self.gemm_sp_node.wg_wait
@property
def policy(self) -> GemmWarpPolicy:
return self.gemm_sp_node.policy
from .gemm_sp_base import GemmSPBase
from tilelang.layout import make_swizzled_layout
from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter
from tilelang.utils.language import is_shared, is_fragment
from tilelang import tvm as tvm
from tvm.target import Target
from tvm import tir
from tilelang import language as T
from tilelang.transform.simplify import _Simplify
class GemmSPMMA(GemmSPBase):
def infer_layout(self, target: Target, thread_nums: int):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = SparseTensorCoreIntrinEmitter(
a_dtype=self.in_dtype,
e_dtype=self.e_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
e_transposed=self.trans_E,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
warp_k=self.K,
)
if self.is_gemm_ss():
return {
self.A: make_swizzled_layout(self.A),
self.B: make_swizzled_layout(self.B),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
elif self.is_gemm_sr():
return {
self.A: make_swizzled_layout(self.A),
self.B: mma_emitter.make_mma_load_layout(self.B, matrix="B"),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
elif self.is_gemm_rs():
return {
self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"),
self.B: make_swizzled_layout(self.B),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
elif self.is_gemm_rr():
return {
self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"),
self.B: mma_emitter.make_mma_load_layout(self.B, matrix="B"),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def lower(self, target: Target, thread_nums: int, thread_var: tir.Var):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = SparseTensorCoreIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
e_dtype=self.e_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
e_transposed=self.trans_E,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
warp_k=self.K,
thread_var=thread_var,
)
in_dtype = self.in_dtype
warp_rows = mma_emitter.warp_rows
warp_cols = mma_emitter.warp_cols
local_size_a = mma_emitter.local_size_a
local_size_e = mma_emitter.local_size_e
local_size_b = mma_emitter.local_size_b
micro_size_k = mma_emitter.micro_size_k
A_shared = self.A
E_shared = self.E
B_shared = self.B
C_local = self.C
assert micro_size_k <= self.K, f"K dimension {self.K} should be >= micro size k {micro_size_k}"
if self.is_gemm_ss():
@T.prim_func
def _gemm_ssr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
E_local = T.alloc_local((warp_rows * local_size_e), self.e_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
for ki in T.serial(0, (self.K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)
# Load E into fragment
mma_emitter.ldmatrix_e(
E_local,
E_shared,
ki,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma_sp(A_local, E_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_sr():
B_local = self.B
@T.prim_func
def _gemm_srr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
E_local = T.alloc_local((warp_rows * local_size_e), self.e_dtype)
for ki in T.serial(0, (self.K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)
# Load E into fragment
mma_emitter.ldmatrix_e(
E_local,
E_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma_sp(A_local, E_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
# alloc_buffers body
# insert into parent block
return _Simplify(_gemm_srr, inline_let=True)
elif self.is_gemm_rs():
A_local = self.A
@T.prim_func
def _gemm_rsr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
E_local = T.alloc_local((warp_rows * local_size_e), self.e_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
for ki in T.serial(0, (self.K // micro_size_k)):
# Load E into fragment
mma_emitter.ldmatrix_e(
E_local,
E_shared,
ki,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma_sp(A_local, E_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
elif self.is_gemm_rr():
A_local = self.A
B_local = self.B
@T.prim_func
def _gemm_rrr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
E_local = T.alloc_local((warp_rows * local_size_e), self.e_dtype)
for ki in T.serial(0, (self.K // micro_size_k)):
# Load E into fragment
mma_emitter.ldmatrix_e(
E_local,
E_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma_sp(A_local, E_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rrr, inline_let=True)
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def is_gemm_ss(self) -> bool:
return is_shared(self.A) and is_shared(self.B)
def is_gemm_sr(self) -> bool:
return is_shared(self.A) and is_fragment(self.B)
def is_gemm_rs(self) -> bool:
return is_fragment(self.A) and is_shared(self.B)
def is_gemm_rr(self) -> bool:
return is_fragment(self.A) and is_fragment(self.B)
......@@ -3,6 +3,7 @@ import os
import torch
import warnings
from tilelang.contrib import nvcc
from tilelang.utils.tensor import is_float8_dtype, fp8_remove_negative_zeros_
from torch.utils.cpp_extension import load, _import_module_from_library
from tilelang import env
......@@ -88,7 +89,18 @@ def compress(A: torch.Tensor,
if compute_version >= (9, 0):
return compress_sm90(A, transposed=transposed, **kwargs)
elif compute_version >= (8, 0):
return compress_sm80(A, transposed=transposed)
if transposed:
A = A.t().contiguous()
origin_dtype = A.dtype
if is_float8_dtype(origin_dtype):
fp8_remove_negative_zeros_(A)
A = A.view(torch.int8)
A_sp, E = compress_sm80(A, transposed=False)
if is_float8_dtype(origin_dtype):
A_sp = A_sp.view(origin_dtype)
if transposed:
A_sp = A_sp.t().contiguous()
return A_sp, E
else:
raise ValueError(f"Unsupported CUDA compute version: {compute_version}. "
"Supported versions are sm_80 and sm_90.")
......@@ -105,6 +117,8 @@ def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device='cuda', transp
transposed (bool): If True, returns a transposed tensor of shape (K, M)
"""
elem, group = 2, 4
if dtype == torch.float32:
elem, group = 1, 2
tensor = torch.randn((M, K), dtype=torch.float, device=device).view(M, -1, group)
indice = tensor.topk(elem, dim=-1).indices
tensor.scatter_(-1, indice, 0)
......@@ -114,6 +128,36 @@ def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device='cuda', transp
return tensor.to(dtype) # dtype like float8 might not have randn kernel
def randint_semi_sparse(M: int,
K: int,
low: int,
high: int,
dtype=torch.int32,
device='cuda',
transposed: bool = False):
"""
Generate a random semi-sparse integer tensor. The generated tensor will have 2:4 sparsity along the K dimension.
Args:
M (int): Number of rows
K (int): Number of columns
low (int): Lower bound of the random integers
high (int): Upper bound of the random integers
dtype: Data type of the tensor
device: Device to create the tensor on
transposed (bool): If True, returns a transposed tensor of shape (K, M)
"""
elem, group = 2, 4
if dtype == torch.float32:
elem, group = 1, 2
tensor = torch.randint(low, high, (M, K), dtype=dtype, device=device).view(M, -1, group)
indice = tensor.topk(elem, dim=-1).indices
tensor.scatter_(-1, indice, 0)
tensor = tensor.view(M, K)
if transposed:
tensor = tensor.t().contiguous()
return tensor
def arange_semi_sparse(M: int,
K: int,
dtype=torch.float16,
......@@ -129,6 +173,8 @@ def arange_semi_sparse(M: int,
transposed (bool): If True, returns a transposed tensor of shape (K, M)
"""
elem, group = 2, 4
if dtype == torch.float32:
elem, group = 1, 2
tensor = torch.arange(M * K, dtype=dtype, device=device).view(M, -1, group)
indice = tensor.topk(elem, dim=-1).indices
tensor.scatter_(-1, indice, 0)
......
......@@ -5,6 +5,22 @@ from tvm import tir
import numpy as np
def is_float8_dtype(dtype: torch.dtype) -> bool:
return dtype in {
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
}
def fp8_remove_negative_zeros_(tensor: torch.Tensor):
assert is_float8_dtype(tensor.dtype), "Input tensor must be of float8 dtype"
bits = tensor.view(torch.uint8)
zeros_mask = (tensor == 0)
bits[zeros_mask] = 0x00
class TensorSupplyType(Enum):
Integer = 1
Uniform = 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment