test_mxfp4_triton_ep.py 7.06 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests that triton_kernel_moe_forward correctly applies expert_map
remapping when expert parallelism (EP) is enabled.

Previously, legacy_routing was always used and it produced routing data
with global expert IDs that didn't correspond to local weight indices,
causing illegal memory access with EP.  The fix splits routing: when
expert_map is provided, topk selection is performed first, expert_map is
applied to remap global→local IDs, and make_routing_data builds routing
structures from the local IDs.
"""

from unittest.mock import MagicMock, patch

import pytest
import torch

from vllm.model_executor.layers.quantization.mxfp4 import (
    Mxfp4Backend,
    Mxfp4MoEMethod,
)


def _make_mock_moe_config(ep_size: int = 1) -> MagicMock:
    """Create a mock FusedMoEConfig with the given EP size."""
    parallel_config = MagicMock()
    parallel_config.ep_size = ep_size

    moe_config = MagicMock()
    moe_config.ep_size = ep_size
    moe_config.is_lora_enabled = False
    moe_config.moe_parallel_config = parallel_config
    return moe_config


class TestMxfp4TritonIsMonolithic:
    """Verify that is_monolithic is always True for the TRITON backend,
    regardless of EP size, since triton_kernel_moe_forward now handles
    expert_map remapping internally."""

    @pytest.mark.parametrize(
        "backend,ep_size,expected_monolithic",
        [
            # TRITON is always monolithic (handles EP via expert_map remapping)
            (Mxfp4Backend.TRITON, 1, True),
            (Mxfp4Backend.TRITON, 2, True),
            (Mxfp4Backend.TRITON, 4, True),
            # SM100 backends are always monolithic
            (Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM, 1, True),
            (Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM, 2, True),
            (Mxfp4Backend.SM100_FI_MXFP4_BF16, 1, True),
            (Mxfp4Backend.SM100_FI_MXFP4_BF16, 2, True),
            # MARLIN is never monolithic
            (Mxfp4Backend.MARLIN, 1, False),
            (Mxfp4Backend.MARLIN, 2, False),
        ],
        ids=[
            "triton-no-ep",
            "triton-ep2",
            "triton-ep4",
            "sm100-trtllm-no-ep",
            "sm100-trtllm-ep2",
            "sm100-bf16-no-ep",
            "sm100-bf16-ep2",
            "marlin-no-ep",
            "marlin-ep2",
        ],
    )
    @patch(
        "vllm.model_executor.layers.quantization.mxfp4.get_mxfp4_backend",
    )
    @patch(
        "vllm.model_executor.layers.quantization.mxfp4.get_current_vllm_config",
    )
    def test_is_monolithic(
        self,
        mock_get_config,
        mock_get_backend,
        backend,
        ep_size,
        expected_monolithic,
    ):
        """is_monolithic should be True for TRITON regardless of EP size."""
        mock_get_backend.return_value = backend

        mock_compilation_config = MagicMock()
        mock_compilation_config.max_cudagraph_capture_size = 1024
        mock_vllm_config = MagicMock()
        mock_vllm_config.compilation_config = mock_compilation_config
        mock_get_config.return_value = mock_vllm_config

        moe_config = _make_mock_moe_config(ep_size=ep_size)
        method = Mxfp4MoEMethod(moe_config)

        assert method.is_monolithic == expected_monolithic, (
            f"Expected is_monolithic={expected_monolithic} for "
            f"backend={backend.name}, ep_size={ep_size}, "
            f"but got {method.is_monolithic}."
        )


class TestTritonMoeForwardExpertMap:
    """Test that triton_kernel_moe_forward applies expert_map remapping
    when expert_map is provided (EP active)."""

    @pytest.mark.parametrize("expert_map_present", [False, True])
    def test_routing_path_selection(self, expert_map_present):
        """Verify that the EP-aware routing path is taken when expert_map
        is present, and the legacy_routing path is taken otherwise."""

        device = "cuda" if torch.cuda.is_available() else "cpu"
        # This is a structural test: we mock the routing functions to
        # verify the correct path is exercised.
        mock_expert_map = (
            torch.tensor([0, -1, 1, -1], device=device) if expert_map_present else None
        )

        with (
            patch(
                "vllm.model_executor.layers.fused_moe."
                "gpt_oss_triton_kernels_moe.legacy_routing"
            ) as mock_legacy,
            patch("triton_kernels.topk.topk") as mock_topk,
            patch(
                "vllm.model_executor.layers.fused_moe."
                "gpt_oss_triton_kernels_moe.make_routing_data"
            ) as mock_make_routing,
            patch(
                "vllm.model_executor.layers.fused_moe."
                "gpt_oss_triton_kernels_moe.triton_kernel_fused_experts"
            ) as mock_fused_experts,
        ):
            from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (  # noqa: E501
                triton_kernel_moe_forward,
            )

            # Set up return values
            mock_routing_data = MagicMock()
            mock_gather = MagicMock()
            mock_scatter = MagicMock()

            if expert_map_present:
                sparse_result = MagicMock()
                sparse_result.indx = torch.tensor([[0, 2]], dtype=torch.int32)
                sparse_result.vals = torch.tensor([[0.6, 0.4]])
                mock_topk.return_value = sparse_result
                mock_make_routing.return_value = (
                    mock_routing_data,
                    mock_gather,
                    mock_scatter,
                )
            else:
                mock_legacy.return_value = (
                    mock_routing_data,
                    mock_gather,
                    mock_scatter,
                )

            mock_fused_experts.return_value = torch.zeros((1, 8), device=device)

            hidden = torch.randn((1, 8), device=device)
            w1 = torch.randn((2, 8, 16), device=device)
            w2 = torch.randn((2, 8, 8), device=device)
            logits = torch.randn((1, 4), device=device)

            triton_kernel_moe_forward(
                hidden_states=hidden,
                w1=w1,
                w2=w2,
                gating_output=logits,
                topk=2,
                renormalize=True,
                expert_map=mock_expert_map,
            )

            if expert_map_present:
                # EP path: should use topk + make_routing_data, NOT
                # legacy_routing
                mock_topk.assert_called_once()
                mock_make_routing.assert_called_once()
                mock_legacy.assert_not_called()
                # expert_map should be None in the fused_experts call
                # (already applied)
                call_kwargs = mock_fused_experts.call_args
                assert call_kwargs[1].get("expert_map") is None or (
                    len(call_kwargs[0]) > 0
                )
            else:
                # Non-EP path: should use legacy_routing
                mock_legacy.assert_called_once()
                mock_topk.assert_not_called()
                mock_make_routing.assert_not_called()