matmul_sparse.md 12.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
# Sparse Matrix-Matrix Multiplication with Tile Library

<div style="text-align: left;">
    <em>Author:</em> <a href="https://github.com/botbw">botbw</a>
</div>

:::{warning}
   This document is still **experimental** and may be incomplete.  

   This feature is still **experimental** and need further optimization.

   Suggestions and improvements are highly encouraged—please submit a PR!
:::

:::{tip}
It's suggested to go through `docs/deeplearning_operators/matmul.md` first.

Example code can be found at `examples/gemm_sp`.
:::

## Structured sparsity in the NVIDIA Ampere architecture

Since the Ampere architecture (sm80 and above), sparsity support has been integrated into Tensor Cores. This allows a 2:4 (or 1:2 for 32-bit data types) semi-structured matrix to be compressed into its non-zero values along with associated metadata, which can then be fed into the Tensor Core. This enables up to **2x throughput** compared to the equivalent dense computation.

:::{warning}
   This tutorial primarily focuses on CUDA, as this feature is not yet supported on ROCm. However, AMD provides a similar capability in the matrix cores of GPUs such as the MI300X.
:::

```{figure} ../_static/img/sparse_mma_storage_example.png
:align: center

Figure: Sparse MMA storage example (from PTX doc)
```

## Compress a dense tensor

To utilize sparse Tensor Cores, a dense tensor must first be **compressed** into its non-zero values along with the corresponding metadata.

