test_modular_kernel_combinations.py 9.66 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import copy
5
6
import textwrap
import traceback
7
from itertools import product
8
from typing import Any
9
10
11
12
13

import pytest
import torch

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
14
15
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.platforms import current_platform
16
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
17
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx
18
from vllm.utils.torch_utils import cuda_device_count_stateless
19

20
21
22
23
24
25
26
from .modular_kernel_tools.common import (
    Config,
    RankTensors,
    WeightTensors,
    reference_moe_impl,
    run_modular_kernel,
)
27
from .modular_kernel_tools.mk_objects import (
28
29
30
31
32
33
34
35
36
37
38
    MK_FUSED_EXPERT_TYPES,
    MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
    MK_QUANT_CONFIGS,
    MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES,
    TestMoEQuantConfig,
    expert_info,
)
from .modular_kernel_tools.parallel_utils import (
    ProcessGroupInfo,
    parallel_launch_with_config,
)
39

40
41
42
has_any_multi_gpu_package = (
    has_deep_ep() or has_deep_gemm() or has_pplx() or has_flashinfer_cutlass_fused_moe()
)
43

44
45
46
meets_multi_gpu_requirements = pytest.mark.skipif(
    not has_any_multi_gpu_package,
    reason="Requires deep_ep or deep_gemm or pplx or flashinfer packages",
47
48
)

49
50
51
52
53
54
if current_platform.is_fp8_fnuz():
    pytest.skip(
        "Tests in this file require float8_e4m3fn and platform does not support",
        allow_module_level=True,
    )

55

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def format_result(verbose, msg, ex=None):
    if ex is not None:
        x = str(ex)
        newx = x.strip(" \n\t")[:16]
        if len(newx) < len(x):
            newx = newx + " ..."

        prefix = "E\t"
        print(f"{textwrap.indent(traceback.format_exc(), prefix)}")
        print(f"FAILED {msg} - {newx}\n")
    elif verbose:
        print(f"PASSED {msg}")
    else:
        print(".", end="")


72
73
74
75
def rank_worker(
    pgi: ProcessGroupInfo,
    vllm_config: VllmConfig,
    cpu_group,
76
    base_config: Config,
77
    weights: WeightTensors,
78
    verbose: bool,
79
80
81
82
83
):
    current_platform.seed_everything(pgi.rank)

    # sanity check
    from vllm import envs
84

85
    if base_config.fused_moe_chunk_size is not None:
86
        assert base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
87
88
89
90

    # get weights to this device
    weights.to_current_device()

91
    Ms = base_config.Ms
92
    assert isinstance(Ms, list)
93
    TOPKs = base_config.topks
94
95
    assert isinstance(TOPKs, list)

96
97
    exceptions = []
    count = 0
98

99
    for m, topk in product(Ms, TOPKs):
100
101
102
103
104
        # override m and topk
        config = copy.deepcopy(base_config)
        config.Ms = m
        config.topks = topk

105
106
107
108
109
        try:
            print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...")
            count = count + 1

            # inputs for rank
110
            rank_tensors = RankTensors.make(config, pgi)
111
112

            # modular kernel out
113
            mk_out = run_modular_kernel(pgi, vllm_config, config, weights, rank_tensors)
114
115

            with set_current_vllm_config(vllm_config):
116
                ref_out = reference_moe_impl(config, weights, rank_tensors)
117
118

            if config.quant_dtype == "nvfp4":
119
120
                atol = 1e-1 if config.K < 4096 else 2e-1
                rtol = 1e-1 if config.K < 4096 else 2e-1
121
122
123
124
125
126
127
128
129
130
131
132
133
            else:
                atol = 3e-2
                rtol = 3e-2

            torch.testing.assert_close(ref_out, mk_out, atol=atol, rtol=rtol)
            format_result(verbose, config.describe())
        except Exception as ex:
            format_result(verbose, config.describe(), ex)
            exceptions.append(ex)

    if len(exceptions) > 0:
        raise RuntimeError(
            f"{len(exceptions)} of {count} tests failed in child process, "
134
135
            f"rank={pgi.rank}."
        )
136
    else:
137
        print(f"{count} of {count} tests passed in child process, rank={pgi.rank}.")
138
139
140


def run(config: Config, verbose: bool):
141
142
    assert config.is_valid()[0]
    assert not is_nyi_config(config)
143
144
145
146

    weights: WeightTensors = WeightTensors.make(config)

    vllm_config, env_dict = config.make_env_data()
147
148
149
    parallel_launch_with_config(
        config.world_size, rank_worker, vllm_config, env_dict, config, weights, verbose
    )
150
151
152


Ms = [32, 64]
153
154
155
# hidden sizes, making this too large will cause fp4 tests to fail.
# Also needs to be a multiple of 1024 for deep_gemm.
Ks = [2048]
156
Ns = [1024]
157
158
159
160
161
162
163
164
TOPKs = [4, 1]
Es = [32]
DTYPEs = [torch.bfloat16]
FUSED_MOE_CHUNK_SIZEs = [None, 16]


def is_nyi_config(config: Config) -> bool:
    # We know these configs to be legitimate. but still fail.
165
    info = expert_info(config.fused_experts_type)
166

167
    if info.needs_matching_quant:
168
169
        # The triton kernels expect both per-act-token-quant and
        # per-out-ch-quant or neither.
170
171
172
        unsupported_quant_config = (
            config.is_per_act_token_quant + config.is_per_out_ch_quant
        ) == 1
173
174
        return unsupported_quant_config

175
    return not info.supports_expert_map
176
177


178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
def generate_valid_test_cases(
    world_size: int, prepare_finalize_types
) -> list[tuple[Any, ...]]:
    cases = []
    total = 0

    for k, n, e, dtype, quant_config, combination, chunk_size in product(
        Ks,
        Ns,
        Es,
        DTYPEs,
        MK_QUANT_CONFIGS,
        product(prepare_finalize_types, MK_FUSED_EXPERT_TYPES),
        FUSED_MOE_CHUNK_SIZEs,
    ):
        total = total + 1

        config = Config(
            Ms=Ms,
            K=k,
            N=n,
            E=e,
            topks=TOPKs,
            dtype=dtype,
            quant_config=quant_config,
            prepare_finalize_type=combination[0],
            fused_experts_type=combination[1],
            fused_moe_chunk_size=chunk_size,
            world_size=world_size,
        )

        # TODO(bnell): figure out how to get verbose flag here.
        verbose = False  # pytestconfig.getoption('verbose') > 0

        valid, reason = config.is_valid()

        if not valid:
            if verbose:
                print(f"Test config {config} is not valid: {reason}")
            continue

        if is_nyi_config(config):
            if verbose:
                print(f"Test config {config} is nyi.")
            continue

        cases.append(
            (
                k,
                n,
                e,
                dtype,
                quant_config,
                combination[0],
                combination[1],
                chunk_size,
                world_size,
            )
        )

    print(f"{len(cases)} of {total} valid configs generated.")

    return cases


243
@pytest.mark.parametrize(
244
245
246
247
    "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
    generate_valid_test_cases(
        world_size=2, prepare_finalize_types=MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
    ),
248
)
249
@meets_multi_gpu_requirements
250
def test_modular_kernel_combinations_multigpu(
251
252
253
254
    k: int,
    n: int,
    e: int,
    dtype: torch.dtype,
255
    quant_config: TestMoEQuantConfig | None,
256
257
    prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
    fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
258
    chunk_size: int | None,
259
260
261
    world_size: int,
    pytestconfig,
):
262
263
264
265
266
267
268
    if cuda_device_count_stateless() < world_size:
        pytest.skip(
            f"Not enough GPUs available to run, got "
            f"{cuda_device_count_stateless()} exepected "
            f"{world_size}."
        )

269
270
271
272
273
274
275
276
    config = Config(
        Ms=Ms,
        K=k,
        N=n,
        E=e,
        topks=TOPKs,
        dtype=dtype,
        quant_config=quant_config,
277
278
279
        prepare_finalize_type=prepare_finalize_type,
        fused_experts_type=fused_experts_type,
        fused_moe_chunk_size=chunk_size,
280
281
        world_size=world_size,
    )
282
    verbosity = pytestconfig.getoption("verbose")
283
    run(config, verbosity > 0)
284
285
286


@pytest.mark.parametrize(
287
288
289
290
    "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
    generate_valid_test_cases(
        world_size=1, prepare_finalize_types=MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
    ),
291
)
292
def test_modular_kernel_combinations_singlegpu(
293
294
295
296
    k: int,
    n: int,
    e: int,
    dtype: torch.dtype,
297
    quant_config: TestMoEQuantConfig | None,
298
299
    prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
    fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
300
    chunk_size: int | None,
301
302
303
    world_size: int,
    pytestconfig,
):
304
305
    """Note: float8_e4m3fn is not supported on CUDA architecture < 89,
    and those tests will be skipped on unsupported hardware."""
306
307
308
309
310
311
312
313
    config = Config(
        Ms=Ms,
        K=k,
        N=n,
        E=e,
        topks=TOPKs,
        dtype=dtype,
        quant_config=quant_config,
314
315
316
        prepare_finalize_type=prepare_finalize_type,
        fused_experts_type=fused_experts_type,
        fused_moe_chunk_size=chunk_size,
317
318
319
        world_size=world_size,
    )

320
321
322
323
324
325
    if (
        quant_config is not None and quant_config.quant_dtype == torch.float8_e4m3fn
    ) and not current_platform.has_device_capability(89):
        pytest.skip(
            "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
        )
326
    verbosity = pytestconfig.getoption("verbose")
327
    run(config, verbosity > 0)
328
329


330
if __name__ == "__main__":
331
    # Ability to test individual PrepareAndFinalize and FusedExperts combination
332
333
334
335
336
    from .modular_kernel_tools.cli_args import make_config, make_config_arg_parser

    parser = make_config_arg_parser(
        description=(
            "Run single prepare-finalize & fused-experts combination test"
337
            "Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations "
338
339
340
            "--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
        )
    )
341
342
343
    args = parser.parse_args()
    config = make_config(args)

344
    run(config, True)