Unverified Commit 283a9a00 authored by botbw's avatar botbw Committed by GitHub
Browse files

[Language] support `T.gemm_sp_v2` on sm80 and sm89 (#1056)

* [misc] add a cpp side wrapper for gemm_sp_py

* [misc] typing

* [IR] bind GemmSPWarpPolicy

* [chore] add wrapper code

* [IR] fix GemmSPWarpPolicy

* [codegen] apply ptxas instructions

* [intrinsic] add typical (unused) mma layout

* [template] add uint16 debug func

* [intrinsic] add b matrix layout

* [gemm_sp] enable fp16/bf16 on sm8x

* [layout] refactor fp16/bf16 layout

* [gemm_sp] enable int8

* [chore] update test case dtype

* [gemm_sp] enable fp32

* [layout] refactor layouts

* [intrinsic] enable ldmatrix for mat A

* [layout] enable ldsm for matrix b

* [layout] add ldmatrix for fp32 and fp8

* [chore] refine

* [chore] refactor

* [chore] add fp8 efactor

* [chore] refactor

* [chore] add remove negative zero util

* [example] add a custom compress kernel

* [chore] minor update

* [test] refactor gemm_sp test

* [refactor] make metadata layout func

* [example] add option for using cutlass layout

* [doc] add a gemm_sp doc

* [doc] minor polish

* [chore] remove unused

* [bugfix] fix non replicate b case

* [test] refactor

* [chore] add a check

* [bugfix] fix util bug

* [wip] init a new test case for v2

* [chore] minor refactor

* [chore] minor update

* [bugfix] enable 16bit rs

* [language] enable rs

* [language] enable gemm_sp_sr

* [language] enable gemm_sp_rr

* [test] enable more tests

* [tvm] update ffi binding

* [chore] remove print

* [chore] fix benchmark script

* [lint] precommit lint

* [chore] apply feedback

* [test] use arch 8.0

* [chore] rollback ::ordered_metadata for backward compatibility

* [bugfix] fix captialized

* [example] keep gemm_sp on hopper

* [test] fix no fp8 normal kernel

* [test] reduce matmul size to satisfy accum error

* [test] use cal_diff for assertion

* [bugfix] expand float8 type

* [lib] add make_int4 for short type

* [language] add transpose E

* [bugfix] fix wrong var

* [format] format

* [chore] refactor binding

* [chore] fix wrong passing var
parent b10ef75f
......@@ -9,7 +9,7 @@ import tilelang.language as T
from tilelang.autotuner import autotune
from tilelang import jit
from tilelang.contrib import nvcc
from tilelang.layout import make_metadata_layout
from tilelang.layout import make_cutlass_metadata_layout
# Configure logger
logger = logging.getLogger(__name__)
......@@ -86,7 +86,7 @@ def get_configs(M, N, K):
return configs
def matmul_sp(M, N, K, accum_dtype):
def matmul_sp(M, N, K, in_dtype, accum_dtype):
"""
Create an autotuned matrix multiplication kernel for matrices of shape:
- A: (M, K)
......@@ -161,14 +161,13 @@ def matmul_sp(M, N, K, accum_dtype):
"""
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "float16"
e_factor, e_dtype = ARCH_INFO[arch]
@T.prim_func
def main(
A_sparse: T.Tensor((M, K // 2), dtype),
A_sparse: T.Tensor((M, K // 2), in_dtype),
E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((K, N), dtype),
B: T.Tensor((K, N), in_dtype),
C: T.Tensor((M, N), accum_dtype),
):
"""
......@@ -187,9 +186,9 @@ def matmul_sp(M, N, K, accum_dtype):
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K // 2), dtype)
A_shared = T.alloc_shared((block_M, block_K // 2), in_dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
B_shared = T.alloc_shared((block_K, block_N), dtype)
B_shared = T.alloc_shared((block_K, block_N), in_dtype)
# Allocate shared memory for E sub-block of shape (block_M, block_K // E_factor)
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
# Allocate a local fragment for intermediate accumulation
......@@ -204,11 +203,9 @@ def matmul_sp(M, N, K, accum_dtype):
T.use_swizzle(panel_size=10, enable=enable_rasterization)
T.annotate_layout({
E:
make_metadata_layout(
E, mma_dtype="float16", backend="cutlass", block_k=block_K),
make_cutlass_metadata_layout(E, mma_dtype=in_dtype, block_k=block_K),
E_shared:
make_metadata_layout(
E_shared, mma_dtype="float16", backend="cutlass", block_k=block_K),
make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, block_k=block_K),
})
# Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
......@@ -220,7 +217,7 @@ def matmul_sp(M, N, K, accum_dtype):
T.copy(B[k * block_K, bx * block_N], B_shared)
# Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared
T.gemm_sp(
T.gemm_sp_v2(
A_shared,
E_shared,
B_shared,
......@@ -268,7 +265,7 @@ if __name__ == "__main__":
total_flops = 2 * M * N * K
# matmul(...) returns (best_latency, best_config, ref_latency)
best_result = matmul_sp(M, N, K, args.accum_dtype)
best_result = matmul_sp(M, N, K, "float16", args.accum_dtype)
best_latency = best_result.latency
best_config = best_result.config
A = torch.randn(M, K, dtype=torch.float16, device="cuda")
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
# Sparse Matrix-Matrix Multiplication with Tile Library
<div style="text-align: left;">
<em>Author:</em> <a href="https://github.com/botbw">botbw</a>
</div>
:::{warning}
This document is still **experimental** and may be incomplete.
This feature is still **experimental** and need further optimization.
Suggestions and improvements are highly encouraged—please submit a PR!
:::
:::{tip}
It's suggested to go through `docs/deeplearning_operators/matmul.md` first.
Example code can be found at `examples/gemm_sp`.
:::
## Structured sparsity in the NVIDIA Ampere architecture
Since the Ampere architecture (sm80 and above), sparsity support has been integrated into Tensor Cores. This allows a 2:4 (or 1:2 for 32-bit data types) semi-structured matrix to be compressed into its non-zero values along with associated metadata, which can then be fed into the Tensor Core. This enables up to **2x throughput** compared to the equivalent dense computation.
:::{warning}
This tutorial primarily focuses on CUDA, as this feature is not yet supported on ROCm. However, AMD provides a similar capability in the matrix cores of GPUs such as the MI300X.
:::
```{figure} ../_static/img/sparse_mma_storage_example.png
:align: center
Figure: Sparse MMA storage example (from PTX doc)
```
## Compress a dense tensor
To utilize sparse Tensor Cores, a dense tensor must first be **compressed** into its non-zero values along with the corresponding metadata.
Both `PyTorch` and `vLLM` use `CUTLASS` as their computation backend (see references [here](https://github.com/pytorch/pytorch/blob/a8d6afb511a69687bbb2b7e88a3cf67917e1697e/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu#L47) and [here](https://github.com/vllm-project/vllm/blob/a5dd03c1ebc5e4f56f3c9d3dc0436e9c582c978f/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh#L116)), leveraging `CUTLASS`’s built-in compressor (or reimplementing it in `PyTorch`).
A set of **CUTLASS-compatible** compressors is provided in `tilelang.utils.sparse`, where a dense tensor—along with other required arguments (e.g., block_K for sm90, transpose options)—can be passed in to perform the compression.
```python
from tilelang.utils.sparse import compress
A_sparse, E = compress(A, transposed=trans_A, block_k=block_K)
```
Here, `A_sparse` contains all the non-zero elements of `A`, while `E` stores the corresponding metadata (indexing information) required to reconstruct the original sparse pattern.
> NOTE: When using CUTLASS compressor, there is no naive position correspondence between the positions in `A_sparse`/`A` and `E`. (i.e. the 4-element group at [n, k] doesn't match the 4-bit metadata at [n, k] if you consider metadata as int4 tensor)
The metadata is reordered internally to optimize memory access patterns (e.g., for ldsm instructions and vectorized loads).
For more information, see **A note on `gemm_sp` and `gemm_sp_v2`**.
## `T.gemm_sp` with CUTLASS's compressor
:::{warning}
It is strongly recommended to use T.gemm_sp_v2 due to its greater flexibility and faster compilation time.
:::
A 2:4 sparse GEMM kernel is similar to its dense counterpart, except that it also requires handling the associated metadata.
Check comments in below kernel code for required modification.
```python
def matmul_sp_sm80(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
trans_A,
trans_B,
):
is_8_bit = "8" in in_dtype
metadata_dtype = 'int32' if is_8_bit else 'int16'
E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype] # Calculate shape for given datatypes
A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
B_shape = (K, N) if not trans_B else (N, K)
A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K)
import tilelang.language as T
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) # Allocate smem for metadata
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({ # Annotate reordered cutlass metadata layout
E:
make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
E_shared:
make_cutlass_metadata_layout(
E_shared, mma_dtype=in_dtype, arch="8.0"),
})
T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
if trans_A:
T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared)
else:
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B) # Call gemm_sp with non-zero values and metadata
T.copy(C_frag, C[by * block_M, bx * block_N])
return main
```
Under the hood, `gemm_sp` invokes templates adapted from `CUTLASS`, and a compatible metadata layout must be specified using `T.annotate_layout`.
## `T.gemm_sp_v2` with a custom compressor
To migrate to `gemm_sp_v2`, simply replace occurrences of `gemm_sp`.
Unlike `gemm_sp`, `gemm_sp_v2` can operate without `T.annotate_layout`, and it also supports user-defined layouts and compressors.
The metadata is stored in a `(u)int8`/`(u)int16`/`(u)int32` tensor, where **each 4-bit chunk represents two 2-bit indices** of non-zero elements within four consecutive elements. Here, we start with an `int16` example, which is the **default dtype** for `bf16` and `fp16` on Ampere GPUs.
Suppose we have the following row vector:
```python
t = tensor([[0, 7, 0, 3], [1, 5, 0, 0], [0, 0, 2, 4], [9, 0, 9, 0]], dtype=torch.float16).flatten()
```
The non-zero elements and their corresponding indices are:
```python
t_sp = tensor([[7, 3], [1, 5], [2, 4], [9, 9]], dtype=torch.float16).flatten()
indices = tensor([[1, 3], [0, 1], [2, 3], [0, 2]], dtype=torch.float16).flatten()
```
The corresponding uint16 metadata is:
```python
# metadata_bits = tensor([0b1101, 0b0100, 0b1110, 0b1000])
# Note: storage uses little-endian order: tensor(0b1000111001001101, dtype=torch.int16)
# Note: the above code is not runnable in python as the interpreter won't take the binary
# as 2's complement
metadata_int16 = tensor(-29107)
```
You can decode an int16 metadata tensor using the following utility:
```python
def decode_metadata(meta: torch.Tensor) -> torch.Tensor:
assert meta.dtype is torch.int16
groups_per_meta = 16 // 4
out = []
for g in range(groups_per_meta):
group_bits = (meta >> (g * 4)) & 0xF
idx0 = group_bits & 0x3
idx1 = (group_bits >> 2) & 0x3
out.append(torch.stack([idx0, idx1], dim=-1))
return torch.concat(out, dim=-1).view(meta.shape[0], -1)
```
The compressor can be implement at either `PyTorch`/`NumPy` level or kernel level.
For example, `PyTorch` provides an Ampere compressor [here](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L47-L179). Note that in this implementation, a [permutation](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L173-L175) is applied to match CUTLASS’s metadata layout. If you do not annotate a metadata layout when using `gemm_sp_v2`, your compressor should replicate the same behavior as the PyTorch example—but without using the `_calculate_meta_reordering_scatter_offsets` function.
If you want to use a custom metadata layout in your kernel, one approach is to define the layout in `TileLang` and then apply the same layout to both your compressor kernel and the matmul_sp kernel.
```python
@tilelang.jit(out_idx=[1, 2], pass_configs={
tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True,
})
def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
e_factor, e_dtype = ARCH_INFO["8.0"]
e_K = K // e_factor
elem, group = 2, 4
assert M % block_M == 0, "M must be divisible by block_M"
assert K % block_K == 0, "K must be divisible by block_K"
assert K % e_factor == 0, "K must be divisible by e_factor"
assert block_K % e_factor == 0, "block_K must be divisible by e_factor"
@T.prim_func
def kernel(
A: T.Tensor((M, K), dtype),
A_sp: T.Tensor((M, K // 2), dtype),
E: T.Tensor((M, e_K), e_dtype),
):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype)
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
if use_cutlass_layout: # NOTE: Make sure compressor metadata layout
T.annotate_layout({ # is same with your computation kernel
E:
make_cutlass_metadata_layout(
E, mma_dtype="float16", arch="8.0", block_k=block_K),
E_shared:
make_cutlass_metadata_layout(
E_shared,
mma_dtype="float16",
arch="8.0",
block_k=block_K),
})
T.clear(A_sp_shared)
T.clear(E_shared)
non_zero_cnt = T.alloc_local((1, ), dtype="uint8")
non_zero_elt_log_idx = T.alloc_local((elem, ), dtype="uint8")
T.copy(A[bx * block_M, by * block_K], A_shared)
for tm in T.Parallel(block_M):
for g_i in range(0, block_K // group):
a_k = g_i * group
T.clear(non_zero_cnt)
T.clear(non_zero_elt_log_idx)
for i in range(group):
val = A_shared[tm, a_k + i]
if val != 0.0:
non_zero_elt_log_idx[non_zero_cnt[0]] = i
A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val
non_zero_cnt[0] += 1
if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3:
non_zero_elt_log_idx[0] = 0
non_zero_elt_log_idx[1] = 3
A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2]
A_sp_shared[tm, a_k // 2] = 0.0
elif non_zero_cnt[0] == 1:
A_sp_shared[tm, a_k // 2 + 1] = 0
non_zero_elt_log_idx[1] = 3
for i in T.serial(elem):
val = non_zero_elt_log_idx[i]
E_shared[tm, a_k // e_factor] |= T.shift_left(val, 4 * (g_i % (e_factor // group)) + 2 * i)
T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2])
T.copy(E_shared, E[bx * block_M, by * block_K // e_factor])
return kernel
```
## A note on `gemm_sp` and `gemm_sp_v2`
Initially, `T.gemm_sp` followed the same design as `T.gemm`, lowering to a `CUTLASS` template. This inherently requires metadata to be reordered offline following a predetermined layout.
However, fixing a specific layout introduces several potential issues:
1. Painful debugging experience: Debugging a failed kernel becomes difficult due to the reordered indexing, including permutations and swizzling.
2. Limited flexibility: For example, concatenating two compressed tensors, such as `A_sparse_0` and `A_sparse_1`, into a new `A_sparse` makes sense. However, concatenating their metadata `E_0` and `E_1` may not be valid unless the layout allows it mathematically.
3. Alignment requirements: `CUTLASS` enforces strict alignment checks, and many hyperparameter configurations can lead to compilation errors. (For reference, sm8x was implemented in `CUTLASS 2`.)
`T.gemm_sp_v2` was designed to address these limitations, following the approach of `T.gemm_v2`. It lowers directly to PTX, removing the need for a fixed metadata layout.
\ No newline at end of file
......@@ -33,6 +33,7 @@ tutorials/auto_tuning
deeplearning_operators/elementwise
deeplearning_operators/gemv
deeplearning_operators/matmul
deeplearning_operators/matmul_sparse
deeplearning_operators/deepseek_mla
:::
......
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
import argparse
import tilelang
import tilelang.language as T
from tilelang.layout import make_cutlass_metadata_layout
from tilelang.utils.sparse import randn_semi_sparse
from tilelang.utils.tensor import torch_assert_close
from triton.testing import do_bench
import torch
torch.manual_seed(42)
DEFAULT_CONFIG = { # take best config from autotune script
"4090": {
'float': {
'block_M': 128,
'block_N': 64,
'block_K': 64,
'num_stages': 1,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
},
'float16': {
'block_M': 256,
'block_N': 128,
'block_K': 64,
'num_stages': 2,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
}
},
"h20": {
'float': {
'block_M': 128,
'block_N': 64,
'block_K': 128,
'num_stages': 3,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
},
'float16': {
'block_M': 128,
'block_N': 64,
'block_K': 128,
'num_stages': 3,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
}
}
}
ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
@tilelang.jit(out_idx=[-1])
def matmul_sp_fp16_custom_compress(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages,
thread_num, policy, enable_rasterization, use_cutlass_layout):
e_factor, e_dtype = (16, "int16")
@T.prim_func
def gemm_sp_fp16_custom_compress(
A_sparse: T.Tensor((M, K // 2), 'float16'),
E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((K, N), 'float16'),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K // 2), 'float16')
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
B_shared = T.alloc_shared((block_K, block_N), 'float16')
C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
if use_cutlass_layout:
T.annotate_layout({
E:
make_cutlass_metadata_layout(
E, mma_dtype="float16", arch="8.0", block_k=block_K),
E_shared:
make_cutlass_metadata_layout(
E_shared, mma_dtype="float16", arch="8.0", block_k=block_K),
})
T.clear(C_local)
T.disable_warp_group_reg_alloc()
T.use_swizzle(panel_size=10, enable=enable_rasterization)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
T.copy(E[by * block_M, k * block_K // e_factor], E_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm_sp_v2(A_shared, E_shared, B_shared, C_local, False, False, policy=policy)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return gemm_sp_fp16_custom_compress
def torch_compress(dense):
"""
A naive compression function, where each 4-bit meta matches 4 elements in original matrix in row major layout.
"""
if dense.dim() != 2:
raise RuntimeError(
f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor")
m, k = dense.shape
meta_dtype = torch.int8
if dense.dtype == torch.int8:
meta_dtype = torch.int32
elif dense.dtype in [torch.half, torch.bfloat16, torch.float]:
meta_dtype = torch.int16
else:
raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
if quadbits_per_meta_elem not in (4, 8):
raise RuntimeError("Invalid number of elements per meta element calculated")
if meta_dtype == torch.int32:
if m % 16 != 0:
raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 16")
else:
if m % 32 != 0:
raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 32")
if k % (4 * quadbits_per_meta_elem) != 0:
raise RuntimeError(
f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}"
)
if dense.dtype != torch.float:
ksparse = 4
dense_4 = dense.view(-1, k // ksparse, ksparse)
m0, m1, _m2, m3 = (dense_4 != 0).unbind(-1)
else:
ksparse = 2
dense_2 = dense.view(-1, k // ksparse, ksparse)
m0, _m2 = m1, m3 = (dense_2 != 0).unbind(-1)
meta_ncols = k // (ksparse * quadbits_per_meta_elem)
# Encoding quadruples of True/False values as follows:
# [True, True, False, False] -> 0b0100
# [True, False, True, False] -> 0b1000
# [False, True, True, False] -> 0b1001
# [True, False, False, True ] -> 0b1100
# [False, True, False, True ] -> 0b1101
# [False, False, True, True ] -> 0b1110
# Thus, lower two bits in the encoding are index of the True value
# at the lowest index in the quadruple, and the higher two bits in
# the encoding are index of the other True value in the quadruple.
# In case there are less than two True values, than False value or
# values at some index or indices are considered True for the
# encoding. In case there are more than two True values, then the
# excess True value(s) at some indices are considered False for
# the encoding. The exact encodings used for these cases are as
# follows:
# [False, False, False, False] -> 0b1110
# [False, False, False, True ] -> 0b1110
# [False, False, True, False] -> 0b1110
# [False, True, False, False] -> 0b1001
# [False, True, True, True ] -> 0b1101
# [True, False, False, False] -> 0b1000
# [True, False, True, True ] -> 0b1100
# [True, True, False, True ] -> 0b0100
# [True, True, True, False] -> 0b0100
# [True, True, True, True ] -> 0b0100
# These particular encodings are chosen, with the help of Espresso
# logic minimizer software, for the purpose of minimization of
# corresponding Boolean functions, that translate non-zero flags
# into encoding bits. Note also possible choices for the first
# and last of these encodings were limited only to (0b0100,
# 0b1110), in order to produce valid encodings for 1:2 sparsity
# case.
expr0 = m0 & m1
expr1 = ~m0 & m1
expr2 = ~m0 & ~m1
bit0 = expr1
bit1 = expr2
bit2 = expr0 | expr2 | m3
bit3 = expr1 | ~m1
idxs0 = bit0 | (bit1.to(torch.int64) << 1)
idxs1 = bit2 | (bit3.to(torch.int64) << 1)
if dense.dtype != torch.float:
sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined]
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
else:
sparse = dense_2.gather(-1,
idxs0.unsqueeze(-1) // 2).view(
m, k // 2) # type: ignore[possibly-undefined]
meta_4 = idxs0 | (idxs1 << 2)
meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
if quadbits_per_meta_elem == 4:
meta = (
meta_n[:, :, 0]
| (meta_n[:, :, 1] << 4)
| (meta_n[:, :, 2] << 8)
| (meta_n[:, :, 3] << 12))
elif quadbits_per_meta_elem == 8:
meta = (
meta_n[:, :, 0]
| (meta_n[:, :, 1] << 4)
| (meta_n[:, :, 2] << 8)
| (meta_n[:, :, 3] << 12)
| (meta_n[:, :, 4] << 16)
| (meta_n[:, :, 5] << 20)
| (meta_n[:, :, 6] << 24)
| (meta_n[:, :, 7] << 28))
return (sparse, meta)
def decode_metadata(meta: torch.Tensor) -> torch.Tensor:
assert meta.dtype is torch.int16
groups_per_meta = 16 // 4 # 4 groups per uint16
out = []
for g in range(groups_per_meta):
group_bits = (meta >> (g * 4)) & 0xF
idx0 = group_bits & 0x3
idx1 = (group_bits >> 2) & 0x3
out.append(torch.stack([idx0, idx1], dim=-1))
return torch.concat(out, dim=-1).view(meta.shape[0], -1)
@tilelang.jit(
out_idx=[1, 2], pass_configs={
tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True,
})
def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
e_factor, e_dtype = ARCH_INFO["8.0"]
e_K = K // e_factor
elem, group = 2, 4
assert M % block_M == 0, "M must be divisible by block_M"
assert K % block_K == 0, "K must be divisible by block_K"
assert K % e_factor == 0, "K must be divisible by e_factor"
assert block_K % e_factor == 0, "block_K must be divisible by e_factor"
@T.prim_func
def kernel(
A: T.Tensor((M, K), dtype),
A_sp: T.Tensor((M, K // 2), dtype),
E: T.Tensor((M, e_K), e_dtype),
):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype)
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
if use_cutlass_layout:
T.annotate_layout({
E:
make_cutlass_metadata_layout(
E, mma_dtype="float16", arch="8.0", block_k=block_K),
E_shared:
make_cutlass_metadata_layout(
E_shared, mma_dtype="float16", arch="8.0", block_k=block_K),
})
T.clear(A_sp_shared)
T.clear(E_shared)
# TODO: alloc_var seems buggy here
non_zero_cnt = T.alloc_local((1,), dtype="uint8")
non_zero_elt_log_idx = T.alloc_local((elem,), dtype="uint8")
T.copy(A[bx * block_M, by * block_K], A_shared)
for tm in T.Parallel(block_M):
for g_i in range(0, block_K // group):
a_k = g_i * group
T.clear(non_zero_cnt)
T.clear(non_zero_elt_log_idx)
for i in range(group):
val = A_shared[tm, a_k + i]
if val != 0.0:
non_zero_elt_log_idx[non_zero_cnt[0]] = i
A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val
non_zero_cnt[0] += 1
# TODO: use T.device_assert(non_zero_cnt <= 2) after rebasing main
if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3:
non_zero_elt_log_idx[0] = 0
non_zero_elt_log_idx[1] = 3
A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2]
A_sp_shared[tm, a_k // 2] = 0.0
elif non_zero_cnt[0] == 1:
A_sp_shared[tm, a_k // 2 + 1] = 0
non_zero_elt_log_idx[1] = 3
for i in T.serial(elem):
val = non_zero_elt_log_idx[i]
E_shared[tm, a_k // e_factor] |= T.shift_left(
val, 4 * (g_i % (e_factor // group)) + 2 * i)
T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2])
T.copy(E_shared, E[bx * block_M, by * block_K // e_factor])
return kernel
def main():
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument(
"--use_cutlass_layout", action='store_true', help="Use cutlass layout for E tensor")
parser.add_argument(
"--use_torch_compressor", action='store_true', help="Use torch sparse for reference")
parser.add_argument(
"--accum_dtype",
type=str,
default="float",
choices=["float", "float16"],
help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090"], default="4090")
args = parser.parse_args()
kernel = matmul_sp_fp16_custom_compress(
args.m,
args.n,
args.k,
args.accum_dtype,
**DEFAULT_CONFIG[args.cfg][args.accum_dtype],
use_cutlass_layout=args.use_cutlass_layout)
a = randn_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half)
b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half)
if args.use_torch_compressor:
assert not args.use_cutlass_layout, "torch sparse must be used with naive layout"
a_sparse, e = torch_compress(a)
else:
a_sparse, e = compress_kernel(
args.m, args.k, 32, 32, "float16", use_cutlass_layout=args.use_cutlass_layout)(
a)
c = kernel(a_sparse, e, b)
ref_c = a @ b
assert not c.isnan().any(), "Reference result contains NaNs, please report an issue"
torch_assert_close(c, ref_c.to(c.dtype), rtol=1e-3, atol=1e-3)
print(
f"Precision check passed. Max diff: {(c - ref_c).abs().max()}, Mean diff: {(c - ref_c).abs().mean()}"
)
latency = do_bench(lambda: kernel(a_sparse, e, b))
ref_latency = do_bench(lambda: a @ b)
total_flops = 2 * args.m * args.n * args.k
tflops = total_flops / latency / 1e9
ref_tflops = total_flops / ref_latency / 1e9
print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency/1e3} s")
print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency/1e3:} s")
if __name__ == "__main__":
main()
......@@ -5,7 +5,7 @@ import argparse
import tilelang
import tilelang.language as T
from tilelang.layout import make_metadata_layout
from tilelang.layout import make_cutlass_metadata_layout
from tilelang.utils.sparse import compress, randn_semi_sparse
from tilelang.contrib import nvcc
from triton.testing import do_bench
......@@ -14,9 +14,7 @@ import torch
arch = nvcc.get_target_compute_version()
ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
default_config = { # take best config from autotune script
DEFAULT_CONFIG = { # take best config from autotune script
"4090": {
'float': {
'block_M': 128,
......@@ -59,6 +57,8 @@ default_config = { # take best config from autotune script
}
}
ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
@tilelang.jit(out_idx=[-1])
def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy,
......@@ -84,15 +84,11 @@ def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages,
T.use_swizzle(panel_size=10, enable=enable_rasterization)
T.annotate_layout({
E:
make_metadata_layout(
E, mma_dtype="float16", backend="cutlass", block_k=block_K, arch=arch),
make_cutlass_metadata_layout(
E, mma_dtype="float16", block_k=block_K, arch=arch),
E_shared:
make_metadata_layout(
E_shared,
mma_dtype="float16",
backend="cutlass",
block_k=block_K,
arch=arch),
make_cutlass_metadata_layout(
E_shared, mma_dtype="float16", block_k=block_K, arch=arch),
})
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
......@@ -117,10 +113,10 @@ def main():
default="float",
choices=["float", "float16"],
help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090", "h20"], required=True)
parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090")
args = parser.parse_args()
kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype,
**default_config[args.cfg][args.accum_dtype])
**DEFAULT_CONFIG[args.cfg][args.accum_dtype])
a = randn_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half)
b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half)
......@@ -128,7 +124,7 @@ def main():
a_sparse, e = compress(
a,
transposed=False,
block_k=default_config[args.cfg][args.accum_dtype]['block_K'],
block_k=DEFAULT_CONFIG[args.cfg][args.accum_dtype]['block_K'],
arch=arch)
c = kernel(a_sparse, e, b)
......
import tilelang.testing
import example_custom_compress
import example_gemm_sp
def test_example_custom_compress():
example_custom_compress.main()
def test_example_gemm_sp():
example_gemm_sp.main()
if __name__ == "__main__":
tilelang.testing.main()
import torch
import tilelang
from tilelang.utils.sparse import compress_sm90
from tilelang.layout import make_metadata_layout
from tilelang.layout import make_cutlass_metadata_layout
import tilelang.testing
......@@ -40,15 +40,11 @@ def matmul_sp(
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({
E:
make_metadata_layout(
E, mma_dtype="float16", arch="9.0", backend="cutlass", block_k=block_K),
make_cutlass_metadata_layout(
E, mma_dtype="float16", arch="9.0", block_k=block_K),
E_shared:
make_metadata_layout(
E_shared,
mma_dtype="float16",
arch="9.0",
backend="cutlass",
block_k=block_K),
make_cutlass_metadata_layout(
E_shared, mma_dtype="float16", arch="9.0", block_k=block_K),
})
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
......
......@@ -307,7 +307,20 @@ TIR_REGISTER_TL_TILE_OP(GemmSP, gemm_sp)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK() { GemmSPNode::RegisterReflection(); }
TVM_REGISTER_OP("tl.GemmSPWarpPolicy")
.set_attr<TScriptPrinterName>("TScriptPrinterName", "GemmSPWarpPolicy");
TVM_FFI_STATIC_INIT_BLOCK() {
GemmSPNode::RegisterReflection();
GemmSPWarpPolicyNode::RegisterReflection();
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def(
"tl.GemmSPWarpPolicyComputeWarpPartition",
[](GemmSPWarpPolicy policy, int M, int N, int block_size, Target target,
bool use_wgmma, int bits) {
policy->computeWarpPartition(M, N, block_size, target, use_wgmma, bits);
return;
});
}
} // namespace tl
} // namespace tvm
......@@ -23,6 +23,14 @@ public:
int bits) const;
TVM_FFI_DECLARE_OBJECT_INFO("tl.GemmSPWarpPolicy", GemmSPWarpPolicyNode,
GemmWarpPolicyNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GemmSPWarpPolicyNode>()
.def_ro("policy_type", &GemmSPWarpPolicyNode::policy_type)
.def_ro("m_warp", &GemmSPWarpPolicyNode::m_warp)
.def_ro("n_warp", &GemmSPWarpPolicyNode::n_warp);
}
};
class GemmSPWarpPolicy : public ObjectRef {
......
/*!
* \file tl/op/gemm_sp_py.cc
* \brief Implementation of Sparse General Matrix Multiplication (GEMM_SP)
* operators
*/
#include "gemm_sp_py.h"
#include "utils.h"
#include "builtin.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/transform.h>
#include "../target/utils.h"
#include "tvm/ffi/string.h"
namespace tvm {
namespace tl {
using namespace tir;
/**
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
* map.
*
* This constructor deserializes operator parameters from `args` and resolves
* buffer references via `vmap`, populating an internal GemmSPPyNode with:
* - device pointers for A, E, B, C and their corresponding Buffer objects,
* - transpose flags for A and B,
* - matrix dimensions M, N, K,
* - warp allocation policy and clear_accum flag,
* - strides and memory offsets for A and B,
* - optional kPack (must be 1 or 2) and optional wg_wait.
*
* The populated GemmSPPyNode is stored into the wrapper's internal `data_`.
*
* @param args Positional serialized arguments produced by the TL frontend:
* expected layout is:
* [Aptr, Eptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool),
* M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
* stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
* (optional) kPack (Int), (optional) wg_wait (Int)]
* @param vmap Mapping from access pointer vars to Buffer objects used to
* resolve the Buffer corresponding to each pointer argument.
*
* @note If `kPack` is provided it must be 1 or 2; otherwise the constructor
* fails with an ICHECK (runtime assertion). No other validation is
* performed here.
*/
GemmSPPy::GemmSPPy(Array<PrimExpr> args) {
ObjectPtr<GemmSPPyNode> node = tvm::ffi::make_object<GemmSPPyNode>();
node->aRegion_ = NormalizeToBufferRegion(args[0]);
node->eRegion_ = NormalizeToBufferRegion(args[1]);
node->bRegion_ = NormalizeToBufferRegion(args[2]);
node->cRegion_ = NormalizeToBufferRegion(args[3]);
node->A = node->aRegion_->buffer;
node->E = node->eRegion_->buffer;
node->B = node->bRegion_->buffer;
node->C = node->cRegion_->buffer;
node->trans_A = args[4].as<Bool>().value();
node->trans_B = args[5].as<Bool>().value();
node->trans_E = args[6].as<Bool>().value();
node->M = args[7].as<IntImm>().value()->value;
node->N = args[8].as<IntImm>().value()->value;
node->K = args[9].as<IntImm>().value()->value;
node->policy = GemmWarpPolicy(args[10].as<IntImm>().value()->value);
node->clear_accum = args[11].as<PrimExpr>().value();
node->stride_A = args[12].as<IntImm>().value()->value;
node->stride_B = args[13].as<IntImm>().value()->value;
node->offset_A = args[14].as<IntImm>().value()->value;
node->offset_B = args[15].as<IntImm>().value()->value;
if (args.size() > 16) {
node->kPack = args[16].as<IntImm>().value()->value;
if (node->kPack != 1 && node->kPack != 2) {
ICHECK(false) << "kPack must be 1 or 2";
}
}
if (args.size() > 17) {
node->wg_wait = args[17].as<IntImm>().value()->value;
}
data_ = std::move(node);
}
/**
* @brief Create a copy of this GemmSPPyNode as a TileOperator.
*
* Constructs a new GemmSPPyNode by copying the current node state and returns
* it wrapped in a GemmSPPy TileOperator.
*
* @return TileOperator A GemmSPPy operator that owns a copy of this node.
*/
TileOperator GemmSPPyNode::Clone() const {
auto op = tvm::ffi::make_object<GemmSPPyNode>(*this);
return GemmSPPy(op);
}
GemmInst GemmSPPyNode::GetGemmInst(int block_size, Target target) const {
int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size;
bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
(num_warps % 4 == 0) && CheckWGMMA();
if (allow_wgmma) {
return GemmInst::kWGMMA;
} else if (TargetIsCDNA(target)) {
return GemmInst::kMFMA;
} else if (TargetIsCuda(target)) {
return GemmInst::kMMA;
} else {
ICHECK(0) << "Unsupported target for gemm: " << target->str();
}
}
/**
* @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM.
*
* Evaluates device-memory placement, data-type combinations, transpose flags,
* and K divisibility constraints required for the Hopper WGMMA code path.
*
* The check returns true only when:
* - B resides in shared memory ("shared" or "shared.dyn"); and
* - (C, A, B) dtypes match one of the supported combinations below and K
* satisfies the required alignment; and
* - for combinations that require specific orientations, A is not transposed
* and B is transposed.
*
* Supported combinations and constraints:
* - C=float16:
* - A=float16, B=float16: K % 16 == 0
* - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K %
* 32 == 0
* - C=float32:
* - A=float16, B=float16: K % 16 == 0
* - A=bfloat16, B=bfloat16: K % 16 == 0
* - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0
* - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0
* - C=int32:
* - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B)
* and K % 32 == 0
*
* @return true if WGMMA is supported for the current buffers, dtypes, and
* transpose/shape constraints; false otherwise.
*/
bool GemmSPPyNode::CheckWGMMA() const {
return false; // not supported yet
// if (B.scope() != "shared.dyn" && B.scope() != "shared") {
// return false;
// }
// if (C->dtype == DataType::Float(16)) {
// if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
// return K % 16 == 0;
// else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
// return (!trans_A) && trans_B && K % 32 == 0;
// else
// return false;
// } else if (C->dtype == DataType::Float(32)) {
// if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
// return K % 16 == 0;
// else if (A->dtype == DataType::BFloat(16) &&
// B->dtype == DataType::BFloat(16))
// return K % 16 == 0;
// else if (A->dtype == DataType::Float(32) && B->dtype ==
// DataType::Float(32))
// return (!trans_A) && trans_B && K % 8 == 0;
// else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
// return (!trans_A) && trans_B && K % 32 == 0;
// else
// return false;
// } else if (C->dtype == DataType::Int(32)) {
// if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8))
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8))
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8))
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8))
// return (!trans_A) && trans_B && K % 32 == 0;
// else
// return false;
// } else {
// return false;
// }
}
/**
* @brief Parse and return the numeric GPU architecture from a Target's "arch"
* attribute.
*
* Examines the target's "arch" string and, if it matches the pattern
* "sm_<num>", returns <num> as an int. If the attribute is present but does not
* match that pattern, returns 0.
*
* Preconditions: the target must have an "arch" attribute (this is checked via
* ICHECK).
*
* @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if
* the arch string does not match "sm_<num>".
*/
static int GetArchInt(Target target) {
int arch_int = 0;
auto s = target->GetAttr<String>("arch");
ICHECK(s.has_value());
std::string arch = s.value();
if (arch.rfind("sm_", 0) == 0) {
arch_int = std::stoi(arch.substr(3));
} else {
arch_int = 0;
}
return arch_int;
}
Stmt GemmSPPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto block_size = *as_const_int(T.thread_bounds->extent);
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
auto [warp_m, warp_n] =
policy->computeWarpPartition(M, N, block_size, T.target, gemm_inst);
if (const auto f = ffi::Function::GetGlobal("tl.gemm_sp_py.lower")) {
auto prim_func =
Downcast<PrimFunc>((*f)(tvm::ffi::GetRef<GemmSPPy>(this), T.target,
T.thread_bounds, T.thread_var));
ICHECK(prim_func->attrs.defined());
auto global_symbol = prim_func->attrs.GetAttr<String>("global_symbol");
ICHECK(global_symbol.has_value());
if (prim_func->body.as<BlockRealizeNode>()) {
BlockRealize block_realize = Downcast<BlockRealize>(prim_func->body);
auto block = block_realize->block;
{
BlockNode *n = block.CopyOnWrite();
n->name_hint = global_symbol.value();
}
return BlockRealize(block_realize->iter_values, block_realize->predicate,
block);
}
// warp with block realize node
return BlockRealize(
/*iter_values=*/Array<PrimExpr>(),
/*predicate=*/const_true(),
/*block=*/
Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
/*name_hint=*/global_symbol.value(), prim_func->body));
} else {
LOG(FATAL) << "No lower function found for gemm_sp_py";
}
}
LayoutMap GemmSPPyNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (completed_)
return {};
LayoutMap results;
if (const auto f = ffi::Function::GetGlobal("tl.gemm_sp_py.infer_layout")) {
results = Downcast<LayoutMap>(
(*f)(tvm::ffi::GetRef<GemmSPPy>(this), T.target, T.thread_bounds));
} else {
LOG(FATAL) << "No infer layout function found for gemm_sp_py";
}
completed_ = true;
return results;
}
TIR_REGISTER_TL_OP(GemmSPPy, gemm_sp_py)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK() { GemmSPPyNode::RegisterReflection(); }
} // namespace tl
} // namespace tvm
/*!
* \file tl/op/gemm_sp_py.h
* \brief Define gemm_sp_py operator.
*
*/
// TODO: @botbw: remove redundant code with gemm_py.h
#ifndef TVM_TL_OP_GEMM_SP_PY_H_
#define TVM_TL_OP_GEMM_SP_PY_H_
#include "gemm_sp.h"
#include "operator.h"
namespace tvm {
namespace tl {
using namespace tir;
class GemmSPPyNode : public TileOperatorNode {
public:
bool CheckWGMMA() const;
tir::Buffer A, E, B, C;
// pointer to the A, E, B, C
BufferRegion aRegion_, eRegion_, bRegion_, cRegion_;
bool trans_A, trans_B, trans_E;
int M, N, K;
int stride_A, stride_B;
int offset_A, offset_B;
PrimExpr clear_accum = const_false();
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
int kPack = 1;
int wg_wait = 0;
// use GemmWarp Policy here as the atom size are flexible in v2
mutable GemmWarpPolicy policy;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmSPPy", GemmSPPyNode,
TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GemmSPPyNode>()
.def_ro("A", &GemmSPPyNode::A)
.def_ro("E", &GemmSPPyNode::E)
.def_ro("B", &GemmSPPyNode::B)
.def_ro("C", &GemmSPPyNode::C)
.def_ro("aRegion", &GemmSPPyNode::aRegion_)
.def_ro("eRegion", &GemmSPPyNode::eRegion_)
.def_ro("bRegion", &GemmSPPyNode::bRegion_)
.def_ro("cRegion", &GemmSPPyNode::cRegion_)
.def_ro("trans_A", &GemmSPPyNode::trans_A)
.def_ro("trans_B", &GemmSPPyNode::trans_B)
.def_ro("trans_E", &GemmSPPyNode::trans_E)
.def_ro("M", &GemmSPPyNode::M)
.def_ro("N", &GemmSPPyNode::N)
.def_ro("K", &GemmSPPyNode::K)
.def_ro("stride_A", &GemmSPPyNode::stride_A)
.def_ro("stride_B", &GemmSPPyNode::stride_B)
.def_ro("offset_A", &GemmSPPyNode::offset_A)
.def_ro("offset_B", &GemmSPPyNode::offset_B)
.def_ro("clear_accum", &GemmSPPyNode::clear_accum)
.def_ro("kPack", &GemmSPPyNode::kPack)
.def_ro("wg_wait", &GemmSPPyNode::wg_wait)
.def_ro("policy", &GemmSPPyNode::policy);
}
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
TileOperator Clone() const;
private:
// Target GEMM instruction
GemmInst GetGemmInst(int block_size, Target target) const;
mutable bool completed_ = false;
};
class GemmSPPy : public TileOperator {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSPPy, TileOperator,
GemmSPPyNode);
TVM_DLL GemmSPPy(Array<PrimExpr> args);
static const Op &Get();
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_GEMM_SP_PY_H_
\ No newline at end of file
......@@ -127,6 +127,16 @@ TL_DEVICE int4_t make_int4(signed char x0, signed char x1, signed char x2,
return result;
}
TL_DEVICE int4_t make_int4(short x0, short x1, short y0, short y1, short z0,
short z1, short w0, short w1) {
int4_t result;
*((short2 *)&result.x) = make_short2(x0, x1);
*((short2 *)&result.y) = make_short2(y0, y1);
*((short2 *)&result.z) = make_short2(z0, z1);
*((short2 *)&result.w) = make_short2(w0, w1);
return result;
}
// Pack eight int values.
TL_DEVICE longlong4 make_longlong4(int x0, int x1, int y0, int y1, int z0,
int z1, int w0, int w1) {
......
......@@ -108,6 +108,16 @@ __device__ void debug_print_buffer_value(const char *msg, const char *buf_name,
PrintTraits<T>::print_buffer(msg, buf_name, index, var);
}
template <>
__device__ void debug_print_buffer_value<uint16_t>(const char *msg,
const char *buf_name,
int index, uint16_t var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=uint16_t value=%u\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, (uint32_t)var);
}
TL_DEVICE void device_assert(bool cond) { assert(cond); }
TL_DEVICE void device_assert_with_msg(bool cond, const char *msg) {
......
......@@ -2,28 +2,46 @@ import torch
import tilelang
import tilelang.testing
from tilelang.utils.sparse import compress, randn_semi_sparse
from tilelang.layout import make_metadata_layout
torch.set_printoptions(threshold=float('inf'), edgeitems=float('inf'), linewidth=10000)
torch.manual_seed(42)
STR_TO_TYPE = {
'float32': torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float8_e4m3": torch.float8_e4m3fn,
"int8": torch.int8,
"int32": torch.int32,
}
SPARSITY_MAP = {
# 'float32': (1, 2), # not supported for now
torch.float16: (2, 4),
torch.bfloat16: (2, 4),
torch.float8_e4m3fn: (2, 4),
torch.int8: (2, 4),
}
from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse
from tilelang.layout import make_cutlass_metadata_layout
from tilelang.utils.tensor import torch_assert_close, map_torch_type
from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter
torch.backends.cuda.matmul.allow_tf32 = False
# torch.manual_seed(42) # only enable when debugging
def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype):
is_8bit = "8" in in_dtype
is_unsigned = "uint" in in_dtype
is_int = "int" in in_dtype
if is_int:
if is_8bit:
low, high = (0, 4) if is_unsigned else (-2, 2)
else:
low, high = (0, 128) if is_unsigned else (-64, 64)
A = randint_semi_sparse(
M,
K,
low=low,
high=high,
dtype=map_torch_type(in_dtype),
device='cuda',
transposed=trans_A)
B = torch.randint(
size=(N, K) if trans_B else (K, N),
low=low,
high=high,
dtype=map_torch_type(in_dtype),
device='cuda')
else:
A = randn_semi_sparse(
M, K, dtype=torch.float32, device='cuda',
transposed=trans_A).to(map_torch_type(in_dtype))
B = torch.randn(
(N, K) if trans_B else (K, N), device='cuda',
dtype=torch.float32).to(map_torch_type(in_dtype))
return A, B
def matmul_sp_sm90(
......@@ -60,21 +78,17 @@ def matmul_sp_sm90(
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // E_factor), 'uint8')
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({
E:
make_metadata_layout(
E, mma_dtype="float16", arch="9.0", backend="cutlass", block_k=block_K),
make_cutlass_metadata_layout(
E, mma_dtype=in_dtype, arch="9.0", block_k=block_K),
E_shared:
make_metadata_layout(
E_shared,
mma_dtype="float16",
arch="9.0",
backend="cutlass",
block_k=block_K),
make_cutlass_metadata_layout(
E_shared, mma_dtype=in_dtype, arch="9.0", block_k=block_K),
})
T.disable_warp_group_reg_alloc()
T.clear(C_local)
T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
if trans_A:
......@@ -85,8 +99,8 @@ def matmul_sp_sm90(
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm_sp(A_shared, E_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B)
T.copy(C_frag, C[by * block_M, bx * block_N])
return main
......@@ -107,7 +121,8 @@ def matmul_sp_sm80(
trans_B,
):
is_8_bit = "8" in in_dtype
E_factor = 32 if is_8_bit else 16
metadata_dtype = 'int32' if is_8_bit else 'int16'
E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype]
A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
B_shape = (K, N) if not trans_B else (N, K)
A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
......@@ -118,22 +133,18 @@ def matmul_sp_sm80(
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), 'int32' if is_8_bit else 'int16'),
E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // E_factor),
'int32' if is_8_bit else 'int16')
E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype)
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({
E:
make_metadata_layout(E, mma_dtype="float16", backend="cutlass", arch="8.0"),
E_shared:
make_metadata_layout(
E_shared, mma_dtype="float16", backend="cutlass", arch="8.0"),
E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
})
T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
......@@ -181,19 +192,14 @@ def run_gemm_sp(
kernel,
out_idx=[-1],
)
A = randn_semi_sparse(M, K, dtype=STR_TO_TYPE[in_dtype], device='cuda', transposed=trans_A)
if trans_B:
B = torch.randn((N, K), device='cuda', dtype=torch.float32)
else:
B = torch.randn((K, N), device='cuda', dtype=torch.float32)
if "float8" in in_dtype or "int8" in in_dtype:
A = normalize(A.float())
B = normalize(B.float())
A = A.to(STR_TO_TYPE[in_dtype])
B = B.to(STR_TO_TYPE[in_dtype])
A, B = generate_dense_input(
M=M,
N=N,
K=K,
trans_A=trans_A,
trans_B=trans_B,
in_dtype=in_dtype,
)
A_sparse, E = compress(A, transposed=trans_A, block_k=block_K)
C_sp = kernel(A_sparse, E, B)
......@@ -206,14 +212,22 @@ def run_gemm_sp(
if "float8" in in_dtype or "int8" in in_dtype:
A = A.to(torch.float32)
B = B.to(torch.float32)
return torch.matmul(A, B).to(STR_TO_TYPE[out_dtype])
return torch.matmul(A, B)
C = _matmul(A, B)
if 'float8' in in_dtype:
diff = calc_diff(C_sp, C)
assert diff < 1e-3, f"{diff=}"
else:
torch.testing.assert_close(C_sp, C, atol=1e-3, rtol=1e-3)
torch_assert_close(
C_sp.to(torch.float32),
C.to(torch.float32),
rtol=1e-3,
atol=1e-3,
base_name="tilelang_sp",
ref_name="ref_dense",
)
print("pass")
......
from tilelang import tvm as tvm
from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse
from tilelang.utils.tensor import torch_assert_close, map_torch_type
from tilelang.layout import make_cutlass_metadata_layout
from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter
import tilelang.testing
import torch
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
metadata_dtype,
E_factor,
num_stages,
threads,
):
A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype)
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({
E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
})
T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
if trans_A:
T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared)
else:
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm_sp_v2(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B)
T.copy(C_frag, C[by * block_M, bx * block_N])
return main
def run_gemm_ss(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16'
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
metadata_dtype,
SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype], # E_factor
num_stages,
num_threads,
)
kernel = tilelang.compile(
program,
out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype)
A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0")
C_sp = kernel(A_sparse, E, B)
def _matmul(A, B):
if trans_A:
A = A.T
if trans_B:
B = B.T
A = A.to(torch.float32)
B = B.to(torch.float32)
return torch.matmul(A, B)
C = _matmul(A, B)
torch_assert_close(
C_sp.to(map_torch_type(out_dtype)).to(torch.float32),
C.to(map_torch_type(out_dtype)).to(torch.float32),
rtol=1e-3,
atol=1e-3,
base_name="tilelang_sp",
ref_name="ref_dense",
)
print("pass")
def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype):
is_8bit = "8" in in_dtype
is_unsigned = "uint" in in_dtype
is_int = "int" in in_dtype
if is_int:
if is_8bit:
low, high = (0, 4) if is_unsigned else (-2, 2)
else:
low, high = (0, 128) if is_unsigned else (-64, 64)
A = randint_semi_sparse(
M,
K,
low=low,
high=high,
dtype=map_torch_type(in_dtype),
device='cuda',
transposed=trans_A)
B = torch.randint(
size=(N, K) if trans_B else (K, N),
low=low,
high=high,
dtype=map_torch_type(in_dtype),
device='cuda')
else:
A = randn_semi_sparse(
M, K, dtype=map_torch_type(in_dtype), device='cuda', transposed=trans_A)
B = torch.randn(
(N, K) if trans_B else (K, N), device='cuda',
dtype=torch.float32).to(map_torch_type(in_dtype))
return A, B
def test_gemm_ss():
# More test case can be found in kernel/test_tilelang_kernel_gemm.py
# GEMM tests for float16
# TODO: support transposed A compressor
run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float", 128, 128, 32, 2)
run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float", 128, 128, 32, 2)
run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float", 128, 128, 32, 2)
run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float", 128, 128, 32, 2)
# n8 test
run_gemm_ss(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128)
# int8 test
run_gemm_ss(128, 128, 128, False, True, "int8", "int32", "int32", 128, 128, 64, 2)
run_gemm_ss(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2)
run_gemm_ss(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2)
run_gemm_ss(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2)
# float8 tests
run_gemm_ss(128, 128, 128, False, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64,
2)
run_gemm_ss(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2)
# tfloat32 test
# run_gemm_ss(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_ss(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_ss(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_ss(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
def matmul_rs(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
metadata_dtype,
E_factor,
num_stages,
threads,
):
A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
A_frag_shape = A_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype)
A_frag = T.alloc_fragment(A_frag_shape, in_dtype)
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
})
T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
if trans_A:
T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared)
else:
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.copy(A_shared, A_frag)
T.gemm_sp_v2(A_frag, E_shared, B_shared, C_frag, trans_A, trans_B)
T.copy(C_frag, C[by * block_M, bx * block_N])
return main
def run_gemm_rs(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16'
program = matmul_rs(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
metadata_dtype,
SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype], # E_factor
num_stages,
num_threads,
)
kernel = tilelang.compile(
program,
out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype)
A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0")
C_sp = kernel(A_sparse, E, B)
def _matmul(A, B):
if trans_A:
A = A.T
if trans_B:
B = B.T
A = A.to(torch.float32)
B = B.to(torch.float32)
return torch.matmul(A, B)
C = _matmul(A, B)
torch_assert_close(
C_sp.to(map_torch_type(out_dtype)).to(torch.float32),
C.to(map_torch_type(out_dtype)).to(torch.float32),
rtol=1e-3,
atol=1e-3,
base_name="tilelang_sp",
ref_name="ref_dense",
)
print("pass")
def test_gemm_rs():
# GEMM tests for float16
run_gemm_rs(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2)
run_gemm_rs(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2)
run_gemm_rs(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2)
run_gemm_rs(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2)
# n8 tests
run_gemm_rs(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128)
# int8 tests
run_gemm_rs(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 64, 2)
run_gemm_rs(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2)
run_gemm_rs(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2)
run_gemm_rs(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2)
# float8 tests
run_gemm_rs(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2)
# float32 tests
# run_gemm_rs(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_rs(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_rs(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_rs(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
def matmul_sr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
metadata_dtype,
E_factor,
num_stages,
threads,
):
A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
B_frag_shape = B_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype)
B_frag = T.alloc_fragment(B_frag_shape, in_dtype)
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
})
T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
if trans_A:
T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared)
else:
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.copy(B_shared, B_frag)
T.gemm_sp_v2(A_shared, E_shared, B_frag, C_frag, trans_A, trans_B)
T.copy(C_frag, C[by * block_M, bx * block_N])
return main
def run_gemm_sr(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16'
program = matmul_sr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
metadata_dtype,
SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype], # E_factor
num_stages,
num_threads,
)
kernel = tilelang.compile(
program,
out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype)
A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0")
C_sp = kernel(A_sparse, E, B)
def _matmul(A, B):
if trans_A:
A = A.T
if trans_B:
B = B.T
A = A.to(torch.float32)
B = B.to(torch.float32)
return torch.matmul(A, B)
C = _matmul(A, B)
torch_assert_close(
C_sp.to(map_torch_type(out_dtype)).to(torch.float32),
C.to(map_torch_type(out_dtype)).to(torch.float32),
rtol=1e-3,
atol=1e-3,
base_name="tilelang_sp",
ref_name="ref_dense",
)
print("pass")
def test_gemm_sr():
# GEMM tests for float16
run_gemm_sr(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2)
run_gemm_sr(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2)
run_gemm_sr(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2)
run_gemm_sr(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2)
# n8 tests
run_gemm_sr(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128)
# int8 tests
run_gemm_sr(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 128, 2)
run_gemm_sr(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 128, 2)
run_gemm_sr(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2)
run_gemm_sr(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2)
# float8 tests
run_gemm_sr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2)
# float32 tests
# run_gemm_sr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_sr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_sr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_sr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
def matmul_rr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
metadata_dtype,
E_factor,
num_stages,
threads,
):
A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
A_frag_shape = A_shared_shape
B_frag_shape = B_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype)
A_frag = T.alloc_fragment(A_frag_shape, in_dtype)
B_frag = T.alloc_fragment(B_frag_shape, in_dtype)
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
})
T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
if trans_A:
T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared)
else:
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.copy(A_shared, A_frag)
T.copy(B_shared, B_frag)
T.gemm_sp_v2(A_frag, E_shared, B_frag, C_frag, trans_A, trans_B)
T.copy(C_frag, C[by * block_M, bx * block_N])
return main
def run_gemm_rr(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16'
program = matmul_rr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
metadata_dtype,
SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype], # E_factor
num_stages,
num_threads,
)
kernel = tilelang.compile(
program,
out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype)
A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0")
C_sp = kernel(A_sparse, E, B)
def _matmul(A, B):
if trans_A:
A = A.T
if trans_B:
B = B.T
A = A.to(torch.float32)
B = B.to(torch.float32)
return torch.matmul(A, B)
C = _matmul(A, B)
torch_assert_close(
C_sp.to(map_torch_type(out_dtype)).to(torch.float32),
C.to(map_torch_type(out_dtype)).to(torch.float32),
rtol=1e-3,
atol=1e-3,
base_name="tilelang_sp",
ref_name="ref_dense",
)
print("pass")
def test_gemm_rr():
# GEMM tests for float16
run_gemm_rr(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2)
run_gemm_rr(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2)
run_gemm_rr(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2)
run_gemm_rr(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2)
run_gemm_rr(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2)
# n8 tests
run_gemm_rr(128, 8, 128, False, True, "float16", "float16", "float", 128, 8, 32, 2)
run_gemm_rr(128, 8, 128, False, True, "int8", "int8", "int32", 128, 8, 64, 2)
# int8 tests
run_gemm_rr(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 64, 2)
run_gemm_rr(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2)
run_gemm_rr(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2)
run_gemm_rr(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2)
# float8 tests
run_gemm_rr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2)
# float32 tests
# run_gemm_rr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_rr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_rr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_rr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -151,12 +151,43 @@ def mma_load_a_32x16_to_shared_16x32_layout(thread_id, local_id):
return row, col
def mma_load_a_32x8_to_shared_16x16_layout(thread_id, local_id):
"""
groupID = %laneid >> 2
threadID_in_group = %laneid % 4
row = groupID for ai where 0 <= i < 2 || 4 <= i < 6
groupID + 8 Otherwise
col = (threadID_in_group * 2) + (i & 0x1) for ai where i < 4
(threadID_in_group * 2) + (i & 0x1) + 8 for ai where i >= 4
"""
row = (thread_id // 4) + 8 * (local_id % 4 // 2)
col = (thread_id % 4) * 2 + (local_id % 2) + 8 * (local_id // 4)
return row, col
def mma_load_b_32x16_to_shared_16x32_layout(thread_id, local_id):
row = 8 * (local_id // 8) + (thread_id // 4)
col = 16 * (local_id % 8 // 4) + (thread_id % 4) * 4 + (local_id % 4)
return row, col
def mma_load_b_32x8_to_shared_16x16_layout(thread_id, local_id):
"""
groupID = %laneid >> 2
threadID_in_group = %laneid % 4
row = (threadID_in_group * 2) + (i & 0x1) for bi where i < 2
(threadID_in_group * 2) + (i & 0x1) + 8 for bi where i >= 2
col = groupID
"""
col = (thread_id % 4) * 2 + ((local_id % 4) % 2) + ((local_id % 4) // 2) * 8
row = (thread_id // 4) + 8 * (local_id // 4)
return row, col
def shared_16x16_to_mma_32x8_smoothlayout(i, j):
return (i * 2 + j // 8, j % 8)
......
......@@ -22,8 +22,10 @@ from tilelang.intrinsics.mma_layout import (
shared_16x32_to_mma_32x16_layout_sr_b,
mma_load_a_32x4_to_shared_16x8_layout,
mma_load_b_32x4_to_shared_16x8_layout,
mma_load_b_32x8_to_shared_16x16_layout,
mma_load_a_32x16_to_shared_16x32_layout,
mma_load_b_32x16_to_shared_16x32_layout,
mma_load_a_32x8_to_shared_16x16_layout,
)
lift = convert
......@@ -291,6 +293,8 @@ class TensorCoreIntrinEmitter:
if not ldmatrix_available:
if DataType(a_dtype).bits == 8:
mma_load_layout = mma_load_a_32x16_to_shared_16x32_layout
elif DataType(a_dtype).bits == 16:
mma_load_layout = mma_load_a_32x8_to_shared_16x16_layout
elif DataType(a_dtype).bits == 32:
mma_load_layout = mma_load_a_32x4_to_shared_16x8_layout
else:
......@@ -417,6 +421,8 @@ class TensorCoreIntrinEmitter:
if not ldmatrix_available:
if DataType(b_dtype).bits == 8:
mma_load_layout = mma_load_b_32x16_to_shared_16x32_layout
elif DataType(b_dtype).bits == 16:
mma_load_layout = mma_load_b_32x8_to_shared_16x16_layout
elif DataType(b_dtype).bits == 32:
mma_load_layout = mma_load_b_32x4_to_shared_16x8_layout
else:
......
from tvm import DataType
from typing import Literal
from tilelang.intrinsics.mma_layout import (
mma_load_a_32x4_to_shared_16x8_layout,
mma_load_a_32x16_to_shared_16x32_layout,
mma_load_a_32x8_to_shared_16x16_layout,
shared_16x8_to_mma_32x4_layout_sr_a,
shared_16x16_to_mma_32x8_layout_sr_a,
shared_16x32_to_mma_32x16_layout_sr_a,
)
def shared_16x16_to_mma_sp_layout_sr_a(i, j):
return shared_16x8_to_mma_32x4_layout_sr_a(i, j)
def shared_16x16_to_mma_sp_layout_sr_b(i, j):
thread_id = 4 * (i % 8) + (j % 4)
return thread_id, 4 * (i // 8) + (j // 4)
def shared_16x32_to_mma_sp_layout_sr_a(i, j):
return shared_16x16_to_mma_32x8_layout_sr_a(i, j)
def shared_16x32_to_mma_sp_layout_sr_b(i, j):
thread_id = 4 * (i % 8) + (j % 8) // 2
return thread_id, 8 * (i // 8) + (j // 8) * 2 + (j % 2)
def shared_16x64_to_mma_sp_layout_sr_a(i, j):
return shared_16x32_to_mma_32x16_layout_sr_a(i, j)
def shared_16x64_to_mma_sp_layout_sr_b(i, j):
thread_id = 4 * (i % 8) + (j % 16) // 4
return thread_id, 16 * (i // 8) + (j // 16) * 4 + j % 4
def mma_sp_load_a_32x4_to_shared_16x16_layout(thread_id, local_id):
return mma_load_a_32x4_to_shared_16x8_layout(thread_id, local_id)
def mma_sp_load_a_32x8_to_shared_16x32_layout(thread_id, local_id):
return mma_load_a_32x8_to_shared_16x16_layout(thread_id, local_id)
def mma_sp_load_a_32x16_to_shared_16x64_layout(thread_id, local_id):
return mma_load_a_32x16_to_shared_16x32_layout(thread_id, local_id)
def mma_sp_load_b_32x8_to_shared_16x16_layout(thread_id, local_id):
col = 4 * (local_id % 4) + (thread_id % 4)
row = 8 * (local_id // 4) + (thread_id // 4)
return row, col
def mma_sp_load_b_32x16_to_shared_16x32_layout(thread_id, local_id):
col = (thread_id % 4) * 2 + (local_id % 2) + ((local_id % 8) // 2) * 8
row = (thread_id // 4) + 8 * (local_id // 8)
return row, col
def mma_sp_load_b_32x32_to_shared_16x64_layout(thread_id, local_id):
col = (thread_id % 4) * 4 + (local_id % 4) + 16 * ((local_id % 16) // 4)
row = (thread_id // 4) + 8 * (local_id // 16)
return row, col
def get_logical_id_32bit(thread_id: int) -> int:
return (thread_id // 4) * 2 + (thread_id % 4) % 2
def metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(thread_id: int,
local_id: int) -> tuple[int, int]:
logical_id = get_logical_id_32bit(thread_id)
row = logical_id // 4 + local_id * 8
col = logical_id % 4
return row, col
def metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(thread_id: int,
local_id: int) -> tuple[int, int]:
logical_id = get_logical_id_32bit(thread_id)
row = logical_id // 2 + local_id * 8
col = logical_id % 2
return row, col
def metadata_8bit_load_32x4_to_shared_16x4_layout_16bit(thread_id: int,
local_id: int) -> tuple[int, int]:
return metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(
thread_id, local_id) # same mapping for 16bit and 32bit
def metadata_16bit_load_32x2_to_shared_16x2_layout_16bit(thread_id: int,
local_id: int) -> tuple[int, int]:
return metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(
thread_id, local_id) # same mapping for 16bit and 32bit
def get_logical_id_8bit(thread_id: int) -> int:
return thread_id
def metadata_8bit_load_32x4_to_shared_16x4_layout_8bit(thread_id: int,
local_id: int) -> tuple[int, int]:
logical_id = get_logical_id_8bit(thread_id)
row = logical_id // 2 + local_id * 8
col = (logical_id % 4) // 2 * 4 + local_id
return row, col
def metadata_16bit_load_32x2_to_shared_16x4_layout_8bit(thread_id: int,
local_id: int) -> tuple[int, int]:
logical_id = get_logical_id_8bit(thread_id)
row = logical_id // 2 + local_id * 8
col = (logical_id % 4) // 2 * 2 + local_id
return row, col
def metadata_32bit_load_32x1_to_shared_16x2_layout_8bit(thread_id: int,
local_id: int) -> tuple[int, int]:
# local_id is always 0
logical_id = get_logical_id_8bit(thread_id)
row = logical_id // 4 + (logical_id % 2) * 8
col = (logical_id % 4) // 2
return row, col
def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id):
row = (local_id // 4) * 8 + thread_id % 8
col = (thread_id // 8) * 4 + local_id % 4
return row, col
def ldmatrix_32x16_to_shared_32x16_layout(thread_id, local_id):
row = thread_id
col = local_id % 8 + 8 * (local_id // 8)
return row, col
def ldmatrix_trans_32x16_to_shared_16x32_layout(thread_id, local_id):
row = 8 * (local_id // 8) + thread_id % 8
col = (thread_id // 8) * 8 + local_id % 8
return row, col
def ldmatrix_trans_32x32_to_shared_shared_16x64_layout(thread_id, local_id):
row = (local_id // 16) * 8 + thread_id % 8
col = (thread_id // 8) * 16 + local_id % 16
return row, col
def get_ldmatrix_offset_b(
matrix: Literal["B"],
row_idx,
col_idx,
stride,
dtype: Literal["float16", "int8"] = "float16",
transposed: bool = False,
):
assert matrix == "B", "matrix should be B"
dtype_bits = DataType(dtype).bits
if dtype_bits == 32:
if transposed:
transform_func = ldmatrix_trans_32x8_to_shared_16x16_layout
new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
else:
raise ValueError("ldmatrix only supports B transposed for 32-bit dtype")
elif dtype_bits == 16:
transform_func = ldmatrix_32x16_to_shared_32x16_layout
transform_func_trans = ldmatrix_trans_32x16_to_shared_16x32_layout
if transposed:
new_row_idx, new_col_idx = transform_func_trans(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
else:
new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
elif dtype_bits == 8:
if transposed:
transform_func = ldmatrix_trans_32x32_to_shared_shared_16x64_layout
new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
else:
raise ValueError("ldmatrix only supports B transposed for 8-bit dtype")
else:
raise ValueError(f"Unsupported dtype {dtype}")
from __future__ import annotations
import tilelang.language as T
from typing import Literal, Callable
from tvm import DataType
from tvm.tir import PrimExpr, IndexMap, Buffer, Var
from tvm.runtime import convert
from .utils import (
mma_store_index_map,
get_ldmatrix_offset,
)
from tilelang.utils import is_fragment
from tilelang.intrinsics.mma_sp_layout import (
shared_16x16_to_mma_sp_layout_sr_a,
shared_16x16_to_mma_sp_layout_sr_b,
shared_16x32_to_mma_sp_layout_sr_a,
shared_16x32_to_mma_sp_layout_sr_b,
shared_16x64_to_mma_sp_layout_sr_a,
shared_16x64_to_mma_sp_layout_sr_b,
mma_sp_load_a_32x4_to_shared_16x16_layout,
mma_sp_load_a_32x8_to_shared_16x32_layout,
mma_sp_load_a_32x16_to_shared_16x64_layout,
mma_sp_load_b_32x8_to_shared_16x16_layout,
mma_sp_load_b_32x16_to_shared_16x32_layout,
mma_sp_load_b_32x32_to_shared_16x64_layout,
metadata_8bit_load_32x4_to_shared_16x4_layout_32bit,
metadata_16bit_load_32x2_to_shared_16x2_layout_32bit,
metadata_8bit_load_32x4_to_shared_16x4_layout_16bit,
metadata_16bit_load_32x2_to_shared_16x2_layout_16bit,
metadata_8bit_load_32x4_to_shared_16x4_layout_8bit,
metadata_16bit_load_32x2_to_shared_16x4_layout_8bit,
metadata_32bit_load_32x1_to_shared_16x2_layout_8bit,
get_ldmatrix_offset_b,
)
lift = convert
class SparseTensorCoreIntrinEmitter:
"""
To eliminate Python syntax within TIR Macro.
"""
M_DIM = 16
SPARSE_FACTOR = 2 # 1:2 for tfloat12, 2:4 for 16-bit and 8-bit datatypes
SPARSE_SELECTOR = 0 # always use lower threads to provide metadata
# use lowercase as n_dim can be dynamic
# the smallest instructions can be m16n8k16, so the n_dim can also be 8
n_dim = 16
WARP_SIZE = 32
dtype_abbrv = {
"float16": "fp16",
"bfloat16": "bf16",
"float32": "fp32",
"int8": "int8",
"int32": "int32",
"float8_e4m3": "e4m3",
"float8_e5m2": "e5m2",
}
E_FACTOR_MAP = { # e_kdim = mma_kdim // e_factor
"float": {
"int16": 8,
"uint16": 8,
},
"float32": {
"int16": 8,
"uint16": 8,
},
"float16": {
"int8": 8,
"uint8": 8,
"int16": 16,
"uint16": 16,
"int32": 32,
"uint32": 32,
},
"bfloat16": {
"int8": 8,
"uint8": 8,
"int16": 16,
"uint16": 16,
"int32": 32,
"uint32": 32,
},
"int8": {
"int8": 8,
"uint8": 8,
"int16": 16,
"uint16": 16,
"int32": 32,
"uint32": 32,
},
"uint8": {
"int8": 8,
"uint8": 8,
"int16": 16,
"uint16": 16,
"int32": 32,
"uint32": 32,
},
"float8_e4m3": {
"int8": 8,
"uint8": 8,
"int16": 16,
"uint16": 16,
"int32": 32,
"uint32": 32,
},
"float8_e5m2": {
"int8": 8,
"uint8": 8,
"int16": 16,
"uint16": 16,
"int32": 32,
"uint32": 32,
},
}
E_REPLICATE_FACTOR = { # metadata replicate every 4 consecutive threads
"float32": 2,
"float16": 2, # 2 of 4 consecutive threads provides
"bfloat16": 2,
"int8": 1, # 4 of 4 consecutive threads provides
"uint8": 1,
"float8_e4m3": 1,
"float8_e5m2": 1,
}
# Represent the thread binding in the form of (tx, warp_n, warp_m)
is_m_first = False
def __init__(
self,
a_dtype: str = "float16",
e_dtype: str = "uint8",
b_dtype: str = "float16",
accum_dtype: str = "float16",
a_transposed: bool = False,
b_transposed: bool = False,
e_transposed: bool = False,
block_row_warps: int = 2,
block_col_warps: int = 2,
warp_row_tiles: int = 8,
warp_col_tiles: int = 8,
warp_k: int = 16,
reduce_k: int = 1,
num_elems_per_byte: int = 1,
is_m_first: bool = False,
thread_var: Var | None = None,
):
self.a_dtype = a_dtype
self.e_dtype = e_dtype
self.b_dtype = b_dtype
self.accum_dtype = accum_dtype
self.a_transposed = a_transposed
self.b_transposed = b_transposed
self.e_transposed = e_transposed
# Hint Information
self.block_row_warps = block_row_warps
self.block_col_warps = block_col_warps
self.warp_row_tiles = warp_row_tiles
self.warp_col_tiles = warp_col_tiles
self.warp_k = warp_k
self.e_factor = self.E_FACTOR_MAP[self.a_dtype][self.e_dtype]
self._initialize_k_dim(a_dtype)
self._initialize_abbrev(a_dtype, b_dtype, accum_dtype)
self._initialize_micro_size(self.M_DIM, self.k_dim)
self._initialize_local_size(self.M_DIM, self.n_dim, self.k_dim, self.WARP_SIZE)
self._initialize_mma_sp_prefix(self.k_dim)
self._initialize_is_m_first(is_m_first)
self.reduce_k = reduce_k
self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k
self.num_elems_per_byte = num_elems_per_byte
self.thread_var = thread_var
if self.warp_rows == 0 or self.warp_cols == 0:
raise ValueError(
f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}"
)
def _initialize_k_dim(self, a_dtype="float16"):
if isinstance(a_dtype, str):
a_dtype = DataType(a_dtype)
# NOTE: k_dim here represents the logical shape of the MMA operation.
# When referring to the physical data movement, it should be divided by sparse_factor.
self.k_dim = 256 // a_dtype.bits * self.SPARSE_FACTOR
def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32):
self.local_size_a = (m_dim * k_dim) // warp_size // self.SPARSE_FACTOR
self.local_size_e = (
m_dim * k_dim) // self.e_factor // warp_size * self.E_REPLICATE_FACTOR[self.a_dtype]
self.local_size_b = (n_dim * k_dim) // warp_size
self.local_size_out = (m_dim * n_dim) // warp_size
def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype):
self.a_dtype_abbrv = self.dtype_abbrv[a_dtype]
self.b_dtype_abbrv = self.dtype_abbrv[b_dtype]
self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype]
def _initialize_mma_sp_prefix(self, k_dim: int = 16):
if k_dim == 16:
# typically used for tfloat32
self.mma_prefix = "m16n8k16"
elif k_dim == 32:
# typically used for float16/bfloat16
self.mma_prefix = "m16n8k32"
elif k_dim == 64:
# typically used for int8/fp8
self.mma_prefix = "m16n8k64"
else:
raise ValueError("Unsupported k_dim")
def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16):
warp_row_tiles = self.warp_row_tiles
warp_col_tiles = self.warp_col_tiles
assert warp_row_tiles >= 16, f"warp_row_tiles must be greater than 16, got {warp_row_tiles}"
assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}"
assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}"
assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}"
self.warp_rows = warp_row_tiles // m_dim
if warp_col_tiles % 16 == 0:
self.n_dim = 16
self.micro_size_y = 16
self.warp_cols = warp_col_tiles // 16
else:
# must be divisible by 8
self.n_dim = 8
self.micro_size_y = 8
self.warp_cols = warp_col_tiles // 8
self.micro_size_x = m_dim
# NOTE: k_dim here represents the logical shape of the MMA operation.
self.micro_size_k = k_dim
def _initialize_is_m_first(self, is_m_first: bool | None = False):
if is_m_first is not None:
self.is_m_first = is_m_first
def get_thread_binding(self):
if self.thread_var is None:
current_frame = T.KernelLaunchFrame.Current()
assert current_frame is not None, "Must be called in a T.Kernel Frame"
return current_frame.get_thread_binding()
else:
return self.thread_var
def get_store_index_map(self, inverse: bool = False) -> IndexMap:
warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
index_map = IndexMap.from_func(mma_store_index_map, index_dtype="int32")
if not inverse:
return index_map
inverse_index_map = index_map.inverse([warp_size, local_size_c])
return inverse_index_map
def extract_thread_binding(
self,
thread_id: PrimExpr,
is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
"""
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]
"""
WARP_SIZE = self.WARP_SIZE
block_row_warps = self.block_row_warps
block_col_warps = self.block_col_warps
# if is_m_first is None, then use the default value
if is_m_first is None:
is_m_first = self.is_m_first
if is_m_first:
lane_id, warp_n, warp_m = (
thread_id % WARP_SIZE,
(thread_id // WARP_SIZE) % block_col_warps,
(thread_id // (WARP_SIZE * block_col_warps)) % block_row_warps,
)
return lane_id, warp_n, warp_m
else:
lane_id, warp_m, warp_n = (
thread_id % WARP_SIZE,
(thread_id // WARP_SIZE) % block_row_warps,
(thread_id // (WARP_SIZE * block_row_warps)) % block_col_warps,
)
return lane_id, warp_n, warp_m
def ldmatrix_a(self, A_local_buf: Buffer, A_shared_buf: Buffer, ki: PrimExpr, rk: PrimExpr = 0):
warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows
warp_k = self.warp_k
micro_size_x = self.micro_size_x
micro_size_k = self.micro_size_k
local_size_a = self.local_size_a
a_dtype = self.a_dtype
a_transposed = self.a_transposed
# ldmatrix cannot be used for int8 + trans case.
ldmatrix_available = not (DataType(a_dtype).bits != 16 and a_transposed)
def mma_load_layout(i, j):
return i, j
if not ldmatrix_available:
if DataType(a_dtype).bits == 8:
mma_load_layout = mma_sp_load_a_32x16_to_shared_16x64_layout
elif DataType(a_dtype).bits == 16:
mma_load_layout = mma_sp_load_a_32x8_to_shared_16x32_layout
elif DataType(a_dtype).bits == 32:
mma_load_layout = mma_sp_load_a_32x4_to_shared_16x16_layout
else:
raise ValueError(f"Unsupported dtype: {a_dtype}")
thread_binding = self.get_thread_binding()
@T.macro
def _warp_ldmatrix_a(
A_local_buf,
A_shared_buf,
ki,
thread_binding,
rk=0,
):
stride = A_shared_buf.shape[-1]
tx, _, warp_m = self.extract_thread_binding(thread_binding)
trans = self.a_transposed
for i in T.serial(warp_rows):
# Assign A_shared_buf_elem
wi, wk = warp_m * warp_row_tiles + i * micro_size_x, (
rk * warp_k + ki * micro_size_k) // self.SPARSE_FACTOR
A_shared_buf_elem = A_shared_buf[wk, wi] if a_transposed else A_shared_buf[wi, wk]
if ldmatrix_available:
T.ptx_ldmatrix(
a_dtype,
T.bool(trans),
4,
".b16",
A_local_buf.data,
i * local_size_a,
T.address_of(A_shared_buf_elem),
get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed),
)
else:
for j in T.serial(local_size_a):
mi, mk = mma_load_layout(tx, j)
A_local_buf[i * local_size_a +
j] = A_shared_buf[wk + mk, wi +
mi] if a_transposed else A_shared_buf[wi + mi,
wk + mk]
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
def ldmatrix_e(self, E_local_buf: Buffer, E_shared_buf: Buffer, ki: PrimExpr, rk: PrimExpr = 0):
warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows
warp_k = self.warp_k
micro_size_x = self.micro_size_x
micro_size_k = self.micro_size_k
local_size_e = self.local_size_e
a_dtype = self.a_dtype
e_dtype = self.e_dtype
trans = self.e_transposed
# ldmatrix cannot be used for int8 + trans case.
# include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h
ldmatrix_available = False # TODO: use ldmatrix when possible
def mma_load_layout(i, j):
return i, j
if not ldmatrix_available:
if DataType(e_dtype).bits == 8:
if DataType(a_dtype).bits == 8:
mma_load_layout = metadata_8bit_load_32x4_to_shared_16x4_layout_8bit
elif DataType(a_dtype).bits == 16:
mma_load_layout = metadata_8bit_load_32x4_to_shared_16x4_layout_16bit
elif DataType(a_dtype).bits == 32:
mma_load_layout = metadata_8bit_load_32x4_to_shared_16x4_layout_32bit
else:
raise ValueError(f"Unsupported a_dtype for e_dtype 8bit: {a_dtype}")
elif DataType(e_dtype).bits == 16:
if DataType(a_dtype).bits == 8:
mma_load_layout = metadata_16bit_load_32x2_to_shared_16x4_layout_8bit
elif DataType(a_dtype).bits == 16:
mma_load_layout = metadata_16bit_load_32x2_to_shared_16x2_layout_16bit
elif DataType(a_dtype).bits == 32:
mma_load_layout = metadata_16bit_load_32x2_to_shared_16x2_layout_32bit
else:
raise ValueError(f"Unsupported a_dtype for e_dtype 16bit: {a_dtype}")
elif DataType(e_dtype).bits == 32:
if DataType(a_dtype).bits == 8:
mma_load_layout = metadata_32bit_load_32x1_to_shared_16x2_layout_8bit
else:
raise ValueError(f"Unsupported a_dtype for e_dtype 32bit: {a_dtype}")
else:
raise ValueError(f"Unsupported dtype: {e_dtype}")
thread_binding = self.get_thread_binding()
@T.macro
def _warp_ldmatrix_e(
E_local_buf,
E_shared_buf,
ki,
thread_binding,
rk=0,
):
tx, _, warp_m = self.extract_thread_binding(thread_binding)
for i in T.serial(warp_rows):
# Assign E_shared_buf_elem
wi, wk = warp_m * warp_row_tiles + i * micro_size_x, (
rk * warp_k + ki * micro_size_k) // self.e_factor
for j in T.serial(local_size_e):
mi, mk = mma_load_layout(tx, j)
E_local_buf[i * local_size_e +
j] = E_shared_buf[wk + mk,
wi + mi] if trans else E_shared_buf[wi + mi,
wk + mk]
return _warp_ldmatrix_e(E_local_buf, E_shared_buf, ki, thread_binding, rk)
def ldmatrix_b(self, B_local_buf: Buffer, B_shared_buf: Buffer, ki: PrimExpr, rk: PrimExpr = 0):
warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols
warp_k = self.warp_k
micro_size_y = self.micro_size_y
micro_size_k = self.micro_size_k
local_size_b = self.local_size_b
b_dtype = self.b_dtype
b_transposed = self.b_transposed
thread_binding = self.get_thread_binding()
replicate_b = (self.n_dim == 16)
# ldmatrix cannot be used for int8 + trans case.
ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed)
def mma_load_layout(i, j):
return i, j
if not ldmatrix_available:
if DataType(b_dtype).bits == 8:
mma_load_layout = mma_sp_load_b_32x32_to_shared_16x64_layout
elif DataType(b_dtype).bits == 16:
mma_load_layout = mma_sp_load_b_32x16_to_shared_16x32_layout
elif DataType(b_dtype).bits == 32:
mma_load_layout = mma_sp_load_b_32x8_to_shared_16x16_layout
else:
raise ValueError(f"Unsupported dtype: {b_dtype}")
@T.macro
def _warp_ldmatrix_b(
B_local_buf,
B_shared_buf,
ki,
thread_binding,
rk=0,
):
stride = B_shared_buf.shape[-1]
tx, warp_n, _ = self.extract_thread_binding(thread_binding)
trans = not b_transposed
for i in T.serial(warp_cols):
# Assign B_shared_elem
wi, wk = (
warp_n * warp_col_tiles + i * micro_size_y,
rk * warp_k + ki * micro_size_k,
)
if ldmatrix_available:
B_shared_buf_elem = B_shared_buf[wi, wk] if b_transposed else B_shared_buf[wk,
wi]
if replicate_b:
T.ptx_ldmatrix(
b_dtype,
T.bool(trans),
4,
".b16",
B_local_buf.data,
i * local_size_b,
T.address_of(B_shared_buf_elem),
get_ldmatrix_offset_b("B", tx, 0, stride, b_dtype, b_transposed),
)
T.ptx_ldmatrix(
b_dtype,
T.bool(trans),
4,
".b16",
B_local_buf.data,
i * local_size_b + lift(local_size_b) // 2,
T.address_of(B_shared_buf_elem),
get_ldmatrix_offset_b("B", tx,
lift(local_size_b) // 2, stride, b_dtype,
b_transposed),
)
else:
T.ptx_ldmatrix(
b_dtype,
T.bool(trans),
4,
".b16",
B_local_buf.data,
i * local_size_b,
T.address_of(B_shared_buf_elem),
get_ldmatrix_offset_b("B", tx, 0, stride, b_dtype, b_transposed),
)
else:
# load 16x32 data from shared buffer to local buffer
# must be transposed.
for j in T.serial(local_size_b):
mi, mk = mma_load_layout(tx, j)
B_local_buf[i * local_size_b +
j] = B_shared_buf[wi + mi, wk +
mk] if b_transposed else B_shared_buf[wk + mk,
wi + mi]
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
def mma_sp(self,
A_local_buf: Buffer,
E_local_buf: Buffer,
B_local_buf: Buffer,
C_local_buf: Buffer,
k_inner: PrimExpr = 0):
warp_rows = self.warp_rows
warp_cols = self.warp_cols
local_size_a = self.local_size_a
local_size_e = self.local_size_e
local_size_b = self.local_size_b
local_size_out = self.local_size_out
a_dtype_abbrv = self.a_dtype_abbrv
b_dtype_abbrv = self.b_dtype_abbrv
accum_dtype = self.accum_dtype
accum_dtype_abbrv = self.accum_dtype_abbrv
mma_prefix = self.mma_prefix
replicate_b = (self.n_dim == 16)
a_is_fragment = is_fragment(A_local_buf)
e_is_fragment = is_fragment(E_local_buf)
b_is_fragment = is_fragment(B_local_buf)
assert not e_is_fragment, f"currently E_local_buf must be a local allocation, found {E_local_buf.scope()}"
a_local_stride: PrimExpr = k_inner * warp_rows * local_size_a if a_is_fragment else 0
e_local_stride: PrimExpr = k_inner * warp_rows * local_size_e if e_is_fragment else 0
b_local_stride: PrimExpr = k_inner * warp_cols * local_size_b if b_is_fragment else 0
@T.macro
def _warp_mma_sp(A_local_buf, E_local_buf, B_local_buf, C_local_buf):
for i, j in T.grid(warp_rows, warp_cols):
T.ptx_mma_sp(
accum_dtype,
mma_prefix,
"row",
"col",
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
a_local_stride + i * local_size_a,
B_local_buf.data,
b_local_stride + j * local_size_b,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out,
E_local_buf.data, # metadata
e_local_stride + i * local_size_e, # metadata offset
self.SPARSE_SELECTOR, # sparse_selector
T.bool(False), # saturate
)
if replicate_b:
T.ptx_mma_sp(
accum_dtype,
mma_prefix,
"row",
"col",
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
a_local_stride + i * local_size_a,
B_local_buf.data,
b_local_stride + j * local_size_b + lift(local_size_b) // 2,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out +
lift(local_size_out) // 2,
E_local_buf.data, # metadata
e_local_stride + i * local_size_e, # metadata offset
self.SPARSE_SELECTOR, # sparse_selector
T.bool(False), # saturate
)
return _warp_mma_sp(A_local_buf, E_local_buf, B_local_buf, C_local_buf)
def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None):
block_row_warps = self.block_row_warps
block_col_warps = self.block_col_warps
warp_rows = self.warp_rows
warp_cols = self.warp_cols
local_size_out = self.local_size_out
is_global = pid_m is not None and pid_n is not None
BLOCK_M = block_row_warps * warp_rows
BLOCK_N = block_col_warps * warp_cols
M_DIM, n_dim = self.M_DIM, self.n_dim
C_buf_dims = len(C_buf.shape)
assert C_buf_dims in {2, 4}, "C_buf should be 2D or 4D"
thread_binding = self.get_thread_binding()
# STS
# MMA Store must be in simulated instead of TVM Intrins
# As TVM Intrins is like a hack that the threadIdx.x should be always
# equal to the warp_size
@T.macro
def _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding):
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
for i, j in T.grid(warp_rows, warp_cols):
for local_id_o in T.serial(local_size_out // 2):
for local_id_i in T.vectorized(2):
local_id = local_id_o * 2 + local_id_i
row, col = T.meta_var(mma_store_index_map(tx, local_id))
if C_buf_dims == 2:
C_buf[(warp_m * warp_rows + i) * M_DIM + row,
(warp_n * warp_cols + j) * n_dim +
col] = C_local_buf[i * (warp_cols * local_size_out) +
j * local_size_out + local_id]
else:
C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row,
col] = C_local_buf[i * (warp_cols * local_size_out) +
j * local_size_out + local_id]
@T.macro
def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding):
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
for i, j in T.grid(warp_rows, warp_cols):
for local_id_o in T.serial(local_size_out // 2):
for local_id_i in T.vectorized(2):
local_id = local_id_o * 2 + local_id_i
row, col = T.meta_var(mma_store_index_map(tx, local_id))
C_buf[
(pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row,
(pid_n * BLOCK_N + warp_n * warp_cols + j) * n_dim + col,
] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out +
local_id]
return (_warp_stmatrix_global(C_local_buf, C_buf, thread_binding)
if is_global else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding))
def make_mma_load_layout(self,
local_buf: Buffer,
matrix: Literal["A", "B"] = "A") -> T.Fragment:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
map fragment indices to threads and local indices.
Parameters
----------
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.utils import is_fragment
assert matrix in ["A", "B"], "matrix should be either A or B"
matrix_is_a: bool = matrix == "A"
matrix_is_b: bool = matrix == "B"
dtype = self.a_dtype if matrix_is_a else self.b_dtype
dtype_bits = DataType(dtype).bits
transposed = self.a_transposed if matrix_is_a else self.b_transposed
# s represents spatial axis
# r represents reduction axis
# sr represents the two dims are spatial + reduction
# rs represents the two dims are reduction + spatial
# sr also can represent a non-transposed basic layout
# then rs also can represent a transposed basic layout
transform_func_sr_a: Callable = None
transform_func_sr_b: Callable = None
if dtype_bits == 32:
transform_func_sr_a = shared_16x16_to_mma_sp_layout_sr_a
transform_func_sr_b = shared_16x16_to_mma_sp_layout_sr_b
elif dtype_bits == 16:
transform_func_sr_a = shared_16x32_to_mma_sp_layout_sr_a
transform_func_sr_b = shared_16x32_to_mma_sp_layout_sr_b
elif dtype_bits == 8:
transform_func_sr_a = shared_16x64_to_mma_sp_layout_sr_a
transform_func_sr_b = shared_16x64_to_mma_sp_layout_sr_b
else:
raise ValueError(f"Unsupported dtype {dtype}")
is_sr_conditions = [False]
is_sr_conditions.append(matrix_is_a and not transposed)
is_sr_conditions.append(matrix_is_b and transposed)
is_sr_axis_order = any(is_sr_conditions)
# the layout of mma.sync is row.col.
# so the b matrix expected a transposed basic layout
transform_func: Callable = None
if matrix_is_a:
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
j, i)
elif matrix_is_b:
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(
j, i)
else:
raise ValueError(f"Unsupported matrix {matrix}")
assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}"
if matrix_is_a:
micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k
else:
micro_size_r, micro_size_s = self.micro_size_k, self.micro_size_y
block_row_warps, block_col_warps = (
self.block_row_warps,
self.block_col_warps,
)
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
def forward_thread(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
"""
lane_id, _ = inverse_mma_load_layout.map_indices([i, j])
return lane_id
def forward_index(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
"""
_, local_id = inverse_mma_load_layout.map_indices([i, j])
return local_id
base_fragment = T.Fragment(
[micro_size_s, micro_size_r // 2 if matrix_is_a else micro_size_r] if is_sr_axis_order
else [micro_size_r // 2 if matrix_is_a else micro_size_r, micro_size_s],
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
)
warp_rows, warp_cols = self.warp_rows, self.warp_cols
chunk = self.warp_k
warp_s = warp_rows if matrix_is_a else warp_cols
warp_r = chunk // micro_size_r
block_s = block_row_warps if matrix_is_a else block_col_warps
replicate = block_col_warps if matrix_is_a else block_row_warps
if is_sr_axis_order:
warp_fragment = base_fragment.repeat([warp_s, warp_r],
repeat_on_thread=False,
lower_dim_first=False)
if matrix_is_a:
block_fragment = warp_fragment.repeat([block_s, 1],
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1],
repeat_on_thread=True,
lower_dim_first=True)
else:
raise ValueError(f"Unsupported matrix type {matrix}")
else:
warp_fragment = base_fragment.repeat([warp_r, warp_s],
repeat_on_thread=False,
lower_dim_first=True)
if matrix_is_a:
block_fragment = warp_fragment.repeat([1, block_s],
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s],
repeat_on_thread=True,
lower_dim_first=True)
else:
raise ValueError(f"Unsupported matrix type {matrix}")
return block_fragment
def make_mma_store_layout(self, local_buf: Buffer) -> T.Fragment:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
map fragment indices to threads and local indices.
Parameters
----------
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.utils import is_fragment
shape = local_buf.shape
inverse_mma_store_layout = self.get_store_index_map(inverse=True)
assert is_fragment(local_buf), "local_buf must be a fragment"
micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y
local_size_out = self.local_size_out
block_row_warps, block_col_warps = self.block_row_warps, self.block_col_warps
warp_rows, warp_cols = self.warp_rows, self.warp_cols
warp_size = self.WARP_SIZE
is_m_first = self.is_m_first
def forward_thread(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
map them to a thread index according to `inverse_mma_store_layout`.
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i, block_j = (i // micro_size_x) // warp_rows, (j // micro_size_y) // warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i, mma_j = i % micro_size_x, j % micro_size_y
lane_id, _ = inverse_mma_store_layout.map_indices([mma_i, mma_j])
if is_m_first:
thread_id = block_i * (block_col_warps * warp_cols) + block_j * warp_size + lane_id
else:
thread_id = block_j * (block_row_warps * warp_size) + block_i * warp_size + lane_id
return thread_id
def forward_index(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
map them to a local index in a single thread according
to `inverse_mma_store_layout`.
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i, warp_j = (i // micro_size_x) % warp_rows, (j // micro_size_y) % warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i, mma_j = i % micro_size_x, j % micro_size_y
_, local_id = inverse_mma_store_layout.map_indices([mma_i, mma_j])
return warp_i * (warp_cols * local_size_out) + warp_j * local_size_out + local_id
return T.Fragment(
shape,
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment