test_quark.py 10.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
"""Test model set-up and weight loading for quark-quantized models.

Run `pytest tests/quantization/test_quark.py`.
6

7
See also `tests/kernels/moe/test_ocp_mx_moe.py`.
8
9
"""

10
11
import importlib.metadata
from dataclasses import dataclass
12
from importlib.util import find_spec
13
14
15

import huggingface_hub
import lm_eval
16
import pytest
17
import torch
18
from packaging import version
19
20

from vllm.model_executor.layers.quantization.quark.quark import (  # noqa: E501
21
22
23
24
    QuarkLinearMethod,
    QuarkW8A8Fp8,
    QuarkW8A8Int8,
)
25
from vllm.platforms import current_platform
26

27
28
from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch

29
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
30
31
    importlib.metadata.version("amd-quark")
) >= version.parse("0.8.99")
32
33

if QUARK_MXFP4_AVAILABLE:
34
    from quark.torch.export.nn.modules.realquantizer import StaticScaledRealQuantizer
35
36
37
38
39
    from quark.torch.kernel import mx as mx_kernel
    from quark.torch.quantization.config.config import FP4PerGroupSpec

try:
    huggingface_hub.list_repo_refs(
40
41
        "amd/Llama-3.3-70B-Instruct-WMXFP4-AMXFP4-KVFP8-Scale-UINT8-SQ"
    )
42
43
44
45
    HF_HUB_AMD_ORG_ACCESS = True
except huggingface_hub.errors.RepositoryNotFoundError:
    HF_HUB_AMD_ORG_ACCESS = False

46

47
@pytest.fixture(scope="function", autouse=True)
48
49
50
def enable_pickle(monkeypatch):
    """`LLM.apply_model` requires pickling a function."""
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
51
52


53
54
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.parametrize("tp", [1])
55
def test_quark_fp8_w_per_tensor_a_per_tensor(vllm_runner, kv_cache_dtype, tp):
56
    model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
57
    with vllm_runner(
58
59
60
61
        model_path,
        enforce_eager=True,
        kv_cache_dtype=kv_cache_dtype,
        tensor_parallel_size=tp,
62
    ) as llm:
63

64
65
        def check_model(model):
            layer = model.model.layers[0]
66

67
            qkv_proj = layer.self_attn.qkv_proj
68

69
70
71
72
73
            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)

            if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
                assert len(qkv_proj.input_scale.shape) == 0
74
                assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
75
76
77
                assert len(qkv_proj.weight_scale.shape) == 0

        llm.apply_model(check_model)
78

79
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
80
        assert output
81
82


83
@pytest.mark.parametrize("tp", [1])
84
85
def test_quark_fp8_w_per_channel_a_per_token(vllm_runner, tp):
    model_path = "amd/Qwen2.5-1.5B-Instruct-ptpc-Quark-ts"
86
    with vllm_runner(model_path, enforce_eager=True, tensor_parallel_size=tp) as llm:
87
88
89
90
91
92
93
94
95
96
97

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj

            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)

            if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
                assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
98
                assert qkv_proj.weight_scale.shape[0] == qkv_proj.weight.shape[1]
99
100
101
102
                assert qkv_proj.weight_scale.shape[1] == 1

        llm.apply_model(check_model)

103
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
104
        assert output
105
106


107
@pytest.mark.parametrize("tp", [1])
108
109
def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp):
    model_path = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test"
110
    with vllm_runner(model_path, enforce_eager=True, tensor_parallel_size=tp) as llm:
111
112
113
114
115
116
117
118
119
120
121

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj

            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Int8)

        llm.apply_model(check_model)

122
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
123
        assert output
124
125
126
127
128
129
130
131
132


def test_quark_fp8_parity(vllm_runner):
    quark_model_id = "amd-quark/llama-tiny-fp8-quark-quant-method"
    fp8_model_id = "amd-quark/llama-tiny-fp8-quant-method"

    llm_kwargs = {
        "tensor_parallel_size": 1,
        "enforce_eager": True,
133
        "gpu_memory_utilization": 0.1,
134
    }
135
136
137
138
    with (
        vllm_runner(quark_model_id, **llm_kwargs) as quark_handle,
        vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle,
    ):
139

140
141
142
        def get_state_dict(model):
            return {k: v.cpu() for k, v in model.state_dict().items()}

143
144
        (quark_state_dict,) = quark_handle.apply_model(get_state_dict)
        (fp8_state_dict,) = fp8_handle.apply_model(get_state_dict)
145
146
147
148
149

    assert fp8_state_dict.keys() == quark_state_dict.keys()

    for key in fp8_state_dict:
        assert torch.equal(fp8_state_dict[key], quark_state_dict[key])
150
151
152


@dataclass
153
class AccuracyTestConfig:
154
155
156
    model_name: str
    excepted_value: float

157
158
159
    def get_model_args(
        self,
        tp_size: int,
160
161
        model_max_len: int | None = None,
        kwargs: dict | None = None,
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    ) -> dict:
        if kwargs is None:
            kwargs = {}

        model_args = {
            "pretrained": self.model_name,
            "dtype": "auto",
            "add_bos_token": True,
            "tensor_parallel_size": tp_size,
            "gpu_memory_utilization": 0.7,
            **kwargs,
        }
        if model_max_len is not None:
            model_args["max_model_len"] = model_max_len

        return model_args


GSM8K_ACCURACY_CONFIGS = [
181
    # Private model.
182
    AccuracyTestConfig(
183
        model_name="amd/DeepSeek-R1-WMXFP4-AMXFP4-Scale-UINT8-MoE-Quant",
184
185
        excepted_value=0.96,
    ),
186
187
]

