test_compressed_tensors.py 14.8 KB
Newer Older
1
"""Test model set-up and weight loading for llmcompressor-quantized models.
2
3
4

Run `pytest tests/quantization/test_compressed_tensors.py`.
"""
5
from typing import Optional
6

7
import pytest
8
import torch
9
import os
10

11
from compressed_tensors.quantization import QuantizationType
12

13
from tests.models.utils import check_logprobs_close
14
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (  # noqa: E501
15
16
17
18
    CompressedTensors24, CompressedTensorsLinearMethod,
    CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
    CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
    CompressedTensorsWNA16)
19
20
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
    sparse_cutlass_supported)
zhuwenwen's avatar
zhuwenwen committed
21
from vllm.platforms import current_platform
22
from ..utils import models_path_prefix
23
24


25
26
@pytest.mark.parametrize(
    "model_args",
zhuwenwen's avatar
zhuwenwen committed
27
    [(os.path.join(models_path_prefix, "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"), "tensor",
28
      QuantizationType.INT, 2560, True),
zhuwenwen's avatar
zhuwenwen committed
29
     (os.path.join(models_path_prefix, "nm-testing/tinyllama-oneshot-w8-channel-a8-tensor"), "channel",
30
      QuantizationType.INT, 2560, True),
zhuwenwen's avatar
zhuwenwen committed
31
     (os.path.join(models_path_prefix, "nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama"), "tensor",
32
      QuantizationType.INT, 2560, False)])
33
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
34
    model_path, strategy, quant_type, shape_0, is_symmetric = model_args
35
    with vllm_runner(model_path, enforce_eager=True) as llm:
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            o_proj = layer.self_attn.o_proj
            gate_up_proj = layer.mlp.gate_up_proj
            down_proj = layer.mlp.down_proj

            # assert zp for symmetric and asymmetric cases
            def zp_valid(zp: Optional[torch.Tensor]):
                if is_symmetric:
                    return zp is None

                return zp is not None and zp.dtype is torch.int32

            assert zp_valid(qkv_proj.input_zero_point)
            assert zp_valid(o_proj.input_zero_point)
            assert zp_valid(gate_up_proj.input_zero_point)
            assert zp_valid(down_proj.input_zero_point)

            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(o_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(gate_up_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(down_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)

            assert qkv_proj.scheme.strategy == strategy
            assert qkv_proj.scheme.is_static_input_scheme
            expected_type = torch.int8

            assert qkv_proj.weight.dtype is expected_type
            assert o_proj.weight.dtype is expected_type
            assert gate_up_proj.weight.dtype is expected_type

            if qkv_proj.scheme.strategy == "tensor":
                # Make sure it is a channelwise buffer
                # After running process_weights_after_loading
                assert len(qkv_proj.weight_scale.shape) == 2
                assert qkv_proj.weight_scale.shape[0] == shape_0
                assert qkv_proj.weight_scale.shape[1] == 1
            assert qkv_proj.weight_scale.dtype is torch.float32
            assert qkv_proj.input_scale.dtype is torch.float32

        llm.apply_model(check_model)
85

86
        output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
87
88
        assert output

89

90
@pytest.mark.parametrize("model_path", [
91
92
93
94
    os.path.join(models_path_prefix, "neuralmagic/Llama-3.2-1B-quantized.w8a8"),
    os.path.join(models_path_prefix, "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym"),
    os.path.join(models_path_prefix, "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym"),
    os.path.join(models_path_prefix, "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym")
95
])
96
97
98
99
100
101
102
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10])
def test_compressed_tensors_w8a8_logprobs(hf_runner, vllm_runner,
                                          example_prompts, model_path,
                                          max_tokens, num_logprobs):
    dtype = "bfloat16"

103
    # skip language translation prompt for the static per tensor asym model
104
    if model_path == os.path.join(models_path_prefix, "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"):  # noqa: E501
