test_modular_kernel_combinations.py 8.29 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import copy
5
6
import textwrap
import traceback
7
8
9
10
11
12
13
from itertools import product
from typing import Optional

import pytest
import torch

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
14
15
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.platforms import current_platform
16
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
17
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
18

19
from ...utils import multi_gpu_test
20
21
22
23
24
from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors,
                                          reference_moe_impl,
                                          run_modular_kernel)
from .modular_kernel_tools.mk_objects import (
    MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
25
26
    MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, TestMoEQuantConfig,
    expert_info)
27
28
29
from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo,
                                                  parallel_launch_with_config)

30
31
has_any_multi_gpu_package = (has_deep_ep() or has_deep_gemm() or has_pplx()
                             or has_flashinfer_cutlass_fused_moe())
32

33
34
35
meets_multi_gpu_requirements = pytest.mark.skipif(
    not has_any_multi_gpu_package,
    reason="Requires deep_ep or deep_gemm or pplx or flashinfer packages",
36
37
38
)


39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def format_result(verbose, msg, ex=None):
    if ex is not None:
        x = str(ex)
        newx = x.strip(" \n\t")[:16]
        if len(newx) < len(x):
            newx = newx + " ..."

        prefix = "E\t"
        print(f"{textwrap.indent(traceback.format_exc(), prefix)}")
        print(f"FAILED {msg} - {newx}\n")
    elif verbose:
        print(f"PASSED {msg}")
    else:
        print(".", end="")


55
56
57
58
def rank_worker(
    pgi: ProcessGroupInfo,
    vllm_config: VllmConfig,
    cpu_group,
59
    base_config: Config,
60
    weights: WeightTensors,
61
    verbose: bool,
62
63
64
65
66
):
    current_platform.seed_everything(pgi.rank)

    # sanity check
    from vllm import envs
