test_modular_kernel_combinations.py 12 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import copy
5
6
import textwrap
import traceback
7
from itertools import product
8
from typing import Any
9
10
11
12
13

import pytest
import torch

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
14
15
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.platforms import current_platform
16
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
17
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
18
from vllm.utils.torch_utils import cuda_device_count_stateless, set_random_seed
19
from vllm.v1.worker.workspace import init_workspace_manager
20

21
22
23
24
25
26
27
from .modular_kernel_tools.common import (
    Config,
    RankTensors,
    WeightTensors,
    reference_moe_impl,
    run_modular_kernel,
)
28
from .modular_kernel_tools.mk_objects import (
29
30
31
32
33
34
35
36
37
38
39
    MK_FUSED_EXPERT_TYPES,
    MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
    MK_QUANT_CONFIGS,
    MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES,
    TestMoEQuantConfig,
    expert_info,
)
from .modular_kernel_tools.parallel_utils import (
    ProcessGroupInfo,
    parallel_launch_with_config,
)
40

41
has_any_multi_gpu_package = (
42
    has_deep_ep() or has_deep_gemm() or has_flashinfer_cutlass_fused_moe()
43
)
44

45
46
meets_multi_gpu_requirements = pytest.mark.skipif(
    not has_any_multi_gpu_package,
47
    reason="Requires deep_ep or deep_gemm or flashinfer packages",
48
49
)

50
51
52
53
54
55
if current_platform.is_fp8_fnuz():
    pytest.skip(
        "Tests in this file require float8_e4m3fn and platform does not support",
        allow_module_level=True,
    )

56

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def format_result(verbose, msg, ex=None):
    if ex is not None:
        x = str(ex)
        newx = x.strip(" \n\t")[:16]
        if len(newx) < len(x):
            newx = newx + " ..."

        prefix = "E\t"
        print(f"{textwrap.indent(traceback.format_exc(), prefix)}")
        print(f"FAILED {msg} - {newx}\n")
    elif verbose:
        print(f"PASSED {msg}")
    else:
        print(".", end="")


73
74
75
76
def rank_worker(
    pgi: ProcessGroupInfo,
    vllm_config: VllmConfig,
    cpu_group,
77
    base_config: Config,
78
    weights: WeightTensors,
79
    verbose: bool,
80
):
81
82
83
84
    # Initialize workspace manager in child process
    device = torch.device(f"cuda:{pgi.local_rank}")
    init_workspace_manager(device)

85
    set_random_seed(pgi.rank)
86
87
88
89

    # get weights to this device
    weights.to_current_device()

90
    Ms = base_config.Ms
91
    assert isinstance(Ms, list)
92
    TOPKs = base_config.topks
93
94
    assert isinstance(TOPKs, list)

95
96
    exceptions = []
    count = 0
97

98
    for m, topk in product(Ms, TOPKs):
99
100
101
102
103
        # override m and topk
        config = copy.deepcopy(base_config)
        config.Ms = m
        config.topks = topk

104
105
106
107
108
        try:
            print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...")
            count = count + 1

            # inputs for rank
109
            rank_tensors = RankTensors.make(config, pgi)
110

111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
            # Skip unsupported: AITER block-scaled MoE does not
            # support apply_router_weight_on_input (topk=1 path).
            # https://github.com/ROCm/aiter/issues/2418
            if (
                topk == 1
                and config.supports_apply_weight_on_input()
                and getattr(config.fused_experts_type, "__name__", "") == "AiterExperts"
                and config.quant_block_shape is not None
            ):
                print(
                    f"Skipping[{pgi.rank}]: m={m}, topk={topk}"
                    " (AITER block-scaled + weight-on-input,"
                    " https://github.com/ROCm/aiter/issues/2418)"
                )
                count -= 1
                continue

128
            # modular kernel out
129
            mk_out = run_modular_kernel(pgi, vllm_config, config, weights, rank_tensors)
130
131

            with set_current_vllm_config(vllm_config):
132
                ref_out = reference_moe_impl(config, weights, rank_tensors)
133
134

            if config.quant_dtype == "nvfp4":
135
136
                atol = 1e-1 if config.K < 4096 else 2e-1
                rtol = 1e-1 if config.K < 4096 else 2e-1
