Unverified Commit d1135a50 authored by danisereb's avatar danisereb Committed by GitHub
Browse files

Fix MoE backend selection for LoRA (unquantized MoE) (#40273)


Signed-off-by: default avatarDaniel Serebrenik <daserebrenik@nvidia.com>
parent 982beae8
......@@ -11,6 +11,11 @@ from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
)
from vllm.platforms import current_platform
skipif_not_cuda_rocm = pytest.mark.skipif(
not (current_platform.is_cuda() or current_platform.is_rocm()),
reason="Only supported on CUDA/ROCm platforms.",
)
@pytest.mark.parametrize(
"platform_method,expected_backend",
......@@ -190,3 +195,83 @@ def test_select_cuda_flashinfer_cutlass_backend(
assert selected_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS
assert experts_cls is not None
@skipif_not_cuda_rocm
def test_select_lora_backend_prefers_triton():
"""LoRA-enabled unquantized MoE should select Triton backend."""
moe_config = make_dummy_moe_config()
moe_config.is_lora_enabled = True
selected_backend, experts_cls = select_unquantized_moe_backend(
moe_config=moe_config
)
assert selected_backend == UnquantizedMoeBackend.TRITON
assert experts_cls is not None
@skipif_not_cuda_rocm
def test_select_lora_explicit_non_triton_backend():
"""LoRA should override explicit non-Triton backend to Triton."""
moe_config = make_dummy_moe_config()
moe_config.is_lora_enabled = True
# Use string from mapping in function map_unquantized_backend()
moe_config.moe_backend = "flashinfer_cutlass"
selected_backend, experts_cls = select_unquantized_moe_backend(
moe_config=moe_config
)
assert selected_backend == UnquantizedMoeBackend.TRITON
assert experts_cls is not None
@skipif_not_cuda_rocm
@pytest.mark.parametrize("is_lora_enabled", [False, True])
def test_select_explicit_triton_backend(is_lora_enabled):
"""Explicit triton backend selection should return Triton."""
moe_config = make_dummy_moe_config()
moe_config.is_lora_enabled = is_lora_enabled
moe_config.moe_backend = "triton"
selected_backend, experts_cls = select_unquantized_moe_backend(
moe_config=moe_config
)
assert selected_backend == UnquantizedMoeBackend.TRITON
assert experts_cls is not None
@skipif_not_cuda_rocm
def test_select_explicit_triton_ignores_flashinfer_env(monkeypatch):
"""Explicit triton backend should override FlashInfer env selection."""
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1")
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput")
moe_config = make_dummy_moe_config()
moe_config.is_lora_enabled = False
moe_config.moe_backend = "triton"
selected_backend, experts_cls = select_unquantized_moe_backend(
moe_config=moe_config
)
assert selected_backend == UnquantizedMoeBackend.TRITON
assert experts_cls is not None
@skipif_not_cuda_rocm
def test_select_lora_ignores_flashinfer_env(monkeypatch):
"""LoRA path should still choose Triton even if FlashInfer env is on."""
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1")
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput")
moe_config = make_dummy_moe_config()
moe_config.is_lora_enabled = True
selected_backend, experts_cls = select_unquantized_moe_backend(
moe_config=moe_config
)
assert selected_backend == UnquantizedMoeBackend.TRITON
assert experts_cls is not None
......@@ -163,6 +163,11 @@ def select_unquantized_moe_backend(
if current_platform.is_out_of_tree():
return UnquantizedMoeBackend.OOT, None
if moe_config.is_lora_enabled:
return UnquantizedMoeBackend.TRITON, backend_to_kernel_cls(
UnquantizedMoeBackend.TRITON
)
# NOTE: the kernels are selected in the following order.
AVAILABLE_BACKENDS = _get_priority_backends(moe_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment