test_cutlass_2of4_sparse.py 7.61 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Tests for sparse cutlass kernels

5
Run `pytest tests/kernels/quantization/test_cutlass_2of4_sparse.py`.
6
7
8
9
10
"""

import pytest
import torch

11
from tests.kernels.utils import baseline_scaled_mm, to_fp8, to_int8
12
13
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
14
15
    sparse_cutlass_supported,
)
16
17
from vllm.platforms import current_platform

18
19
20
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)
]
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]


def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
    return tensor.to(dtype=torch.bfloat16)


def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
    return tensor.to(dtype=torch.float16)


def prune_to_2_4(tensor):
    # Reshape tensor to [N, 4] where N is number of groups of 4
    original_shape = tensor.shape
    reshaped = tensor.reshape(-1, 4)

    # Get indices of top 2 absolute values in each group of 4
    _, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)

    # Create binary mask
    mask = torch.zeros_like(reshaped)
44
    mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype))
45
46
47
48
49
50
51
52
53
54

    # Apply mask and reshape back
    pruned = reshaped * mask

    # Turn all -0.0 to 0.0
    pruned[pruned == -0.0] = 0.0

    return pruned.reshape(original_shape)


55
56
# This function checks that applying an identity matrix multiplication
# to the compressed weights yields the original uncompressed weights.
57
58
59
60
61
62
def check_compress_decompress_invariance(
    dtype: torch.dtype,
    b: torch.Tensor,
    b_compressed: torch.Tensor,
    b_metadata: torch.Tensor,
):
63
64
65
66
67
    # For float16 and bfloat16, cutlass_scaled_sparse_mm's output must be the
    # same dtype as its inputs. This line addresses that constraint while
    # arbitrarily using bfloat16 for the int8/fp8 cases.
    out_dtype = torch.float16 if dtype is torch.float16 else torch.bfloat16

68
69
70
71
72
    eye = torch.eye(b.shape[0], device="cuda", dtype=dtype)
    eye_scale = torch.ones(1, device="cuda", dtype=torch.float32)
    b_decomp = ops.cutlass_scaled_sparse_mm(
        eye, b_compressed, b_metadata, eye_scale, eye_scale, out_dtype=out_dtype
    )
73
74
75
76

    torch.testing.assert_close(b.to(dtype=out_dtype), b_decomp)


77
def make_rand_sparse_tensors(
78
    dtype: torch.dtype, m: int, n: int, k: int
79
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
80
81
    a = torch.randn((m, k), device="cuda")
    b = torch.randn((n, k), device="cuda").t()
82
83
84
85
86

    if dtype == torch.int8:
        # ensure A and B aren't all zeros after rounding
        a = a * 5.0
        b = b * 5.0
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

    b = prune_to_2_4(b.t()).t()

    if dtype == torch.int8:
        a, b = to_int8(a), to_int8(b)
    elif dtype == torch.float8_e4m3fn:
        a, b = to_fp8(a), to_fp8(b)
    elif dtype == torch.float16:
        a, b = to_fp16(a), to_fp16(b)
    elif dtype == torch.bfloat16:
        a, b = to_bf16(a), to_bf16(b)
    else:
        raise ValueError("unsupported dtype")

    b_compressed, e = ops.cutlass_sparse_compress(b.t())
102
    check_compress_decompress_invariance(dtype, b, b_compressed, e)
103
104
105
106
107

    # Compressed B, Metadata, Original A, B
    return b_compressed, e, a, b


108
109
110
111
@pytest.mark.skipif(
    not sparse_cutlass_supported(),
    reason="Sparse CUTLASS is not supported on this GPU type.",
)
112
113
114
115
116
117
# Test working with a subset of A and B for sparse matmul
def test_cutlass_sparse_subset():
    big_m = 1024
    m, n, k = 512, 512, 512

    # Create tensors
118
    b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, big_m, n, k)
119
120
121
122
    a = whole_a[0:m, 0:k]
    scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
    scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10

123
124
125
126
    out = ops.cutlass_scaled_sparse_mm(
        a, b_comp, e, scale_a, scale_b, out_dtype=torch.bfloat16
    )
    baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16)
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

    torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)


MNK_FACTORS = [
    (1, 256, 128),
    (1, 16384, 1024),
    (1, 24576, 512),
    (16, 256, 512),
    (16, 16384, 128),
    (16, 24576, 4096),
    (32, 8192, 4096),
    (32, 16384, 4096),
    (33, 1024, 1024),
    (33, 8192, 128),
    (64, 2048, 512),
    (64, 16384, 1024),
    (100, 8192, 512),
    (128, 32768, 4096),
    (256, 4096, 4096),
    (512, 256, 1024),
    (512, 8192, 4096),
    (512, 16384, 128),
    (512, 24576, 128),
]


# Test working with a subset of A and B for sparse matmul
155
156
157
158
@pytest.mark.skipif(
    not sparse_cutlass_supported(),
    reason="Sparse CUTLASS is not supported on this GPU type.",
)
159
@pytest.mark.parametrize("m, n, k", MNK_FACTORS)
160
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
161
@pytest.mark.parametrize("use_bias", [True, False])
162
163
164
def test_cutlass_sparse_gemm(
    m: int, k: int, n: int, dtype: type[torch.dtype], use_bias: bool
):
165
166
167
168
169
    # Create tensors
    b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
    scale_a = torch.ones((1, 1), device="cuda", dtype=torch.float32)
    scale_b = torch.ones((1, 1), device="cuda", dtype=torch.float32)

170
    bias = torch.rand((n,), device="cuda", dtype=dtype) if use_bias else None
171

172
173
174
    out = ops.cutlass_scaled_sparse_mm(
        a, b_comp, e, scale_a, scale_b, out_dtype=dtype, bias=bias
    )
175

176
    baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=dtype, bias=bias)
177
178

    torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
179
180


181
182
183
184
@pytest.mark.skipif(
    not sparse_cutlass_supported(),
    reason="Sparse CUTLASS is not supported on this GPU type.",
)
185
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
186
187
188
189
@pytest.mark.skipif(
    not current_platform.has_device_capability(89),
    reason="FP8 is not supported on this GPU type.",
)
190
191
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int, use_bias: bool):
192
193
    # Create tensors
    b_comp, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
194
195
    scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32)
    scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32)
196
197
    out_dtype = torch.bfloat16

198
    bias = torch.rand((n,), device="cuda", dtype=out_dtype) * 10 if use_bias else None
199

200
201
202
    out = ops.cutlass_scaled_sparse_mm(
        a, b_comp, e, scale_a, scale_b, out_dtype=out_dtype, bias=bias
    )
203

204
205
206
    baseline = baseline_scaled_mm(
        a, b, scale_a, scale_b, out_dtype=out_dtype, bias=bias
    )
207

208
    torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
209
210


211
212
213
214
@pytest.mark.skipif(
    not sparse_cutlass_supported(),
    reason="Sparse CUTLASS is not supported on this GPU type.",
)
215
216
217
218
@pytest.mark.parametrize("m,k,n", MNK_FACTORS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
219
220
221
def test_cutlass_sparse_int8_gemm(
    m: int, n: int, k: int, per_act_token: bool, per_out_ch: bool, use_bias: bool
):
222
223
    # Create tensors
    b_comp, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
224
225
    scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32)
    scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32)
226
227
    out_dtype = torch.bfloat16

228
229
230
231
232
233
234
235
236
    bias = torch.rand((n,), device="cuda", dtype=out_dtype) * 10 if use_bias else None

    out = ops.cutlass_scaled_sparse_mm(
        a, b_comp, e, scale_a, scale_b, out_dtype=out_dtype, bias=bias
    )

    baseline = baseline_scaled_mm(
        a, b, scale_a, scale_b, out_dtype=out_dtype, bias=bias
    )
237
238

    torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)