137
138
139
140
            else:
                atol = 3e-2
                rtol = 3e-2

141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
            # On ROCm, AITER FP8 fused MoE uses hardware FP8
            # dot-product which can produce slightly larger error
            # than dequant+f32 matmul at FP8 representable-value
            # boundaries. Allow a small percentage of elements to
            # exceed the base tolerance by a bounded margin.
            # https://github.com/ROCm/aiter/issues/2421
            from vllm.platforms import current_platform as _cp

            is_aiter_fp8 = (
                _cp.is_rocm()
                and getattr(config.fused_experts_type, "__name__", "") == "AiterExperts"
                and config.quant_config is not None
            )
            if is_aiter_fp8:
                diff = (ref_out - mk_out).abs()
                n_total = diff.numel()
                max_diff = diff.max().item()
                n_exceed = int((diff > atol).sum().item())
                pct_exceed = n_exceed / n_total * 100
                # FP8 hw matmul vs f32 reference: up to ~4% of
                # elements may exceed base tolerance, but max
                # error should stay within 3x base tolerance.
                max_pct_allowed = 5.0
                relaxed_atol = atol * 4
                print(
                    f"[AITER FP8 precision] "
                    f"max_diff={max_diff:.6f}, "
                    f"exceed_atol={n_exceed}/{n_total} "
                    f"({pct_exceed:.4f}%), "
                    f"max_pct_allowed={max_pct_allowed}%, "
                    f"relaxed_limit={relaxed_atol}"
                )
                assert pct_exceed <= max_pct_allowed, (
                    f"AITER FP8: {pct_exceed:.2f}% elements exceed "
                    f"atol={atol} (max allowed {max_pct_allowed}%)"
                )
                assert max_diff <= relaxed_atol, (
                    f"AITER FP8: max_diff={max_diff:.6f} exceeds "
                    f"relaxed limit {relaxed_atol}"
                )
            else:
                torch.testing.assert_close(ref_out, mk_out, atol=atol, rtol=rtol)
183
184
185
186
187
188
189
190
            format_result(verbose, config.describe())
        except Exception as ex:
            format_result(verbose, config.describe(), ex)
            exceptions.append(ex)

    if len(exceptions) > 0:
        raise RuntimeError(
            f"{len(exceptions)} of {count} tests failed in child process, "
191
192
            f"rank={pgi.rank}."
        )
193
    else:
194
        print(f"{count} of {count} tests passed in child process, rank={pgi.rank}.")
195
196
197


def run(config: Config, verbose: bool):
198
199
    assert config.is_valid()[0]
    assert not is_nyi_config(config)
200
201
202
203

    weights: WeightTensors = WeightTensors.make(config)

    vllm_config, env_dict = config.make_env_data()
204
205
206
    parallel_launch_with_config(
        config.world_size, rank_worker, vllm_config, env_dict, config, weights, verbose
    )
207
208
209


Ms = [32, 64]
210
211
212
# hidden sizes, making this too large will cause fp4 tests to fail.
# Also needs to be a multiple of 1024 for deep_gemm.
Ks = [2048]
213
Ns = [1024]
214
215
216
217
218
219
220
TOPKs = [4, 1]
Es = [32]
DTYPEs = [torch.bfloat16]


def is_nyi_config(config: Config) -> bool:
    # We know these configs to be legitimate. but still fail.
221
222
    info = expert_info(config.fused_experts_type)
    if info.needs_matching_quant:
223
224
        # The triton kernels expect both per-act-token-quant and
        # per-out-ch-quant or neither.
225
226
227
        unsupported_quant_config = (
            config.is_per_act_token_quant + config.is_per_out_ch_quant
        ) == 1
228
229
        return unsupported_quant_config

230
    return not info.supports_expert_map
231
232


233
234
235
236
237
238
def generate_valid_test_cases(
    world_size: int, prepare_finalize_types
) -> list[tuple[Any, ...]]:
    cases = []
    total = 0

