test_modular_kernel_combinations.py 9.82 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import copy
5
6
import textwrap
import traceback
7
from itertools import product
8
from typing import Any
9
10
11
12
13

import pytest
import torch

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
14
15
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.platforms import current_platform
16
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
17
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
18
from vllm.utils.torch_utils import cuda_device_count_stateless, set_random_seed
19
from vllm.v1.worker.workspace import init_workspace_manager
20

21
22
23
24
25
26
27
from .modular_kernel_tools.common import (
    Config,
    RankTensors,
    WeightTensors,
    reference_moe_impl,
    run_modular_kernel,
)
28
from .modular_kernel_tools.mk_objects import (
29
30
31
32
33
34
35
36
37
38
39
    MK_FUSED_EXPERT_TYPES,
    MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
    MK_QUANT_CONFIGS,
    MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES,
    TestMoEQuantConfig,
    expert_info,
)
from .modular_kernel_tools.parallel_utils import (
    ProcessGroupInfo,
    parallel_launch_with_config,
)
40

41
has_any_multi_gpu_package = (
42
    has_deep_ep() or has_deep_gemm() or has_flashinfer_cutlass_fused_moe()
43
)
44

45
46
meets_multi_gpu_requirements = pytest.mark.skipif(
    not has_any_multi_gpu_package,
47
    reason="Requires deep_ep or deep_gemm or flashinfer packages",
48
49
)

50
51
52
53
54
55
if current_platform.is_fp8_fnuz():
    pytest.skip(
        "Tests in this file require float8_e4m3fn and platform does not support",
        allow_module_level=True,
    )

56

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def format_result(verbose, msg, ex=None):
    if ex is not None:
        x = str(ex)
        newx = x.strip(" \n\t")[:16]
        if len(newx) < len(x):
            newx = newx + " ..."

        prefix = "E\t"
        print(f"{textwrap.indent(traceback.format_exc(), prefix)}")
        print(f"FAILED {msg} - {newx}\n")
    elif verbose:
        print(f"PASSED {msg}")
    else:
        print(".", end="")


73
74
75
76
def rank_worker(
    pgi: ProcessGroupInfo,
    vllm_config: VllmConfig,
    cpu_group,
77
    base_config: Config,
78
    weights: WeightTensors,
79
    verbose: bool,
80
):
81
82
83
84
    # Initialize workspace manager in child process
    device = torch.device(f"cuda:{pgi.local_rank}")
    init_workspace_manager(device)

85
    set_random_seed(pgi.rank)
86
87
88

    # sanity check
    from vllm import envs
89

90
    if base_config.fused_moe_chunk_size is not None:
91
        assert base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
92
93
94
95

    # get weights to this device
    weights.to_current_device()

96
    Ms = base_config.Ms
97
    assert isinstance(Ms, list)
98
    TOPKs = base_config.topks
99
100
    assert isinstance(TOPKs, list)

101
102
    exceptions = []
    count = 0
103

104
    for m, topk in product(Ms, TOPKs):
105
106
107
108
109
        # override m and topk
        config = copy.deepcopy(base_config)
        config.Ms = m
        config.topks = topk

110
111
112
113
114
        try:
            print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...")
            count = count + 1

            # inputs for rank
115
            rank_tensors = RankTensors.make(config, pgi)
116
117

            # modular kernel out
118
            mk_out = run_modular_kernel(pgi, vllm_config, config, weights, rank_tensors)
119
120

            with set_current_vllm_config(vllm_config):
121
                ref_out = reference_moe_impl(config, weights, rank_tensors)
122
123

            if config.quant_dtype == "nvfp4":
124
125
                atol = 1e-1 if config.K < 4096 else 2e-1
                rtol = 1e-1 if config.K < 4096 else 2e-1
126
127
128
129
130
131
132
133
134
135
136
137
138
            else:
                atol = 3e-2
                rtol = 3e-2

            torch.testing.assert_close(ref_out, mk_out, atol=atol, rtol=rtol)
            format_result(verbose, config.describe())
        except Exception as ex:
            format_result(verbose, config.describe(), ex)
            exceptions.append(ex)

    if len(exceptions) > 0:
        raise RuntimeError(
            f"{len(exceptions)} of {count} tests failed in child process, "
139
140
            f"rank={pgi.rank}."
        )
141
    else:
142
        print(f"{count} of {count} tests passed in child process, rank={pgi.rank}.")
143
144
145


def run(config: Config, verbose: bool):
146
147
    assert config.is_valid()[0]
    assert not is_nyi_config(config)
148
149
150
151

    weights: WeightTensors = WeightTensors.make(config)

    vllm_config, env_dict = config.make_env_data()
152
153
154
    parallel_launch_with_config(
        config.world_size, rank_worker, vllm_config, env_dict, config, weights, verbose
    )
155
156
157


Ms = [32, 64]
158
159
160
# hidden sizes, making this too large will cause fp4 tests to fail.
# Also needs to be a multiple of 1024 for deep_gemm.
Ks = [2048]
161
Ns = [1024]
162
163
164
TOPKs = [4, 1]
Es = [32]
DTYPEs = [torch.bfloat16]
Jiayi Yan's avatar
Jiayi Yan committed
165
FUSED_MOE_CHUNK_SIZES = [None, 16]
166
167
168
169


def is_nyi_config(config: Config) -> bool:
    # We know these configs to be legitimate. but still fail.
170
171
    info = expert_info(config.fused_experts_type)
    if info.needs_matching_quant:
