test_quark.py 10.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
"""Test model set-up and weight loading for quark-quantized models.

Run `pytest tests/quantization/test_quark.py`.
6

7
See also `tests/kernels/moe/test_ocp_mx_moe.py`.
8
9
"""

10
11
12
import importlib.metadata
import os
from dataclasses import dataclass
13
from importlib.util import find_spec
14
15
16

import huggingface_hub
import lm_eval
17
import pytest
18
import torch
19
from packaging import version
20
21

from vllm.model_executor.layers.quantization.quark.quark import (  # noqa: E501
22
23
24
25
    QuarkLinearMethod,
    QuarkW8A8Fp8,
    QuarkW8A8Int8,
)
26
from vllm.platforms import current_platform
27

28
29
from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch

30
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
31
32
    importlib.metadata.version("amd-quark")
) >= version.parse("0.8.99")
33
34

if QUARK_MXFP4_AVAILABLE:
35
    from quark.torch.export.nn.modules.realquantizer import StaticScaledRealQuantizer
36
37
38
39
40
    from quark.torch.kernel import mx as mx_kernel
    from quark.torch.quantization.config.config import FP4PerGroupSpec

try:
    huggingface_hub.list_repo_refs(
41
42
        "amd/Llama-3.3-70B-Instruct-WMXFP4-AMXFP4-KVFP8-Scale-UINT8-SQ"
    )
43
44
45
46
    HF_HUB_AMD_ORG_ACCESS = True
except huggingface_hub.errors.RepositoryNotFoundError:
    HF_HUB_AMD_ORG_ACCESS = False

47

48
@pytest.fixture(scope="function", autouse=True)
49
50
51
def enable_pickle(monkeypatch):
    """`LLM.apply_model` requires pickling a function."""
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
52
53


54
55
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.parametrize("tp", [1])
56
def test_quark_fp8_w_per_tensor_a_per_tensor(vllm_runner, kv_cache_dtype, tp):
57
    model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
58
59
60
    with vllm_runner(
        model_path, kv_cache_dtype=kv_cache_dtype, tensor_parallel_size=tp
    ) as llm:
61

62
63
        def check_model(model):
            layer = model.model.layers[0]
64

65
            qkv_proj = layer.self_attn.qkv_proj
66

67
68
69
70
71
            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)

            if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
                assert len(qkv_proj.input_scale.shape) == 0
72
                assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
73
74
75
                assert len(qkv_proj.weight_scale.shape) == 0

        llm.apply_model(check_model)
76
77
78

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        assert output
79
80


81
@pytest.mark.parametrize("tp", [1])
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def test_quark_fp8_w_per_channel_a_per_token(vllm_runner, tp):
    model_path = "amd/Qwen2.5-1.5B-Instruct-ptpc-Quark-ts"
    with vllm_runner(model_path, tensor_parallel_size=tp) as llm:

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj

            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)

            if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
                assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
96
                assert qkv_proj.weight_scale.shape[0] == qkv_proj.weight.shape[1]
97
98
99
100
101
102
                assert qkv_proj.weight_scale.shape[1] == 1

        llm.apply_model(check_model)

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        assert output
103
104


105
@pytest.mark.parametrize("tp", [1])
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp):
    model_path = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test"
    with vllm_runner(model_path, tensor_parallel_size=tp) as llm:

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj

            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Int8)

        llm.apply_model(check_model)

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        assert output
122
123
124
125
126
127
128
129
130


def test_quark_fp8_parity(vllm_runner):
    quark_model_id = "amd-quark/llama-tiny-fp8-quark-quant-method"
    fp8_model_id = "amd-quark/llama-tiny-fp8-quant-method"

    llm_kwargs = {
        "tensor_parallel_size": 1,
        "enforce_eager": True,
131
        "gpu_memory_utilization": 0.1,
132
    }
133
134
135
136
    with (
        vllm_runner(quark_model_id, **llm_kwargs) as quark_handle,
        vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle,
    ):
137

138
139
140
        def get_state_dict(model):
            return {k: v.cpu() for k, v in model.state_dict().items()}

141
142
        (quark_state_dict,) = quark_handle.apply_model(get_state_dict)
        (fp8_state_dict,) = fp8_handle.apply_model(get_state_dict)
143
144
145
146
147

    assert fp8_state_dict.keys() == quark_state_dict.keys()

    for key in fp8_state_dict:
        assert torch.equal(fp8_state_dict[key], quark_state_dict[key])
148
149
150


@dataclass
151
class AccuracyTestConfig:
152
153
154
    model_name: str
    excepted_value: float

155
156
157
    def get_model_args(
        self,
        tp_size: int,
158
159
        model_max_len: int | None = None,
        kwargs: dict | None = None,
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    ) -> dict:
        if kwargs is None:
            kwargs = {}

        model_args = {
            "pretrained": self.model_name,
            "dtype": "auto",
            "add_bos_token": True,
            "tensor_parallel_size": tp_size,
            "gpu_memory_utilization": 0.7,
            **kwargs,
        }
        if model_max_len is not None:
            model_args["max_model_len"] = model_max_len

        return model_args


GSM8K_ACCURACY_CONFIGS = [
179
    # Private model.
180
    AccuracyTestConfig(
181
        model_name="amd/DeepSeek-R1-WMXFP4-AMXFP4-Scale-UINT8-MoE-Quant",
182
183
        excepted_value=0.96,
    ),
184
185
]

