test_modular_kernel_combinations.py 9.33 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import copy
5
6
import textwrap
import traceback
7
from itertools import product
8
from typing import Any
9
10
11
12
13

import pytest
import torch

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
14
15
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.platforms import current_platform
16
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
17
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
18
from vllm.utils.torch_utils import cuda_device_count_stateless, set_random_seed
19
from vllm.v1.worker.workspace import init_workspace_manager
20

21
22
23
24
25
26
27
from .modular_kernel_tools.common import (
    Config,
    RankTensors,
    WeightTensors,
    reference_moe_impl,
    run_modular_kernel,
)
28
from .modular_kernel_tools.mk_objects import (
29
30
31
32
33
34
35
36
37
38
39
    MK_FUSED_EXPERT_TYPES,
    MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
    MK_QUANT_CONFIGS,
    MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES,
    TestMoEQuantConfig,
    expert_info,
)
from .modular_kernel_tools.parallel_utils import (
    ProcessGroupInfo,
    parallel_launch_with_config,
)
40

41
has_any_multi_gpu_package = (
42
    has_deep_ep() or has_deep_gemm() or has_flashinfer_cutlass_fused_moe()
43
)
44

45
46
meets_multi_gpu_requirements = pytest.mark.skipif(
    not has_any_multi_gpu_package,
47
    reason="Requires deep_ep or deep_gemm or flashinfer packages",
48
49
)

50
51
52
53
54
55
if current_platform.is_fp8_fnuz():
    pytest.skip(
        "Tests in this file require float8_e4m3fn and platform does not support",
        allow_module_level=True,
    )

56

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def format_result(verbose, msg, ex=None):
    if ex is not None:
        x = str(ex)
        newx = x.strip(" \n\t")[:16]
        if len(newx) < len(x):
            newx = newx + " ..."

        prefix = "E\t"
        print(f"{textwrap.indent(traceback.format_exc(), prefix)}")
        print(f"FAILED {msg} - {newx}\n")
    elif verbose:
        print(f"PASSED {msg}")
    else:
        print(".", end="")


73
74
75
76
def rank_worker(
    pgi: ProcessGroupInfo,
    vllm_config: VllmConfig,
    cpu_group,
77
    base_config: Config,
78
    weights: WeightTensors,
79
    verbose: bool,
80
):
81
82
83
84
    # Initialize workspace manager in child process
    device = torch.device(f"cuda:{pgi.local_rank}")
    init_workspace_manager(device)

85
    set_random_seed(pgi.rank)
86
87
88
89

    # get weights to this device
    weights.to_current_device()

90
    Ms = base_config.Ms
91
    assert isinstance(Ms, list)
92
    TOPKs = base_config.topks
93
94
    assert isinstance(TOPKs, list)

95
96
    exceptions = []
    count = 0
97

98
    for m, topk in product(Ms, TOPKs):
99
100
101
102
103
        # override m and topk
        config = copy.deepcopy(base_config)
        config.Ms = m
        config.topks = topk

104
105
106
107
108
        try:
            print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...")
            count = count + 1

            # inputs for rank
109
            rank_tensors = RankTensors.make(config, pgi)
110
111

            # modular kernel out
112
            mk_out = run_modular_kernel(pgi, vllm_config, config, weights, rank_tensors)
113
114

            with set_current_vllm_config(vllm_config):
115
                ref_out = reference_moe_impl(config, weights, rank_tensors)
116
117

            if config.quant_dtype == "nvfp4":
118
119
                atol = 1e-1 if config.K < 4096 else 2e-1
                rtol = 1e-1 if config.K < 4096 else 2e-1
120
121
122
123
124
125
126
127
128
129
130
131
132
            else:
                atol = 3e-2
                rtol = 3e-2

            torch.testing.assert_close(ref_out, mk_out, atol=atol, rtol=rtol)
            format_result(verbose, config.describe())
        except Exception as ex:
            format_result(verbose, config.describe(), ex)
            exceptions.append(ex)

    if len(exceptions) > 0:
        raise RuntimeError(
            f"{len(exceptions)} of {count} tests failed in child process, "
133
134
            f"rank={pgi.rank}."
        )
135
    else:
136
        print(f"{count} of {count} tests passed in child process, rank={pgi.rank}.")
137
138
139


def run(config: Config, verbose: bool):
140
141
    assert config.is_valid()[0]
    assert not is_nyi_config(config)
142
143
144
145

    weights: WeightTensors = WeightTensors.make(config)

    vllm_config, env_dict = config.make_env_data()
146
147
148
    parallel_launch_with_config(
        config.world_size, rank_worker, vllm_config, env_dict, config, weights, verbose
    )
149
150
151


Ms = [32, 64]
152
153
154
# hidden sizes, making this too large will cause fp4 tests to fail.
# Also needs to be a multiple of 1024 for deep_gemm.
Ks = [2048]
155
Ns = [1024]
156
157
158
159
160
161
162
TOPKs = [4, 1]
Es = [32]
DTYPEs = [torch.bfloat16]


def is_nyi_config(config: Config) -> bool:
    # We know these configs to be legitimate. but still fail.
