"vscode:/vscode.git/clone" did not exist on "afb390ab027918430272727da54e7c3221d28e3e"
test_compressed_tensors.py 13.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Test model set-up and weight loading for llmcompressor-quantized models.
3
4
5

Run `pytest tests/quantization/test_compressed_tensors.py`.
"""
6
from typing import Optional
7

8
import pytest
9
import torch
10
from compressed_tensors.quantization import QuantizationType
11

12
from tests.models.utils import check_logprobs_close
13
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (  # noqa: E501
14
15
16
17
    CompressedTensors24, CompressedTensorsLinearMethod,
    CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
    CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
    CompressedTensorsWNA16)
18
19
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
    sparse_cutlass_supported)
20
from vllm.platforms import current_platform
21
22


23
24
25
26
27
28
29
30
@pytest.mark.parametrize(
    "model_args",
    [("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor",
      QuantizationType.INT, 2560, True),
     ("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel",
      QuantizationType.INT, 2560, True),
     ("nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama", "tensor",
      QuantizationType.INT, 2560, False)])
31
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
32
    model_path, strategy, quant_type, shape_0, is_symmetric = model_args
33
    with vllm_runner(model_path, enforce_eager=True) as llm:
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            o_proj = layer.self_attn.o_proj
            gate_up_proj = layer.mlp.gate_up_proj
            down_proj = layer.mlp.down_proj

            # assert zp for symmetric and asymmetric cases
            def zp_valid(zp: Optional[torch.Tensor]):
                if is_symmetric:
                    return zp is None

                return zp is not None and zp.dtype is torch.int32

            assert zp_valid(qkv_proj.input_zero_point)
            assert zp_valid(o_proj.input_zero_point)
            assert zp_valid(gate_up_proj.input_zero_point)
            assert zp_valid(down_proj.input_zero_point)

            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(o_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(gate_up_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(down_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)

            assert qkv_proj.scheme.strategy == strategy
            assert qkv_proj.scheme.is_static_input_scheme
            expected_type = torch.int8

            assert qkv_proj.weight.dtype is expected_type
            assert o_proj.weight.dtype is expected_type
            assert gate_up_proj.weight.dtype is expected_type

            if qkv_proj.scheme.strategy == "tensor":
                # Make sure it is a channelwise buffer
                # After running process_weights_after_loading
                assert len(qkv_proj.weight_scale.shape) == 2
                assert qkv_proj.weight_scale.shape[0] == shape_0
                assert qkv_proj.weight_scale.shape[1] == 1
            assert qkv_proj.weight_scale.dtype is torch.float32
            assert qkv_proj.input_scale.dtype is torch.float32

        llm.apply_model(check_model)
83

84
        output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
85
86
        assert output

87

88
89
90
91
92
93
@pytest.mark.parametrize("model_path", [
    "neuralmagic/Llama-3.2-1B-quantized.w8a8",
    "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym",
    "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym",
    "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
])
94
95
96
97
98
99
100
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10])
def test_compressed_tensors_w8a8_logprobs(hf_runner, vllm_runner,
                                          example_prompts, model_path,
                                          max_tokens, num_logprobs):
    dtype = "bfloat16"

101
102
103
104
    # skip language translation prompt for the static per tensor asym model
    if model_path == "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym":  # noqa: E501
        example_prompts = example_prompts[0:-1]

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    with hf_runner(model_path, dtype=dtype) as hf_model:
        hf_outputs = hf_model.generate_greedy_logprobs_limit(
            example_prompts, max_tokens, num_logprobs)

    with vllm_runner(model_path, dtype=dtype) as vllm_model:
        vllm_outputs = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)

    check_logprobs_close(
        outputs_0_lst=hf_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )


121
def test_compressed_tensors_no_enforce_eager(vllm_runner):
122
    model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
123
    with vllm_runner(model_path) as llm:
124
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
125
126
127
        assert output


128
129
@pytest.mark.parametrize("model_args", [
    ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"),
130
    ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym", "tensor"),
131
    ("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", "channel"),
132
133
    ("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym",
     "channel"),
134
])
135
def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
136
    model_path, strategy = model_args
137
    with vllm_runner(model_path, dtype=torch.float16) as llm:
138

139
140
141
142
143
144
145
146
147
148
149
        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj

            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
            assert not qkv_proj.scheme.is_static_input_scheme
            assert qkv_proj.scheme.strategy == strategy
            assert qkv_proj.weight.dtype is torch.int8
150

151
        llm.apply_model(check_model)
152

153
        output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
154
155
        assert output

156

157
158
159
160
161
@pytest.mark.parametrize(
    "wNa16_args",
    [("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8),
     ("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8),
     ("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4)])
162
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
163
    model, strategy, group, pack_factor = wNa16_args
164
165
    with vllm_runner(model) as llm:

166
167
        def check_model(model):
            layer = model.model.layers[0]
168

169
170
171
172
            qkv_proj = layer.self_attn.qkv_proj
            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16)
173

174
175
176
177
178
179
180
181
182
            assert qkv_proj.scheme.strategy == strategy
            assert qkv_proj.scheme.group_size == (-1
                                                  if group is None else group)

            assert qkv_proj.weight_packed.dtype is torch.int32
            assert qkv_proj.weight_scale.dtype is torch.float16
            assert qkv_proj.scheme.pack_factor == pack_factor

        llm.apply_model(check_model)
183

184
185
186
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        assert output

187
188
189
190
191

def test_compressed_tensors_w4a16_marlin24(vllm_runner):
    model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
    with vllm_runner(model_path) as llm:

192
193
194
195
        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
196

197
198
199
200
201
202
            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24)
            assert qkv_proj.weight_packed.dtype is torch.int32

        llm.apply_model(check_model)
203

204
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
205
        assert output
206
207
208
209
210
211


def test_compressed_tensors_fp8(vllm_runner):
    model_path = "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
    with vllm_runner(model_path) as llm:

212
213
        def check_model(model):
            layer = model.model.layers[0]
214

215
            qkv_proj = layer.self_attn.qkv_proj
216

217
218
219
220
221
            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(
                qkv_proj.scheme,
                (CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8))
222

223
224
225
226
227
228
229
230
231
            assert qkv_proj.input_scale.dtype is torch.float32

            if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8):
                assert len(qkv_proj.input_scale.shape) == 0
                assert qkv_proj.weight.dtype is torch.float8_e4m3fn
                assert qkv_proj.weight_scale.dtype is torch.float32
                assert len(qkv_proj.weight_scale.shape) == 0

        llm.apply_model(check_model)
232

233
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
234
        assert output
235
236
237
238
239
240


def test_compressed_tensors_kv_cache(vllm_runner):
    model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
    with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
        output = llm.generate_greedy("Hello world!", max_tokens=20)
241
        assert output
242
243


244
@pytest.mark.skipif(not sparse_cutlass_supported(),
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
                    reason="Sparse FP8 is not yet supported on this GPU type.")
def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy):
    assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
    assert isinstance(qkv_proj.scheme, CompressedTensors24)

    assert qkv_proj.scheme.weight_quant.strategy == weight_strategy
    assert qkv_proj.scheme.input_quant.strategy == input_strategy
    assert qkv_proj.scheme.quantized
    assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
    sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map  # noqa: E501
    assert sparsity_map.get("Linear").format == "dense"
    assert sparsity_map.get("Linear").sparsity_structure == "2:4"


@pytest.mark.skipif(not current_platform.has_device_capability(90),
                    reason="Sparse FP8 is not yet supported on this GPU type.")
@pytest.mark.parametrize("args_2of4", [
    ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", "channel",
     "token"),
    ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing",
     "channel", "tensor"),
    ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing", "tensor",
     "tensor"),
    ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing",
     "tensor", "token"),
])
def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
    model, weight_strategy, input_strategy = args_2of4
    with vllm_runner(model) as llm:

275
276
277
278
279
280
281
282
        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
            _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)

        llm.apply_model(check_model)
283
284
285
286
287
288

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        print(output)
        assert output


289
@pytest.mark.skipif(not sparse_cutlass_supported(),
290
291
292
293
294
295
296
297
298
299
300
301
302
                    reason="Sparse FP8 is not yet supported on this GPU type.")
@pytest.mark.parametrize("args_2of4", [
    ("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing",
     "channel", "token"),
    ("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing", "tensor",
     "tensor"),
    ("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing",
     "tensor", "token"),
])
def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
    model, weight_strategy, input_strategy = args_2of4
    with vllm_runner(model) as llm:

303
304
305
306
307
308
309
310
        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert qkv_proj.scheme.weights_dtype == torch.int8
            _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)

        llm.apply_model(check_model)
311
312
313
314
315
316

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        print(output)
        assert output


317
@pytest.mark.skip(reason="2of4 sparse w16a16 CUTLASS produces bad output.")
318
319
320
@pytest.mark.skipif(
    not sparse_cutlass_supported(),
    reason="2of4 Sparse is not yet supported on this GPU type.")
321
322
323
324
325
326
@pytest.mark.parametrize(
    "args_2of4",
    [("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")])
def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
    model = args_2of4
    with vllm_runner(model) as llm:
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(qkv_proj.scheme, CompressedTensors24)

            assert qkv_proj.scheme.weight_quant is None
            assert qkv_proj.scheme.input_quant is None
            assert not qkv_proj.scheme.quantized
            assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
            sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map  # noqa: E501
            assert sparsity_map.get("Linear").format == "dense"
            assert sparsity_map.get("Linear").sparsity_structure == "2:4"

        llm.apply_model(check_model)
345
346
347
348

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        print(output)
        assert output