67
68
69
    if base_config.fused_moe_chunk_size is not None:
        assert (
            base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
70
71
72
73

    # get weights to this device
    weights.to_current_device()

74
    Ms = base_config.Ms
75
    assert isinstance(Ms, list)
76
    TOPKs = base_config.topks
77
78
    assert isinstance(TOPKs, list)

79
80
    exceptions = []
    count = 0
81

82
    for m, topk in product(Ms, TOPKs):
83
84
85
86
87
        # override m and topk
        config = copy.deepcopy(base_config)
        config.Ms = m
        config.topks = topk

88
89
90
91
92
        try:
            print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...")
            count = count + 1

            # inputs for rank
93
            rank_tensors = RankTensors.make(config, pgi)
94
95

            # modular kernel out
96
            mk_out = run_modular_kernel(pgi, vllm_config, config, weights,
97
98
99
                                        rank_tensors)

            with set_current_vllm_config(vllm_config):
100
                ref_out = reference_moe_impl(config, weights, rank_tensors)
101
102

            if config.quant_dtype == "nvfp4":
103
104
                atol = 1e-1 if config.K < 4096 else 2e-1
                rtol = 1e-1 if config.K < 4096 else 2e-1
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
            else:
                atol = 3e-2
                rtol = 3e-2

            torch.testing.assert_close(ref_out, mk_out, atol=atol, rtol=rtol)
            format_result(verbose, config.describe())
        except Exception as ex:
            format_result(verbose, config.describe(), ex)
            exceptions.append(ex)

    if len(exceptions) > 0:
        raise RuntimeError(
            f"{len(exceptions)} of {count} tests failed in child process, "
            f"rank={pgi.rank}.")
    else:
        print(f"{count} of {count} tests passed in child process, "
              f"rank={pgi.rank}.")


def run(config: Config, verbose: bool):
125
126
127
128
129
130
    assert config.is_valid()

    weights: WeightTensors = WeightTensors.make(config)

    vllm_config, env_dict = config.make_env_data()
    parallel_launch_with_config(config.world_size, rank_worker, vllm_config,
131
                                env_dict, config, weights, verbose)
132
133
134


Ms = [32, 64]
135
136
137
# hidden sizes, making this too large will cause fp4 tests to fail.
# Also needs to be a multiple of 1024 for deep_gemm.
Ks = [2048]
138
Ns = [1024]
139
140
141
142
143
144
145
146
TOPKs = [4, 1]
Es = [32]
DTYPEs = [torch.bfloat16]
FUSED_MOE_CHUNK_SIZEs = [None, 16]


def is_nyi_config(config: Config) -> bool:
    # We know these configs to be legitimate. but still fail.
147
    info = expert_info(config.fused_experts_type)
148

149
    if info.needs_matching_quant:
150
151
152
153
154
155
        # The triton kernels expect both per-act-token-quant and
        # per-out-ch-quant or neither.
        unsupported_quant_config = ((config.is_per_act_token_quant +
                                     config.is_per_out_ch_quant) == 1)
        return unsupported_quant_config

156
    return not info.supports_expert_map
157
158
159
160
161
162
163
164
165
166
167
168


@pytest.mark.parametrize("k", Ks)
@pytest.mark.parametrize("n", Ns)
@pytest.mark.parametrize("e", Es)
@pytest.mark.parametrize("dtype", DTYPEs)
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
@pytest.mark.parametrize(
    "combination",
    product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [2])
169
@multi_gpu_test(num_gpus=2)
170
@meets_multi_gpu_requirements
171
172
def test_modular_kernel_combinations_multigpu(
        k: int, n: int, e: int, dtype: torch.dtype,
173
        quant_config: Optional[TestMoEQuantConfig],
174
175
        combination: tuple[mk.FusedMoEPrepareAndFinalize,
                           mk.FusedMoEPermuteExpertsUnpermute],
176
        fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
177
178
179
180
181
182
183
184
185
186
187
188
189
190

    config = Config(
        Ms=Ms,
        K=k,
        N=n,
        E=e,
        topks=TOPKs,
        dtype=dtype,
        quant_config=quant_config,
        prepare_finalize_type=combination[0],
        fused_experts_type=combination[1],
        fused_moe_chunk_size=fused_moe_chunk_size,
        world_size=world_size,
    )
191

192
193
194
195
196
197
    if not config.is_valid():
        pytest.skip(f"Tests config {config} is not valid. Skipping ...")

    if is_nyi_config(config):
        pytest.skip(f"Tests config {config} is nyi. Skipping ...")

198
199
    verbosity = pytestconfig.getoption('verbose')
    run(config, verbosity > 0)
200
201
202
203
204
205
206
207
208
209
210
211
212
213


@pytest.mark.parametrize("k", Ks)
@pytest.mark.parametrize("n", Ns)
@pytest.mark.parametrize("e", Es)
@pytest.mark.parametrize("dtype", DTYPEs)
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
@pytest.mark.parametrize(
    "combination",
    product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [1])
def test_modular_kernel_combinations_singlegpu(
        k: int, n: int, e: int, dtype: torch.dtype,
214
        quant_config: Optional[TestMoEQuantConfig],
215
216
        combination: tuple[mk.FusedMoEPrepareAndFinalize,
                           mk.FusedMoEPermuteExpertsUnpermute],
217
        fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    config = Config(
        Ms=Ms,
        K=k,
        N=n,
        E=e,
        topks=TOPKs,
        dtype=dtype,
        quant_config=quant_config,
        prepare_finalize_type=combination[0],
        fused_experts_type=combination[1],
        fused_moe_chunk_size=fused_moe_chunk_size,
        world_size=world_size,
    )

    if not config.is_valid():
        pytest.skip(f"Tests config {config} is not valid. Skipping ...")

    if is_nyi_config(config):
        pytest.skip(f"Tests config {config} is nyi. Skipping ...")

238
239
    verbosity = pytestconfig.getoption('verbose')
    run(config, verbosity > 0)
240
241
242
243
244
245
246
247
248
249
250
251
252
253


if __name__ == '__main__':
    # Ability to test individual PrepareAndFinalize and FusedExperts combination
    from .modular_kernel_tools.cli_args import (make_config,
                                                make_config_arg_parser)
    parser = make_config_arg_parser(description=(
        "Run single prepare-finalize & fused-experts combination test"
        "Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations "  #noqa: E501
        "--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
    ))
    args = parser.parse_args()
    config = make_config(args)

254
    run(config, True)