test_compressed_tensors.py 8.52 KB
Newer Older
1
"""Test model set-up and weight loading for llmcompressor-quantized models.
2
3
4

Run `pytest tests/quantization/test_compressed_tensors.py`.
"""
5
from typing import Optional
6

7
import pytest
8
import torch
9
from compressed_tensors.quantization import QuantizationType
10

11
from tests.models.utils import check_logprobs_close
12
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (  # noqa: E501
13
    CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
14
    CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
15
    CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
16
17


18
19
20
21
22
23
24
25
@pytest.mark.parametrize(
    "model_args",
    [("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor",
      QuantizationType.INT, 2560, True),
     ("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel",
      QuantizationType.INT, 2560, True),
     ("nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama", "tensor",
      QuantizationType.INT, 2560, False)])
26
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
27
    model_path, strategy, quant_type, shape_0, is_symmetric = model_args
28
    with vllm_runner(model_path, enforce_eager=True) as llm:
29
30
        model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model  # noqa: E501
        layer = model.model.layers[0]
31

32
33
34
35
        qkv_proj = layer.self_attn.qkv_proj
        o_proj = layer.self_attn.o_proj
        gate_up_proj = layer.mlp.gate_up_proj
        down_proj = layer.mlp.down_proj
36

37
38
39
40
41
42
43
44
45
46
47
48
        # assert zp for symmetric and asymmetric cases
        def zp_valid(zp: Optional[torch.Tensor]):
            if is_symmetric:
                return zp is None

            return zp is not None and zp.dtype is torch.int32

        assert zp_valid(qkv_proj.input_zero_point)
        assert zp_valid(o_proj.input_zero_point)
        assert zp_valid(gate_up_proj.input_zero_point)
        assert zp_valid(down_proj.input_zero_point)

49
50
51
52
53
54
        assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
        assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod)
        assert isinstance(gate_up_proj.quant_method,
                          CompressedTensorsLinearMethod)
        assert isinstance(down_proj.quant_method,
                          CompressedTensorsLinearMethod)
55
        assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
56

57
        assert qkv_proj.scheme.strategy == strategy
58
        assert qkv_proj.scheme.is_static_input_scheme
59
        expected_type = torch.int8
60
61
62
63

        assert qkv_proj.weight.dtype is expected_type
        assert o_proj.weight.dtype is expected_type
        assert gate_up_proj.weight.dtype is expected_type
64

65
        if qkv_proj.scheme.strategy == "tensor":
66
67
68
69
70
71
            # Make sure it is a channelwise buffer
            # After running process_weights_after_loading
            assert len(qkv_proj.weight_scale.shape) == 2
            assert qkv_proj.weight_scale.shape[0] == shape_0
            assert qkv_proj.weight_scale.shape[1] == 1
        assert qkv_proj.weight_scale.dtype is torch.float32
72
        assert qkv_proj.input_scale.dtype is torch.float32
73

74
        output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
75
76
        assert output

77

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
@pytest.mark.parametrize(
    "model_path",
    [
        "neuralmagic/Llama-3.2-1B-quantized.w8a8"
        # TODO static & asymmetric
    ])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10])
def test_compressed_tensors_w8a8_logprobs(hf_runner, vllm_runner,
                                          example_prompts, model_path,
                                          max_tokens, num_logprobs):
    dtype = "bfloat16"

    with hf_runner(model_path, dtype=dtype) as hf_model:
        hf_outputs = hf_model.generate_greedy_logprobs_limit(
            example_prompts, max_tokens, num_logprobs)

    with vllm_runner(model_path, dtype=dtype) as vllm_model:
        vllm_outputs = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)

    check_logprobs_close(
        outputs_0_lst=hf_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )


107
def test_compressed_tensors_no_enforce_eager(vllm_runner):
108
    model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
109
    with vllm_runner(model_path) as llm:
110
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
111
112
113
        assert output


114
115
@pytest.mark.parametrize("model_args", [
    ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"),
116
    ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym", "tensor"),
117
    ("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", "channel"),
118
119
    ("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym",
     "channel"),
120
])
121
def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
122
    model_path, strategy = model_args
123
    with vllm_runner(model_path, dtype=torch.float16) as llm:
124
125
126
127
128
129
        model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model  # noqa: E501
        layer = model.model.layers[0]

        qkv_proj = layer.self_attn.qkv_proj

        assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
130
        assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
131
        assert not qkv_proj.scheme.is_static_input_scheme
132
        assert qkv_proj.scheme.strategy == strategy
133
        assert qkv_proj.weight.dtype is torch.int8
134

135
        output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
136
137
        assert output

138

139
140
141
142
143
@pytest.mark.parametrize(
    "wNa16_args",
    [("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8),
     ("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8),
     ("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4)])
144
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
145
    model, strategy, group, pack_factor = wNa16_args
146
147
148
149
150
151
    with vllm_runner(model) as llm:
        model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model  # noqa: E501
        layer = model.model.layers[0]

        qkv_proj = layer.self_attn.qkv_proj
        assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
152
        assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16)
153
154

        assert qkv_proj.scheme.strategy == strategy
155
        assert qkv_proj.scheme.group_size == (-1 if group is None else group)
156
157
158

        assert qkv_proj.weight_packed.dtype is torch.int32
        assert qkv_proj.weight_scale.dtype is torch.float16
159
        assert qkv_proj.scheme.pack_factor == pack_factor
160

161
162
163
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        assert output

164
165
166
167
168
169
170
171
172
173
174
175
176

def test_compressed_tensors_w4a16_marlin24(vllm_runner):
    model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
    with vllm_runner(model_path) as llm:
        model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model  # noqa: E501
        layer = model.model.layers[0]

        qkv_proj = layer.self_attn.qkv_proj

        assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
        assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24)
        assert qkv_proj.weight_packed.dtype is torch.int32

177
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
178
        assert output
179
180
181
182
183
184
185
186
187
188
189


def test_compressed_tensors_fp8(vllm_runner):
    model_path = "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
    with vllm_runner(model_path) as llm:
        model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model  # noqa: E501
        layer = model.model.layers[0]

        qkv_proj = layer.self_attn.qkv_proj

        assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
190
191
192
193
        assert isinstance(
            qkv_proj.scheme,
            (CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8))

194
        assert qkv_proj.input_scale.dtype is torch.float32
195
196
197
198
199
200

        if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8):
            assert len(qkv_proj.input_scale.shape) == 0
            assert qkv_proj.weight.dtype is torch.float8_e4m3fn
            assert qkv_proj.weight_scale.dtype is torch.float32
            assert len(qkv_proj.weight_scale.shape) == 0
201

202
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
203
        assert output
204
205
206
207
208
209


def test_compressed_tensors_kv_cache(vllm_runner):
    model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
    with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
        output = llm.generate_greedy("Hello world!", max_tokens=20)
210
        assert output