test_quark.py 9.12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
"""Test model set-up and weight loading for quark-quantized models.

Run `pytest tests/quantization/test_quark.py`.
6
7

See also `tests/kernels/moe/test_mxfp4_moe.py`.
8
9
"""

10
11
12
import importlib.metadata
import os
from dataclasses import dataclass
13
from importlib.util import find_spec
14
15
16

import huggingface_hub
import lm_eval
17
import pytest
18
import torch
19
from packaging import version
20
21

from vllm.model_executor.layers.quantization.quark.quark import (  # noqa: E501
22
23
24
25
    QuarkLinearMethod,
    QuarkW8A8Fp8,
    QuarkW8A8Int8,
)
26
from vllm.platforms import current_platform
27

28
29
from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch

30
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
31
32
    importlib.metadata.version("amd-quark")
) >= version.parse("0.8.99")
33
34

if QUARK_MXFP4_AVAILABLE:
35
    from quark.torch.export.nn.modules.realquantizer import StaticScaledRealQuantizer
36
37
38
39
40
    from quark.torch.kernel import mx as mx_kernel
    from quark.torch.quantization.config.config import FP4PerGroupSpec

try:
    huggingface_hub.list_repo_refs(
41
42
        "amd/Llama-3.3-70B-Instruct-WMXFP4-AMXFP4-KVFP8-Scale-UINT8-SQ"
    )
43
44
45
46
    HF_HUB_AMD_ORG_ACCESS = True
except huggingface_hub.errors.RepositoryNotFoundError:
    HF_HUB_AMD_ORG_ACCESS = False

47

48
@pytest.fixture(scope="function", autouse=True)
49
50
51
def enable_pickle(monkeypatch):
    """`LLM.apply_model` requires pickling a function."""
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
52
53


54
55
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.parametrize("tp", [1])
56
def test_quark_fp8_w_per_tensor_a_per_tensor(vllm_runner, kv_cache_dtype, tp):
57
    model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
58
59
60
    with vllm_runner(
        model_path, kv_cache_dtype=kv_cache_dtype, tensor_parallel_size=tp
    ) as llm:
61

62
63
        def check_model(model):
            layer = model.model.layers[0]
64

65
            qkv_proj = layer.self_attn.qkv_proj
66

67
68
69
70
71
            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)

            if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
                assert len(qkv_proj.input_scale.shape) == 0
72
                assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
73
74
75
                assert len(qkv_proj.weight_scale.shape) == 0

        llm.apply_model(check_model)
76
77
78

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        assert output
79
80


81
@pytest.mark.parametrize("tp", [1])
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def test_quark_fp8_w_per_channel_a_per_token(vllm_runner, tp):
    model_path = "amd/Qwen2.5-1.5B-Instruct-ptpc-Quark-ts"
    with vllm_runner(model_path, tensor_parallel_size=tp) as llm:

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj

            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)

            if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
                assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
96
                assert qkv_proj.weight_scale.shape[0] == qkv_proj.weight.shape[1]
97
98
99
100
101
102
                assert qkv_proj.weight_scale.shape[1] == 1

        llm.apply_model(check_model)

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        assert output
103
104


105
@pytest.mark.parametrize("tp", [1])
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp):
    model_path = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test"
    with vllm_runner(model_path, tensor_parallel_size=tp) as llm:

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj

            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Int8)

        llm.apply_model(check_model)

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        assert output
122
123
124
125
126
127
128
129
130


def test_quark_fp8_parity(vllm_runner):
    quark_model_id = "amd-quark/llama-tiny-fp8-quark-quant-method"
    fp8_model_id = "amd-quark/llama-tiny-fp8-quant-method"

    llm_kwargs = {
        "tensor_parallel_size": 1,
        "enforce_eager": True,
131
        "gpu_memory_utilization": 0.1,
132
    }
133
134
135
136
    with (
        vllm_runner(quark_model_id, **llm_kwargs) as quark_handle,
        vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle,
    ):
137

138
139
140
        def get_state_dict(model):
            return {k: v.cpu() for k, v in model.state_dict().items()}

141
142
        (quark_state_dict,) = quark_handle.apply_model(get_state_dict)
        (fp8_state_dict,) = fp8_handle.apply_model(get_state_dict)
143
144
145
146
147

    assert fp8_state_dict.keys() == quark_state_dict.keys()

    for key in fp8_state_dict:
        assert torch.equal(fp8_state_dict[key], quark_state_dict[key])
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171


@dataclass
class ModelCase:
    model_id: str
    tp: int