105
106
        example_prompts = example_prompts[0:-1]

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    with hf_runner(model_path, dtype=dtype) as hf_model:
        hf_outputs = hf_model.generate_greedy_logprobs_limit(
            example_prompts, max_tokens, num_logprobs)

    with vllm_runner(model_path, dtype=dtype) as vllm_model:
        vllm_outputs = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)

    check_logprobs_close(
        outputs_0_lst=hf_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )


123
def test_compressed_tensors_no_enforce_eager(vllm_runner):
124
    model_path = os.path.join(models_path_prefix, "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change")
125
    with vllm_runner(model_path) as llm:
126
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
127
128
129
        assert output


130
@pytest.mark.parametrize("model_args", [
131
    (os.path.join(models_path_prefix, "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2"), "tensor"),
zhuwenwen's avatar
zhuwenwen committed
132
133
134
    (os.path.join(models_path_prefix, "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym"), "tensor"),
    (os.path.join(models_path_prefix, "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2"), "channel"),
    (os.path.join(models_path_prefix, "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym"),
135
     "channel"),
136
])
137
def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
138
    model_path, strategy = model_args
139
    with vllm_runner(model_path, dtype=torch.float16) as llm:
140

141
142
        def check_model(model):
            layer = model.model.layers[0]
143

144
145
146
147
148
149
150
151
            qkv_proj = layer.self_attn.qkv_proj

            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
            assert not qkv_proj.scheme.is_static_input_scheme
            assert qkv_proj.scheme.strategy == strategy
            assert qkv_proj.weight.dtype is torch.int8
152

153
        llm.apply_model(check_model)
154

155
        output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
156
157
        assert output

158

zhuwenwen's avatar
zhuwenwen committed
159
@pytest.mark.skipif(current_platform(),
160
                    reason="WNA16 is not supported on ROCm.")
161
162
@pytest.mark.parametrize(
    "wNa16_args",
163
164
165
    [(os.path.join(models_path_prefix,"nm-testing/tinyllama-oneshot-w4a16-channel-v2"), "channel", None, 8),
     (os.path.join(models_path_prefix,"nm-testing/tinyllama-oneshot-w4a16-group128-v2"), "group", 128, 8),
     (os.path.join(models_path_prefix,"nm-testing/tinyllama-oneshot-w8a16-per-channel"), "channel", None, 4)])
166
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
167
    model, strategy, group, pack_factor = wNa16_args
168
169
    with vllm_runner(model) as llm:

170
171
        def check_model(model):
            layer = model.model.layers[0]
172

173
174
175
176
            qkv_proj = layer.self_attn.qkv_proj
            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16)
177

178
179
180
            assert qkv_proj.scheme.strategy == strategy
            assert qkv_proj.scheme.group_size == (-1
                                                  if group is None else group)
181

182
183
184
185
186
            assert qkv_proj.weight_packed.dtype is torch.int32
            assert qkv_proj.weight_scale.dtype is torch.float16
            assert qkv_proj.scheme.pack_factor == pack_factor

        llm.apply_model(check_model)
187

188
189
190
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        assert output

191

zhuwenwen's avatar
zhuwenwen committed
192
@pytest.mark.skipif(current_platform(),
193
                    reason="W4A16 MARLIN is not supported on ROCm.")
194
def test_compressed_tensors_w4a16_marlin24(vllm_runner):
195
    model_path = os.path.join(models_path_prefix,"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t")
196
197
    with vllm_runner(model_path) as llm:

198
199
200
201
        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
202

203
204
205
206
            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24)
            assert qkv_proj.weight_packed.dtype is torch.int32
207

208
        llm.apply_model(check_model)
209

210
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
211
        assert output
212
213


zhuwenwen's avatar
zhuwenwen committed
214
@pytest.mark.skipif(current_platform(),
215
                    reason="FP8 is not supported on ROCm.")
216
def test_compressed_tensors_fp8(vllm_runner):
217
    model_path = os.path.join(models_path_prefix,"nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test")
218
219
    with vllm_runner(model_path) as llm:

220
221
        def check_model(model):
            layer = model.model.layers[0]
222

223
            qkv_proj = layer.self_attn.qkv_proj
224

225
226
227
228
229
            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(
                qkv_proj.scheme,
                (CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8))
230

231
232
233
234
235
236
237
238
239
            assert qkv_proj.input_scale.dtype is torch.float32

            if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8):
                assert len(qkv_proj.input_scale.shape) == 0
                assert qkv_proj.weight.dtype is torch.float8_e4m3fn
                assert qkv_proj.weight_scale.dtype is torch.float32
                assert len(qkv_proj.weight_scale.shape) == 0

        llm.apply_model(check_model)
240

241
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
242
        assert output
243
244


zhuwenwen's avatar
zhuwenwen committed
245
@pytest.mark.skipif(current_platform(),
246
                    reason="FP8 KV cache is not supported on ROCm.")
247
def test_compressed_tensors_kv_cache(vllm_runner):
248
    model_path = os.path.join(models_path_prefix,"nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme")
249
250
    with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
        output = llm.generate_greedy("Hello world!", max_tokens=20)
251
        assert output
252
253


254
@pytest.mark.skipif(not sparse_cutlass_supported(),
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
                    reason="Sparse FP8 is not yet supported on this GPU type.")
def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy):
    assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
    assert isinstance(qkv_proj.scheme, CompressedTensors24)

    assert qkv_proj.scheme.weight_quant.strategy == weight_strategy
    assert qkv_proj.scheme.input_quant.strategy == input_strategy
    assert qkv_proj.scheme.quantized
    assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
    sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map  # noqa: E501
    assert sparsity_map.get("Linear").format == "dense"
    assert sparsity_map.get("Linear").sparsity_structure == "2:4"


@pytest.mark.skipif(not current_platform.has_device_capability(90),
                    reason="Sparse FP8 is not yet supported on this GPU type.")
@pytest.mark.parametrize("args_2of4", [
    ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", "channel",
     "token"),
    ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing",
     "channel", "tensor"),
    ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing", "tensor",
     "tensor"),
    ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing",
     "tensor", "token"),
])
def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
    model, weight_strategy, input_strategy = args_2of4
    with vllm_runner(model) as llm:

285
286
287
288
289
290
291
292
        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
            _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)

        llm.apply_model(check_model)
293
294
295
296
297
298

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        print(output)
        assert output


299
@pytest.mark.skipif(not sparse_cutlass_supported(),
300
301
302
303
304
305
306
307
308
309
310
311
312
                    reason="Sparse FP8 is not yet supported on this GPU type.")
@pytest.mark.parametrize("args_2of4", [
    ("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing",
     "channel", "token"),
    ("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing", "tensor",
     "tensor"),
    ("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing",
     "tensor", "token"),
])
def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
    model, weight_strategy, input_strategy = args_2of4
    with vllm_runner(model) as llm:

313
314
315
316
317
318
319
320
        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert qkv_proj.scheme.weights_dtype == torch.int8
            _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)

        llm.apply_model(check_model)
321
322
323
324
325
326

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        print(output)
        assert output


327
@pytest.mark.skip(reason="2of4 sparse w16a16 CUTLASS produces bad output.")
328
329
330
@pytest.mark.skipif(
    not sparse_cutlass_supported(),
    reason="2of4 Sparse is not yet supported on this GPU type.")
331
332
333
334
335
336
@pytest.mark.parametrize(
    "args_2of4",
    [("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")])
def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
    model = args_2of4
    with vllm_runner(model) as llm:
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(qkv_proj.scheme, CompressedTensors24)

            assert qkv_proj.scheme.weight_quant is None
            assert qkv_proj.scheme.input_quant is None
            assert not qkv_proj.scheme.quantized
            assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
            sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map  # noqa: E501
            assert sparsity_map.get("Linear").format == "dense"
            assert sparsity_map.get("Linear").sparsity_structure == "2:4"

        llm.apply_model(check_model)
355
356
357
358

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        print(output)
        assert output