test_modular_kernel_combinations.py 9.05 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import copy
5
6
import textwrap
import traceback
7
from itertools import product
8
from typing import Any
9
10
11
12
13

import pytest
import torch

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
14
15
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.platforms import current_platform
16
from vllm.utils import cuda_device_count_stateless, has_deep_ep, has_deep_gemm, has_pplx
17
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
18

19
20
21
22
23
24
25
from .modular_kernel_tools.common import (
    Config,
    RankTensors,
    WeightTensors,
    reference_moe_impl,
    run_modular_kernel,
)
26
from .modular_kernel_tools.mk_objects import (
27
28
29
30
31
32
33
34
35
36
37
    MK_FUSED_EXPERT_TYPES,
    MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
    MK_QUANT_CONFIGS,
    MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES,
    TestMoEQuantConfig,
    expert_info,
)
from .modular_kernel_tools.parallel_utils import (
    ProcessGroupInfo,
    parallel_launch_with_config,
)
38

39
40
41
has_any_multi_gpu_package = (
    has_deep_ep() or has_deep_gemm() or has_pplx() or has_flashinfer_cutlass_fused_moe()
)
42

43
44
45
meets_multi_gpu_requirements = pytest.mark.skipif(
    not has_any_multi_gpu_package,
    reason="Requires deep_ep or deep_gemm or pplx or flashinfer packages",
46
47
48
)


49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def format_result(verbose, msg, ex=None):
    if ex is not None:
        x = str(ex)
        newx = x.strip(" \n\t")[:16]
        if len(newx) < len(x):
            newx = newx + " ..."

        prefix = "E\t"
        print(f"{textwrap.indent(traceback.format_exc(), prefix)}")
        print(f"FAILED {msg} - {newx}\n")
    elif verbose:
        print(f"PASSED {msg}")
    else:
        print(".", end="")


65
66
67
68
def rank_worker(
    pgi: ProcessGroupInfo,
    vllm_config: VllmConfig,
    cpu_group,
69
    base_config: Config,
70
    weights: WeightTensors,
71
    verbose: bool,
72
73
74
75
76
):
    current_platform.seed_everything(pgi.rank)

    # sanity check
    from vllm import envs
77

78
    if base_config.fused_moe_chunk_size is not None:
79
        assert base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
80
81
82
83

    # get weights to this device
    weights.to_current_device()

84
    Ms = base_config.Ms
85
    assert isinstance(Ms, list)
86
    TOPKs = base_config.topks
87
88
    assert isinstance(TOPKs, list)

89
90
    exceptions = []
    count = 0
91

92
    for m, topk in product(Ms, TOPKs):
93
94
95
96
97
        # override m and topk
        config = copy.deepcopy(base_config)
        config.Ms = m
        config.topks = topk

98
99
100
101
102
        try:
            print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...")
            count = count + 1

            # inputs for rank
103
            rank_tensors = RankTensors.make(config, pgi)
104
105

            # modular kernel out
106
            mk_out = run_modular_kernel(pgi, vllm_config, config, weights, rank_tensors)
107
108

            with set_current_vllm_config(vllm_config):
109
                ref_out = reference_moe_impl(config, weights, rank_tensors)
110
111

            if config.quant_dtype == "nvfp4":
112
113
                atol = 1e-1 if config.K < 4096 else 2e-1
                rtol = 1e-1 if config.K < 4096 else 2e-1
114
115
116
117
118
119
120
121
122
123
124
125
126
            else:
                atol = 3e-2
                rtol = 3e-2

            torch.testing.assert_close(ref_out, mk_out, atol=atol, rtol=rtol)
            format_result(verbose, config.describe())
        except Exception as ex:
            format_result(verbose, config.describe(), ex)
            exceptions.append(ex)

    if len(exceptions) > 0:
        raise RuntimeError(
            f"{len(exceptions)} of {count} tests failed in child process, "
127
128
            f"rank={pgi.rank}."
        )
129
    else:
130
        print(f"{count} of {count} tests passed in child process, rank={pgi.rank}.")
131
132
133


def run(config: Config, verbose: bool):
134
135
    assert config.is_valid()[0]
    assert not is_nyi_config(config)
136
137
138
139

    weights: WeightTensors = WeightTensors.make(config)

    vllm_config, env_dict = config.make_env_data()
140
141
142
    parallel_launch_with_config(
        config.world_size, rank_worker, vllm_config, env_dict, config, weights, verbose
    )
143
144
145


Ms = [32, 64]
146
147
148
# hidden sizes, making this too large will cause fp4 tests to fail.
# Also needs to be a multiple of 1024 for deep_gemm.
Ks = [2048]
149
Ns = [1024]
150
151
152
153
154
155
156
157
TOPKs = [4, 1]
Es = [32]
DTYPEs = [torch.bfloat16]
FUSED_MOE_CHUNK_SIZEs = [None, 16]


