test_blackwell_moe.py 5.67 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json
import os
6
from typing import Any
7
8
9
10
11
12

import pytest

from tests.utils import RemoteOpenAIServer
from vllm.platforms import current_platform

13
if not current_platform.is_device_capability_family(100):
14
    pytest.skip(
15
        "This test only runs on Blackwell GPUs (SM10x).", allow_module_level=True
16
    )
17

18
19
20
21
22
23
24
25
26

@pytest.fixture(scope="module", autouse=True)
def set_test_environment():
    """Sets environment variables required for this test module."""
    # Make sure TRTLLM attention is available
    os.environ["VLLM_HAS_FLASHINFER_CUBIN"] = "1"
    # Set compilation threads to 16 to speed up startup
    os.environ["FLASHINFER_NVCC_THREADS"] = "16"

27

28
29
30
31
32
33
34
35
# Overide the backbone layers to 4 for faster startup
HF_OVERRIDE_TEXT = {
    "num_layers": 4,
    "num_hidden_layers": 4,
}
HF_OVERRIDE_MM = {
    "text_config": {"num_layers": 4, "num_hidden_layers": 4},
}
36
37


38
39
40
41
42
def can_initialize(
    model: str,
    hf_overrides: dict[str, Any] | None = None,
    extra_args: list[str] | None = None,
):
43
    # Server arguments
44
    extra_args = extra_args if extra_args is not None else []
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    server_args = [
        "--max-model-len",
        "2048",
        "--max-num-batched-tokens",
        "256",
        "--load-format",
        "dummy",
        "--trust-remote-code",
        "--limit-mm-per-prompt",
        json.dumps({"image": 0}),
        *extra_args,
    ]

    # Launch server and make a simple request
    with RemoteOpenAIServer(
60
61
        model,
        server_args,
62
        max_wait_seconds=1500,  # Due to FlashInfer compile
63
        override_hf_configs=hf_overrides,
64
    ) as server:
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        client = server.get_client()
        # Make a simple request to verify the server works
        completion = client.completions.create(
            model=model,
            prompt=["Hello, World!"],
            temperature=0,
            max_tokens=2,
        )
        print(completion)
        assert completion.choices[0].text is not None


## Llama4 ##


80
81
82
83
84
85
86
87
@pytest.mark.skip(
    reason=(
        "RuntimeError: run_moe() Expected a value of type "
        "'Optional[List[Tensor]]' for argument '_9' but instead found type "
        "'list'."
    )
)
def test_llama4_fp8_tensor_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch):
88
    can_initialize(
89
90
91
        "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
        hf_overrides=HF_OVERRIDE_MM,
        extra_args=["--moe-backend=flashinfer_cutlass"],
92
    )
93
94


95
def test_llama4_fp8_tensor_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
96
    can_initialize(
97
98
99
        "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
        hf_overrides=HF_OVERRIDE_MM,
        extra_args=["--moe-backend=flashinfer_trtllm"],
100
    )
101
102
103


def test_llama4_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch):
104
    can_initialize(
105
106
107
        "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
        hf_overrides=HF_OVERRIDE_MM,
        extra_args=["--moe-backend=flashinfer_cutlass"],
108
    )
109
110
111


def test_llama4_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
112
    can_initialize(
113
114
115
        "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
        hf_overrides=HF_OVERRIDE_MM,
        extra_args=["--moe-backend=flashinfer_trtllm"],
116
    )
117
118
119
120
121
122


## DeepSeekV3 ##


def test_deepseek_fp8_block_moe_deep_gemm(monkeypatch: pytest.MonkeyPatch):
123
124
125
126
127
    can_initialize(
        "deepseek-ai/DeepSeek-V3.1",
        hf_overrides=HF_OVERRIDE_TEXT,
        extra_args=["--moe-backend=deep_gemm"],
    )
128
129


130
131
132
133
134
135
136
@pytest.mark.skip(
    reason=(
        "Known issue: lack of kernel support. "
        "Expected failure: assert self.block_quant is None"
    )
)
def test_deepseek_fp8_block_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch):
137
138
139
140
141
    can_initialize(
        "deepseek-ai/DeepSeek-V3.1",
        hf_overrides=HF_OVERRIDE_TEXT,
        extra_args=["--moe-backend=flashinfer_cutlass"],
    )
142
143


144
def test_deepseek_fp8_block_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
145
146
147
148
149
    can_initialize(
        "deepseek-ai/DeepSeek-V3.1",
        hf_overrides=HF_OVERRIDE_TEXT,
        extra_args=["--moe-backend=flashinfer_trtllm"],
    )
150
151


152
def test_deepseek_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch):
153
154
155
156
157
    can_initialize(
        "nvidia/DeepSeek-R1-0528-FP4-v2",
        hf_overrides=HF_OVERRIDE_TEXT,
        extra_args=["--moe-backend=flashinfer_cutlass"],
    )
158
159
160


def test_deepseek_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
161
162
163
164
165
    can_initialize(
        "nvidia/DeepSeek-R1-0528-FP4-v2",
        hf_overrides=HF_OVERRIDE_TEXT,
        extra_args=["--moe-backend=flashinfer_trtllm"],
    )
166
167
168
169
170
171
172


## GPT-OSS ##


def test_gptoss_mxfp4bf16_moe_flashinfer(monkeypatch: pytest.MonkeyPatch):
    monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "1")
173
    can_initialize("openai/gpt-oss-20b", hf_overrides=HF_OVERRIDE_TEXT)
174
175


176
def test_gptoss_mxfp4mxfp8_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch):
177
    monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "1")
178
    can_initialize("openai/gpt-oss-20b", hf_overrides=HF_OVERRIDE_TEXT)
179
180


181
def test_gptoss_mxfp4mxfp8_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
182
    monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")
183
    can_initialize("openai/gpt-oss-20b", hf_overrides=HF_OVERRIDE_TEXT)
184
185


186
def test_gptoss_eager(monkeypatch: pytest.MonkeyPatch):
187
188
189
    can_initialize(
        "openai/gpt-oss-20b",
        hf_overrides=HF_OVERRIDE_TEXT,
190
        extra_args=["--enforce-eager"],
191
    )
192
193
194
195
196
197


## Qwen3 Next ##


def test_qwen3_next_bf16_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
198
199
200
201
202
    can_initialize(
        "Qwen/Qwen3-Next-80B-A3B-Instruct",
        hf_overrides=HF_OVERRIDE_TEXT,
        extra_args=["--moe-backend=flashinfer_trtllm"],
    )