239
    for k, n, e, dtype, quant_config, combination in product(
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
        Ks,
        Ns,
        Es,
        DTYPEs,
        MK_QUANT_CONFIGS,
        product(prepare_finalize_types, MK_FUSED_EXPERT_TYPES),
    ):
        total = total + 1

        config = Config(
            Ms=Ms,
            K=k,
            N=n,
            E=e,
            topks=TOPKs,
            dtype=dtype,
            quant_config=quant_config,
            prepare_finalize_type=combination[0],
            fused_experts_type=combination[1],
            world_size=world_size,
        )

        # TODO(bnell): figure out how to get verbose flag here.
        verbose = False  # pytestconfig.getoption('verbose') > 0

        valid, reason = config.is_valid()

        if not valid:
            if verbose:
                print(f"Test config {config} is not valid: {reason}")
            continue

        if is_nyi_config(config):
            if verbose:
                print(f"Test config {config} is nyi.")
            continue

        cases.append(
            (
                k,
                n,
                e,
                dtype,
                quant_config,
                combination[0],
                combination[1],
                world_size,
            )
        )

    print(f"{len(cases)} of {total} valid configs generated.")

    return cases


295
@pytest.mark.parametrize(
296
    "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,world_size",
297
298
299
    generate_valid_test_cases(
        world_size=2, prepare_finalize_types=MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
    ),
300
)
301
@meets_multi_gpu_requirements
302
def test_modular_kernel_combinations_multigpu(
303
304
305
306
    k: int,
    n: int,
    e: int,
    dtype: torch.dtype,
307
    quant_config: TestMoEQuantConfig | None,
308
    prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
309
    fused_experts_type: mk.FusedMoEExperts,
310
311
312
    world_size: int,
    pytestconfig,
):
313
314
315
    if cuda_device_count_stateless() < world_size:
        pytest.skip(
            f"Not enough GPUs available to run, got "
Jiayi Yan's avatar
Jiayi Yan committed
316
            f"{cuda_device_count_stateless()} expected "
317
318
319
            f"{world_size}."
        )

320
321
322
323
324
325
326
327
    config = Config(
        Ms=Ms,
        K=k,
        N=n,
        E=e,
        topks=TOPKs,
        dtype=dtype,
        quant_config=quant_config,
328
329
        prepare_finalize_type=prepare_finalize_type,
        fused_experts_type=fused_experts_type,
330
331
        world_size=world_size,
    )
332
    verbosity = pytestconfig.getoption("verbose")
333
    run(config, verbosity > 0)
334
335
336


@pytest.mark.parametrize(
337
    "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,world_size",
338
339
340
    generate_valid_test_cases(
        world_size=1, prepare_finalize_types=MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
    ),
341
)
342
def test_modular_kernel_combinations_singlegpu(
343
344
345
346
    k: int,
    n: int,
    e: int,
    dtype: torch.dtype,
347
    quant_config: TestMoEQuantConfig | None,
348
    prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
349
    fused_experts_type: mk.FusedMoEExperts,
350
351
    world_size: int,
    pytestconfig,
352
    workspace_init,
353
):
354
355
    """Note: float8_e4m3fn is not supported on CUDA architecture < 89,
    and those tests will be skipped on unsupported hardware."""
356
357
358
359
360
361
362
363
    config = Config(
        Ms=Ms,
        K=k,
        N=n,
        E=e,
        topks=TOPKs,
        dtype=dtype,
        quant_config=quant_config,
364
365
        prepare_finalize_type=prepare_finalize_type,
        fused_experts_type=fused_experts_type,
366
367
368
        world_size=world_size,
    )

369
370
371
372
373
374
    if (
        quant_config is not None and quant_config.quant_dtype == torch.float8_e4m3fn
    ) and not current_platform.has_device_capability(89):
        pytest.skip(
            "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
        )
375
    verbosity = pytestconfig.getoption("verbose")
376
    run(config, verbosity > 0)
377
378


379
if __name__ == "__main__":
380
    # Ability to test individual PrepareAndFinalize and FusedExperts combination
381
382
383
384
385
    from .modular_kernel_tools.cli_args import make_config, make_config_arg_parser

    parser = make_config_arg_parser(
        description=(
            "Run single prepare-finalize & fused-experts combination test"
386
            "Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations "
387
            "--pf-type DeepEPLLPrepareAndFinalize --experts-type BatchedTritonExperts"
388
389
        )
    )
390
391
392
    args = parser.parse_args()
    config = make_config(args)

393
    run(config, True)