Both `PyTorch` and `vLLM` use `CUTLASS` as their computation backend (see references [here](https://github.com/pytorch/pytorch/blob/a8d6afb511a69687bbb2b7e88a3cf67917e1697e/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu#L47) and [here](https://github.com/vllm-project/vllm/blob/a5dd03c1ebc5e4f56f3c9d3dc0436e9c582c978f/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh#L116)), leveraging `CUTLASS`’s built-in compressor (or reimplementing it in `PyTorch`).

A set of **CUTLASS-compatible** compressors is provided in `tilelang.utils.sparse`, where a dense tensor—along with other required arguments (e.g., block_K for sm90, transpose options)—can be passed in to perform the compression.

```python 
from tilelang.utils.sparse import compress
A_sparse, E = compress(A, transposed=trans_A, block_k=block_K)
```

Here, `A_sparse` contains all the non-zero elements of `A`, while `E` stores the corresponding metadata (indexing information) required to reconstruct the original sparse pattern.

> NOTE: When using CUTLASS compressor, there is no naive position correspondence between the positions in `A_sparse`/`A` and `E`. (i.e. the 4-element group at [n, k] doesn't match the 4-bit metadata at [n, k] if you consider metadata as int4 tensor)
The metadata is reordered internally to optimize memory access patterns (e.g., for ldsm instructions and vectorized loads).
For more information, see **A note on `gemm_sp` and `gemm_sp_v2`**.


## `T.gemm_sp` with CUTLASS's compressor

:::{warning}

It is strongly recommended to use T.gemm_sp_v2 due to its greater flexibility and faster compilation time.

:::

A 2:4 sparse GEMM kernel is similar to its dense counterpart, except that it also requires handling the associated metadata.

Check comments in below kernel code for required modification.

```python
def matmul_sp_sm80(
    M,
    N,
    K,
    block_M,
    block_N,
    block_K,
    in_dtype,
    out_dtype,
    accum_dtype,
    num_stages,
    threads,
    trans_A,
    trans_B,
):
    is_8_bit = "8" in in_dtype
    metadata_dtype = 'int32' if is_8_bit else 'int16'
    E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype]  # Calculate shape for given datatypes
    A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
    B_shape = (K, N) if not trans_B else (N, K)
    A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
    B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K)

    import tilelang.language as T

    @T.prim_func
    def main(
            A_sparse: T.Tensor(A_sparse_shape, in_dtype),
            E: T.Tensor((M, K // E_factor), metadata_dtype),
            B: T.Tensor(B_shape, in_dtype),
            C: T.Tensor((M, N), out_dtype),
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
            A_shared = T.alloc_shared(A_shared_shape, in_dtype)
            B_shared = T.alloc_shared(B_shared_shape, in_dtype)
            E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype)  # Allocate smem for metadata
            C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
            T.annotate_layout({  # Annotate reordered cutlass metadata layout
                E:
                    make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
                E_shared:
                    make_cutlass_metadata_layout(
                        E_shared, mma_dtype=in_dtype, arch="8.0"),
            })
            T.clear(C_frag)
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
                if trans_A:
                    T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared)
                else:
                    T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
                if trans_B:
                    T.copy(B[bx * block_N, k * block_K], B_shared)
                else:
                    T.copy(B[k * block_K, bx * block_N], B_shared)
                T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B)  # Call gemm_sp with non-zero values and metadata
            T.copy(C_frag, C[by * block_M, bx * block_N])

    return main
```

Under the hood, `gemm_sp` invokes templates adapted from `CUTLASS`, and a compatible metadata layout must be specified using `T.annotate_layout`.

## `T.gemm_sp_v2` with a custom compressor

To migrate to `gemm_sp_v2`, simply replace occurrences of `gemm_sp`.

Unlike `gemm_sp`, `gemm_sp_v2` can operate without `T.annotate_layout`, and it also supports user-defined layouts and compressors.

The metadata is stored in a `(u)int8`/`(u)int16`/`(u)int32` tensor, where **each 4-bit chunk represents two 2-bit indices** of non-zero elements within four consecutive elements. Here, we start with an `int16` example, which is the **default dtype** for `bf16` and `fp16` on Ampere GPUs.

Suppose we have the following row vector:
```python
t = tensor([[0, 7, 0, 3], [1, 5, 0, 0], [0, 0, 2, 4], [9, 0, 9, 0]], dtype=torch.float16).flatten()
```

The non-zero elements and their corresponding indices are:

```python
t_sp = tensor([[7, 3], [1, 5], [2, 4], [9, 9]], dtype=torch.float16).flatten()
indices = tensor([[1, 3], [0, 1], [2, 3], [0, 2]], dtype=torch.float16).flatten()
```

The corresponding uint16 metadata is:
```python
# metadata_bits = tensor([0b1101, 0b0100, 0b1110, 0b1000])
# Note: storage uses little-endian order: tensor(0b1000111001001101, dtype=torch.int16)
# Note: the above code is not runnable in python as the interpreter won't take the binary
#       as 2's complement
metadata_int16 = tensor(-29107)
```

You can decode an int16 metadata tensor using the following utility:
```python
def decode_metadata(meta: torch.Tensor) -> torch.Tensor:
    assert meta.dtype is torch.int16
    groups_per_meta = 16 // 4
    out = []
    for g in range(groups_per_meta):
        group_bits = (meta >> (g * 4)) & 0xF
        idx0 = group_bits & 0x3
        idx1 = (group_bits >> 2) & 0x3
        out.append(torch.stack([idx0, idx1], dim=-1))
    return torch.concat(out, dim=-1).view(meta.shape[0], -1)
```

The compressor can be implement at either `PyTorch`/`NumPy` level or kernel level.

For example, `PyTorch` provides an Ampere compressor [here](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L47-L179). Note that in this implementation, a [permutation](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L173-L175) is applied to match CUTLASS’s metadata layout. If you do not annotate a metadata layout when using `gemm_sp_v2`, your compressor should replicate the same behavior as the PyTorch example—but without using the `_calculate_meta_reordering_scatter_offsets` function.

If you want to use a custom metadata layout in your kernel, one approach is to define the layout in `TileLang` and then apply the same layout to both your compressor kernel and the matmul_sp kernel.

```python

@tilelang.jit(out_idx=[1, 2], pass_configs={
    tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True,
})
def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
    e_factor, e_dtype = ARCH_INFO["8.0"]
    e_K = K // e_factor
    elem, group = 2, 4

    assert M % block_M == 0, "M must be divisible by block_M"
    assert K % block_K == 0, "K must be divisible by block_K"
    assert K % e_factor == 0, "K must be divisible by e_factor"
    assert block_K % e_factor == 0, "block_K must be divisible by e_factor"

    @T.prim_func
    def kernel(
        A: T.Tensor((M, K), dtype),
        A_sp: T.Tensor((M, K // 2), dtype),
        E: T.Tensor((M, e_K), e_dtype),
    ):
        with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype)
            E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
            if use_cutlass_layout:  # NOTE: Make sure compressor metadata layout
                T.annotate_layout({ # is same with your computation kernel
                    E:
                        make_cutlass_metadata_layout(
                            E, mma_dtype="float16", arch="8.0", block_k=block_K),
                    E_shared:
                        make_cutlass_metadata_layout(
                            E_shared,
                            mma_dtype="float16",
                            arch="8.0",
                            block_k=block_K),
                })
            T.clear(A_sp_shared)
            T.clear(E_shared)
            non_zero_cnt = T.alloc_local((1, ), dtype="uint8")
            non_zero_elt_log_idx = T.alloc_local((elem, ), dtype="uint8")
            T.copy(A[bx * block_M, by * block_K], A_shared)
            for tm in T.Parallel(block_M):
                for g_i in range(0, block_K // group):
                    a_k = g_i * group
                    T.clear(non_zero_cnt)
                    T.clear(non_zero_elt_log_idx)
                    for i in range(group):
                        val = A_shared[tm, a_k + i]
                        if val != 0.0:
                            non_zero_elt_log_idx[non_zero_cnt[0]] = i
                            A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val
                            non_zero_cnt[0] += 1
                    if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3:
                        non_zero_elt_log_idx[0] = 0
                        non_zero_elt_log_idx[1] = 3
                        A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2]
                        A_sp_shared[tm, a_k // 2] = 0.0
                    elif non_zero_cnt[0] == 1:
                        A_sp_shared[tm, a_k // 2 + 1] = 0
                        non_zero_elt_log_idx[1] = 3
                    for i in T.serial(elem):
                        val = non_zero_elt_log_idx[i]
                        E_shared[tm, a_k // e_factor] |= T.shift_left(val, 4 * (g_i % (e_factor // group)) + 2 * i)
            T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2])
            T.copy(E_shared, E[bx * block_M, by * block_K // e_factor])

    return kernel
```

## A note on `gemm_sp` and `gemm_sp_v2`

Initially, `T.gemm_sp` followed the same design as `T.gemm`, lowering to a `CUTLASS` template. This inherently requires metadata to be reordered offline following a predetermined layout.

However, fixing a specific layout introduces several potential issues:

1. Painful debugging experience: Debugging a failed kernel becomes difficult due to the reordered indexing, including permutations and swizzling.

2. Limited flexibility: For example, concatenating two compressed tensors, such as `A_sparse_0` and `A_sparse_1`, into a new `A_sparse` makes sense. However, concatenating their metadata `E_0` and `E_1` may not be valid unless the layout allows it mathematically.

3. Alignment requirements: `CUTLASS` enforces strict alignment checks, and many hyperparameter configurations can lead to compilation errors. (For reference, sm8x was implemented in `CUTLASS 2`.)

`T.gemm_sp_v2` was designed to address these limitations, following the approach of `T.gemm_v2`. It lowers directly to PTX, removing the need for a fixed metadata layout.