test_compressed_tensors.py 7.53 KB
Newer Older
1
"""Test model set-up and weight loading for llmcompressor-quantized models.
2
3
4

Run `pytest tests/quantization/test_compressed_tensors.py`.
"""
5
from typing import Optional
6

7
import pytest
8
import torch
9
from compressed_tensors.quantization import QuantizationType
10
11

from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (  # noqa: E501
12
    CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
13
    CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
14
    CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
15
16


17
18
19
20
21
22
23
24
@pytest.mark.parametrize(
    "model_args",
    [("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor",
      QuantizationType.INT, 2560, True),
     ("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel",
      QuantizationType.INT, 2560, True),
     ("nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama", "tensor",
      QuantizationType.INT, 2560, False)])
25
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
26
    model_path, strategy, quant_type, shape_0, is_symmetric = model_args
27
    with vllm_runner(model_path, enforce_eager=True) as llm:
28
29
        model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model  # noqa: E501
        layer = model.model.layers[0]
30

31
32
33
34
        qkv_proj = layer.self_attn.qkv_proj
        o_proj = layer.self_attn.o_proj
        gate_up_proj = layer.mlp.gate_up_proj
        down_proj = layer.mlp.down_proj
35

36
37
38
39
40
41
42
43
44
45
46
47
        # assert zp for symmetric and asymmetric cases
        def zp_valid(zp: Optional[torch.Tensor]):
            if is_symmetric:
                return zp is None

            return zp is not None and zp.dtype is torch.int32

        assert zp_valid(qkv_proj.input_zero_point)
        assert zp_valid(o_proj.input_zero_point)
        assert zp_valid(gate_up_proj.input_zero_point)
        assert zp_valid(down_proj.input_zero_point)

48
49
50
51
52
53
        assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
        assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod)
        assert isinstance(gate_up_proj.quant_method,
                          CompressedTensorsLinearMethod)
        assert isinstance(down_proj.quant_method,
                          CompressedTensorsLinearMethod)
54
        assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
55

56
        assert qkv_proj.scheme.strategy == strategy
57
        assert qkv_proj.scheme.is_static_input_scheme
58
        expected_type = torch.int8
59
60
61
62

        assert qkv_proj.weight.dtype is expected_type
        assert o_proj.weight.dtype is expected_type
        assert gate_up_proj.weight.dtype is expected_type
63

64
        if qkv_proj.scheme.strategy == "tensor":
65
66
67
68
69
70
            # Make sure it is a channelwise buffer
            # After running process_weights_after_loading
            assert len(qkv_proj.weight_scale.shape) == 2
            assert qkv_proj.weight_scale.shape[0] == shape_0
            assert qkv_proj.weight_scale.shape[1] == 1
        assert qkv_proj.weight_scale.dtype is torch.float32
71
        assert qkv_proj.input_scale.dtype is torch.float32
72

73
        output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
74
75
        assert output

76

77
def test_compressed_tensors_no_enforce_eager(vllm_runner):
78
    model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
79
    with vllm_runner(model_path) as llm:
80
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
81
82
83
        assert output


84
85
@pytest.mark.parametrize("model_args", [
    ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"),
86
    ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym", "tensor"),
87
    ("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", "channel"),
88
89
    ("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym",
     "channel"),
90
])
91
def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
92
    model_path, strategy = model_args
93
    with vllm_runner(model_path, dtype=torch.float16) as llm:
94
95
96
97
98
99
        model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model  # noqa: E501
        layer = model.model.layers[0]

        qkv_proj = layer.self_attn.qkv_proj

        assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
100
        assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
101
        assert not qkv_proj.scheme.is_static_input_scheme
102
        assert qkv_proj.scheme.strategy == strategy
103
        assert qkv_proj.weight.dtype is torch.int8
104

105
        output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
106
107
        assert output

108

109
110
111
112
113
@pytest.mark.parametrize(
    "wNa16_args",
    [("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8),
     ("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8),
     ("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4)])
114
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
115
    model, strategy, group, pack_factor = wNa16_args
116
117
118
119
120
121
    with vllm_runner(model) as llm:
        model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model  # noqa: E501
        layer = model.model.layers[0]

        qkv_proj = layer.self_attn.qkv_proj
        assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
122
        assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16)
123
124

        assert qkv_proj.scheme.strategy == strategy
125
        assert qkv_proj.scheme.group_size == (-1 if group is None else group)
126
127
128

        assert qkv_proj.weight_packed.dtype is torch.int32
        assert qkv_proj.weight_scale.dtype is torch.float16
129
        assert qkv_proj.scheme.pack_factor == pack_factor
130

131
132
133
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        assert output

134
135
136
137
138
139
140
141
142
143
144
145
146

def test_compressed_tensors_w4a16_marlin24(vllm_runner):
    model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
    with vllm_runner(model_path) as llm:
        model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model  # noqa: E501
        layer = model.model.layers[0]

        qkv_proj = layer.self_attn.qkv_proj

        assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
        assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24)
        assert qkv_proj.weight_packed.dtype is torch.int32

147
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
148
        assert output
149
150
151
152
153
154
155
156
157
158
159


def test_compressed_tensors_fp8(vllm_runner):
    model_path = "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
    with vllm_runner(model_path) as llm:
        model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model  # noqa: E501
        layer = model.model.layers[0]

        qkv_proj = layer.self_attn.qkv_proj

        assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
160
161
162
163
        assert isinstance(
            qkv_proj.scheme,
            (CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8))

164
        assert qkv_proj.input_scale.dtype is torch.float32
165
166
167
168
169
170

        if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8):
            assert len(qkv_proj.input_scale.shape) == 0
            assert qkv_proj.weight.dtype is torch.float8_e4m3fn
            assert qkv_proj.weight_scale.dtype is torch.float32
            assert len(qkv_proj.weight_scale.shape) == 0
171

172
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
173
        assert output
174
175
176
177
178
179


def test_compressed_tensors_kv_cache(vllm_runner):
    model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
    with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
        output = llm.generate_greedy("Hello world!", max_tokens=20)
180
        assert output