Unverified Commit 283a9a00 authored by botbw's avatar botbw Committed by GitHub
Browse files

[Language] support `T.gemm_sp_v2` on sm80 and sm89 (#1056)

* [misc] add a cpp side wrapper for gemm_sp_py

* [misc] typing

* [IR] bind GemmSPWarpPolicy

* [chore] add wrapper code

* [IR] fix GemmSPWarpPolicy

* [codegen] apply ptxas instructions

* [intrinsic] add typical (unused) mma layout

* [template] add uint16 debug func

* [intrinsic] add b matrix layout

* [gemm_sp] enable fp16/bf16 on sm8x

* [layout] refactor fp16/bf16 layout

* [gemm_sp] enable int8

* [chore] update test case dtype

* [gemm_sp] enable fp32

* [layout] refactor layouts

* [intrinsic] enable ldmatrix for mat A

* [layout] enable ldsm for matrix b

* [layout] add ldmatrix for fp32 and fp8

* [chore] refine

* [chore] refactor

* [chore] add fp8 efactor

* [chore] refactor

* [chore] add remove negative zero util

* [example] add a custom compress kernel

* [chore] minor update

* [test] refactor gemm_sp test

* [refactor] make metadata layout func

* [example] add option for using cutlass layout

* [doc] add a gemm_sp doc

* [doc] minor polish

* [chore] remove unused

* [bugfix] fix non replicate b case

* [test] refactor

* [chore] add a check

* [bugfix] fix util bug

* [wip] init a new test case for v2

* [chore] minor refactor

* [chore] minor update

* [bugfix] enable 16bit rs

* [language] enable rs

* [language] enable gemm_sp_sr

* [language] enable gemm_sp_rr

* [test] enable more tests

* [tvm] update ffi binding

* [chore] remove print

* [chore] fix benchmark script

* [lint] precommit lint

* [chore] apply feedback

* [test] use arch 8.0

* [chore] rollback ::ordered_metadata for backward compatibility

* [bugfix] fix captialized

* [example] keep gemm_sp on hopper

* [test] fix no fp8 normal kernel

* [test] reduce matmul size to satisfy accum error

* [test] use cal_diff for assertion

* [bugfix] expand float8 type

* [lib] add make_int4 for short type

* [language] add transpose E

* [bugfix] fix wrong var

* [format] format

* [chore] refactor binding

* [chore] fix wrong passing var
parent b10ef75f
...@@ -9,7 +9,7 @@ import tilelang.language as T ...@@ -9,7 +9,7 @@ import tilelang.language as T
from tilelang.autotuner import autotune from tilelang.autotuner import autotune
from tilelang import jit from tilelang import jit
from tilelang.contrib import nvcc from tilelang.contrib import nvcc
from tilelang.layout import make_metadata_layout from tilelang.layout import make_cutlass_metadata_layout
# Configure logger # Configure logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -86,7 +86,7 @@ def get_configs(M, N, K): ...@@ -86,7 +86,7 @@ def get_configs(M, N, K):
return configs return configs
def matmul_sp(M, N, K, accum_dtype): def matmul_sp(M, N, K, in_dtype, accum_dtype):
""" """
Create an autotuned matrix multiplication kernel for matrices of shape: Create an autotuned matrix multiplication kernel for matrices of shape:
- A: (M, K) - A: (M, K)
...@@ -161,14 +161,13 @@ def matmul_sp(M, N, K, accum_dtype): ...@@ -161,14 +161,13 @@ def matmul_sp(M, N, K, accum_dtype):
""" """
# Use half-precision for input data to reduce memory bandwidth, # Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy # accumulate in float for better numerical accuracy
dtype = "float16"
e_factor, e_dtype = ARCH_INFO[arch] e_factor, e_dtype = ARCH_INFO[arch]
@T.prim_func @T.prim_func
def main( def main(
A_sparse: T.Tensor((M, K // 2), dtype), A_sparse: T.Tensor((M, K // 2), in_dtype),
E: T.Tensor((M, K // e_factor), e_dtype), E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), in_dtype),
C: T.Tensor((M, N), accum_dtype), C: T.Tensor((M, N), accum_dtype),
): ):
""" """
...@@ -187,9 +186,9 @@ def matmul_sp(M, N, K, accum_dtype): ...@@ -187,9 +186,9 @@ def matmul_sp(M, N, K, accum_dtype):
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K) # Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K // 2), dtype) A_shared = T.alloc_shared((block_M, block_K // 2), in_dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K) # Allocate shared memory for B sub-block of shape (block_N, block_K)
B_shared = T.alloc_shared((block_K, block_N), dtype) B_shared = T.alloc_shared((block_K, block_N), in_dtype)
# Allocate shared memory for E sub-block of shape (block_M, block_K // E_factor) # Allocate shared memory for E sub-block of shape (block_M, block_K // E_factor)
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
# Allocate a local fragment for intermediate accumulation # Allocate a local fragment for intermediate accumulation
...@@ -204,11 +203,9 @@ def matmul_sp(M, N, K, accum_dtype): ...@@ -204,11 +203,9 @@ def matmul_sp(M, N, K, accum_dtype):
T.use_swizzle(panel_size=10, enable=enable_rasterization) T.use_swizzle(panel_size=10, enable=enable_rasterization)
T.annotate_layout({ T.annotate_layout({
E: E:
make_metadata_layout( make_cutlass_metadata_layout(E, mma_dtype=in_dtype, block_k=block_K),
E, mma_dtype="float16", backend="cutlass", block_k=block_K),
E_shared: E_shared:
make_metadata_layout( make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, block_k=block_K),
E_shared, mma_dtype="float16", backend="cutlass", block_k=block_K),
}) })
# Loop over sub-blocks in K dimension, pipelined by num_stages # Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
...@@ -220,7 +217,7 @@ def matmul_sp(M, N, K, accum_dtype): ...@@ -220,7 +217,7 @@ def matmul_sp(M, N, K, accum_dtype):
T.copy(B[k * block_K, bx * block_N], B_shared) T.copy(B[k * block_K, bx * block_N], B_shared)
# Perform a partial matrix multiplication: # Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared # C_local += A_shared @ B_shared
T.gemm_sp( T.gemm_sp_v2(
A_shared, A_shared,
E_shared, E_shared,
B_shared, B_shared,
...@@ -268,7 +265,7 @@ if __name__ == "__main__": ...@@ -268,7 +265,7 @@ if __name__ == "__main__":
total_flops = 2 * M * N * K total_flops = 2 * M * N * K
# matmul(...) returns (best_latency, best_config, ref_latency) # matmul(...) returns (best_latency, best_config, ref_latency)
best_result = matmul_sp(M, N, K, args.accum_dtype) best_result = matmul_sp(M, N, K, "float16", args.accum_dtype)
best_latency = best_result.latency best_latency = best_result.latency
best_config = best_result.config best_config = best_result.config
A = torch.randn(M, K, dtype=torch.float16, device="cuda") A = torch.randn(M, K, dtype=torch.float16, device="cuda")
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
# Sparse Matrix-Matrix Multiplication with Tile Library
<div style="text-align: left;">
<em>Author:</em> <a href="https://github.com/botbw">botbw</a>
</div>
:::{warning}
This document is still **experimental** and may be incomplete.
This feature is still **experimental** and need further optimization.
Suggestions and improvements are highly encouraged—please submit a PR!
:::
:::{tip}
It's suggested to go through `docs/deeplearning_operators/matmul.md` first.
Example code can be found at `examples/gemm_sp`.
:::
## Structured sparsity in the NVIDIA Ampere architecture
Since the Ampere architecture (sm80 and above), sparsity support has been integrated into Tensor Cores. This allows a 2:4 (or 1:2 for 32-bit data types) semi-structured matrix to be compressed into its non-zero values along with associated metadata, which can then be fed into the Tensor Core. This enables up to **2x throughput** compared to the equivalent dense computation.
:::{warning}
This tutorial primarily focuses on CUDA, as this feature is not yet supported on ROCm. However, AMD provides a similar capability in the matrix cores of GPUs such as the MI300X.
:::
```{figure} ../_static/img/sparse_mma_storage_example.png
:align: center
Figure: Sparse MMA storage example (from PTX doc)
```
## Compress a dense tensor
To utilize sparse Tensor Cores, a dense tensor must first be **compressed** into its non-zero values along with the corresponding metadata.
Both `PyTorch` and `vLLM` use `CUTLASS` as their computation backend (see references [here](https://github.com/pytorch/pytorch/blob/a8d6afb511a69687bbb2b7e88a3cf67917e1697e/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu#L47) and [here](https://github.com/vllm-project/vllm/blob/a5dd03c1ebc5e4f56f3c9d3dc0436e9c582c978f/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh#L116)), leveraging `CUTLASS`’s built-in compressor (or reimplementing it in `PyTorch`).
A set of **CUTLASS-compatible** compressors is provided in `tilelang.utils.sparse`, where a dense tensor—along with other required arguments (e.g., block_K for sm90, transpose options)—can be passed in to perform the compression.
```python
from tilelang.utils.sparse import compress
A_sparse, E = compress(A, transposed=trans_A, block_k=block_K)
```
Here, `A_sparse` contains all the non-zero elements of `A`, while `E` stores the corresponding metadata (indexing information) required to reconstruct the original sparse pattern.
> NOTE: When using CUTLASS compressor, there is no naive position correspondence between the positions in `A_sparse`/`A` and `E`. (i.e. the 4-element group at [n, k] doesn't match the 4-bit metadata at [n, k] if you consider metadata as int4 tensor)
The metadata is reordered internally to optimize memory access patterns (e.g., for ldsm instructions and vectorized loads).
For more information, see **A note on `gemm_sp` and `gemm_sp_v2`**.
## `T.gemm_sp` with CUTLASS's compressor
:::{warning}
It is strongly recommended to use T.gemm_sp_v2 due to its greater flexibility and faster compilation time.
:::
A 2:4 sparse GEMM kernel is similar to its dense counterpart, except that it also requires handling the associated metadata.
Check comments in below kernel code for required modification.
```python
def matmul_sp_sm80(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
trans_A,
trans_B,
):
is_8_bit = "8" in in_dtype
metadata_dtype = 'int32' if is_8_bit else 'int16'
E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype] # Calculate shape for given datatypes
A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
B_shape = (K, N) if not trans_B else (N, K)
A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K)
import tilelang.language as T
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) # Allocate smem for metadata
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({ # Annotate reordered cutlass metadata layout
E:
make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
E_shared:
make_cutlass_metadata_layout(
E_shared, mma_dtype=in_dtype, arch="8.0"),
})
T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
if trans_A:
T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared)
else:
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B) # Call gemm_sp with non-zero values and metadata
T.copy(C_frag, C[by * block_M, bx * block_N])
return main
```
Under the hood, `gemm_sp` invokes templates adapted from `CUTLASS`, and a compatible metadata layout must be specified using `T.annotate_layout`.
## `T.gemm_sp_v2` with a custom compressor
To migrate to `gemm_sp_v2`, simply replace occurrences of `gemm_sp`.
Unlike `gemm_sp`, `gemm_sp_v2` can operate without `T.annotate_layout`, and it also supports user-defined layouts and compressors.
The metadata is stored in a `(u)int8`/`(u)int16`/`(u)int32` tensor, where **each 4-bit chunk represents two 2-bit indices** of non-zero elements within four consecutive elements. Here, we start with an `int16` example, which is the **default dtype** for `bf16` and `fp16` on Ampere GPUs.
Suppose we have the following row vector:
```python
t = tensor([[0, 7, 0, 3], [1, 5, 0, 0], [0, 0, 2, 4], [9, 0, 9, 0]], dtype=torch.float16).flatten()
```
The non-zero elements and their corresponding indices are:
```python
t_sp = tensor([[7, 3], [1, 5], [2, 4], [9, 9]], dtype=torch.float16).flatten()
indices = tensor([[1, 3], [0, 1], [2, 3], [0, 2]], dtype=torch.float16).flatten()
```
The corresponding uint16 metadata is:
```python
# metadata_bits = tensor([0b1101, 0b0100, 0b1110, 0b1000])
# Note: storage uses little-endian order: tensor(0b1000111001001101, dtype=torch.int16)
# Note: the above code is not runnable in python as the interpreter won't take the binary
# as 2's complement
metadata_int16 = tensor(-29107)
```
You can decode an int16 metadata tensor using the following utility:
```python
def decode_metadata(meta: torch.Tensor) -> torch.Tensor:
assert meta.dtype is torch.int16
groups_per_meta = 16 // 4
out = []
for g in range(groups_per_meta):
group_bits = (meta >> (g * 4)) & 0xF
idx0 = group_bits & 0x3
idx1 = (group_bits >> 2) & 0x3
out.append(torch.stack([idx0, idx1], dim=-1))
return torch.concat(out, dim=-1).view(meta.shape[0], -1)
```
The compressor can be implement at either `PyTorch`/`NumPy` level or kernel level.
For example, `PyTorch` provides an Ampere compressor [here](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L47-L179). Note that in this implementation, a [permutation](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L173-L175) is applied to match CUTLASS’s metadata layout. If you do not annotate a metadata layout when using `gemm_sp_v2`, your compressor should replicate the same behavior as the PyTorch example—but without using the `_calculate_meta_reordering_scatter_offsets` function.
If you want to use a custom metadata layout in your kernel, one approach is to define the layout in `TileLang` and then apply the same layout to both your compressor kernel and the matmul_sp kernel.
```python
@tilelang.jit(out_idx=[1, 2], pass_configs={
tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True,
})
def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
e_factor, e_dtype = ARCH_INFO["8.0"]
e_K = K // e_factor
elem, group = 2, 4
assert M % block_M == 0, "M must be divisible by block_M"
assert K % block_K == 0, "K must be divisible by block_K"
assert K % e_factor == 0, "K must be divisible by e_factor"
assert block_K % e_factor == 0, "block_K must be divisible by e_factor"
@T.prim_func
def kernel(
A: T.Tensor((M, K), dtype),
A_sp: T.Tensor((M, K // 2), dtype),
E: T.Tensor((M, e_K), e_dtype),
):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype)
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
if use_cutlass_layout: # NOTE: Make sure compressor metadata layout
T.annotate_layout({ # is same with your computation kernel
E:
make_cutlass_metadata_layout(
E, mma_dtype="float16", arch="8.0", block_k=block_K),
E_shared:
make_cutlass_metadata_layout(
E_shared,
mma_dtype="float16",
arch="8.0",
block_k=block_K),
})
T.clear(A_sp_shared)
T.clear(E_shared)
non_zero_cnt = T.alloc_local((1, ), dtype="uint8")
non_zero_elt_log_idx = T.alloc_local((elem, ), dtype="uint8")
T.copy(A[bx * block_M, by * block_K], A_shared)
for tm in T.Parallel(block_M):
for g_i in range(0, block_K // group):
a_k = g_i * group
T.clear(non_zero_cnt)
T.clear(non_zero_elt_log_idx)
for i in range(group):
val = A_shared[tm, a_k + i]
if val != 0.0:
non_zero_elt_log_idx[non_zero_cnt[0]] = i
A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val
non_zero_cnt[0] += 1
if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3:
non_zero_elt_log_idx[0] = 0
non_zero_elt_log_idx[1] = 3
A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2]
A_sp_shared[tm, a_k // 2] = 0.0
elif non_zero_cnt[0] == 1:
A_sp_shared[tm, a_k // 2 + 1] = 0
non_zero_elt_log_idx[1] = 3
for i in T.serial(elem):
val = non_zero_elt_log_idx[i]
E_shared[tm, a_k // e_factor] |= T.shift_left(val, 4 * (g_i % (e_factor // group)) + 2 * i)
T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2])
T.copy(E_shared, E[bx * block_M, by * block_K // e_factor])
return kernel
```
## A note on `gemm_sp` and `gemm_sp_v2`
Initially, `T.gemm_sp` followed the same design as `T.gemm`, lowering to a `CUTLASS` template. This inherently requires metadata to be reordered offline following a predetermined layout.
However, fixing a specific layout introduces several potential issues:
1. Painful debugging experience: Debugging a failed kernel becomes difficult due to the reordered indexing, including permutations and swizzling.
2. Limited flexibility: For example, concatenating two compressed tensors, such as `A_sparse_0` and `A_sparse_1`, into a new `A_sparse` makes sense. However, concatenating their metadata `E_0` and `E_1` may not be valid unless the layout allows it mathematically.
3. Alignment requirements: `CUTLASS` enforces strict alignment checks, and many hyperparameter configurations can lead to compilation errors. (For reference, sm8x was implemented in `CUTLASS 2`.)
`T.gemm_sp_v2` was designed to address these limitations, following the approach of `T.gemm_v2`. It lowers directly to PTX, removing the need for a fixed metadata layout.
\ No newline at end of file
...@@ -33,6 +33,7 @@ tutorials/auto_tuning ...@@ -33,6 +33,7 @@ tutorials/auto_tuning
deeplearning_operators/elementwise deeplearning_operators/elementwise
deeplearning_operators/gemv deeplearning_operators/gemv
deeplearning_operators/matmul deeplearning_operators/matmul
deeplearning_operators/matmul_sparse
deeplearning_operators/deepseek_mla deeplearning_operators/deepseek_mla
::: :::
......
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
import argparse
import tilelang
import tilelang.language as T
from tilelang.layout import make_cutlass_metadata_layout
from tilelang.utils.sparse import randn_semi_sparse
from tilelang.utils.tensor import torch_assert_close
from triton.testing import do_bench
import torch
torch.manual_seed(42)
DEFAULT_CONFIG = { # take best config from autotune script
"4090": {
'float': {
'block_M': 128,
'block_N': 64,
'block_K': 64,
'num_stages': 1,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
},
'float16': {
'block_M': 256,
'block_N': 128,
'block_K': 64,
'num_stages': 2,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
}
},
"h20": {
'float': {
'block_M': 128,
'block_N': 64,
'block_K': 128,
'num_stages': 3,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
},
'float16': {
'block_M': 128,
'block_N': 64,
'block_K': 128,
'num_stages': 3,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
}
}
}
ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
@tilelang.jit(out_idx=[-1])
def matmul_sp_fp16_custom_compress(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages,
thread_num, policy, enable_rasterization, use_cutlass_layout):
e_factor, e_dtype = (16, "int16")
@T.prim_func
def gemm_sp_fp16_custom_compress(
A_sparse: T.Tensor((M, K // 2), 'float16'),
E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((K, N), 'float16'),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K // 2), 'float16')
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
B_shared = T.alloc_shared((block_K, block_N), 'float16')
C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
if use_cutlass_layout:
T.annotate_layout({
E:
make_cutlass_metadata_layout(
E, mma_dtype="float16", arch="8.0", block_k=block_K),
E_shared:
make_cutlass_metadata_layout(
E_shared, mma_dtype="float16", arch="8.0", block_k=block_K),
})
T.clear(C_local)
T.disable_warp_group_reg_alloc()
T.use_swizzle(panel_size=10, enable=enable_rasterization)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
T.copy(E[by * block_M, k * block_K // e_factor], E_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm_sp_v2(A_shared, E_shared, B_shared, C_local, False, False, policy=policy)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return gemm_sp_fp16_custom_compress
def torch_compress(dense):
"""
A naive compression function, where each 4-bit meta matches 4 elements in original matrix in row major layout.
"""
if dense.dim() != 2:
raise RuntimeError(
f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor")
m, k = dense.shape
meta_dtype = torch.int8
if dense.dtype == torch.int8:
meta_dtype = torch.int32
elif dense.dtype in [torch.half, torch.bfloat16, torch.float]:
meta_dtype = torch.int16
else:
raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
if quadbits_per_meta_elem not in (4, 8):
raise RuntimeError("Invalid number of elements per meta element calculated")
if meta_dtype == torch.int32:
if m % 16 != 0:
raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 16")
else:
if m % 32 != 0:
raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 32")
if k % (4 * quadbits_per_meta_elem) != 0:
raise RuntimeError(
f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}"
)
if dense.dtype != torch.float:
ksparse = 4
dense_4 = dense.view(-1, k // ksparse, ksparse)
m0, m1, _m2, m3 = (dense_4 != 0).unbind(-1)
else:
ksparse = 2
dense_2 = dense.view(-1, k // ksparse, ksparse)
m0, _m2 = m1, m3 = (dense_2 != 0).unbind(-1)
meta_ncols = k // (ksparse * quadbits_per_meta_elem)
# Encoding quadruples of True/False values as follows:
# [True, True, False, False] -> 0b0100
# [True, False, True, False] -> 0b1000
# [False, True, True, False] -> 0b1001
# [True, False, False, True ] -> 0b1100
# [False, True, False, True ] -> 0b1101
# [False, False, True, True ] -> 0b1110
# Thus, lower two bits in the encoding are index of the True value
# at the lowest index in the quadruple, and the higher two bits in
# the encoding are index of the other True value in the quadruple.
# In case there are less than two True values, than False value or
# values at some index or indices are considered True for the
# encoding. In case there are more than two True values, then the
# excess True value(s) at some indices are considered False for
# the encoding. The exact encodings used for these cases are as
# follows:
# [False, False, False, False] -> 0b1110
# [False, False, False, True ] -> 0b1110
# [False, False, True, False] -> 0b1110
# [False, True, False, False] -> 0b1001
# [False, True, True, True ] -> 0b1101
# [True, False, False, False] -> 0b1000
# [True, False, True, True ] -> 0b1100
# [True, True, False, True ] -> 0b0100
# [True, True, True, False] -> 0b0100
# [True, True, True, True ] -> 0b0100
# These particular encodings are chosen, with the help of Espresso
# logic minimizer software, for the purpose of minimization of
# corresponding Boolean functions, that translate non-zero flags
# into encoding bits. Note also possible choices for the first
# and last of these encodings were limited only to (0b0100,
# 0b1110), in order to produce valid encodings for 1:2 sparsity
# case.
expr0 = m0 & m1
expr1 = ~m0 & m1
expr2 = ~m0 & ~m1
bit0 = expr1
bit1 = expr2
bit2 = expr0 | expr2 | m3
bit3 = expr1 | ~m1
idxs0 = bit0 | (bit1.to(torch.int64) << 1)
idxs1 = bit2 | (bit3.to(torch.int64) << 1)
if dense.dtype != torch.float:
sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined]
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
else:
sparse = dense_2.gather(-1,
idxs0.unsqueeze(-1) // 2).view(
m, k // 2) # type: ignore[possibly-undefined]
meta_4 = idxs0 | (idxs1 << 2)
meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
if quadbits_per_meta_elem == 4:
meta = (
meta_n[:, :, 0]
| (meta_n[:, :, 1] << 4)
| (meta_n[:, :, 2] << 8)
| (meta_n[:, :, 3] << 12))
elif quadbits_per_meta_elem == 8:
meta = (
meta_n[:, :, 0]
| (meta_n[:, :, 1] << 4)
| (meta_n[:, :, 2] << 8)
| (meta_n[:, :, 3] << 12)
| (meta_n[:, :, 4] << 16)
| (meta_n[:, :, 5] << 20)
| (meta_n[:, :, 6] << 24)
| (meta_n[:, :, 7] << 28))
return (sparse, meta)
def decode_metadata(meta: torch.Tensor) -> torch.Tensor:
assert meta.dtype is torch.int16
groups_per_meta = 16 // 4 # 4 groups per uint16
out = []
for g in range(groups_per_meta):
group_bits = (meta >> (g * 4)) & 0xF
idx0 = group_bits & 0x3
idx1 = (group_bits >> 2) & 0x3
out.append(torch.stack([idx0, idx1], dim=-1))
return torch.concat(out, dim=-1).view(meta.shape[0], -1)
@tilelang.jit(
out_idx=[1, 2], pass_configs={
tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True,
})
def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
e_factor, e_dtype = ARCH_INFO["8.0"]
e_K = K // e_factor
elem, group = 2, 4
assert M % block_M == 0, "M must be divisible by block_M"
assert K % block_K == 0, "K must be divisible by block_K"
assert K % e_factor == 0, "K must be divisible by e_factor"
assert block_K % e_factor == 0, "block_K must be divisible by e_factor"
@T.prim_func
def kernel(
A: T.Tensor((M, K), dtype),
A_sp: T.Tensor((M, K // 2), dtype),
E: T.Tensor((M, e_K), e_dtype),
):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype)
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
if use_cutlass_layout:
T.annotate_layout({
E:
make_cutlass_metadata_layout(
E, mma_dtype="float16", arch="8.0", block_k=block_K),
E_shared:
make_cutlass_metadata_layout(
E_shared, mma_dtype="float16", arch="8.0", block_k=block_K),
})
T.clear(A_sp_shared)
T.clear(E_shared)
# TODO: alloc_var seems buggy here
non_zero_cnt = T.alloc_local((1,), dtype="uint8")
non_zero_elt_log_idx = T.alloc_local((elem,), dtype="uint8")
T.copy(A[bx * block_M, by * block_K], A_shared)
for tm in T.Parallel(block_M):
for g_i in range(0, block_K // group):
a_k = g_i * group
T.clear(non_zero_cnt)
T.clear(non_zero_elt_log_idx)
for i in range(group):
val = A_shared[tm, a_k + i]
if val != 0.0:
non_zero_elt_log_idx[non_zero_cnt[0]] = i
A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val
non_zero_cnt[0] += 1
# TODO: use T.device_assert(non_zero_cnt <= 2) after rebasing main
if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3:
non_zero_elt_log_idx[0] = 0
non_zero_elt_log_idx[1] = 3
A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2]
A_sp_shared[tm, a_k // 2] = 0.0
elif non_zero_cnt[0] == 1:
A_sp_shared[tm, a_k // 2 + 1] = 0
non_zero_elt_log_idx[1] = 3
for i in T.serial(elem):
val = non_zero_elt_log_idx[i]
E_shared[tm, a_k // e_factor] |= T.shift_left(
val, 4 * (g_i % (e_factor // group)) + 2 * i)
T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2])
T.copy(E_shared, E[bx * block_M, by * block_K // e_factor])
return kernel
def main():
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument(
"--use_cutlass_layout", action='store_true', help="Use cutlass layout for E tensor")
parser.add_argument(
"--use_torch_compressor", action='store_true', help="Use torch sparse for reference")
parser.add_argument(
"--accum_dtype",
type=str,
default="float",
choices=["float", "float16"],
help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090"], default="4090")
args = parser.parse_args()
kernel = matmul_sp_fp16_custom_compress(
args.m,
args.n,
args.k,
args.accum_dtype,
**DEFAULT_CONFIG[args.cfg][args.accum_dtype],
use_cutlass_layout=args.use_cutlass_layout)
a = randn_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half)
b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half)
if args.use_torch_compressor:
assert not args.use_cutlass_layout, "torch sparse must be used with naive layout"
a_sparse, e = torch_compress(a)
else:
a_sparse, e = compress_kernel(
args.m, args.k, 32, 32, "float16", use_cutlass_layout=args.use_cutlass_layout)(
a)
c = kernel(a_sparse, e, b)
ref_c = a @ b
assert not c.isnan().any(), "Reference result contains NaNs, please report an issue"
torch_assert_close(c, ref_c.to(c.dtype), rtol=1e-3, atol=1e-3)
print(
f"Precision check passed. Max diff: {(c - ref_c).abs().max()}, Mean diff: {(c - ref_c).abs().mean()}"
)
latency = do_bench(lambda: kernel(a_sparse, e, b))
ref_latency = do_bench(lambda: a @ b)
total_flops = 2 * args.m * args.n * args.k
tflops = total_flops / latency / 1e9
ref_tflops = total_flops / ref_latency / 1e9
print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency/1e3} s")
print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency/1e3:} s")
if __name__ == "__main__":
main()
...@@ -5,7 +5,7 @@ import argparse ...@@ -5,7 +5,7 @@ import argparse
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
from tilelang.layout import make_metadata_layout from tilelang.layout import make_cutlass_metadata_layout
from tilelang.utils.sparse import compress, randn_semi_sparse from tilelang.utils.sparse import compress, randn_semi_sparse
from tilelang.contrib import nvcc from tilelang.contrib import nvcc
from triton.testing import do_bench from triton.testing import do_bench
...@@ -14,9 +14,7 @@ import torch ...@@ -14,9 +14,7 @@ import torch
arch = nvcc.get_target_compute_version() arch = nvcc.get_target_compute_version()
ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} DEFAULT_CONFIG = { # take best config from autotune script
default_config = { # take best config from autotune script
"4090": { "4090": {
'float': { 'float': {
'block_M': 128, 'block_M': 128,
...@@ -59,6 +57,8 @@ default_config = { # take best config from autotune script ...@@ -59,6 +57,8 @@ default_config = { # take best config from autotune script
} }
} }
ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy,
...@@ -84,15 +84,11 @@ def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, ...@@ -84,15 +84,11 @@ def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages,
T.use_swizzle(panel_size=10, enable=enable_rasterization) T.use_swizzle(panel_size=10, enable=enable_rasterization)
T.annotate_layout({ T.annotate_layout({
E: E:
make_metadata_layout( make_cutlass_metadata_layout(
E, mma_dtype="float16", backend="cutlass", block_k=block_K, arch=arch), E, mma_dtype="float16", block_k=block_K, arch=arch),
E_shared: E_shared:
make_metadata_layout( make_cutlass_metadata_layout(
E_shared, E_shared, mma_dtype="float16", block_k=block_K, arch=arch),
mma_dtype="float16",
backend="cutlass",
block_k=block_K,
arch=arch),
}) })
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
...@@ -117,10 +113,10 @@ def main(): ...@@ -117,10 +113,10 @@ def main():
default="float", default="float",
choices=["float", "float16"], choices=["float", "float16"],
help="Accumulation datatype") help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090", "h20"], required=True) parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090")
args = parser.parse_args() args = parser.parse_args()
kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype, kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype,
**default_config[args.cfg][args.accum_dtype]) **DEFAULT_CONFIG[args.cfg][args.accum_dtype])
a = randn_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half) a = randn_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half)
b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half) b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half)
...@@ -128,7 +124,7 @@ def main(): ...@@ -128,7 +124,7 @@ def main():
a_sparse, e = compress( a_sparse, e = compress(
a, a,
transposed=False, transposed=False,
block_k=default_config[args.cfg][args.accum_dtype]['block_K'], block_k=DEFAULT_CONFIG[args.cfg][args.accum_dtype]['block_K'],
arch=arch) arch=arch)
c = kernel(a_sparse, e, b) c = kernel(a_sparse, e, b)
......
import tilelang.testing
import example_custom_compress
import example_gemm_sp
def test_example_custom_compress():
example_custom_compress.main()
def test_example_gemm_sp():
example_gemm_sp.main()
if __name__ == "__main__":
tilelang.testing.main()
import torch import torch
import tilelang import tilelang
from tilelang.utils.sparse import compress_sm90 from tilelang.utils.sparse import compress_sm90
from tilelang.layout import make_metadata_layout from tilelang.layout import make_cutlass_metadata_layout
import tilelang.testing import tilelang.testing
...@@ -40,15 +40,11 @@ def matmul_sp( ...@@ -40,15 +40,11 @@ def matmul_sp(
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({ T.annotate_layout({
E: E:
make_metadata_layout( make_cutlass_metadata_layout(
E, mma_dtype="float16", arch="9.0", backend="cutlass", block_k=block_K), E, mma_dtype="float16", arch="9.0", block_k=block_K),
E_shared: E_shared:
make_metadata_layout( make_cutlass_metadata_layout(
E_shared, E_shared, mma_dtype="float16", arch="9.0", block_k=block_K),
mma_dtype="float16",
arch="9.0",
backend="cutlass",
block_k=block_K),
}) })
T.clear(C_local) T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
......
...@@ -307,7 +307,20 @@ TIR_REGISTER_TL_TILE_OP(GemmSP, gemm_sp) ...@@ -307,7 +307,20 @@ TIR_REGISTER_TL_TILE_OP(GemmSP, gemm_sp)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK() { GemmSPNode::RegisterReflection(); } TVM_REGISTER_OP("tl.GemmSPWarpPolicy")
.set_attr<TScriptPrinterName>("TScriptPrinterName", "GemmSPWarpPolicy");
TVM_FFI_STATIC_INIT_BLOCK() {
GemmSPNode::RegisterReflection();
GemmSPWarpPolicyNode::RegisterReflection();
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def(
"tl.GemmSPWarpPolicyComputeWarpPartition",
[](GemmSPWarpPolicy policy, int M, int N, int block_size, Target target,
bool use_wgmma, int bits) {
policy->computeWarpPartition(M, N, block_size, target, use_wgmma, bits);
return;
});
}
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -23,6 +23,14 @@ public: ...@@ -23,6 +23,14 @@ public:
int bits) const; int bits) const;
TVM_FFI_DECLARE_OBJECT_INFO("tl.GemmSPWarpPolicy", GemmSPWarpPolicyNode, TVM_FFI_DECLARE_OBJECT_INFO("tl.GemmSPWarpPolicy", GemmSPWarpPolicyNode,
GemmWarpPolicyNode); GemmWarpPolicyNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GemmSPWarpPolicyNode>()
.def_ro("policy_type", &GemmSPWarpPolicyNode::policy_type)
.def_ro("m_warp", &GemmSPWarpPolicyNode::m_warp)
.def_ro("n_warp", &GemmSPWarpPolicyNode::n_warp);
}
}; };
class GemmSPWarpPolicy : public ObjectRef { class GemmSPWarpPolicy : public ObjectRef {
......
/*!
* \file tl/op/gemm_sp_py.cc
* \brief Implementation of Sparse General Matrix Multiplication (GEMM_SP)
* operators
*/
#include "gemm_sp_py.h"
#include "utils.h"
#include "builtin.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/transform.h>
#include "../target/utils.h"
#include "tvm/ffi/string.h"
namespace tvm {
namespace tl {
using namespace tir;
/**
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
* map.
*
* This constructor deserializes operator parameters from `args` and resolves
* buffer references via `vmap`, populating an internal GemmSPPyNode with:
* - device pointers for A, E, B, C and their corresponding Buffer objects,
* - transpose flags for A and B,
* - matrix dimensions M, N, K,
* - warp allocation policy and clear_accum flag,
* - strides and memory offsets for A and B,
* - optional kPack (must be 1 or 2) and optional wg_wait.
*
* The populated GemmSPPyNode is stored into the wrapper's internal `data_`.
*
* @param args Positional serialized arguments produced by the TL frontend:
* expected layout is:
* [Aptr, Eptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool),
* M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
* stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
* (optional) kPack (Int), (optional) wg_wait (Int)]
* @param vmap Mapping from access pointer vars to Buffer objects used to
* resolve the Buffer corresponding to each pointer argument.
*
* @note If `kPack` is provided it must be 1 or 2; otherwise the constructor
* fails with an ICHECK (runtime assertion). No other validation is
* performed here.
*/
GemmSPPy::GemmSPPy(Array<PrimExpr> args) {
ObjectPtr<GemmSPPyNode> node = tvm::ffi::make_object<GemmSPPyNode>();
node->aRegion_ = NormalizeToBufferRegion(args[0]);
node->eRegion_ = NormalizeToBufferRegion(args[1]);
node->bRegion_ = NormalizeToBufferRegion(args[2]);
node->cRegion_ = NormalizeToBufferRegion(args[3]);
node->A = node->aRegion_->buffer;
node->E = node->eRegion_->buffer;
node->B = node->bRegion_->buffer;
node->C = node->cRegion_->buffer;
node->trans_A = args[4].as<Bool>().value();
node->trans_B = args[5].as<Bool>().value();
node->trans_E = args[6].as<Bool>().value();
node->M = args[7].as<IntImm>().value()->value;
node->N = args[8].as<IntImm>().value()->value;
node->K = args[9].as<IntImm>().value()->value;
node->policy = GemmWarpPolicy(args[10].as<IntImm>().value()->value);
node->clear_accum = args[11].as<PrimExpr>().value();
node->stride_A = args[12].as<IntImm>().value()->value;
node->stride_B = args[13].as<IntImm>().value()->value;
node->offset_A = args[14].as<IntImm>().value()->value;
node->offset_B = args[15].as<IntImm>().value()->value;
if (args.size() > 16) {
node->kPack = args[16].as<IntImm>().value()->value;
if (node->kPack != 1 && node->kPack != 2) {
ICHECK(false) << "kPack must be 1 or 2";
}
}
if (args.size() > 17) {
node->wg_wait = args[17].as<IntImm>().value()->value;
}
data_ = std::move(node);
}
/**
* @brief Create a copy of this GemmSPPyNode as a TileOperator.
*
* Constructs a new GemmSPPyNode by copying the current node state and returns
* it wrapped in a GemmSPPy TileOperator.
*
* @return TileOperator A GemmSPPy operator that owns a copy of this node.
*/
TileOperator GemmSPPyNode::Clone() const {
auto op = tvm::ffi::make_object<GemmSPPyNode>(*this);
return GemmSPPy(op);
}
GemmInst GemmSPPyNode::GetGemmInst(int block_size, Target target) const {
int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size;
bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
(num_warps % 4 == 0) && CheckWGMMA();
if (allow_wgmma) {
return GemmInst::kWGMMA;
} else if (TargetIsCDNA(target)) {
return GemmInst::kMFMA;
} else if (TargetIsCuda(target)) {
return GemmInst::kMMA;
} else {
ICHECK(0) << "Unsupported target for gemm: " << target->str();
}
}
/**
* @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM.
*
* Evaluates device-memory placement, data-type combinations, transpose flags,
* and K divisibility constraints required for the Hopper WGMMA code path.
*
* The check returns true only when:
* - B resides in shared memory ("shared" or "shared.dyn"); and
* - (C, A, B) dtypes match one of the supported combinations below and K
* satisfies the required alignment; and
* - for combinations that require specific orientations, A is not transposed
* and B is transposed.
*
* Supported combinations and constraints:
* - C=float16:
* - A=float16, B=float16: K % 16 == 0
* - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K %
* 32 == 0
* - C=float32:
* - A=float16, B=float16: K % 16 == 0
* - A=bfloat16, B=bfloat16: K % 16 == 0
* - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0
* - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0
* - C=int32:
* - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B)
* and K % 32 == 0
*
* @return true if WGMMA is supported for the current buffers, dtypes, and
* transpose/shape constraints; false otherwise.
*/
bool GemmSPPyNode::CheckWGMMA() const {
return false; // not supported yet
// if (B.scope() != "shared.dyn" && B.scope() != "shared") {
// return false;
// }
// if (C->dtype == DataType::Float(16)) {
// if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
// return K % 16 == 0;
// else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
// return (!trans_A) && trans_B && K % 32 == 0;
// else
// return false;
// } else if (C->dtype == DataType::Float(32)) {
// if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
// return K % 16 == 0;
// else if (A->dtype == DataType::BFloat(16) &&
// B->dtype == DataType::BFloat(16))
// return K % 16 == 0;
// else if (A->dtype == DataType::Float(32) && B->dtype ==
// DataType::Float(32))
// return (!trans_A) && trans_B && K % 8 == 0;
// else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
// return (!trans_A) && trans_B && K % 32 == 0;
// else
// return false;
// } else if (C->dtype == DataType::Int(32)) {
// if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8))
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8))
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8))
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8))
// return (!trans_A) && trans_B && K % 32 == 0;
// else
// return false;
// } else {
// return false;
// }
}
/**
* @brief Parse and return the numeric GPU architecture from a Target's "arch"
* attribute.
*
* Examines the target's "arch" string and, if it matches the pattern
* "sm_<num>", returns <num> as an int. If the attribute is present but does not
* match that pattern, returns 0.
*
* Preconditions: the target must have an "arch" attribute (this is checked via
* ICHECK).
*
* @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if
* the arch string does not match "sm_<num>".
*/
static int GetArchInt(Target target) {
int arch_int = 0;
auto s = target->GetAttr<String>("arch");
ICHECK(s.has_value());
std::string arch = s.value();
if (arch.rfind("sm_", 0) == 0) {
arch_int = std::stoi(arch.substr(3));
} else {
arch_int = 0;
}
return arch_int;
}
Stmt GemmSPPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto block_size = *as_const_int(T.thread_bounds->extent);
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
auto [warp_m, warp_n] =
policy->computeWarpPartition(M, N, block_size, T.target, gemm_inst);
if (const auto f = ffi::Function::GetGlobal("tl.gemm_sp_py.lower")) {
auto prim_func =
Downcast<PrimFunc>((*f)(tvm::ffi::GetRef<GemmSPPy>(this), T.target,
T.thread_bounds, T.thread_var));
ICHECK(prim_func->attrs.defined());
auto global_symbol = prim_func->attrs.GetAttr<String>("global_symbol");
ICHECK(global_symbol.has_value());
if (prim_func->body.as<BlockRealizeNode>()) {
BlockRealize block_realize = Downcast<BlockRealize>(prim_func->body);
auto block = block_realize->block;
{
BlockNode *n = block.CopyOnWrite();
n->name_hint = global_symbol.value();
}
return BlockRealize(block_realize->iter_values, block_realize->predicate,
block);
}
// warp with block realize node
return BlockRealize(
/*iter_values=*/Array<PrimExpr>(),
/*predicate=*/const_true(),
/*block=*/
Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
/*name_hint=*/global_symbol.value(), prim_func->body));
} else {
LOG(FATAL) << "No lower function found for gemm_sp_py";
}
}
LayoutMap GemmSPPyNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (completed_)
return {};
LayoutMap results;
if (const auto f = ffi::Function::GetGlobal("tl.gemm_sp_py.infer_layout")) {
results = Downcast<LayoutMap>(
(*f)(tvm::ffi::GetRef<GemmSPPy>(this), T.target, T.thread_bounds));
} else {
LOG(FATAL) << "No infer layout function found for gemm_sp_py";
}
completed_ = true;
return results;
}
TIR_REGISTER_TL_OP(GemmSPPy, gemm_sp_py)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK() { GemmSPPyNode::RegisterReflection(); }
} // namespace tl
} // namespace tvm
/*!
* \file tl/op/gemm_sp_py.h
* \brief Define gemm_sp_py operator.
*
*/
// TODO: @botbw: remove redundant code with gemm_py.h
#ifndef TVM_TL_OP_GEMM_SP_PY_H_
#define TVM_TL_OP_GEMM_SP_PY_H_
#include "gemm_sp.h"
#include "operator.h"
namespace tvm {
namespace tl {
using namespace tir;
class GemmSPPyNode : public TileOperatorNode {
public:
bool CheckWGMMA() const;
tir::Buffer A, E, B, C;
// pointer to the A, E, B, C
BufferRegion aRegion_, eRegion_, bRegion_, cRegion_;
bool trans_A, trans_B, trans_E;
int M, N, K;
int stride_A, stride_B;
int offset_A, offset_B;
PrimExpr clear_accum = const_false();
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
int kPack = 1;
int wg_wait = 0;
// use GemmWarp Policy here as the atom size are flexible in v2
mutable GemmWarpPolicy policy;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmSPPy", GemmSPPyNode,
TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GemmSPPyNode>()
.def_ro("A", &GemmSPPyNode::A)
.def_ro("E", &GemmSPPyNode::E)
.def_ro("B", &GemmSPPyNode::B)
.def_ro("C", &GemmSPPyNode::C)
.def_ro("aRegion", &GemmSPPyNode::aRegion_)
.def_ro("eRegion", &GemmSPPyNode::eRegion_)
.def_ro("bRegion", &GemmSPPyNode::bRegion_)
.def_ro("cRegion", &GemmSPPyNode::cRegion_)
.def_ro("trans_A", &GemmSPPyNode::trans_A)
.def_ro("trans_B", &GemmSPPyNode::trans_B)
.def_ro("trans_E", &GemmSPPyNode::trans_E)
.def_ro("M", &GemmSPPyNode::M)
.def_ro("N", &GemmSPPyNode::N)
.def_ro("K", &GemmSPPyNode::K)
.def_ro("stride_A", &GemmSPPyNode::stride_A)
.def_ro("stride_B", &GemmSPPyNode::stride_B)
.def_ro("offset_A", &GemmSPPyNode::offset_A)
.def_ro("offset_B", &GemmSPPyNode::offset_B)
.def_ro("clear_accum", &GemmSPPyNode::clear_accum)
.def_ro("kPack", &GemmSPPyNode::kPack)
.def_ro("wg_wait", &GemmSPPyNode::wg_wait)
.def_ro("policy", &GemmSPPyNode::policy);
}
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
TileOperator Clone() const;
private:
// Target GEMM instruction
GemmInst GetGemmInst(int block_size, Target target) const;
mutable bool completed_ = false;
};
class GemmSPPy : public TileOperator {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSPPy, TileOperator,
GemmSPPyNode);
TVM_DLL GemmSPPy(Array<PrimExpr> args);
static const Op &Get();
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_GEMM_SP_PY_H_
\ No newline at end of file
...@@ -127,6 +127,16 @@ TL_DEVICE int4_t make_int4(signed char x0, signed char x1, signed char x2, ...@@ -127,6 +127,16 @@ TL_DEVICE int4_t make_int4(signed char x0, signed char x1, signed char x2,
return result; return result;
} }
TL_DEVICE int4_t make_int4(short x0, short x1, short y0, short y1, short z0,
short z1, short w0, short w1) {
int4_t result;
*((short2 *)&result.x) = make_short2(x0, x1);
*((short2 *)&result.y) = make_short2(y0, y1);
*((short2 *)&result.z) = make_short2(z0, z1);
*((short2 *)&result.w) = make_short2(w0, w1);
return result;
}
// Pack eight int values. // Pack eight int values.
TL_DEVICE longlong4 make_longlong4(int x0, int x1, int y0, int y1, int z0, TL_DEVICE longlong4 make_longlong4(int x0, int x1, int y0, int y1, int z0,
int z1, int w0, int w1) { int z1, int w0, int w1) {
......
...@@ -108,6 +108,16 @@ __device__ void debug_print_buffer_value(const char *msg, const char *buf_name, ...@@ -108,6 +108,16 @@ __device__ void debug_print_buffer_value(const char *msg, const char *buf_name,
PrintTraits<T>::print_buffer(msg, buf_name, index, var); PrintTraits<T>::print_buffer(msg, buf_name, index, var);
} }
template <>
__device__ void debug_print_buffer_value<uint16_t>(const char *msg,
const char *buf_name,
int index, uint16_t var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=uint16_t value=%u\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, (uint32_t)var);
}
TL_DEVICE void device_assert(bool cond) { assert(cond); } TL_DEVICE void device_assert(bool cond) { assert(cond); }
TL_DEVICE void device_assert_with_msg(bool cond, const char *msg) { TL_DEVICE void device_assert_with_msg(bool cond, const char *msg) {
......
...@@ -2,28 +2,46 @@ import torch ...@@ -2,28 +2,46 @@ import torch
import tilelang import tilelang
import tilelang.testing import tilelang.testing
from tilelang.utils.sparse import compress, randn_semi_sparse from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse
from tilelang.layout import make_metadata_layout from tilelang.layout import make_cutlass_metadata_layout
from tilelang.utils.tensor import torch_assert_close, map_torch_type
torch.set_printoptions(threshold=float('inf'), edgeitems=float('inf'), linewidth=10000) from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter
torch.manual_seed(42)
torch.backends.cuda.matmul.allow_tf32 = False
STR_TO_TYPE = { # torch.manual_seed(42) # only enable when debugging
'float32': torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16, def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype):
"float8_e4m3": torch.float8_e4m3fn, is_8bit = "8" in in_dtype
"int8": torch.int8, is_unsigned = "uint" in in_dtype
"int32": torch.int32, is_int = "int" in in_dtype
} if is_int:
if is_8bit:
SPARSITY_MAP = { low, high = (0, 4) if is_unsigned else (-2, 2)
# 'float32': (1, 2), # not supported for now else:
torch.float16: (2, 4), low, high = (0, 128) if is_unsigned else (-64, 64)
torch.bfloat16: (2, 4), A = randint_semi_sparse(
torch.float8_e4m3fn: (2, 4), M,
torch.int8: (2, 4), K,
} low=low,
high=high,
dtype=map_torch_type(in_dtype),
device='cuda',
transposed=trans_A)
B = torch.randint(
size=(N, K) if trans_B else (K, N),
low=low,
high=high,
dtype=map_torch_type(in_dtype),
device='cuda')
else:
A = randn_semi_sparse(
M, K, dtype=torch.float32, device='cuda',
transposed=trans_A).to(map_torch_type(in_dtype))
B = torch.randn(
(N, K) if trans_B else (K, N), device='cuda',
dtype=torch.float32).to(map_torch_type(in_dtype))
return A, B
def matmul_sp_sm90( def matmul_sp_sm90(
...@@ -60,21 +78,17 @@ def matmul_sp_sm90( ...@@ -60,21 +78,17 @@ def matmul_sp_sm90(
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // E_factor), 'uint8') E_shared = T.alloc_shared((block_M, block_K // E_factor), 'uint8')
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({ T.annotate_layout({
E: E:
make_metadata_layout( make_cutlass_metadata_layout(
E, mma_dtype="float16", arch="9.0", backend="cutlass", block_k=block_K), E, mma_dtype=in_dtype, arch="9.0", block_k=block_K),
E_shared: E_shared:
make_metadata_layout( make_cutlass_metadata_layout(
E_shared, E_shared, mma_dtype=in_dtype, arch="9.0", block_k=block_K),
mma_dtype="float16",
arch="9.0",
backend="cutlass",
block_k=block_K),
}) })
T.disable_warp_group_reg_alloc() T.disable_warp_group_reg_alloc()
T.clear(C_local) T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared) T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
if trans_A: if trans_A:
...@@ -85,8 +99,8 @@ def matmul_sp_sm90( ...@@ -85,8 +99,8 @@ def matmul_sp_sm90(
T.copy(B[bx * block_N, k * block_K], B_shared) T.copy(B[bx * block_N, k * block_K], B_shared)
else: else:
T.copy(B[k * block_K, bx * block_N], B_shared) T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm_sp(A_shared, E_shared, B_shared, C_local, trans_A, trans_B) T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N]) T.copy(C_frag, C[by * block_M, bx * block_N])
return main return main
...@@ -107,7 +121,8 @@ def matmul_sp_sm80( ...@@ -107,7 +121,8 @@ def matmul_sp_sm80(
trans_B, trans_B,
): ):
is_8_bit = "8" in in_dtype is_8_bit = "8" in in_dtype
E_factor = 32 if is_8_bit else 16 metadata_dtype = 'int32' if is_8_bit else 'int16'
E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype]
A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
B_shape = (K, N) if not trans_B else (N, K) B_shape = (K, N) if not trans_B else (N, K)
A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
...@@ -118,22 +133,18 @@ def matmul_sp_sm80( ...@@ -118,22 +133,18 @@ def matmul_sp_sm80(
@T.prim_func @T.prim_func
def main( def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype), A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), 'int32' if is_8_bit else 'int16'), E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // E_factor), E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype)
'int32' if is_8_bit else 'int16')
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({ T.annotate_layout({
E: E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
make_metadata_layout(E, mma_dtype="float16", backend="cutlass", arch="8.0"), E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
E_shared:
make_metadata_layout(
E_shared, mma_dtype="float16", backend="cutlass", arch="8.0"),
}) })
T.clear(C_frag) T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
...@@ -181,19 +192,14 @@ def run_gemm_sp( ...@@ -181,19 +192,14 @@ def run_gemm_sp(
kernel, kernel,
out_idx=[-1], out_idx=[-1],
) )
A = randn_semi_sparse(M, K, dtype=STR_TO_TYPE[in_dtype], device='cuda', transposed=trans_A) A, B = generate_dense_input(
if trans_B: M=M,
B = torch.randn((N, K), device='cuda', dtype=torch.float32) N=N,
else: K=K,
B = torch.randn((K, N), device='cuda', dtype=torch.float32) trans_A=trans_A,
trans_B=trans_B,
if "float8" in in_dtype or "int8" in in_dtype: in_dtype=in_dtype,
A = normalize(A.float()) )
B = normalize(B.float())
A = A.to(STR_TO_TYPE[in_dtype])
B = B.to(STR_TO_TYPE[in_dtype])
A_sparse, E = compress(A, transposed=trans_A, block_k=block_K) A_sparse, E = compress(A, transposed=trans_A, block_k=block_K)
C_sp = kernel(A_sparse, E, B) C_sp = kernel(A_sparse, E, B)
...@@ -206,14 +212,22 @@ def run_gemm_sp( ...@@ -206,14 +212,22 @@ def run_gemm_sp(
if "float8" in in_dtype or "int8" in in_dtype: if "float8" in in_dtype or "int8" in in_dtype:
A = A.to(torch.float32) A = A.to(torch.float32)
B = B.to(torch.float32) B = B.to(torch.float32)
return torch.matmul(A, B).to(STR_TO_TYPE[out_dtype]) return torch.matmul(A, B)
C = _matmul(A, B) C = _matmul(A, B)
if 'float8' in in_dtype: if 'float8' in in_dtype:
diff = calc_diff(C_sp, C) diff = calc_diff(C_sp, C)
assert diff < 1e-3, f"{diff=}" assert diff < 1e-3, f"{diff=}"
else: else:
torch.testing.assert_close(C_sp, C, atol=1e-3, rtol=1e-3) torch_assert_close(
C_sp.to(torch.float32),
C.to(torch.float32),
rtol=1e-3,
atol=1e-3,
base_name="tilelang_sp",
ref_name="ref_dense",
)
print("pass") print("pass")
......
...@@ -151,12 +151,43 @@ def mma_load_a_32x16_to_shared_16x32_layout(thread_id, local_id): ...@@ -151,12 +151,43 @@ def mma_load_a_32x16_to_shared_16x32_layout(thread_id, local_id):
return row, col return row, col
def mma_load_a_32x8_to_shared_16x16_layout(thread_id, local_id):
"""
groupID = %laneid >> 2
threadID_in_group = %laneid % 4
row = groupID for ai where 0 <= i < 2 || 4 <= i < 6
groupID + 8 Otherwise
col = (threadID_in_group * 2) + (i & 0x1) for ai where i < 4
(threadID_in_group * 2) + (i & 0x1) + 8 for ai where i >= 4
"""
row = (thread_id // 4) + 8 * (local_id % 4 // 2)
col = (thread_id % 4) * 2 + (local_id % 2) + 8 * (local_id // 4)
return row, col
def mma_load_b_32x16_to_shared_16x32_layout(thread_id, local_id): def mma_load_b_32x16_to_shared_16x32_layout(thread_id, local_id):
row = 8 * (local_id // 8) + (thread_id // 4) row = 8 * (local_id // 8) + (thread_id // 4)
col = 16 * (local_id % 8 // 4) + (thread_id % 4) * 4 + (local_id % 4) col = 16 * (local_id % 8 // 4) + (thread_id % 4) * 4 + (local_id % 4)
return row, col return row, col
def mma_load_b_32x8_to_shared_16x16_layout(thread_id, local_id):
"""
groupID = %laneid >> 2
threadID_in_group = %laneid % 4
row = (threadID_in_group * 2) + (i & 0x1) for bi where i < 2
(threadID_in_group * 2) + (i & 0x1) + 8 for bi where i >= 2
col = groupID
"""
col = (thread_id % 4) * 2 + ((local_id % 4) % 2) + ((local_id % 4) // 2) * 8
row = (thread_id // 4) + 8 * (local_id // 4)
return row, col
def shared_16x16_to_mma_32x8_smoothlayout(i, j): def shared_16x16_to_mma_32x8_smoothlayout(i, j):
return (i * 2 + j // 8, j % 8) return (i * 2 + j // 8, j % 8)
......
...@@ -22,8 +22,10 @@ from tilelang.intrinsics.mma_layout import ( ...@@ -22,8 +22,10 @@ from tilelang.intrinsics.mma_layout import (
shared_16x32_to_mma_32x16_layout_sr_b, shared_16x32_to_mma_32x16_layout_sr_b,
mma_load_a_32x4_to_shared_16x8_layout, mma_load_a_32x4_to_shared_16x8_layout,
mma_load_b_32x4_to_shared_16x8_layout, mma_load_b_32x4_to_shared_16x8_layout,
mma_load_b_32x8_to_shared_16x16_layout,
mma_load_a_32x16_to_shared_16x32_layout, mma_load_a_32x16_to_shared_16x32_layout,
mma_load_b_32x16_to_shared_16x32_layout, mma_load_b_32x16_to_shared_16x32_layout,
mma_load_a_32x8_to_shared_16x16_layout,
) )
lift = convert lift = convert
...@@ -291,6 +293,8 @@ class TensorCoreIntrinEmitter: ...@@ -291,6 +293,8 @@ class TensorCoreIntrinEmitter:
if not ldmatrix_available: if not ldmatrix_available:
if DataType(a_dtype).bits == 8: if DataType(a_dtype).bits == 8:
mma_load_layout = mma_load_a_32x16_to_shared_16x32_layout mma_load_layout = mma_load_a_32x16_to_shared_16x32_layout
elif DataType(a_dtype).bits == 16:
mma_load_layout = mma_load_a_32x8_to_shared_16x16_layout
elif DataType(a_dtype).bits == 32: elif DataType(a_dtype).bits == 32:
mma_load_layout = mma_load_a_32x4_to_shared_16x8_layout mma_load_layout = mma_load_a_32x4_to_shared_16x8_layout
else: else:
...@@ -417,6 +421,8 @@ class TensorCoreIntrinEmitter: ...@@ -417,6 +421,8 @@ class TensorCoreIntrinEmitter:
if not ldmatrix_available: if not ldmatrix_available:
if DataType(b_dtype).bits == 8: if DataType(b_dtype).bits == 8:
mma_load_layout = mma_load_b_32x16_to_shared_16x32_layout mma_load_layout = mma_load_b_32x16_to_shared_16x32_layout
elif DataType(b_dtype).bits == 16:
mma_load_layout = mma_load_b_32x8_to_shared_16x16_layout
elif DataType(b_dtype).bits == 32: elif DataType(b_dtype).bits == 32:
mma_load_layout = mma_load_b_32x4_to_shared_16x8_layout mma_load_layout = mma_load_b_32x4_to_shared_16x8_layout
else: else:
......
from tvm import DataType
from typing import Literal
from tilelang.intrinsics.mma_layout import (
mma_load_a_32x4_to_shared_16x8_layout,
mma_load_a_32x16_to_shared_16x32_layout,
mma_load_a_32x8_to_shared_16x16_layout,
shared_16x8_to_mma_32x4_layout_sr_a,
shared_16x16_to_mma_32x8_layout_sr_a,
shared_16x32_to_mma_32x16_layout_sr_a,
)
def shared_16x16_to_mma_sp_layout_sr_a(i, j):
return shared_16x8_to_mma_32x4_layout_sr_a(i, j)
def shared_16x16_to_mma_sp_layout_sr_b(i, j):
thread_id = 4 * (i % 8) + (j % 4)
return thread_id, 4 * (i // 8) + (j // 4)
def shared_16x32_to_mma_sp_layout_sr_a(i, j):
return shared_16x16_to_mma_32x8_layout_sr_a(i, j)
def shared_16x32_to_mma_sp_layout_sr_b(i, j):
thread_id = 4 * (i % 8) + (j % 8) // 2
return thread_id, 8 * (i // 8) + (j // 8) * 2 + (j % 2)
def shared_16x64_to_mma_sp_layout_sr_a(i, j):
return shared_16x32_to_mma_32x16_layout_sr_a(i, j)
def shared_16x64_to_mma_sp_layout_sr_b(i, j):
thread_id = 4 * (i % 8) + (j % 16) // 4
return thread_id, 16 * (i // 8) + (j // 16) * 4 + j % 4
def mma_sp_load_a_32x4_to_shared_16x16_layout(thread_id, local_id):
return mma_load_a_32x4_to_shared_16x8_layout(thread_id, local_id)
def mma_sp_load_a_32x8_to_shared_16x32_layout(thread_id, local_id):
return mma_load_a_32x8_to_shared_16x16_layout(thread_id, local_id)
def mma_sp_load_a_32x16_to_shared_16x64_layout(thread_id, local_id):
return mma_load_a_32x16_to_shared_16x32_layout(thread_id, local_id)
def mma_sp_load_b_32x8_to_shared_16x16_layout(thread_id, local_id):
col = 4 * (local_id % 4) + (thread_id % 4)
row = 8 * (local_id // 4) + (thread_id // 4)
return row, col
def mma_sp_load_b_32x16_to_shared_16x32_layout(thread_id, local_id):
col = (thread_id % 4) * 2 + (local_id % 2) + ((local_id % 8) // 2) * 8
row = (thread_id // 4) + 8 * (local_id // 8)
return row, col
def mma_sp_load_b_32x32_to_shared_16x64_layout(thread_id, local_id):
col = (thread_id % 4) * 4 + (local_id % 4) + 16 * ((local_id % 16) // 4)
row = (thread_id // 4) + 8 * (local_id // 16)
return row, col
def get_logical_id_32bit(thread_id: int) -> int:
return (thread_id // 4) * 2 + (thread_id % 4) % 2
def metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(thread_id: int,
local_id: int) -> tuple[int, int]:
logical_id = get_logical_id_32bit(thread_id)
row = logical_id // 4 + local_id * 8
col = logical_id % 4
return row, col
def metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(thread_id: int,
local_id: int) -> tuple[int, int]:
logical_id = get_logical_id_32bit(thread_id)
row = logical_id // 2 + local_id * 8
col = logical_id % 2
return row, col
def metadata_8bit_load_32x4_to_shared_16x4_layout_16bit(thread_id: int,
local_id: int) -> tuple[int, int]:
return metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(
thread_id, local_id) # same mapping for 16bit and 32bit
def metadata_16bit_load_32x2_to_shared_16x2_layout_16bit(thread_id: int,
local_id: int) -> tuple[int, int]:
return metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(
thread_id, local_id) # same mapping for 16bit and 32bit
def get_logical_id_8bit(thread_id: int) -> int:
return thread_id
def metadata_8bit_load_32x4_to_shared_16x4_layout_8bit(thread_id: int,
local_id: int) -> tuple[int, int]:
logical_id = get_logical_id_8bit(thread_id)
row = logical_id // 2 + local_id * 8
col = (logical_id % 4) // 2 * 4 + local_id
return row, col
def metadata_16bit_load_32x2_to_shared_16x4_layout_8bit(thread_id: int,
local_id: int) -> tuple[int, int]:
logical_id = get_logical_id_8bit(thread_id)
row = logical_id // 2 + local_id * 8
col = (logical_id % 4) // 2 * 2 + local_id
return row, col
def metadata_32bit_load_32x1_to_shared_16x2_layout_8bit(thread_id: int,
local_id: int) -> tuple[int, int]:
# local_id is always 0
logical_id = get_logical_id_8bit(thread_id)
row = logical_id // 4 + (logical_id % 2) * 8
col = (logical_id % 4) // 2
return row, col
def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id):
row = (local_id // 4) * 8 + thread_id % 8
col = (thread_id // 8) * 4 + local_id % 4
return row, col
def ldmatrix_32x16_to_shared_32x16_layout(thread_id, local_id):
row = thread_id
col = local_id % 8 + 8 * (local_id // 8)
return row, col
def ldmatrix_trans_32x16_to_shared_16x32_layout(thread_id, local_id):
row = 8 * (local_id // 8) + thread_id % 8
col = (thread_id // 8) * 8 + local_id % 8
return row, col
def ldmatrix_trans_32x32_to_shared_shared_16x64_layout(thread_id, local_id):
row = (local_id // 16) * 8 + thread_id % 8
col = (thread_id // 8) * 16 + local_id % 16
return row, col
def get_ldmatrix_offset_b(
matrix: Literal["B"],
row_idx,
col_idx,
stride,
dtype: Literal["float16", "int8"] = "float16",
transposed: bool = False,
):
assert matrix == "B", "matrix should be B"
dtype_bits = DataType(dtype).bits
if dtype_bits == 32:
if transposed:
transform_func = ldmatrix_trans_32x8_to_shared_16x16_layout
new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
else:
raise ValueError("ldmatrix only supports B transposed for 32-bit dtype")
elif dtype_bits == 16:
transform_func = ldmatrix_32x16_to_shared_32x16_layout
transform_func_trans = ldmatrix_trans_32x16_to_shared_16x32_layout
if transposed:
new_row_idx, new_col_idx = transform_func_trans(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
else:
new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
elif dtype_bits == 8:
if transposed:
transform_func = ldmatrix_trans_32x32_to_shared_shared_16x64_layout
new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
else:
raise ValueError("ldmatrix only supports B transposed for 8-bit dtype")
else:
raise ValueError(f"Unsupported dtype {dtype}")
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment