Commit be44758c authored by botbw's avatar botbw Committed by LeiWang1999
Browse files

[Experimental][Language] add `T.GEMM_SP` for sm90 sparse tensor core (#526)



* [experimental] add a draft gemm_sp

* [3rdparty] bump cutlass to v3.9.3

* [lint] run format.sh

* [chore] rebase

* [chore] use abs path

* [gemm_sp] add metadata layout

* [ci] add more example

* [lint] run format.sh

* [chore] polish

* [chore] move gemm_sp to experimental

* [chore] polish

* [lint] run format.sh

* [Enhancement] Improve bulk copy handling and update GEMM sparse tensor test

* Added a warning log for unsupported non-swizzled global layouts in the bulk copy operation, ensuring fallback to normal copy.
* Refactored the GEMM sparse tensor test by removing unnecessary imports and simplifying the kernel compilation process.
* Updated the test to directly call the `run_gemm_sp` function, enhancing clarity and functionality.

* Implement Test

* [Enhancement] Update GEMM SP and SM89 templates for improved functionality

* Refactored GEMM SP computation to enhance warp partitioning logic, ensuring compatibility with Hopper architecture.
* Updated layout inference to support new WGMMA conditions and improved error messaging for unsupported targets.
* Modified SM89 templates to utilize new MMA atom structures, enhancing performance and compatibility with fp8 types.
* Added conditional inclusion for GEMM SP header based on CUDA architecture version.

* lint fix

* [gemm_sp] support more layout and data types

* Enhancement: sync T.gemm_sp's layout inference with T.gemm

* Enhancement: support more block_k in compress util

* [Enhancement] enable block_k=64

* [Lint] run format.sh

* [Enhancement] compressor support more dtype

* Enhancement: enable block_K=32

* [Lint] format.sh

* [Fixbug] fix shape

* Refactor: sync gemm

* [Enhancement] enable transpose

* [Enhancement] enable fp8_e4m3

* [Enhancement] enable int8

* [Lint] run format.sh

* [Benchmark] add gemm_sp benchmark

* [Example] fix 256 threads hang

* [CI] fix ci

* [Chore] resolve gemini feedback

* [Benchmark] increase search space

* [Lint] format

* [CI] skip sparse tensor core related tests as only sm90 is supported

* [CI] pass local run

* Update gemm_sm89.h

* lint fix

* lint fix

* [Enhancement] Add support for sparse GEMM and initialize CUDA architecture flags

- Introduced a new boolean flag `enable_sparse_gemm_` to control the inclusion of sparse GEMM functionality in CUDA code generation.
- Updated the `Finish` method to conditionally include the sparse GEMM header based on the new flag.
- Implemented logic in `VisitStmt_` to enable sparse GEMM when the corresponding external call is detected.
- Added a function to initialize the `TORCH_CUDA_ARCH_LIST` environment variable based on the target compute version, enhancing compatibility with PyTorch.
- Refactored the initialization function into the appropriate module and ensured it is called in the sparse utilities module.

* Update test_compress_utils.py

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent d7aebf4d
...@@ -42,6 +42,7 @@ from .allocate import ( ...@@ -42,6 +42,7 @@ from .allocate import (
) )
from .copy import copy, c2d_im2col # noqa: F401 from .copy import copy, c2d_im2col # noqa: F401
from .gemm import GemmWarpPolicy, gemm # noqa: F401 from .gemm import GemmWarpPolicy, gemm # noqa: F401
from .experimental.gemm_sp import gemm_sp # noqa: F401
from .fill import fill, clear # noqa: F401 from .fill import fill, clear # noqa: F401
from .reduce import ( from .reduce import (
reduce, # noqa: F401 reduce, # noqa: F401
......
"""The language interface for tl programs."""
from tilelang.primitives.gemm.base import GemmWarpPolicy
import tilelang.language as T
from tvm import tir
from typing import Union
def gemm_sp(
A_sparse: Union[tir.Buffer, tir.Var],
E: Union[tir.Buffer, tir.Var],
B: Union[tir.Buffer, tir.Var],
C: Union[tir.Buffer, tir.Var],
transpose_A: bool = False,
transpose_B: bool = False,
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
clear_accum: bool = False,
k_pack: int = 1,
wg_wait: int = 0,
):
"""Perform a Sparse General Matrix Multiplication (GEMM-sp) operation.
This function computes C = A @ B where A and B can optionally be transposed.
The operation supports various warp policies and accumulation modes.
Args:
A_sparse (Union[tir.Buffer, tir.Var]): First input matrix dense values
E (Union[tir.Buffer, tir.Var]): First input matrix sparse metadata
B (Union[tir.Buffer, tir.Var]): Second input matrix
C (Union[tir.Buffer, tir.Var]): Output matrix for results
transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False.
transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False.
policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square.
clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False.
k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1.
wg_wait (int, optional): Warp group wait count. Defaults to 0.
Returns:
tir.Call: A handle to the GEMM operation
Raises:
AssertionError: If the K dimensions of matrices A and B don't match
"""
def legalize_arguments(arg: Union[tir.Buffer, tir.Var]):
"""Convert let-bound variables to their corresponding buffers.
Args:
arg (Union[tir.Buffer, tir.Var]): Input argument to legalize
Returns:
Union[tir.Buffer, tir.Var]: The legalized argument
"""
if isinstance(arg, tir.Var) and T.has_let_value(arg):
return T.get_let_value(arg).buffer
return arg
A_sparse = legalize_arguments(A_sparse)
B = legalize_arguments(B)
C = legalize_arguments(C)
M = C.shape[0]
N = C.shape[1]
K_A = A_sparse.shape[0] if transpose_A else A_sparse.shape[1]
K_B = B.shape[1] if transpose_B else B.shape[0]
assert K_A * 2 == K_B, f"T.gemm_sp K shape check failed: K_A = {K_A}, K_B = {K_B}"
Aptr = A_sparse.access_ptr("r")
Bptr = B.access_ptr("r")
Cptr = C.access_ptr("rw")
Eptr = E.access_ptr("r")
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.gemm_sp"),
Aptr,
Eptr,
Bptr,
Cptr,
transpose_A,
transpose_B,
M,
N,
K_B,
policy,
clear_accum,
k_pack,
wg_wait,
)
...@@ -4,3 +4,4 @@ ...@@ -4,3 +4,4 @@
from .layout import Layout # noqa: F401 from .layout import Layout # noqa: F401
from .fragment import Fragment # noqa: F401 from .fragment import Fragment # noqa: F401
from .swizzle import make_swizzled_layout # noqa: F401 from .swizzle import make_swizzled_layout # noqa: F401
from .gemm_sp import make_metadata_layout # noqa: F401
\ No newline at end of file
"""Wrapping Layouts."""
# pylint: disable=invalid-name, unsupported-binary-operation
import tvm
import tilelang.language as T
import warnings
from typing import List
from math import prod
def decompose_col_major(index_1d: int, basis: List[int]) -> List[int]:
res = []
for x in basis:
res.append(index_1d % x)
index_1d //= x
return res
def __make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str, block_k: int):
if block_k > 128:
block_k = 128
# Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146
warnings.warn(f"block_k {block_k} is too large, set to 128 for {mma_dtype}.", stacklevel=2)
if mma_dtype not in ["float16", "bfloat16", "float32", "int8", "float8"]:
raise NotImplementedError(f"Unsupported dtype: {mma_dtype}")
if buffer.dtype not in ["uint8", "int8"]:
raise ValueError(f"metadata should be 8 bit, got {buffer.dtype}")
bits_map = {
"float16": 16,
"bfloat16": 16,
"float32": 32,
"int8": 8,
"float8": 8,
}
# ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl#L108-L117
# get atom layout according to mma dtype
BlockK = 512 // bits_map[mma_dtype]
if block_k % BlockK != 0:
raise ValueError(f"Tile K is too small, which should be at least {BlockK} for {mma_dtype}")
NumK = block_k // BlockK # block_k is MinTileShapeK
def gen_stride(shape_ik, order):
stride_ik = [None for _ in range(len(shape_ik))]
order = [(i, o) for i, o in enumerate(order)]
order.sort(key=lambda x: x[1])
accu_shape = 1
for i, (o, _) in enumerate(order):
if i == 0:
stride_ik[o] = 1
else:
stride_ik[o] = accu_shape
accu_shape *= shape_ik[o]
return stride_ik
if bits_map[mma_dtype] == 32: # x // 8 is to convert bits into uint8
shape_ik = [8, 2, 4, 8 // 8, 2, NumK]
stride_ik = gen_stride(shape_ik, [3, 1, 5, 0, 4, 2])
shape_i, shape_k = shape_ik[:3], shape_ik[3:]
stride_i, stride_k = stride_ik[:3], stride_ik[3:]
elif bits_map[mma_dtype] == 16:
shape_ik = [8, 2, 4, 16 // 8, 2, NumK]
stride_ik = gen_stride(shape_ik, [3, 1, 5, 0, 4, 2])
shape_i, shape_k = shape_ik[:3], shape_ik[3:]
stride_i, stride_k = stride_ik[:3], stride_ik[3:]
elif bits_map[mma_dtype] == 8:
shape_i, shape_k = [64], [BlockK]
stride_i, stride_k = [BlockK], [1]
else:
raise NotImplementedError(f"Unknown mma type {mma_dtype}")
shape = buffer.shape
# repeat to buffer size in col major
rep_i = (shape[0] + 63) // 64
rep_k = (shape[1] + prod(shape_k) - 1) // prod(shape_k)
rep_i_stride = prod(shape_i + shape_k)
shape_i.append(rep_i)
stride_i.append(rep_i_stride)
rep_k_stirde = prod(shape_i + shape_k)
shape_k.append(rep_k)
stride_k.append(rep_k_stirde)
def transform(i: int, k: int) -> int:
nonlocal shape_i, shape_k, stride_i, stride_k
i_decomposed = decompose_col_major(i, shape_i)
k_decomposed = decompose_col_major(k, shape_k)
i_offset = sum(i_decomposed[k] * stride_i[k] for k in range(len(i_decomposed)))
k_offset = sum(k_decomposed[k] * stride_k[k] for k in range(len(k_decomposed)))
return i_offset + k_offset
return T.Layout(shape, transform)
def make_metadata_layout(buffer: tvm.tir.Buffer,
mma_dtype: str = "float16",
arch: str = "sm90",
backend: str = 'cutlass',
**extra_args):
if arch == "sm90":
if backend == 'cutlass':
return __make_metadata_layout_sm90_cutlass(buffer, mma_dtype, **extra_args)
else:
raise NotImplementedError(f"Arch {arch}, Unsupported backend: {backend}")
else:
raise NotImplementedError(f"Unsupported architecture: {arch}")
import os
import torch
import warnings
from torch.utils.cpp_extension import load, _import_module_from_library
from tilelang.env import TILELANG_CACHE_DIR, TILELANG_TEMPLATE_PATH, CUTLASS_INCLUDE_DIR
# Define paths
compress_util = os.path.join(TILELANG_TEMPLATE_PATH, "tl_templates/cuda/compress_sm90.cu")
# Cache directory for compiled extensions
_CACHE_DIR = os.path.join(TILELANG_CACHE_DIR, "sparse_compressor")
os.makedirs(_CACHE_DIR, exist_ok=True)
def _get_cached_lib():
name = 'compress_lib'
cached_path = os.path.join(_CACHE_DIR, f"{name}.so")
if os.path.exists(cached_path):
try:
return _import_module_from_library(name, cached_path)
except Exception:
# If loading fails, recompile
pass
from tilelang.env import _initialize_torch_cuda_arch_flags
# Set TORCH_CUDA_ARCH_LIST
_initialize_torch_cuda_arch_flags()
# Compile if not cached or loading failed
return load(
name=name,
sources=[compress_util],
extra_cuda_cflags=[
'-O2',
'-std=c++17',
'-lineinfo',
f'-I{CUTLASS_INCLUDE_DIR}',
f'-I{CUTLASS_INCLUDE_DIR}/../tools/util/include',
'-arch=sm_90',
],
build_directory=_CACHE_DIR,
)
def compress_sm90(A: torch.Tensor, block_k: int,
transposed: bool) -> tuple[torch.Tensor, torch.Tensor]:
if block_k > 128:
block_k = 128
# Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146
warnings.warn(
f"block_k {block_k} is too large, set to 128 for sm90 compression.", stacklevel=2)
# Load the library (will use cache if available)
compress_lib = _get_cached_lib()
return compress_lib.compress_sm90(A, block_k, transposed)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment