test_modular_kernel_combinations.py 9.88 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import copy
5
6
import textwrap
import traceback
7
from itertools import product
8
from typing import Any
9
10
11
12
13

import pytest
import torch

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
14
15
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.platforms import current_platform
16
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
17
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx
18
from vllm.utils.torch_utils import cuda_device_count_stateless
19
from vllm.v1.worker.workspace import init_workspace_manager
20

21
22
23
24
25
26
27
from .modular_kernel_tools.common import (
    Config,
    RankTensors,
    WeightTensors,
    reference_moe_impl,
    run_modular_kernel,
)
28
from .modular_kernel_tools.mk_objects import (
29
30
31
32
33
34
35
36
37
38
39
    MK_FUSED_EXPERT_TYPES,
    MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
    MK_QUANT_CONFIGS,
    MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES,
    TestMoEQuantConfig,
    expert_info,
)
from .modular_kernel_tools.parallel_utils import (
    ProcessGroupInfo,
    parallel_launch_with_config,
)
40

41
42
43
has_any_multi_gpu_package = (
    has_deep_ep() or has_deep_gemm() or has_pplx() or has_flashinfer_cutlass_fused_moe()
)
44

45
46
47
meets_multi_gpu_requirements = pytest.mark.skipif(
    not has_any_multi_gpu_package,
    reason="Requires deep_ep or deep_gemm or pplx or flashinfer packages",
48
49
)

50
51
52
53
54
55
if current_platform.is_fp8_fnuz():
    pytest.skip(
        "Tests in this file require float8_e4m3fn and platform does not support",
        allow_module_level=True,
    )

56

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def format_result(verbose, msg, ex=None):
    if ex is not None:
        x = str(ex)
        newx = x.strip(" \n\t")[:16]
        if len(newx) < len(x):
            newx = newx + " ..."

        prefix = "E\t"
        print(f"{textwrap.indent(traceback.format_exc(), prefix)}")
        print(f"FAILED {msg} - {newx}\n")
    elif verbose:
        print(f"PASSED {msg}")
    else:
        print(".", end="")


73
74
75
76
def rank_worker(
    pgi: ProcessGroupInfo,
    vllm_config: VllmConfig,
    cpu_group,
77
    base_config: Config,
78
    weights: WeightTensors,
79
    verbose: bool,
80
):
81
82
83
84
    # Initialize workspace manager in child process
    device = torch.device(f"cuda:{pgi.local_rank}")
    init_workspace_manager(device)

85
86
87
88
    current_platform.seed_everything(pgi.rank)

    # sanity check
    from vllm import envs
89

90
    if base_config.fused_moe_chunk_size is not None:
91
        assert base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
92
93
94
95

    # get weights to this device
    weights.to_current_device()

96
    Ms = base_config.Ms
97
    assert isinstance(Ms, list)
98
    TOPKs = base_config.topks
99
100
    assert isinstance(TOPKs, list)

101
102
    exceptions = []
    count = 0
103

104
    for m, topk in product(Ms, TOPKs):
105
106
107
108
109
        # override m and topk
        config = copy.deepcopy(base_config)
        config.Ms = m
        config.topks = topk

110
111
112
113
114
        try:
            print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...")
            count = count + 1

            # inputs for rank
115
            rank_tensors = RankTensors.make(config, pgi)
116
117

            # modular kernel out
118
            mk_out = run_modular_kernel(pgi, vllm_config, config, weights, rank_tensors)
119
120

            with set_current_vllm_config(vllm_config):
121
                ref_out = reference_moe_impl(config, weights, rank_tensors)
122
123

            if config.quant_dtype == "nvfp4":
124
125
                atol = 1e-1 if config.K < 4096 else 2e-1
                rtol = 1e-1 if config.K < 4096 else 2e-1
126
127
128
129
130
131
132
133
134
135
136
137
138
            else:
                atol = 3e-2
                rtol = 3e-2

            torch.testing.assert_close(ref_out, mk_out, atol=atol, rtol=rtol)
            format_result(verbose, config.describe())
        except Exception as ex:
            format_result(verbose, config.describe(), ex)
            exceptions.append(ex)

    if len(exceptions) > 0:
        raise RuntimeError(
            f"{len(exceptions)} of {count} tests failed in child process, "
139
140
            f"rank={pgi.rank}."
        )
141
    else:
142
        print(f"{count} of {count} tests passed in child process, rank={pgi.rank}.")
143
144
145


def run(config: Config, verbose: bool):
146
147
    assert config.is_valid()[0]
    assert not is_nyi_config(config)
148
149
150
151

    weights: WeightTensors = WeightTensors.make(config)

    vllm_config, env_dict = config.make_env_data()
152
153
154
    parallel_launch_with_config(
        config.world_size, rank_worker, vllm_config, env_dict, config, weights, verbose
    )
155
156
157


Ms = [32, 64]
158
159
160
# hidden sizes, making this too large will cause fp4 tests to fail.
# Also needs to be a multiple of 1024 for deep_gemm.
Ks = [2048]
161
Ns = [1024]
162
163
164
165
166
167
168
169
TOPKs = [4, 1]
Es = [32]
DTYPEs = [torch.bfloat16]
FUSED_MOE_CHUNK_SIZEs = [None, 16]