172
173
        # The triton kernels expect both per-act-token-quant and
        # per-out-ch-quant or neither.
174
175
176
        unsupported_quant_config = (
            config.is_per_act_token_quant + config.is_per_out_ch_quant
        ) == 1
177
178
        return unsupported_quant_config

179
    return not info.supports_expert_map
180
181


182
183
184
185
186
187
188
189
190
191
192
193
194
def generate_valid_test_cases(
    world_size: int, prepare_finalize_types
) -> list[tuple[Any, ...]]:
    cases = []
    total = 0

    for k, n, e, dtype, quant_config, combination, chunk_size in product(
        Ks,
        Ns,
        Es,
        DTYPEs,
        MK_QUANT_CONFIGS,
        product(prepare_finalize_types, MK_FUSED_EXPERT_TYPES),
Jiayi Yan's avatar
Jiayi Yan committed
195
        FUSED_MOE_CHUNK_SIZES,
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
    ):
        total = total + 1

        config = Config(
            Ms=Ms,
            K=k,
            N=n,
            E=e,
            topks=TOPKs,
            dtype=dtype,
            quant_config=quant_config,
            prepare_finalize_type=combination[0],
            fused_experts_type=combination[1],
            fused_moe_chunk_size=chunk_size,
            world_size=world_size,
        )

        # TODO(bnell): figure out how to get verbose flag here.
        verbose = False  # pytestconfig.getoption('verbose') > 0

        valid, reason = config.is_valid()

        if not valid:
            if verbose:
                print(f"Test config {config} is not valid: {reason}")
            continue

        if is_nyi_config(config):
            if verbose:
                print(f"Test config {config} is nyi.")
            continue

        cases.append(
            (
                k,
                n,
                e,
                dtype,
                quant_config,
                combination[0],
                combination[1],
                chunk_size,
                world_size,
            )
        )

    print(f"{len(cases)} of {total} valid configs generated.")

    return cases


247
@pytest.mark.parametrize(
248
249
250
251
    "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
    generate_valid_test_cases(
        world_size=2, prepare_finalize_types=MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
    ),
252
)
253
@meets_multi_gpu_requirements
254
def test_modular_kernel_combinations_multigpu(
255
256
257
258
    k: int,
    n: int,
    e: int,
    dtype: torch.dtype,
259
    quant_config: TestMoEQuantConfig | None,
260
    prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
261
    fused_experts_type: mk.FusedMoEExperts,
262
    chunk_size: int | None,
263
264
265
    world_size: int,
    pytestconfig,
):
266
267
268
    if cuda_device_count_stateless() < world_size:
        pytest.skip(
            f"Not enough GPUs available to run, got "
Jiayi Yan's avatar
Jiayi Yan committed
269
            f"{cuda_device_count_stateless()} expected "
270
271
272
            f"{world_size}."
        )

273
274
275
276
277
278
279
280
    config = Config(
        Ms=Ms,
        K=k,
        N=n,
        E=e,
        topks=TOPKs,
        dtype=dtype,
        quant_config=quant_config,
281
282
283
        prepare_finalize_type=prepare_finalize_type,
        fused_experts_type=fused_experts_type,
        fused_moe_chunk_size=chunk_size,
284
285
        world_size=world_size,
    )
286
    verbosity = pytestconfig.getoption("verbose")
287
    run(config, verbosity > 0)
288
289
290


@pytest.mark.parametrize(
291
292
293
294
    "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
    generate_valid_test_cases(
        world_size=1, prepare_finalize_types=MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
    ),
295
)
296
def test_modular_kernel_combinations_singlegpu(
297
298
299
300
    k: int,
    n: int,
    e: int,
    dtype: torch.dtype,
301
    quant_config: TestMoEQuantConfig | None,
302
    prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
303
    fused_experts_type: mk.FusedMoEExperts,
304
    chunk_size: int | None,
305
306
    world_size: int,
    pytestconfig,
307
    workspace_init,
308
):
309
310
    """Note: float8_e4m3fn is not supported on CUDA architecture < 89,
    and those tests will be skipped on unsupported hardware."""
311
312
313
314
315
316
317
318
    config = Config(
        Ms=Ms,
        K=k,
        N=n,
        E=e,
        topks=TOPKs,
        dtype=dtype,
        quant_config=quant_config,
319
320
321
        prepare_finalize_type=prepare_finalize_type,
        fused_experts_type=fused_experts_type,
        fused_moe_chunk_size=chunk_size,
322
323
324
        world_size=world_size,
    )

325
326
327
328
329
330
    if (
        quant_config is not None and quant_config.quant_dtype == torch.float8_e4m3fn
    ) and not current_platform.has_device_capability(89):
        pytest.skip(
            "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
        )
331
    verbosity = pytestconfig.getoption("verbose")
332
    run(config, verbosity > 0)
333
334


335
if __name__ == "__main__":
336
    # Ability to test individual PrepareAndFinalize and FusedExperts combination
337
338
339
340
341
    from .modular_kernel_tools.cli_args import make_config, make_config_arg_parser

    parser = make_config_arg_parser(
        description=(
            "Run single prepare-finalize & fused-experts combination test"
342
            "Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations "
343
            "--pf-type DeepEPLLPrepareAndFinalize --experts-type BatchedTritonExperts"
344
345
        )
    )
346
347
348
    args = parser.parse_args()
    config = make_config(args)

349
    run(config, True)