test_unquantized_backend_selection.py 9.99 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import patch

import pytest

from tests.kernels.moe.utils import make_dummy_moe_config
from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
    UnquantizedMoeBackend,
    select_unquantized_moe_backend,
)
12
from vllm.platforms import current_platform
13

14
15
16
17
18
skipif_not_cuda_rocm = pytest.mark.skipif(
    not (current_platform.is_cuda() or current_platform.is_rocm()),
    reason="Only supported on CUDA/ROCm platforms.",
)

19
20
21
22
23

@pytest.mark.parametrize(
    "platform_method,expected_backend",
    [
        ("is_cuda", UnquantizedMoeBackend.TRITON),  # Default CUDA without FlashInfer
24
        ("is_rocm", UnquantizedMoeBackend.TRITON),  # ROCm without AITER
25
26
27
28
29
30
31
        ("is_cpu", UnquantizedMoeBackend.CPU),
        ("is_xpu", UnquantizedMoeBackend.XPU),
        ("is_tpu", UnquantizedMoeBackend.TPU),
        ("is_out_of_tree", UnquantizedMoeBackend.OOT),
    ],
)
@patch(
32
    "vllm.utils.flashinfer.has_flashinfer",
33
34
    return_value=False,
)
35
36
37
38
@patch(
    "vllm.model_executor.layers.fused_moe.oracle.unquantized.rocm_aiter_ops.is_fused_moe_enabled",
    return_value=False,
)
39
def test_select_default_backend_by_platform(
40
    mock_aiter_enabled,
41
42
43
44
45
    mock_has_flashinfer,
    monkeypatch,
    platform_method,
    expected_backend,
):
46
47
    """Test default backend selection per platform with all optional
    accelerators (FlashInfer, AITER) disabled."""
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    with patch(
        "vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform"
    ) as mock_platform:
        # Set all platform checks to False
        mock_platform.is_cuda.return_value = False
        mock_platform.is_rocm.return_value = False
        mock_platform.is_cpu.return_value = False
        mock_platform.is_xpu.return_value = False
        mock_platform.is_tpu.return_value = False
        mock_platform.is_out_of_tree.return_value = False

        # Set only the specified platform to True
        getattr(mock_platform, platform_method).return_value = True

62
63
64
65
66
67
68
69
70
    with (
        patch.object(current_platform, "is_cuda", return_value=False),
        patch.object(current_platform, "is_rocm", return_value=False),
        patch.object(current_platform, "is_cpu", return_value=False),
        patch.object(current_platform, "is_xpu", return_value=False),
        patch.object(current_platform, "is_tpu", return_value=False),
        patch.object(current_platform, "is_out_of_tree", return_value=False),
        patch.object(current_platform, platform_method, return_value=True),
    ):
71
        moe_config = make_dummy_moe_config()
72
73
        selected_backend, expert_cls = select_unquantized_moe_backend(
            moe_config=moe_config
74
75
76
        )

        assert selected_backend == expected_backend
77
78
79
80
81
82
83
84
        if expected_backend in [
            UnquantizedMoeBackend.CPU,
            UnquantizedMoeBackend.OOT,
            UnquantizedMoeBackend.TPU,
        ]:
            assert expert_cls is None
        else:
            assert expert_cls is not None
85
86


87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
@patch(
    "vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer",
    return_value=False,
)
@patch(
    "vllm.model_executor.layers.fused_moe.oracle.unquantized.rocm_aiter_ops.is_fused_moe_enabled",
    return_value=True,
)
@pytest.mark.skipif(
    not current_platform.is_rocm(), reason="ROCm-specific backend selection test"
)
def test_select_rocm_aiter_backend(mock_aiter_enabled, mock_has_flashinfer):
    """Test ROCm backend selection when AITER is available."""
    with patch(
        "vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform"
    ) as mock_platform:
        mock_platform.is_cuda.return_value = False
        mock_platform.is_rocm.return_value = True
        mock_platform.is_cpu.return_value = False
        mock_platform.is_xpu.return_value = False
        mock_platform.is_tpu.return_value = False
        mock_platform.is_out_of_tree.return_value = False

        moe_config = make_dummy_moe_config()
111
        selected_backend, expert_cls = select_unquantized_moe_backend(
112
113
114
115
            moe_config=moe_config,
        )

        assert selected_backend == UnquantizedMoeBackend.AITER
116
        assert expert_cls is not None
117
118


119
@patch(
120
    "vllm.model_executor.layers.fused_moe.experts.trtllm_bf16_moe.TrtLlmBf16Experts.is_supported_config",
121
122
    return_value=(True, None),
)
123
124
125
@pytest.mark.skipif(
    not current_platform.is_cuda(), reason="Only supported on NVIDIA platforms."
)
126
def test_select_cuda_flashinfer_trtllm_backend(mock_is_supported_trtllm, monkeypatch):
127
    """Test CUDA backend selection when FlashInfer TRTLLM is available and enabled."""
128
129
130
131
132
133
134
135
136
    with (
        patch.object(current_platform, "is_cuda", return_value=True),
        patch.object(current_platform, "is_rocm", return_value=False),
        patch.object(current_platform, "is_cpu", return_value=False),
        patch.object(current_platform, "is_xpu", return_value=False),
        patch.object(current_platform, "is_tpu", return_value=False),
        patch.object(current_platform, "is_out_of_tree", return_value=False),
        patch.object(current_platform, "has_device_capability", return_value=True),
    ):