def is_nyi_config(config: Config) -> bool:
    # We know these configs to be legitimate. but still fail.
158
    info = expert_info(config.fused_experts_type)
159

160
    if info.needs_matching_quant:
161
162
        # The triton kernels expect both per-act-token-quant and
        # per-out-ch-quant or neither.
163
164
165
        unsupported_quant_config = (
            config.is_per_act_token_quant + config.is_per_out_ch_quant
        ) == 1
166
167
        return unsupported_quant_config

168
    return not info.supports_expert_map
169
170


171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def generate_valid_test_cases(
    world_size: int, prepare_finalize_types
) -> list[tuple[Any, ...]]:
    cases = []
    total = 0

    for k, n, e, dtype, quant_config, combination, chunk_size in product(
        Ks,
        Ns,
        Es,
        DTYPEs,
        MK_QUANT_CONFIGS,
        product(prepare_finalize_types, MK_FUSED_EXPERT_TYPES),
        FUSED_MOE_CHUNK_SIZEs,
    ):
        total = total + 1

        config = Config(
            Ms=Ms,
            K=k,
            N=n,
            E=e,
            topks=TOPKs,
            dtype=dtype,
            quant_config=quant_config,
            prepare_finalize_type=combination[0],
            fused_experts_type=combination[1],
            fused_moe_chunk_size=chunk_size,
            world_size=world_size,
        )

        # TODO(bnell): figure out how to get verbose flag here.
        verbose = False  # pytestconfig.getoption('verbose') > 0

        valid, reason = config.is_valid()

        if not valid:
            if verbose:
                print(f"Test config {config} is not valid: {reason}")
            continue

        if is_nyi_config(config):
            if verbose:
                print(f"Test config {config} is nyi.")
            continue

        cases.append(
            (
                k,
                n,
                e,
                dtype,
                quant_config,
                combination[0],
                combination[1],
                chunk_size,
                world_size,
            )
        )

    print(f"{len(cases)} of {total} valid configs generated.")

    return cases


236
@pytest.mark.parametrize(
237
238
239
240
    "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
    generate_valid_test_cases(
        world_size=2, prepare_finalize_types=MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
    ),
241
)
242
@meets_multi_gpu_requirements
243
def test_modular_kernel_combinations_multigpu(
244
245
246
247
    k: int,
    n: int,
    e: int,
    dtype: torch.dtype,
248
    quant_config: TestMoEQuantConfig | None,
249
250
    prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
    fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
251
    chunk_size: int | None,
252
253
254
    world_size: int,
    pytestconfig,
):
255
256
257
258
259
260
261
    if cuda_device_count_stateless() < world_size:
        pytest.skip(
            f"Not enough GPUs available to run, got "
            f"{cuda_device_count_stateless()} exepected "
            f"{world_size}."
        )

262
263
264
265
266
267
268
269
    config = Config(
        Ms=Ms,
        K=k,
        N=n,
        E=e,
        topks=TOPKs,
        dtype=dtype,
        quant_config=quant_config,
270
271
272
        prepare_finalize_type=prepare_finalize_type,
        fused_experts_type=fused_experts_type,
        fused_moe_chunk_size=chunk_size,
273
274
        world_size=world_size,
    )
275
    verbosity = pytestconfig.getoption("verbose")
276
    run(config, verbosity > 0)
277
278
279


@pytest.mark.parametrize(
280
281
282
283
    "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
    generate_valid_test_cases(
        world_size=1, prepare_finalize_types=MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
    ),
284
)
285
def test_modular_kernel_combinations_singlegpu(
286
287
288
289
    k: int,
    n: int,
    e: int,
    dtype: torch.dtype,
290
    quant_config: TestMoEQuantConfig | None,
291
292
    prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
    fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
293
    chunk_size: int | None,
294
295
296
    world_size: int,
    pytestconfig,
):
297
298
299
300
301
302
303
304
    config = Config(
        Ms=Ms,
        K=k,
        N=n,
        E=e,
        topks=TOPKs,
        dtype=dtype,
        quant_config=quant_config,
305
306
307
        prepare_finalize_type=prepare_finalize_type,
        fused_experts_type=fused_experts_type,
        fused_moe_chunk_size=chunk_size,
308
309
310
        world_size=world_size,
    )

311
    verbosity = pytestconfig.getoption("verbose")
312
    run(config, verbosity > 0)
313
314


315
if __name__ == "__main__":
316
    # Ability to test individual PrepareAndFinalize and FusedExperts combination
317
318
319
320
321
    from .modular_kernel_tools.cli_args import make_config, make_config_arg_parser

    parser = make_config_arg_parser(
        description=(
            "Run single prepare-finalize & fused-experts combination test"
322
            "Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations "
323
324
325
            "--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
        )
    )
326
327
328
    args = parser.parse_args()
    config = make_config(args)

329
    run(config, True)