def is_nyi_config(config: Config) -> bool:
    # We know these configs to be legitimate. but still fail.
170
    info = expert_info(config.fused_experts_type)
171

172
    if info.needs_matching_quant:
173
174
        # The triton kernels expect both per-act-token-quant and
        # per-out-ch-quant or neither.
175
176
177
        unsupported_quant_config = (
            config.is_per_act_token_quant + config.is_per_out_ch_quant
        ) == 1
178
179
        return unsupported_quant_config

180
    return not info.supports_expert_map
181
182


183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
def generate_valid_test_cases(
    world_size: int, prepare_finalize_types
) -> list[tuple[Any, ...]]:
    cases = []
    total = 0

    for k, n, e, dtype, quant_config, combination, chunk_size in product(
        Ks,
        Ns,
        Es,
        DTYPEs,
        MK_QUANT_CONFIGS,
        product(prepare_finalize_types, MK_FUSED_EXPERT_TYPES),
        FUSED_MOE_CHUNK_SIZEs,
    ):
        total = total + 1

        config = Config(
            Ms=Ms,
            K=k,
            N=n,
            E=e,
            topks=TOPKs,
            dtype=dtype,
            quant_config=quant_config,
            prepare_finalize_type=combination[0],
            fused_experts_type=combination[1],
            fused_moe_chunk_size=chunk_size,
            world_size=world_size,
        )

        # TODO(bnell): figure out how to get verbose flag here.
        verbose = False  # pytestconfig.getoption('verbose') > 0

        valid, reason = config.is_valid()

        if not valid:
            if verbose:
                print(f"Test config {config} is not valid: {reason}")
            continue

        if is_nyi_config(config):
            if verbose:
                print(f"Test config {config} is nyi.")
            continue

        cases.append(
            (
                k,
                n,
                e,
                dtype,
                quant_config,
                combination[0],
                combination[1],
                chunk_size,
                world_size,
            )
        )

    print(f"{len(cases)} of {total} valid configs generated.")

    return cases


248
@pytest.mark.parametrize(
249
250
251
252
    "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
    generate_valid_test_cases(
        world_size=2, prepare_finalize_types=MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
    ),
253
)
254
@meets_multi_gpu_requirements
255
def test_modular_kernel_combinations_multigpu(
256
257
258
259
    k: int,
    n: int,
    e: int,
    dtype: torch.dtype,
260
    quant_config: TestMoEQuantConfig | None,
261
262
    prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
    fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
263
    chunk_size: int | None,
264
265
266
    world_size: int,
    pytestconfig,
):
267
268
269
270
271
272
273
    if cuda_device_count_stateless() < world_size:
        pytest.skip(
            f"Not enough GPUs available to run, got "
            f"{cuda_device_count_stateless()} exepected "
            f"{world_size}."
        )

274
275
276
277
278
279
280
281
    config = Config(
        Ms=Ms,
        K=k,
        N=n,
        E=e,
        topks=TOPKs,
        dtype=dtype,
        quant_config=quant_config,
282
283
284
        prepare_finalize_type=prepare_finalize_type,
        fused_experts_type=fused_experts_type,
        fused_moe_chunk_size=chunk_size,
285
286
        world_size=world_size,
    )
287
    verbosity = pytestconfig.getoption("verbose")
288
    run(config, verbosity > 0)
289
290
291


@pytest.mark.parametrize(
292
293
294
295
    "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
    generate_valid_test_cases(
        world_size=1, prepare_finalize_types=MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
    ),
296
)
297
def test_modular_kernel_combinations_singlegpu(
298
299
300
301
    k: int,
    n: int,
    e: int,
    dtype: torch.dtype,
302
    quant_config: TestMoEQuantConfig | None,
303
304
    prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
    fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
305
    chunk_size: int | None,
306
307
    world_size: int,
    pytestconfig,
308
    workspace_init,
309
):
310
311
    """Note: float8_e4m3fn is not supported on CUDA architecture < 89,
    and those tests will be skipped on unsupported hardware."""
312
313
314
315
316
317
318
319
    config = Config(
        Ms=Ms,
        K=k,
        N=n,
        E=e,
        topks=TOPKs,
        dtype=dtype,
        quant_config=quant_config,
320
321
322
        prepare_finalize_type=prepare_finalize_type,
        fused_experts_type=fused_experts_type,
        fused_moe_chunk_size=chunk_size,
323
324
325
        world_size=world_size,
    )

326
327
328
329
330
331
    if (
        quant_config is not None and quant_config.quant_dtype == torch.float8_e4m3fn
    ) and not current_platform.has_device_capability(89):
        pytest.skip(
            "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
        )
332
    verbosity = pytestconfig.getoption("verbose")
333
    run(config, verbosity > 0)
334
335


336
if __name__ == "__main__":
337
    # Ability to test individual PrepareAndFinalize and FusedExperts combination
338
339
340
341
342
    from .modular_kernel_tools.cli_args import make_config, make_config_arg_parser

    parser = make_config_arg_parser(
        description=(
            "Run single prepare-finalize & fused-experts combination test"
343
            "Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations "
344
345
346
            "--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
        )
    )
347
348
349
    args = parser.parse_args()
    config = make_config(args)

350
    run(config, True)