test_tp1_quant.py 5.83 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable

import pytest

from vllm.config import PassConfig
8
from vllm.platforms import current_platform
9
from vllm.utils.flashinfer import is_flashinfer_fp8_blockscale_gemm_supported
10
11
12
13
14
15
16
17
18
19

from .common import (
    INDUCTOR_GRAPH_PARTITION,
    AttentionBackendCase,
    Matches,
    custom_ops_combos,
    is_blackwell,
)
from .models import (
    FLASHINFER_ATTN,
20
    FLASHINFER_MLA_ATTN,
21
22
    ROCM_AITER_UNIFIED_ATTN,
    ROCM_ATTN,
23
    TRITON_ATTN,
24
25
    TRITON_MLA_ATTN,
    deepseek_v3_fp8,
26
27
28
29
30
31
32
33
34
35
36
37
38
    llama3_8b_fp4,
    llama3_8b_fp8,
    llama4_scout_fp4,
    llama4_scout_fp8,
    qwen3_a3b_fp8,
)


@pytest.mark.parametrize(
    "model_name, matches_fn, model_kwargs, hf_overrides, use_deepgemm",
    [
        (*llama3_8b_fp8, False),
        (*qwen3_a3b_fp8, False),
39
40
41
        (*qwen3_a3b_fp8, True),
        (*deepseek_v3_fp8, False),
        (*deepseek_v3_fp8, True),
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        pytest.param(
            *llama4_scout_fp8,
            False,
            marks=pytest.mark.skipif(
                not current_platform.is_cuda(),
                reason="Llama4 Scout FP8 only supported on CUDA",
            ),
        ),
    ],
)
@pytest.mark.parametrize(
    "attn_backend",
    [
        TRITON_ATTN,
        FLASHINFER_ATTN,
        ROCM_ATTN,
        ROCM_AITER_UNIFIED_ATTN,
59
60
        FLASHINFER_MLA_ATTN,
        TRITON_MLA_ATTN,
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    ],
)
@pytest.mark.parametrize("n_layers", [6])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm"))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
def test_tp1_fp8_fusions(
    model_name: str,
    matches_fn: Callable[[int], Matches],
    model_kwargs: dict,
    hf_overrides: Callable[[int], dict],
    attn_backend: AttentionBackendCase,
    n_layers: int,
    custom_ops: str,
    inductor_graph_partition: bool,
    use_deepgemm: bool,
    run_e2e_fusion_test,
    monkeypatch,
):
79
80
81
    if use_deepgemm and not current_platform.is_cuda():
        pytest.skip("DeepGemm only supported on CUDA")

82
83
84
85
    if use_deepgemm and is_flashinfer_fp8_blockscale_gemm_supported():
        # Flashinfer block FP8 GEMM has internal quantization, so it can't
        # be fused with other ops.
        pytest.skip("FlashInfer block FP8 GEMM not supported")
86
87
    if use_deepgemm and is_blackwell():
        # TODO(luka) DeepGEMM uses different quants, matching not supported
88
89
90
91
92
        #  - on Blackwell, uses a special quant fp8, currently not supported
        pytest.skip("DeepGEMM & quant matching not currently supported")

    matches = matches_fn(n_layers)

93
94
    block_fp8 = "qwen" in model_name.lower() or "deepseek" in model_name.lower()
    if block_fp8 and "-quant_fp8" in custom_ops:
95
96
97
98
99
100
101
        # This is why config forces +quant_fp8 by default
        pytest.skip("native QuantFP8 matching not supported for group quant")

    # Reduce size of model and skip weight loading time
    model_kwargs["hf_overrides"] = hf_overrides(n_layers)
    model_kwargs["load_format"] = "dummy"
    model_kwargs["max_model_len"] = 1024
102
103
    model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}

104
105
106
107
108
109
110
111
112
113
114
    compilation_config = dict(
        use_inductor_graph_partition=inductor_graph_partition,
        custom_ops=custom_ops.split(","),
        pass_config=PassConfig(
            fuse_norm_quant=True,
            fuse_act_quant=True,
            fuse_attn_quant=True,
            enable_qk_norm_rope_fusion=True,
        ),
    )

115
116
    use_aiter = current_platform.is_rocm() and ("qwen" in model_name.lower())

117
118
119
120
121
122
123
    matches_check = [
        "rms_quant_fusion",
        "act_quant_fusion",
        "norm_rope_fusion",
        "attn_quant_fusion",
    ]

124
125
126
127
128
129
130
131
132
    if use_aiter:
        matches_check[0] = "aiter_rms_quant_fusion"

        matches = matches._replace(aiter_rms_quant_fusion=matches.rms_quant_fusion)
        # TODO: enable the `norm_rope_fusion` test,
        # On ROCm norm_rope_fusion is only supported without
        # enabling AITER.
        matches_check.remove("norm_rope_fusion")

133
134
135
136
137
138
139
140
    run_e2e_fusion_test(
        model_name,
        matches,
        model_kwargs,
        attn_backend,
        compilation_config,
        matches_check,
        use_deepgemm=use_deepgemm,
141
        use_aiter=use_aiter,
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    )


@pytest.mark.parametrize(
    "model_name, matches_fn, model_kwargs, hf_overrides",
    [llama3_8b_fp4, llama4_scout_fp4],
)
@pytest.mark.parametrize("attn_backend", [FLASHINFER_ATTN])
@pytest.mark.parametrize("n_layers", [6])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm"))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
@pytest.mark.skipif(not is_blackwell(), reason="Blackwell required for fp4")
def test_tp1_fp4_fusions(
    model_name: str,
    matches_fn: Callable[[int], Matches],
    model_kwargs: dict,
    hf_overrides: Callable[[int], dict],
    attn_backend: AttentionBackendCase,
    n_layers: int,
    custom_ops: str,
    inductor_graph_partition: bool,
    run_e2e_fusion_test,
):
    matches = matches_fn(n_layers)

    # Reduce size of model and skip weight loading time
    model_kwargs["hf_overrides"] = hf_overrides(n_layers)
    model_kwargs["load_format"] = "dummy"
    model_kwargs["max_model_len"] = 1024
171
    model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193

    compilation_config = dict(
        use_inductor_graph_partition=inductor_graph_partition,
        custom_ops=custom_ops.split(","),
        pass_config=PassConfig(
            fuse_norm_quant=True,
            fuse_act_quant=True,
            fuse_attn_quant=True,
            enable_qk_norm_rope_fusion=True,
        ),
    )

    matches_check = ["act_quant_fusion", "attn_quant_fusion", "norm_rope_fusion"]

    run_e2e_fusion_test(
        model_name,
        matches,
        model_kwargs,
        attn_backend,
        compilation_config,
        matches_check,
    )