test_compressed_tensors.py 6.69 KB
Newer Older
1
"""Test model set-up and weight loading for llmcompressor-quantized models.
2
3
4
5

Run `pytest tests/quantization/test_compressed_tensors.py`.
"""

6
import pytest
7
8
9
import torch

from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (  # noqa: E501
10
    CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
11
12
    CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
    CompressedTensorsWNA16)
13
14
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
    QuantizationType)
15
16


17
@pytest.mark.parametrize("model_args", [
18
19
20
21
    ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor",
     QuantizationType.INT, 2560),
    ("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel",
     QuantizationType.INT, 2560),
22
23
])
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
24
    model_path, strategy, quant_type, shape_0 = model_args
25
    with vllm_runner(model_path, enforce_eager=True) as llm:
26
27
        model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model  # noqa: E501
        layer = model.model.layers[0]
28

29
30
31
32
        qkv_proj = layer.self_attn.qkv_proj
        o_proj = layer.self_attn.o_proj
        gate_up_proj = layer.mlp.gate_up_proj
        down_proj = layer.mlp.down_proj
33

34
35
36
37
38
39
        assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
        assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod)
        assert isinstance(gate_up_proj.quant_method,
                          CompressedTensorsLinearMethod)
        assert isinstance(down_proj.quant_method,
                          CompressedTensorsLinearMethod)
40
        assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
41

42
        assert qkv_proj.scheme.strategy == strategy
43
        assert qkv_proj.scheme.is_static_input_scheme
44
        expected_type = torch.int8
45
46
47
48

        assert qkv_proj.weight.dtype is expected_type
        assert o_proj.weight.dtype is expected_type
        assert gate_up_proj.weight.dtype is expected_type
49

50
        if qkv_proj.scheme.strategy == "tensor":
51
52
53
54
55
56
            # Make sure it is a channelwise buffer
            # After running process_weights_after_loading
            assert len(qkv_proj.weight_scale.shape) == 2
            assert qkv_proj.weight_scale.shape[0] == shape_0
            assert qkv_proj.weight_scale.shape[1] == 1
        assert qkv_proj.weight_scale.dtype is torch.float32
57
        assert qkv_proj.input_scale.dtype is torch.float32
58

59
60
61
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        assert output

62

63
def test_compressed_tensors_no_enforce_eager(vllm_runner):
64
    model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
65
    with vllm_runner(model_path) as llm:
66
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
67
68
69
        assert output


70
71
72
73
74
75
@pytest.mark.parametrize("model_args", [
    ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"),
    ("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", "channel"),
])
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
    model_path, strategy = model_args
76
    with vllm_runner(model_path, dtype=torch.float16) as llm:
77
78
79
80
81
82
        model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model  # noqa: E501
        layer = model.model.layers[0]

        qkv_proj = layer.self_attn.qkv_proj

        assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
83
        assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
84
        assert not qkv_proj.scheme.is_static_input_scheme
85
        assert qkv_proj.scheme.strategy == strategy
86
        assert qkv_proj.weight.dtype is torch.int8
87

88
89
90
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        assert output

91

92
93
94
95
96
@pytest.mark.parametrize(
    "wNa16_args",
    [("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8),
     ("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8),
     ("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4)])
97
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
98
    model, strategy, group, pack_factor = wNa16_args
99
100
101
102
103
104
    with vllm_runner(model) as llm:
        model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model  # noqa: E501
        layer = model.model.layers[0]

        qkv_proj = layer.self_attn.qkv_proj
        assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
105
        assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16)
106
107

        assert qkv_proj.scheme.strategy == strategy
108
        assert qkv_proj.scheme.group_size == (-1 if group is None else group)
109
110
111

        assert qkv_proj.weight_packed.dtype is torch.int32
        assert qkv_proj.weight_scale.dtype is torch.float16
112
        assert qkv_proj.weight_packed.pack_factor == pack_factor
113

114
115
116
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        assert output

117
118
119
120
121
122
123
124
125
126
127
128
129

def test_compressed_tensors_w4a16_marlin24(vllm_runner):
    model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
    with vllm_runner(model_path) as llm:
        model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model  # noqa: E501
        layer = model.model.layers[0]

        qkv_proj = layer.self_attn.qkv_proj

        assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
        assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24)
        assert qkv_proj.weight_packed.dtype is torch.int32

130
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
131
        assert output
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150


def test_compressed_tensors_fp8(vllm_runner):
    model_path = "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
    with vllm_runner(model_path) as llm:
        model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model  # noqa: E501
        layer = model.model.layers[0]

        qkv_proj = layer.self_attn.qkv_proj

        assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
        assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
        assert qkv_proj.weight.dtype is torch.float8_e4m3fn
        assert qkv_proj.input_scale.dtype is torch.float32
        assert qkv_proj.weight_scale.dtype is torch.float32
        # should be scalars after processing
        assert len(qkv_proj.input_scale.shape) == 0
        assert len(qkv_proj.weight_scale.shape) == 0

151
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
152
        assert output
153
154
155
156
157
158
159


def test_compressed_tensors_kv_cache(vllm_runner):
    model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
    with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
        output = llm.generate_greedy("Hello world!", max_tokens=20)
        assert output