test_compressed_tensors.py 18.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Test model set-up and weight loading for llmcompressor-quantized models.
3
4
5

Run `pytest tests/quantization/test_compressed_tensors.py`.
"""
6

7
from typing import Optional
8

9
import pytest
10
import torch
11
from compressed_tensors.quantization import QuantizationType
12

13
from tests.models.utils import check_logprobs_close
14
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (  # noqa: E501
15
16
17
18
    CompressedTensors24, CompressedTensorsLinearMethod,
    CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
    CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
    CompressedTensorsWNA16)
19
20
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
    sparse_cutlass_supported)
21
from vllm.platforms import current_platform
22
23


24
25
26
27
28
29
30
31
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
    """
    This module relies on V0 internals, so set VLLM_USE_V1=0.
    """
    monkeypatch.setenv('VLLM_USE_V1', '0')


32
33
@pytest.mark.parametrize(
    "model_args",
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    [
        (
            "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
            "tensor",
            QuantizationType.INT,
            2560,
            True,
        ),
        (
            "nm-testing/tinyllama-oneshot-w8-channel-a8-tensor",
            "channel",
            QuantizationType.INT,
            2560,
            True,
        ),
        (
            "nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama",
            "tensor",
            QuantizationType.INT,
            2560,
            False,
        ),
    ],
)
58
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
59
    model_path, strategy, quant_type, shape_0, is_symmetric = model_args
60
    with vllm_runner(model_path, enforce_eager=True) as llm:
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            o_proj = layer.self_attn.o_proj
            gate_up_proj = layer.mlp.gate_up_proj
            down_proj = layer.mlp.down_proj

            # assert zp for symmetric and asymmetric cases
            def zp_valid(zp: Optional[torch.Tensor]):
                if is_symmetric:
                    return zp is None

                return zp is not None and zp.dtype is torch.int32

            assert zp_valid(qkv_proj.input_zero_point)
            assert zp_valid(o_proj.input_zero_point)
            assert zp_valid(gate_up_proj.input_zero_point)
            assert zp_valid(down_proj.input_zero_point)

            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(o_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(gate_up_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(down_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)

            assert qkv_proj.scheme.strategy == strategy
            assert qkv_proj.scheme.is_static_input_scheme
            expected_type = torch.int8

            assert qkv_proj.weight.dtype is expected_type
            assert o_proj.weight.dtype is expected_type
            assert gate_up_proj.weight.dtype is expected_type

            if qkv_proj.scheme.strategy == "tensor":
                # Make sure it is a channelwise buffer
                # After running process_weights_after_loading
                assert len(qkv_proj.weight_scale.shape) == 2
                assert qkv_proj.weight_scale.shape[0] == shape_0
                assert qkv_proj.weight_scale.shape[1] == 1
            assert qkv_proj.weight_scale.dtype is torch.float32
            assert qkv_proj.input_scale.dtype is torch.float32

        llm.apply_model(check_model)
110

111
        output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
112
113
        assert output

114

115
116
117
118
119
120
121
122
123
@pytest.mark.parametrize(
    "model_path",
    [
        "neuralmagic/Llama-3.2-1B-quantized.w8a8",
        "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym",
        "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym",
        "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym",
    ],
)
124
125
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10])
126
127
128
129
130
131
132
133
def test_compressed_tensors_w8a8_logprobs(
    hf_runner,
    vllm_runner,
    example_prompts,
    model_path,
    max_tokens,
    num_logprobs,
):
134
135
    dtype = "bfloat16"

136
    # skip language translation prompt for the static per tensor asym model
137
138
139
    if (model_path ==
            "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
        ):  # noqa: E501
140
141
        example_prompts = example_prompts[0:-1]

142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    with hf_runner(model_path, dtype=dtype) as hf_model:
        hf_outputs = hf_model.generate_greedy_logprobs_limit(
            example_prompts, max_tokens, num_logprobs)

    with vllm_runner(model_path, dtype=dtype) as vllm_model:
        vllm_outputs = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)

    check_logprobs_close(
        outputs_0_lst=hf_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )


158
def test_compressed_tensors_no_enforce_eager(vllm_runner):
159
    model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
160
    with vllm_runner(model_path) as llm:
161
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
162
163
164
        assert output


165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
@pytest.mark.parametrize(
    "model_args",
    [
        ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"),
        ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym", "tensor"),
        (
            "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2",
            "channel",
        ),
        (
            "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym",
            "channel",
        ),
    ],
)
180
def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
181
    model_path, strategy = model_args
182
    with vllm_runner(model_path, dtype=torch.float16) as llm:
183

184
185
186
187
188
189
190
191
192
193
194
        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj

            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
            assert not qkv_proj.scheme.is_static_input_scheme
            assert qkv_proj.scheme.strategy == strategy
            assert qkv_proj.weight.dtype is torch.int8
195

196
        llm.apply_model(check_model)
197

198
        output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
199
200
        assert output

201

202
203
@pytest.mark.parametrize(
    "wNa16_args",
204
205
206
207
208
209
    [
        ("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8),
        ("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8),
        ("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4),
    ],
)
210
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
211
    model, strategy, group, pack_factor = wNa16_args
212
213
    with vllm_runner(model) as llm:

214
215
        def check_model(model):
            layer = model.model.layers[0]
216

217
218
219
220
            qkv_proj = layer.self_attn.qkv_proj
            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16)
221

222
223
224
225
226
227
228
            assert qkv_proj.scheme.strategy == strategy
            assert qkv_proj.scheme.group_size == (-1
                                                  if group is None else group)

            assert qkv_proj.scheme.pack_factor == pack_factor

        llm.apply_model(check_model)
229

230
231
232
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        assert output

233
234
235
236
237

def test_compressed_tensors_w4a16_marlin24(vllm_runner):
    model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
    with vllm_runner(model_path) as llm:

238
239
240
241
        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
242

243
244
245
246
247
248
            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24)
            assert qkv_proj.weight_packed.dtype is torch.int32

        llm.apply_model(check_model)
249

250
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
251
        assert output
252
253
254
255
256
257


def test_compressed_tensors_fp8(vllm_runner):
    model_path = "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
    with vllm_runner(model_path) as llm:

258
259
        def check_model(model):
            layer = model.model.layers[0]
260

261
            qkv_proj = layer.self_attn.qkv_proj
262

263
264
265
266
            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(
                qkv_proj.scheme,
267
268
                (CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8),
            )
269

270
271
272
273
274
275
276
277
278
            assert qkv_proj.input_scale.dtype is torch.float32

            if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8):
                assert len(qkv_proj.input_scale.shape) == 0
                assert qkv_proj.weight.dtype is torch.float8_e4m3fn
                assert qkv_proj.weight_scale.dtype is torch.float32
                assert len(qkv_proj.weight_scale.shape) == 0

        llm.apply_model(check_model)
279

280
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
281
        assert output
282
283
284
285
286
287


def test_compressed_tensors_kv_cache(vllm_runner):
    model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
    with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
        output = llm.generate_greedy("Hello world!", max_tokens=20)
288
        assert output
289
290


291
292
293
294
295
296
297
298
@pytest.mark.skipif(
    not sparse_cutlass_supported(),
    reason="Sparse FP8 is not yet supported on this GPU type.",
)
def _test_2of4_quant_models(qkv_proj,
                            weight_strategy,
                            input_strategy,
                            format="dense"):
299
300
301
302
303
304
305
306
    assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
    assert isinstance(qkv_proj.scheme, CompressedTensors24)

    assert qkv_proj.scheme.weight_quant.strategy == weight_strategy
    assert qkv_proj.scheme.input_quant.strategy == input_strategy
    assert qkv_proj.scheme.quantized
    assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
    sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map  # noqa: E501
307
    assert sparsity_map.get("Linear").format == format
308
309
310
    assert sparsity_map.get("Linear").sparsity_structure == "2:4"


311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
@pytest.mark.skipif(
    not current_platform.has_device_capability(90),
    reason="Sparse FP8 is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
    "args_2of4",
    [
        (
            "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing",
            "channel",
            "token",
        ),
        (
            "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing",
            "channel",
            "tensor",
        ),
        (
            "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing",
            "tensor",
            "tensor",
        ),
        (
            "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing",
            "tensor",
            "token",
        ),
    ],
)
340
341
342
343
def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
    model, weight_strategy, input_strategy = args_2of4
    with vllm_runner(model) as llm:

344
345
346
347
348
349
350
351
        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
            _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)

        llm.apply_model(check_model)
352
353
354
355
356
357

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        print(output)
        assert output


358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
@pytest.mark.skipif(
    not current_platform.has_device_capability(90),
    reason="Sparse FP8 is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
    "args_2of4",
    [
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM",
            "channel",
            "token",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM",
            "channel",
            "tensor",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM",
            "tensor",
            "token",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM",
            "tensor",
            "tensor",
        ),
    ],
)
def test_compressed_tensors_2of4_quant_fp8_compressed(vllm_runner, args_2of4):
    model, weight_strategy, input_strategy = args_2of4
    with vllm_runner(model) as llm:

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
            _test_2of4_quant_models(
                qkv_proj,
                weight_strategy,
                input_strategy,
                format="sparse-24-bitmask",
            )

        llm.apply_model(check_model)

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        print(output)
        assert output


@pytest.mark.skipif(
    not sparse_cutlass_supported(),
    reason="cutlass is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
    "args_2of4",
    [
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM",
            "channel",
            "token",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM",
            "channel",
            "tensor",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM",
            "tensor",
            "token",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM",
            "tensor",
            "tensor",
        ),
    ],
)
def test_compressed_tensors_2of4_quant_int8_compressed(vllm_runner, args_2of4):
    model, weight_strategy, input_strategy = args_2of4
    with vllm_runner(model) as llm:

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert qkv_proj.scheme.weights_dtype == torch.int8
            _test_2of4_quant_models(
                qkv_proj,
                weight_strategy,
                input_strategy,
                format="sparse-24-bitmask",
            )

        llm.apply_model(check_model)

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        print(output)
        assert output


@pytest.mark.skipif(
    not sparse_cutlass_supported(),
    reason="Sparse FP8 is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
    "args_2of4",
    [
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing",
            "channel",
            "token",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing",
            "tensor",
            "tensor",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing",
            "tensor",
            "token",
        ),
    ],
)
486
487
488
489
def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
    model, weight_strategy, input_strategy = args_2of4
    with vllm_runner(model) as llm:

490
491
492
493
494
495
496
497
        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert qkv_proj.scheme.weights_dtype == torch.int8
            _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)

        llm.apply_model(check_model)
498
499
500
501
502
503

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        print(output)
        assert output


504
505
@pytest.mark.skipif(
    not sparse_cutlass_supported(),
506
507
    reason="2of4 Sparse is not yet supported on this GPU type.",
)
508
509
@pytest.mark.parametrize(
    "args_2of4",
510
511
    [("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")],
)
512
513
514
def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
    model = args_2of4
    with vllm_runner(model) as llm:
515
516
517
518
519
520
521
522
523
524
525
526
527

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(qkv_proj.scheme, CompressedTensors24)

            assert qkv_proj.scheme.weight_quant is None
            assert qkv_proj.scheme.input_quant is None
            assert not qkv_proj.scheme.quantized
            assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
528
529
530
            sparsity_map = (
                qkv_proj.quant_method.quantization_config.sparsity_scheme_map
            )  # noqa: E501
531
532
533
534
            assert sparsity_map.get("Linear").format == "dense"
            assert sparsity_map.get("Linear").sparsity_structure == "2:4"

        llm.apply_model(check_model)
535
536
537
538

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        print(output)
        assert output
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573


@pytest.mark.skipif(
    not sparse_cutlass_supported(),
    reason="Cutlass is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
    "args_2of4", [("nm-testing/llama2.c-stories42M-pruned2.4-compressed")])
def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4):
    model = args_2of4
    with vllm_runner(model) as llm:

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(qkv_proj.scheme, CompressedTensors24)

            assert qkv_proj.scheme.weight_quant is None
            assert qkv_proj.scheme.input_quant is None
            assert not qkv_proj.scheme.quantized
            assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
            sparsity_map = (
                qkv_proj.quant_method.quantization_config.sparsity_scheme_map
            )  # noqa: E501
            assert sparsity_map.get("Linear").format == "sparse-24-bitmask"
            assert sparsity_map.get("Linear").sparsity_structure == "2:4"

        llm.apply_model(check_model)

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        print(output)
        assert output