188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
WIKITEXT_ACCURACY_CONFIGS = [
    AccuracyTestConfig(
        model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp4_a_fp6_e2m3",
        excepted_value=11.3,
    ),
    AccuracyTestConfig(
        model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp6_e3m2_a_fp6_e3m2",
        excepted_value=10.6,
    ),
    AccuracyTestConfig(
        model_name="fxmarty/qwen_1.5-moe-a2.7b-mxfp4", excepted_value=12.4
    ),
]


@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
@pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS)
@pytest.mark.parametrize("tp_size", [1, 2])
def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
    if torch.cuda.device_count() < tp_size:
        pytest.skip(
            f"This test requires >={tp_size} gpus, got only {torch.cuda.device_count()}"
        )

    task = "wikitext"
    rtol = 0.1

    # Smaller cuda_graph_sizes to speed up the test.
    results = lm_eval.simple_evaluate(
        model="vllm",
        model_args=config.get_model_args(
            tp_size=tp_size, kwargs={"cuda_graph_sizes": [16]}
        ),
        tasks=task,
        batch_size=64,
    )

    EXPECTED_VALUE = config.excepted_value
    measured_value = results["results"][task]["word_perplexity,none"]
    assert (
        measured_value < EXPECTED_VALUE + rtol
        and measured_value > EXPECTED_VALUE - rtol
    ), f"Expected: {EXPECTED_VALUE} |  Measured: {measured_value}"

232

233
@pytest.mark.parametrize("config", GSM8K_ACCURACY_CONFIGS)
234
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
235
236
@pytest.mark.skipif(
    not HF_HUB_AMD_ORG_ACCESS,
237
238
    reason="Read access to huggingface.co/amd is required for this test.",
)
239
def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig):
240
241
242
243
244
245
246
247
248
249
    if torch.cuda.device_count() < 8:
        pytest.skip(
            f"This test requires >=8 gpus, got only {torch.cuda.device_count()}"
        )

    task = "gsm8k"
    rtol = 0.03

    results = lm_eval.simple_evaluate(
        model="vllm",
250
        model_args=config.get_model_args(tp_size=8, model_max_len=38768),
251
252
253
254
255
256
257
        tasks=task,
        batch_size=64,
        num_fewshot=8,
    )

    EXPECTED_VALUE = config.excepted_value
    measured_value = results["results"][task]["exact_match,strict-match"]
258
259
260
261
    assert (
        measured_value - rtol < EXPECTED_VALUE
        and measured_value + rtol > EXPECTED_VALUE
    ), f"Expected: {EXPECTED_VALUE} |  Measured: {measured_value}"
262
263


264
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
265
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
266
267
@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, scalings: list[int]):
268
269
270
    torch.manual_seed(0)

    hidden_size = 64 * 32
271
    inp = (torch.rand(1, hidden_size, dtype=float_dtype, device="cuda") - 0.5) * 2
272
    for i in range(hidden_size // 32):
273
274
275
        inp[:, i * 32 : (i + 1) * 32] = (
            inp[:, i * 32 : (i + 1) * 32] * scalings[i % len(scalings)]
        )
276
277
278
279
280
281
282
283

    inp_kernel = inp.clone()
    inp_kernel_clone = inp_kernel.clone()

    res_hip = mx_kernel.qdq_mxfp4_hip(inp_kernel_clone, "even")
    res_torch = qdq_mxfp4_torch(inp_kernel, "even")

    for i in range(hidden_size // 32):
284
285
        assert torch.all(torch.isfinite(res_hip[:, i * 32 : (i + 1) * 32]))
        assert torch.all(torch.isfinite(res_torch[:, i * 32 : (i + 1) * 32]))
286

287
288
289
        torch.testing.assert_close(
            res_hip[:, i * 32 : (i + 1) * 32], res_torch[:, i * 32 : (i + 1) * 32]
        )
290
291


292
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
293
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
294
295
296
297
@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
def test_mxfp4_dequant_kernel_match_quark(
    float_dtype: torch.dtype, scalings: list[int]
):
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
    qspec = FP4PerGroupSpec(
        ch_axis=-1,
        group_size=32,
        scale_format="e8m0",
        scale_calculation_mode="even",
        is_dynamic=False,
    ).to_quantization_spec()

    weight_quantizer = StaticScaledRealQuantizer(
        qspec=qspec,
        quantizer=None,
        reorder=False,
        real_quantized=True,
        float_dtype=float_dtype,
        device="cuda",
    )

    observer = qspec.observer_cls(qspec, device="cuda")

    hidden_size = 512
    shape = (11008, hidden_size)

    w = (torch.rand(shape, device="cuda", dtype=float_dtype) - 0.5) * 2

    # Make it so that different groups have different scales.
    for i in range(hidden_size // 32):
324
325
326
        w[:, i * 32 : (i + 1) * 32] = (
            w[:, i * 32 : (i + 1) * 32] * scalings[i % len(scalings)]
        )
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341

    observer(w)
    scale, _ = observer._calculate_qparams()
    weight_quantizer.scale = scale

    w_mxfp4 = weight_quantizer.to_real_quantize_params(w).to("cuda")
    weight_quantizer.maybe_convert_and_transpose_scale()

    scale = weight_quantizer.scale

    out_hip = mx_kernel.dq_mxfp4_hip(w_mxfp4, scale, float_dtype)

    out_torch = dq_mxfp4_torch(w_mxfp4, scale, float_dtype)

    assert torch.equal(out_hip, out_torch)