test_compressed_tensors.py 7.21 KB
Newer Older
1
"""Test model set-up and weight loading for llmcompressor-quantized models.
2
3
4
5

Run `pytest tests/quantization/test_compressed_tensors.py`.
"""

6
import pytest
7
import torch
8
import os
9
10

from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (  # noqa: E501
11
    CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
12
    CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
13
    CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
14
15
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
    QuantizationType)
16
from ..utils import models_path_prefix
17
18


19
@pytest.mark.parametrize("model_args", [
20
    (os.path.join(models_path_prefix, "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"), "tensor",
21
     QuantizationType.INT, 2560),
22
    (os.path.join(models_path_prefix, "nm-testing/tinyllama-oneshot-w8-channel-a8-tensor"), "channel",
23
     QuantizationType.INT, 2560),
24
25
])
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
26
    model_path, strategy, quant_type, shape_0 = model_args
27
    with vllm_runner(model_path, enforce_eager=True) as llm:
28
29
        model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model  # noqa: E501
        layer = model.model.layers[0]
30

31
32
33
34
        qkv_proj = layer.self_attn.qkv_proj
        o_proj = layer.self_attn.o_proj
        gate_up_proj = layer.mlp.gate_up_proj
        down_proj = layer.mlp.down_proj
35

36
37
38
39
40
41
        assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
        assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod)
        assert isinstance(gate_up_proj.quant_method,
                          CompressedTensorsLinearMethod)
        assert isinstance(down_proj.quant_method,
                          CompressedTensorsLinearMethod)
42
        assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
43

44
        assert qkv_proj.scheme.strategy == strategy
45
        assert qkv_proj.scheme.is_static_input_scheme
46
        expected_type = torch.int8
47
48
49
50

        assert qkv_proj.weight.dtype is expected_type
        assert o_proj.weight.dtype is expected_type
        assert gate_up_proj.weight.dtype is expected_type
51

52
        if qkv_proj.scheme.strategy == "tensor":
53
54
55
56
57
58
            # Make sure it is a channelwise buffer
            # After running process_weights_after_loading
            assert len(qkv_proj.weight_scale.shape) == 2
            assert qkv_proj.weight_scale.shape[0] == shape_0
            assert qkv_proj.weight_scale.shape[1] == 1
        assert qkv_proj.weight_scale.dtype is torch.float32
59
        assert qkv_proj.input_scale.dtype is torch.float32
60

61
        output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
62
63
        assert output

64

65
def test_compressed_tensors_no_enforce_eager(vllm_runner):
66
    model_path = os.path.join(models_path_prefix, "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change")
67
    with vllm_runner(model_path) as llm:
68
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
69
70
71
        assert output


72
@pytest.mark.parametrize("model_args", [
73
74
    (os.path.join(models_path_prefix, "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2"), "tensor"),
    (os.path.join(models_path_prefix,"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2"), "channel"),
75
76
77
])
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
    model_path, strategy = model_args
78
    with vllm_runner(model_path, dtype=torch.float16) as llm:
79
80
81
82
83
84
        model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model  # noqa: E501
        layer = model.model.layers[0]

        qkv_proj = layer.self_attn.qkv_proj

        assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
85
        assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
86
        assert not qkv_proj.scheme.is_static_input_scheme
87
        assert qkv_proj.scheme.strategy == strategy
88
        assert qkv_proj.weight.dtype is torch.int8
89

90
        output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
91
92
        assert output

93

94
95
@pytest.mark.parametrize(
    "wNa16_args",
96
97
98
    [(os.path.join(models_path_prefix,"nm-testing/tinyllama-oneshot-w4a16-channel-v2"), "channel", None, 8),
     (os.path.join(models_path_prefix,"nm-testing/tinyllama-oneshot-w4a16-group128-v2"), "group", 128, 8),
     (os.path.join(models_path_prefix,"nm-testing/tinyllama-oneshot-w8a16-per-channel"), "channel", None, 4)])
99
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
100
    model, strategy, group, pack_factor = wNa16_args
101
102
103
104
105
106
    with vllm_runner(model) as llm:
        model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model  # noqa: E501
        layer = model.model.layers[0]

        qkv_proj = layer.self_attn.qkv_proj
        assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
107
        assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16)
108
109

        assert qkv_proj.scheme.strategy == strategy
110
        assert qkv_proj.scheme.group_size == (-1 if group is None else group)
111
112
113

        assert qkv_proj.weight_packed.dtype is torch.int32
        assert qkv_proj.weight_scale.dtype is torch.float16
114
        assert qkv_proj.scheme.pack_factor == pack_factor
115

116
117
118
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        assert output

119
120

def test_compressed_tensors_w4a16_marlin24(vllm_runner):
121
    model_path = os.path.join(models_path_prefix,"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t")
122
123
124
125
126
127
128
129
130
131
    with vllm_runner(model_path) as llm:
        model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model  # noqa: E501
        layer = model.model.layers[0]

        qkv_proj = layer.self_attn.qkv_proj

        assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
        assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24)
        assert qkv_proj.weight_packed.dtype is torch.int32

132
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
133
        assert output
134
135
136


def test_compressed_tensors_fp8(vllm_runner):
137
    model_path = os.path.join(models_path_prefix,"nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test")
138
139
140
141
142
143
144
    with vllm_runner(model_path) as llm:
        model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model  # noqa: E501
        layer = model.model.layers[0]

        qkv_proj = layer.self_attn.qkv_proj

        assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
145
146
147
148
        assert isinstance(
            qkv_proj.scheme,
            (CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8))

149
        assert qkv_proj.input_scale.dtype is torch.float32
150
151
152
153
154
155

        if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8):
            assert len(qkv_proj.input_scale.shape) == 0
            assert qkv_proj.weight.dtype is torch.float8_e4m3fn
            assert qkv_proj.weight_scale.dtype is torch.float32
            assert len(qkv_proj.weight_scale.shape) == 0
156

157
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
158
        assert output
159
160
161


def test_compressed_tensors_kv_cache(vllm_runner):
162
    model_path = os.path.join(models_path_prefix,"nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme")
163
164
    with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
        output = llm.generate_greedy("Hello world!", max_tokens=20)
165
        assert output