@dataclass
class GSM8KAccuracyTestConfig:
    model_name: str
    excepted_value: float

    def get_model_args(self) -> str:
        return (
            f"pretrained={self.model_name},"
            "dtype=auto,add_bos_token=True,tensor_parallel_size=8,gpu_memory_utilization=0.7,max_model_len=38768"
        )


ACCURACY_CONFIGS = [
    # Private model.
    GSM8KAccuracyTestConfig(
        model_name="amd/DeepSeek-R1-WMXFP4-AMXFP4-Scale-UINT8-MoE-Quant",
172
173
        excepted_value=0.96,
    ),
174
175
176
177
]


@pytest.mark.parametrize("config", ACCURACY_CONFIGS)
178
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
179
180
@pytest.mark.skipif(
    not HF_HUB_AMD_ORG_ACCESS,
181
182
    reason="Read access to huggingface.co/amd is required for this test.",
)
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def test_mxfp4_gsm8k_correctness(config: GSM8KAccuracyTestConfig):
    if torch.cuda.device_count() < 8:
        pytest.skip(
            f"This test requires >=8 gpus, got only {torch.cuda.device_count()}"
        )

    task = "gsm8k"
    rtol = 0.03

    os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"

    results = lm_eval.simple_evaluate(
        model="vllm",
        model_args=config.get_model_args(),
        tasks=task,
        batch_size=64,
        num_fewshot=8,
    )

    EXPECTED_VALUE = config.excepted_value
    measured_value = results["results"][task]["exact_match,strict-match"]
204
205
206
207
    assert (
        measured_value - rtol < EXPECTED_VALUE
        and measured_value + rtol > EXPECTED_VALUE
    ), f"Expected: {EXPECTED_VALUE} |  Measured: {measured_value}"
208
209
210
211

    del os.environ["VLLM_USE_TRITON_FLASH_ATTN"]


212
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
213
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
214
215
@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, scalings: list[int]):
216
217
218
    torch.manual_seed(0)

    hidden_size = 64 * 32
219
    inp = (torch.rand(1, hidden_size, dtype=float_dtype, device="cuda") - 0.5) * 2
220
    for i in range(hidden_size // 32):
221
222
223
        inp[:, i * 32 : (i + 1) * 32] = (
            inp[:, i * 32 : (i + 1) * 32] * scalings[i % len(scalings)]
        )
224
225
226
227
228
229
230
231

    inp_kernel = inp.clone()
    inp_kernel_clone = inp_kernel.clone()

    res_hip = mx_kernel.qdq_mxfp4_hip(inp_kernel_clone, "even")
    res_torch = qdq_mxfp4_torch(inp_kernel, "even")

    for i in range(hidden_size // 32):
232
233
        assert torch.all(torch.isfinite(res_hip[:, i * 32 : (i + 1) * 32]))
        assert torch.all(torch.isfinite(res_torch[:, i * 32 : (i + 1) * 32]))
234

235
236
237
        torch.testing.assert_close(
            res_hip[:, i * 32 : (i + 1) * 32], res_torch[:, i * 32 : (i + 1) * 32]
        )
238
239


240
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
241
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
242
243
244
245
@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
def test_mxfp4_dequant_kernel_match_quark(
    float_dtype: torch.dtype, scalings: list[int]
):
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
    qspec = FP4PerGroupSpec(
        ch_axis=-1,
        group_size=32,
        scale_format="e8m0",
        scale_calculation_mode="even",
        is_dynamic=False,
    ).to_quantization_spec()

    weight_quantizer = StaticScaledRealQuantizer(
        qspec=qspec,
        quantizer=None,
        reorder=False,
        real_quantized=True,
        float_dtype=float_dtype,
        device="cuda",
    )

    observer = qspec.observer_cls(qspec, device="cuda")

    hidden_size = 512
    shape = (11008, hidden_size)

    w = (torch.rand(shape, device="cuda", dtype=float_dtype) - 0.5) * 2

    # Make it so that different groups have different scales.
    for i in range(hidden_size // 32):
272
273
274
        w[:, i * 32 : (i + 1) * 32] = (
            w[:, i * 32 : (i + 1) * 32] * scalings[i % len(scalings)]
        )
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289

    observer(w)
    scale, _ = observer._calculate_qparams()
    weight_quantizer.scale = scale

    w_mxfp4 = weight_quantizer.to_real_quantize_params(w).to("cuda")
    weight_quantizer.maybe_convert_and_transpose_scale()

    scale = weight_quantizer.scale

    out_hip = mx_kernel.dq_mxfp4_hip(w_mxfp4, scale, float_dtype)

    out_torch = dq_mxfp4_torch(w_mxfp4, scale, float_dtype)

    assert torch.equal(out_hip, out_torch)