186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
WIKITEXT_ACCURACY_CONFIGS = [
    AccuracyTestConfig(
        model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp4_a_fp6_e2m3",
        excepted_value=11.3,
    ),
    AccuracyTestConfig(
        model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp6_e3m2_a_fp6_e3m2",
        excepted_value=10.6,
    ),
    AccuracyTestConfig(
        model_name="fxmarty/qwen_1.5-moe-a2.7b-mxfp4", excepted_value=12.4
    ),
]


@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
@pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS)
@pytest.mark.parametrize("tp_size", [1, 2])
def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
    if torch.cuda.device_count() < tp_size:
        pytest.skip(
            f"This test requires >={tp_size} gpus, got only {torch.cuda.device_count()}"
        )

    task = "wikitext"
    rtol = 0.1

    # Smaller cuda_graph_sizes to speed up the test.
    results = lm_eval.simple_evaluate(
        model="vllm",
        model_args=config.get_model_args(
            tp_size=tp_size, kwargs={"cuda_graph_sizes": [16]}
        ),
        tasks=task,
        batch_size=64,
    )

    EXPECTED_VALUE = config.excepted_value
    measured_value = results["results"][task]["word_perplexity,none"]
    assert (
        measured_value < EXPECTED_VALUE + rtol
        and measured_value > EXPECTED_VALUE - rtol
    ), f"Expected: {EXPECTED_VALUE} |  Measured: {measured_value}"

230

231
@pytest.mark.parametrize("config", GSM8K_ACCURACY_CONFIGS)
232
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
233
234
@pytest.mark.skipif(
    not HF_HUB_AMD_ORG_ACCESS,
235
236
    reason="Read access to huggingface.co/amd is required for this test.",
)
237
def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig):
238
239
240
241
242
243
244
245
246
247
248
249
    if torch.cuda.device_count() < 8:
        pytest.skip(
            f"This test requires >=8 gpus, got only {torch.cuda.device_count()}"
        )

    task = "gsm8k"
    rtol = 0.03

    os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"

    results = lm_eval.simple_evaluate(
        model="vllm",
250
        model_args=config.get_model_args(tp_size=8, model_max_len=38768),
251
252
253
254
255
256
257
        tasks=task,
        batch_size=64,
        num_fewshot=8,
    )

    EXPECTED_VALUE = config.excepted_value
    measured_value = results["results"][task]["exact_match,strict-match"]
258
259
260
261
    assert (
        measured_value - rtol < EXPECTED_VALUE
        and measured_value + rtol > EXPECTED_VALUE
    ), f"Expected: {EXPECTED_VALUE} |  Measured: {measured_value}"
262
263
264
265

    del os.environ["VLLM_USE_TRITON_FLASH_ATTN"]


266
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
267
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
268
269
@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, scalings: list[int]):
270
271
272
    torch.manual_seed(0)

    hidden_size = 64 * 32
273
    inp = (torch.rand(1, hidden_size, dtype=float_dtype, device="cuda") - 0.5) * 2
274
    for i in range(hidden_size // 32):
275
276
277
        inp[:, i * 32 : (i + 1) * 32] = (
            inp[:, i * 32 : (i + 1) * 32] * scalings[i % len(scalings)]
        )
278
279
280
281
282
283
284
285

    inp_kernel = inp.clone()
    inp_kernel_clone = inp_kernel.clone()

    res_hip = mx_kernel.qdq_mxfp4_hip(inp_kernel_clone, "even")
    res_torch = qdq_mxfp4_torch(inp_kernel, "even")

    for i in range(hidden_size // 32):
286
287
        assert torch.all(torch.isfinite(res_hip[:, i * 32 : (i + 1) * 32]))
        assert torch.all(torch.isfinite(res_torch[:, i * 32 : (i + 1) * 32]))
288

289
290
291
        torch.testing.assert_close(
            res_hip[:, i * 32 : (i + 1) * 32], res_torch[:, i * 32 : (i + 1) * 32]
        )
292
293


294
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
295
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
296
297
298
299
@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
def test_mxfp4_dequant_kernel_match_quark(
    float_dtype: torch.dtype, scalings: list[int]
):
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
    qspec = FP4PerGroupSpec(
        ch_axis=-1,
        group_size=32,
        scale_format="e8m0",
        scale_calculation_mode="even",
        is_dynamic=False,
    ).to_quantization_spec()

    weight_quantizer = StaticScaledRealQuantizer(
        qspec=qspec,
        quantizer=None,
        reorder=False,
        real_quantized=True,
        float_dtype=float_dtype,
        device="cuda",
    )

    observer = qspec.observer_cls(qspec, device="cuda")

    hidden_size = 512
    shape = (11008, hidden_size)

    w = (torch.rand(shape, device="cuda", dtype=float_dtype) - 0.5) * 2

    # Make it so that different groups have different scales.
    for i in range(hidden_size // 32):
326
327
328
        w[:, i * 32 : (i + 1) * 32] = (
            w[:, i * 32 : (i + 1) * 32] * scalings[i % len(scalings)]
        )
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343

    observer(w)
    scale, _ = observer._calculate_qparams()
    weight_quantizer.scale = scale

    w_mxfp4 = weight_quantizer.to_real_quantize_params(w).to("cuda")
    weight_quantizer.maybe_convert_and_transpose_scale()

    scale = weight_quantizer.scale

    out_hip = mx_kernel.dq_mxfp4_hip(w_mxfp4, scale, float_dtype)

    out_torch = dq_mxfp4_torch(w_mxfp4, scale, float_dtype)

    assert torch.equal(out_hip, out_torch)