test_modular_kernel_combinations.py 8.05 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import copy
5
6
import textwrap
import traceback
7
8
9
10
11
12
13
from itertools import product
from typing import Optional

import pytest
import torch

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
14
15
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.platforms import current_platform
16
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
17
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
18

19
from ...utils import multi_gpu_test
20
21
22
23
24
25
26
from .modular_kernel_tools.common import (
    Config,
    RankTensors,
    WeightTensors,
    reference_moe_impl,
    run_modular_kernel,
)
27
from .modular_kernel_tools.mk_objects import (
28
29
30
31
32
33
34
35
36
37
38
    MK_FUSED_EXPERT_TYPES,
    MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
    MK_QUANT_CONFIGS,
    MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES,
    TestMoEQuantConfig,
    expert_info,
)
from .modular_kernel_tools.parallel_utils import (
    ProcessGroupInfo,
    parallel_launch_with_config,
)
39

40
41
42
has_any_multi_gpu_package = (
    has_deep_ep() or has_deep_gemm() or has_pplx() or has_flashinfer_cutlass_fused_moe()
)
43

44
45
46
meets_multi_gpu_requirements = pytest.mark.skipif(
    not has_any_multi_gpu_package,
    reason="Requires deep_ep or deep_gemm or pplx or flashinfer packages",
47
48
49
)


50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def format_result(verbose, msg, ex=None):
    if ex is not None:
        x = str(ex)
        newx = x.strip(" \n\t")[:16]
        if len(newx) < len(x):
            newx = newx + " ..."

        prefix = "E\t"
        print(f"{textwrap.indent(traceback.format_exc(), prefix)}")
        print(f"FAILED {msg} - {newx}\n")
    elif verbose:
        print(f"PASSED {msg}")
    else:
        print(".", end="")


66
67
68
69
def rank_worker(
    pgi: ProcessGroupInfo,
    vllm_config: VllmConfig,
    cpu_group,
70
    base_config: Config,
71
    weights: WeightTensors,
72
    verbose: bool,
73
74
75
76
77
):
    current_platform.seed_everything(pgi.rank)

    # sanity check
    from vllm import envs
78

79
    if base_config.fused_moe_chunk_size is not None:
80
        assert base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
81
82
83
84

    # get weights to this device
    weights.to_current_device()

85
    Ms = base_config.Ms
86
    assert isinstance(Ms, list)
87
    TOPKs = base_config.topks
88
89
    assert isinstance(TOPKs, list)

90
91
    exceptions = []
    count = 0
92

93
    for m, topk in product(Ms, TOPKs):
94
95
96
97
98
        # override m and topk
        config = copy.deepcopy(base_config)
        config.Ms = m
        config.topks = topk

99
100
101
102
103
        try:
            print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...")
            count = count + 1

            # inputs for rank
104
            rank_tensors = RankTensors.make(config, pgi)
105
106

            # modular kernel out
107
            mk_out = run_modular_kernel(pgi, vllm_config, config, weights, rank_tensors)
108
109

            with set_current_vllm_config(vllm_config):
110
                ref_out = reference_moe_impl(config, weights, rank_tensors)
111
112

            if config.quant_dtype == "nvfp4":
113
114
                atol = 1e-1 if config.K < 4096 else 2e-1
                rtol = 1e-1 if config.K < 4096 else 2e-1
115
116
117
118
119
120
121
122
123
124
125
126
127
            else:
                atol = 3e-2
                rtol = 3e-2

            torch.testing.assert_close(ref_out, mk_out, atol=atol, rtol=rtol)
            format_result(verbose, config.describe())
        except Exception as ex:
            format_result(verbose, config.describe(), ex)
            exceptions.append(ex)

    if len(exceptions) > 0:
        raise RuntimeError(
            f"{len(exceptions)} of {count} tests failed in child process, "
128
129
            f"rank={pgi.rank}."
        )
130
    else:
131
        print(f"{count} of {count} tests passed in child process, rank={pgi.rank}.")
132
133
134


def run(config: Config, verbose: bool):
135
136
137
138
139
    assert config.is_valid()

    weights: WeightTensors = WeightTensors.make(config)

    vllm_config, env_dict = config.make_env_data()
140
141
142
    parallel_launch_with_config(
        config.world_size, rank_worker, vllm_config, env_dict, config, weights, verbose
    )
143
144
145


