test_modular_kernel_combinations.py 9.08 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import copy
5
6
import textwrap
import traceback
7
from itertools import product
8
from typing import Any
9
10
11
12
13

import pytest
import torch

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
14
15
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.platforms import current_platform
16
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
17
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
18
from vllm.utils.torch_utils import cuda_device_count_stateless
19

20
21
22
23
24
25
26
from .modular_kernel_tools.common import (
    Config,
    RankTensors,
    WeightTensors,
    reference_moe_impl,
    run_modular_kernel,
)
27
from .modular_kernel_tools.mk_objects import (
28
29
30
31
32
33
34
35
36
37
38
    MK_FUSED_EXPERT_TYPES,
    MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
    MK_QUANT_CONFIGS,
    MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES,
    TestMoEQuantConfig,
    expert_info,
)
from .modular_kernel_tools.parallel_utils import (
    ProcessGroupInfo,
    parallel_launch_with_config,
)
39

40
41
42
has_any_multi_gpu_package = (
    has_deep_ep() or has_deep_gemm() or has_pplx() or has_flashinfer_cutlass_fused_moe()
)
43

44
45
46
meets_multi_gpu_requirements = pytest.mark.skipif(
    not has_any_multi_gpu_package,
    reason="Requires deep_ep or deep_gemm or pplx or flashinfer packages",
47
48
49
)


50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def format_result(verbose, msg, ex=None):
    if ex is not None:
        x = str(ex)
        newx = x.strip(" \n\t")[:16]
        if len(newx) < len(x):
            newx = newx + " ..."

        prefix = "E\t"
        print(f"{textwrap.indent(traceback.format_exc(), prefix)}")
        print(f"FAILED {msg} - {newx}\n")
    elif verbose:
        print(f"PASSED {msg}")
    else:
        print(".", end="")


66
67
68
69
def rank_worker(
    pgi: ProcessGroupInfo,
    vllm_config: VllmConfig,
    cpu_group,
70
    base_config: Config,
71
    weights: WeightTensors,
72
    verbose: bool,
73
74
75
76
77
):
    current_platform.seed_everything(pgi.rank)

    # sanity check
    from vllm import envs
78

79
    if base_config.fused_moe_chunk_size is not None:
80
        assert base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
81
82
83
84

    # get weights to this device
    weights.to_current_device()

85
    Ms = base_config.Ms
86
    assert isinstance(Ms, list)
87
    TOPKs = base_config.topks
88
89
    assert isinstance(TOPKs, list)

90
91
    exceptions = []
    count = 0
92

93
    for m, topk in product(Ms, TOPKs):
94
95
96
97
98
        # override m and topk
        config = copy.deepcopy(base_config)
        config.Ms = m
        config.topks = topk

99
100
101
102
103
        try:
            print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...")
            count = count + 1

            # inputs for rank
104
            rank_tensors = RankTensors.make(config, pgi)
105
106

            # modular kernel out
107
            mk_out = run_modular_kernel(pgi, vllm_config, config, weights, rank_tensors)
108
109

            with set_current_vllm_config(vllm_config):
110
                ref_out = reference_moe_impl(config, weights, rank_tensors)
111
112

            if config.quant_dtype == "nvfp4":
113
114
                atol = 1e-1 if config.K < 4096 else 2e-1
                rtol = 1e-1 if config.K < 4096 else 2e-1
115
116
117
118
119
120
121
122
123
124
125
126
127
            else:
                atol = 3e-2
                rtol = 3e-2

            torch.testing.assert_close(ref_out, mk_out, atol=atol, rtol=rtol)
            format_result(verbose, config.describe())
        except Exception as ex:
            format_result(verbose, config.describe(), ex)
            exceptions.append(ex)

    if len(exceptions) > 0:
        raise RuntimeError(
            f"{len(exceptions)} of {count} tests failed in child process, "
128
129
            f"rank={pgi.rank}."
        )
130
    else:
131
        print(f"{count} of {count} tests passed in child process, rank={pgi.rank}.")
132
133
134


def run(config: Config, verbose: bool):
135
136
    assert config.is_valid()[0]
    assert not is_nyi_config(config)
137
138
139
140

    weights: WeightTensors = WeightTensors.make(config)

    vllm_config, env_dict = config.make_env_data()
141
142
143
    parallel_launch_with_config(
        config.world_size, rank_worker, vllm_config, env_dict, config, weights, verbose
    )
144
145
146


Ms = [32, 64]
147
148
149
# hidden sizes, making this too large will cause fp4 tests to fail.
# Also needs to be a multiple of 1024 for deep_gemm.
Ks = [2048]
150
Ns = [1024]
151
152
153
154
155
156
157
158
TOPKs = [4, 1]
Es = [32]
DTYPEs = [torch.bfloat16]
FUSED_MOE_CHUNK_SIZEs = [None, 16]


def is_nyi_config(config: Config) -> bool:
    # We know these configs to be legitimate. but still fail.
159
    info = expert_info(config.fused_experts_type)
160

161
    if info.needs_matching_quant:
162
163
        # The triton kernels expect both per-act-token-quant and
        # per-out-ch-quant or neither.
164
165
166
        unsupported_quant_config = (
            config.is_per_act_token_quant + config.is_per_out_ch_quant
        ) == 1
167
168
        return unsupported_quant_config

169
    return not info.supports_expert_map
170
171


172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def generate_valid_test_cases(
    world_size: int, prepare_finalize_types
) -> list[tuple[Any, ...]]:
    cases = []
    total = 0

    for k, n, e, dtype, quant_config, combination, chunk_size in product(
        Ks,
        Ns,
        Es,
        DTYPEs,
        MK_QUANT_CONFIGS,
        product(prepare_finalize_types, MK_FUSED_EXPERT_TYPES),
        FUSED_MOE_CHUNK_SIZEs,
    ):
        total = total + 1

        config = Config(
            Ms=Ms,
            K=k,
            N=n,
            E=e,
            topks=TOPKs,
            dtype=dtype,
            quant_config=quant_config,
            prepare_finalize_type=combination[0],
            fused_experts_type=combination[1],
            fused_moe_chunk_size=chunk_size,
            world_size=world_size,
        )

        # TODO(bnell): figure out how to get verbose flag here.
        verbose = False  # pytestconfig.getoption('verbose') > 0

        valid, reason = config.is_valid()

        if not valid:
            if verbose:
                print(f"Test config {config} is not valid: {reason}")
            continue

        if is_nyi_config(config):
            if verbose:
                print(f"Test config {config} is nyi.")
            continue

        cases.append(
            (
                k,
                n,
                e,
                dtype,
                quant_config,
                combination[0],
                combination[1],
                chunk_size,
                world_size,
            )
        )

    print(f"{len(cases)} of {total} valid configs generated.")

    return cases


237
@pytest.mark.parametrize(
238
239
240
241
    "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
    generate_valid_test_cases(
        world_size=2, prepare_finalize_types=MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
    ),
242
)
243
@meets_multi_gpu_requirements
244
def test_modular_kernel_combinations_multigpu(
245
246
247
248
    k: int,
    n: int,
    e: int,
    dtype: torch.dtype,
249
    quant_config: TestMoEQuantConfig | None,
250
251
    prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
    fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
252
    chunk_size: int | None,
253
254
255
    world_size: int,
    pytestconfig,
):
256
257
258
259
260
261
262
    if cuda_device_count_stateless() < world_size:
        pytest.skip(
            f"Not enough GPUs available to run, got "
            f"{cuda_device_count_stateless()} exepected "
            f"{world_size}."
        )

263
264
265
266
267
268
269
270
    config = Config(
        Ms=Ms,
        K=k,
        N=n,
        E=e,
        topks=TOPKs,
        dtype=dtype,
        quant_config=quant_config,
271
272
273
        prepare_finalize_type=prepare_finalize_type,
        fused_experts_type=fused_experts_type,
        fused_moe_chunk_size=chunk_size,
274
275
        world_size=world_size,
    )
276
    verbosity = pytestconfig.getoption("verbose")
277
    run(config, verbosity > 0)
278
279
280


@pytest.mark.parametrize(
281
282
283
284
    "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
    generate_valid_test_cases(
        world_size=1, prepare_finalize_types=MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
    ),
285
)
286
def test_modular_kernel_combinations_singlegpu(
287
288
289
290
    k: int,
    n: int,
    e: int,
    dtype: torch.dtype,
291
    quant_config: TestMoEQuantConfig | None,
292
293
    prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
    fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
294
    chunk_size: int | None,
295
296
297
    world_size: int,
    pytestconfig,
):
298
299
300
301
302
303
304
305
    config = Config(
        Ms=Ms,
        K=k,
        N=n,
        E=e,
        topks=TOPKs,
        dtype=dtype,
        quant_config=quant_config,
306
307
308
        prepare_finalize_type=prepare_finalize_type,
        fused_experts_type=fused_experts_type,
        fused_moe_chunk_size=chunk_size,
309
310
311
        world_size=world_size,
    )

312
    verbosity = pytestconfig.getoption("verbose")
313
    run(config, verbosity > 0)
314
315


316
if __name__ == "__main__":
317
    # Ability to test individual PrepareAndFinalize and FusedExperts combination
318
319
320
321
322
    from .modular_kernel_tools.cli_args import make_config, make_config_arg_parser

    parser = make_config_arg_parser(
        description=(
            "Run single prepare-finalize & fused-experts combination test"
323
            "Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations "
324
325
326
            "--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
        )
    )
327
328
329
    args = parser.parse_args()
    config = make_config(args)

330
    run(config, True)