163
164
    info = expert_info(config.fused_experts_type)
    if info.needs_matching_quant:
165
166
        # The triton kernels expect both per-act-token-quant and
        # per-out-ch-quant or neither.
167
168
169
        unsupported_quant_config = (
            config.is_per_act_token_quant + config.is_per_out_ch_quant
        ) == 1
170
171
        return unsupported_quant_config

172
    return not info.supports_expert_map
173
174


175
176
177
178
179
180
def generate_valid_test_cases(
    world_size: int, prepare_finalize_types
) -> list[tuple[Any, ...]]:
    cases = []
    total = 0

181
    for k, n, e, dtype, quant_config, combination in product(
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
        Ks,
        Ns,
        Es,
        DTYPEs,
        MK_QUANT_CONFIGS,
        product(prepare_finalize_types, MK_FUSED_EXPERT_TYPES),
    ):
        total = total + 1

        config = Config(
            Ms=Ms,
            K=k,
            N=n,
            E=e,
            topks=TOPKs,
            dtype=dtype,
            quant_config=quant_config,
            prepare_finalize_type=combination[0],
            fused_experts_type=combination[1],
            world_size=world_size,
        )

        # TODO(bnell): figure out how to get verbose flag here.
        verbose = False  # pytestconfig.getoption('verbose') > 0

        valid, reason = config.is_valid()

        if not valid:
            if verbose:
                print(f"Test config {config} is not valid: {reason}")
            continue

        if is_nyi_config(config):
            if verbose:
                print(f"Test config {config} is nyi.")
            continue

        cases.append(
            (
                k,
                n,
                e,
                dtype,
                quant_config,
                combination[0],
                combination[1],
                world_size,
            )
        )

    print(f"{len(cases)} of {total} valid configs generated.")

    return cases


237
@pytest.mark.parametrize(
238
    "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,world_size",
239
240
241
    generate_valid_test_cases(
        world_size=2, prepare_finalize_types=MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
    ),
242
)
243
@meets_multi_gpu_requirements
244
def test_modular_kernel_combinations_multigpu(
245
246
247
248
    k: int,
    n: int,
    e: int,
    dtype: torch.dtype,
249
    quant_config: TestMoEQuantConfig | None,
250
    prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
251
    fused_experts_type: mk.FusedMoEExperts,
252
253
254
    world_size: int,
    pytestconfig,
):
255
256
257
    if cuda_device_count_stateless() < world_size:
        pytest.skip(
            f"Not enough GPUs available to run, got "
Jiayi Yan's avatar
Jiayi Yan committed
258
            f"{cuda_device_count_stateless()} expected "
259
260
261
            f"{world_size}."
        )

262
263
264
265
266
267
268
269
    config = Config(
        Ms=Ms,
        K=k,
        N=n,
        E=e,
        topks=TOPKs,
        dtype=dtype,
        quant_config=quant_config,
270
271
        prepare_finalize_type=prepare_finalize_type,
        fused_experts_type=fused_experts_type,
272
273
        world_size=world_size,
    )
274
    verbosity = pytestconfig.getoption("verbose")
275
    run(config, verbosity > 0)
276
277
278


@pytest.mark.parametrize(
279
    "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,world_size",
280
281
282
    generate_valid_test_cases(
        world_size=1, prepare_finalize_types=MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
    ),
283
)
284
def test_modular_kernel_combinations_singlegpu(
285
286
287
288
    k: int,
    n: int,
    e: int,
    dtype: torch.dtype,
289
    quant_config: TestMoEQuantConfig | None,
290
    prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
291
    fused_experts_type: mk.FusedMoEExperts,
292
293
    world_size: int,
    pytestconfig,
294
    workspace_init,
295
):
296
297
    """Note: float8_e4m3fn is not supported on CUDA architecture < 89,
    and those tests will be skipped on unsupported hardware."""
298
299
300
301
302
303
304
305
    config = Config(
        Ms=Ms,
        K=k,
        N=n,
        E=e,
        topks=TOPKs,
        dtype=dtype,
        quant_config=quant_config,
306
307
        prepare_finalize_type=prepare_finalize_type,
        fused_experts_type=fused_experts_type,
308
309
310
        world_size=world_size,
    )

311
312
313
314
315
316
    if (
        quant_config is not None and quant_config.quant_dtype == torch.float8_e4m3fn
    ) and not current_platform.has_device_capability(89):
        pytest.skip(
            "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
        )
317
    verbosity = pytestconfig.getoption("verbose")
318
    run(config, verbosity > 0)
319
320


321
if __name__ == "__main__":
322
    # Ability to test individual PrepareAndFinalize and FusedExperts combination
323
324
325
326
327
    from .modular_kernel_tools.cli_args import make_config, make_config_arg_parser

    parser = make_config_arg_parser(
        description=(
            "Run single prepare-finalize & fused-experts combination test"
328
            "Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations "
329
            "--pf-type DeepEPLLPrepareAndFinalize --experts-type BatchedTritonExperts"
330
331
        )
    )
332
333
334
    args = parser.parse_args()
    config = make_config(args)

335
    run(config, True)