Ms = [32, 64]
146
147
148
# hidden sizes, making this too large will cause fp4 tests to fail.
# Also needs to be a multiple of 1024 for deep_gemm.
Ks = [2048]
149
Ns = [1024]
150
151
152
153
154
155
156
157
TOPKs = [4, 1]
Es = [32]
DTYPEs = [torch.bfloat16]
FUSED_MOE_CHUNK_SIZEs = [None, 16]


def is_nyi_config(config: Config) -> bool:
    # We know these configs to be legitimate. but still fail.
158
    info = expert_info(config.fused_experts_type)
159

160
    if info.needs_matching_quant:
161
162
        # The triton kernels expect both per-act-token-quant and
        # per-out-ch-quant or neither.
163
164
165
        unsupported_quant_config = (
            config.is_per_act_token_quant + config.is_per_out_ch_quant
        ) == 1
166
167
        return unsupported_quant_config

168
    return not info.supports_expert_map
169
170
171
172
173
174
175
176


@pytest.mark.parametrize("k", Ks)
@pytest.mark.parametrize("n", Ns)
@pytest.mark.parametrize("e", Es)
@pytest.mark.parametrize("dtype", DTYPEs)
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
@pytest.mark.parametrize(
177
178
    "combination", product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)
)
179
180
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [2])
181
@multi_gpu_test(num_gpus=2)
182
@meets_multi_gpu_requirements
183
def test_modular_kernel_combinations_multigpu(
184
185
186
187
188
189
190
191
192
193
194
195
    k: int,
    n: int,
    e: int,
    dtype: torch.dtype,
    quant_config: Optional[TestMoEQuantConfig],
    combination: tuple[
        mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute
    ],
    fused_moe_chunk_size: Optional[int],
    world_size: int,
    pytestconfig,
):
196
197
198
199
200
201
202
203
204
205
206
207
208
    config = Config(
        Ms=Ms,
        K=k,
        N=n,
        E=e,
        topks=TOPKs,
        dtype=dtype,
        quant_config=quant_config,
        prepare_finalize_type=combination[0],
        fused_experts_type=combination[1],
        fused_moe_chunk_size=fused_moe_chunk_size,
        world_size=world_size,
    )
209

210
211
212
213
214
215
    if not config.is_valid():
        pytest.skip(f"Tests config {config} is not valid. Skipping ...")

    if is_nyi_config(config):
        pytest.skip(f"Tests config {config} is nyi. Skipping ...")

216
    verbosity = pytestconfig.getoption("verbose")
217
    run(config, verbosity > 0)
218
219
220
221
222
223
224
225


@pytest.mark.parametrize("k", Ks)
@pytest.mark.parametrize("n", Ns)
@pytest.mark.parametrize("e", Es)
@pytest.mark.parametrize("dtype", DTYPEs)
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
@pytest.mark.parametrize(
226
227
    "combination", product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)
)
228
229
230
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [1])
def test_modular_kernel_combinations_singlegpu(
231
232
233
234
235
236
237
238
239
240
241
242
    k: int,
    n: int,
    e: int,
    dtype: torch.dtype,
    quant_config: Optional[TestMoEQuantConfig],
    combination: tuple[
        mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute
    ],
    fused_moe_chunk_size: Optional[int],
    world_size: int,
    pytestconfig,
):
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    config = Config(
        Ms=Ms,
        K=k,
        N=n,
        E=e,
        topks=TOPKs,
        dtype=dtype,
        quant_config=quant_config,
        prepare_finalize_type=combination[0],
        fused_experts_type=combination[1],
        fused_moe_chunk_size=fused_moe_chunk_size,
        world_size=world_size,
    )

    if not config.is_valid():
        pytest.skip(f"Tests config {config} is not valid. Skipping ...")

    if is_nyi_config(config):
        pytest.skip(f"Tests config {config} is nyi. Skipping ...")

263
    verbosity = pytestconfig.getoption("verbose")
264
    run(config, verbosity > 0)
265
266


267
if __name__ == "__main__":
268
    # Ability to test individual PrepareAndFinalize and FusedExperts combination
269
270
271
272
273
274
275
276
277
    from .modular_kernel_tools.cli_args import make_config, make_config_arg_parser

    parser = make_config_arg_parser(
        description=(
            "Run single prepare-finalize & fused-experts combination test"
            "Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations "  # noqa: E501
            "--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
        )
    )
278
279
280
    args = parser.parse_args()
    config = make_config(args)

281
    run(config, True)