137
138
139
        monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1")

        moe_config = make_dummy_moe_config()
140
141
142
        # TRTLLM requires EP and does not support DP
        moe_config.moe_parallel_config.use_ep = True
        moe_config.moe_parallel_config.use_dp = False
143

144
145
        selected_backend, experts_cls = select_unquantized_moe_backend(
            moe_config=moe_config
146
147
148
        )

        assert selected_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
149
        assert experts_cls is not None
150
151
152


@patch(
153
    "vllm.utils.flashinfer.has_flashinfer",
154
155
156
    return_value=True,
)
@patch(
157
    "vllm.model_executor.layers.fused_moe.experts.trtllm_bf16_moe.TrtLlmBf16Experts.is_supported_config",
158
159
    return_value=(False, None),
)
160
161
162
163
@patch(
    "vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts.is_supported_config",
    return_value=(True, None),
)
164
165
166
@pytest.mark.skipif(
    not current_platform.is_cuda(), reason="Only supported on NVIDIA platforms."
)
167
def test_select_cuda_flashinfer_cutlass_backend(
168
169
170
171
    mock_has_flashinfer,
    mock_is_supported_trtllm,
    mock_is_supported_cutlass,
    monkeypatch,
172
173
174
):
    """Test CUDA backend selection when FlashInfer TRTLLM is not available
    and FlashInfer CUTLASS is available."""
175
176
177
178
179
180
181
182
183
    with (
        patch.object(current_platform, "is_cuda", return_value=True),
        patch.object(current_platform, "is_rocm", return_value=False),
        patch.object(current_platform, "is_cpu", return_value=False),
        patch.object(current_platform, "is_xpu", return_value=False),
        patch.object(current_platform, "is_tpu", return_value=False),
        patch.object(current_platform, "is_out_of_tree", return_value=False),
        patch.object(current_platform, "has_device_capability", return_value=True),
    ):
184
185
186
187
        # Enable FlashInfer via env var
        monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1")

        moe_config = make_dummy_moe_config()
188
189
190
        # CUTLASS requires EP and does not support DP
        moe_config.moe_parallel_config.use_ep = True
        moe_config.moe_parallel_config.use_dp = False
191

192
193
        selected_backend, experts_cls = select_unquantized_moe_backend(
            moe_config=moe_config
194
195
196
        )

        assert selected_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS
197
        assert experts_cls is not None
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277


@skipif_not_cuda_rocm
def test_select_lora_backend_prefers_triton():
    """LoRA-enabled unquantized MoE should select Triton backend."""
    moe_config = make_dummy_moe_config()
    moe_config.is_lora_enabled = True
    selected_backend, experts_cls = select_unquantized_moe_backend(
        moe_config=moe_config
    )

    assert selected_backend == UnquantizedMoeBackend.TRITON
    assert experts_cls is not None


@skipif_not_cuda_rocm
def test_select_lora_explicit_non_triton_backend():
    """LoRA should override explicit non-Triton backend to Triton."""
    moe_config = make_dummy_moe_config()
    moe_config.is_lora_enabled = True

    # Use string from mapping in function map_unquantized_backend()
    moe_config.moe_backend = "flashinfer_cutlass"

    selected_backend, experts_cls = select_unquantized_moe_backend(
        moe_config=moe_config
    )

    assert selected_backend == UnquantizedMoeBackend.TRITON
    assert experts_cls is not None


@skipif_not_cuda_rocm
@pytest.mark.parametrize("is_lora_enabled", [False, True])
def test_select_explicit_triton_backend(is_lora_enabled):
    """Explicit triton backend selection should return Triton."""
    moe_config = make_dummy_moe_config()
    moe_config.is_lora_enabled = is_lora_enabled
    moe_config.moe_backend = "triton"

    selected_backend, experts_cls = select_unquantized_moe_backend(
        moe_config=moe_config
    )

    assert selected_backend == UnquantizedMoeBackend.TRITON
    assert experts_cls is not None


@skipif_not_cuda_rocm
def test_select_explicit_triton_ignores_flashinfer_env(monkeypatch):
    """Explicit triton backend should override FlashInfer env selection."""
    monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1")
    monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput")

    moe_config = make_dummy_moe_config()
    moe_config.is_lora_enabled = False
    moe_config.moe_backend = "triton"

    selected_backend, experts_cls = select_unquantized_moe_backend(
        moe_config=moe_config
    )

    assert selected_backend == UnquantizedMoeBackend.TRITON
    assert experts_cls is not None


@skipif_not_cuda_rocm
def test_select_lora_ignores_flashinfer_env(monkeypatch):
    """LoRA path should still choose Triton even if FlashInfer env is on."""
    monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1")
    monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput")

    moe_config = make_dummy_moe_config()
    moe_config.is_lora_enabled = True
    selected_backend, experts_cls = select_unquantized_moe_backend(
        moe_config=moe_config
    )

    assert selected_backend == UnquantizedMoeBackend.TRITON
    assert experts_cls is not None