test_modular_kernel_combinations.py 8.27 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import copy
5
6
import textwrap
import traceback
7
8
9
10
11
12
13
14
15
from itertools import product
from typing import Optional

import pytest
import torch

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import VllmConfig, current_platform, set_current_vllm_config
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
16
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
17

18
from ...utils import multi_gpu_test
19
20
21
22
23
from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors,
                                          reference_moe_impl,
                                          run_modular_kernel)
from .modular_kernel_tools.mk_objects import (
    MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
24
25
    MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, TestMoEQuantConfig,
    expert_info)
26
27
28
from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo,
                                                  parallel_launch_with_config)

29
30
has_any_multi_gpu_package = (has_deep_ep() or has_deep_gemm() or has_pplx()
                             or has_flashinfer_cutlass_fused_moe())
31

32
33
34
meets_multi_gpu_requirements = pytest.mark.skipif(
    not has_any_multi_gpu_package,
    reason="Requires deep_ep or deep_gemm or pplx or flashinfer packages",
35
36
37
)


38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def format_result(verbose, msg, ex=None):
    if ex is not None:
        x = str(ex)
        newx = x.strip(" \n\t")[:16]
        if len(newx) < len(x):
            newx = newx + " ..."

        prefix = "E\t"
        print(f"{textwrap.indent(traceback.format_exc(), prefix)}")
        print(f"FAILED {msg} - {newx}\n")
    elif verbose:
        print(f"PASSED {msg}")
    else:
        print(".", end="")


54
55
56
57
def rank_worker(
    pgi: ProcessGroupInfo,
    vllm_config: VllmConfig,
    cpu_group,
58
    base_config: Config,
59
    weights: WeightTensors,
60
    verbose: bool,
61
62
63
64
65
):
    current_platform.seed_everything(pgi.rank)

    # sanity check
    from vllm import envs
66
67
68
    if base_config.fused_moe_chunk_size is not None:
        assert (
            base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
69
70
71
72

    # get weights to this device
    weights.to_current_device()

73
    Ms = base_config.Ms
74
    assert isinstance(Ms, list)
75
    TOPKs = base_config.topks
76
77
    assert isinstance(TOPKs, list)

78
79
    exceptions = []
    count = 0
80

81
    for m, topk in product(Ms, TOPKs):
82
83
84
85
86
        # override m and topk
        config = copy.deepcopy(base_config)
        config.Ms = m
        config.topks = topk

87
88
89
90
91
        try:
            print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...")
            count = count + 1

            # inputs for rank
92
            rank_tensors = RankTensors.make(config, pgi)
93
94

            # modular kernel out
95
            mk_out = run_modular_kernel(pgi, vllm_config, config, weights,
96
97
98
                                        rank_tensors)

            with set_current_vllm_config(vllm_config):
99
                ref_out = reference_moe_impl(config, weights, rank_tensors)
100
101

            if config.quant_dtype == "nvfp4":
102
103
                atol = 1e-1 if config.K < 4096 else 2e-1
                rtol = 1e-1 if config.K < 4096 else 2e-1
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
            else:
                atol = 3e-2
                rtol = 3e-2

            torch.testing.assert_close(ref_out, mk_out, atol=atol, rtol=rtol)
            format_result(verbose, config.describe())
        except Exception as ex:
            format_result(verbose, config.describe(), ex)
            exceptions.append(ex)

    if len(exceptions) > 0:
        raise RuntimeError(
            f"{len(exceptions)} of {count} tests failed in child process, "
            f"rank={pgi.rank}.")
    else:
        print(f"{count} of {count} tests passed in child process, "
              f"rank={pgi.rank}.")


def run(config: Config, verbose: bool):
124
125
126
127
128
129
    assert config.is_valid()

    weights: WeightTensors = WeightTensors.make(config)

    vllm_config, env_dict = config.make_env_data()
    parallel_launch_with_config(config.world_size, rank_worker, vllm_config,
130
                                env_dict, config, weights, verbose)
131
132
133


Ms = [32, 64]
134
135
136
# hidden sizes, making this too large will cause fp4 tests to fail.
# Also needs to be a multiple of 1024 for deep_gemm.
Ks = [2048]
137
Ns = [1024]
138
139
140
141
142
143
144
145
TOPKs = [4, 1]
Es = [32]
DTYPEs = [torch.bfloat16]
FUSED_MOE_CHUNK_SIZEs = [None, 16]


def is_nyi_config(config: Config) -> bool:
    # We know these configs to be legitimate. but still fail.
146
    info = expert_info(config.fused_experts_type)
147

148
    if info.needs_matching_quant:
149
150
151
152
153
154
        # The triton kernels expect both per-act-token-quant and
        # per-out-ch-quant or neither.
        unsupported_quant_config = ((config.is_per_act_token_quant +
                                     config.is_per_out_ch_quant) == 1)
        return unsupported_quant_config

155
    return not info.supports_expert_map
156
157
158
159
160
161
162
163
164
165
166
167


@pytest.mark.parametrize("k", Ks)
@pytest.mark.parametrize("n", Ns)
@pytest.mark.parametrize("e", Es)
@pytest.mark.parametrize("dtype", DTYPEs)
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
@pytest.mark.parametrize(
    "combination",
    product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [2])
168
@multi_gpu_test(num_gpus=2)
169
@meets_multi_gpu_requirements
170
171
def test_modular_kernel_combinations_multigpu(
        k: int, n: int, e: int, dtype: torch.dtype,
172
        quant_config: Optional[TestMoEQuantConfig],
173
174
        combination: tuple[mk.FusedMoEPrepareAndFinalize,
                           mk.FusedMoEPermuteExpertsUnpermute],
175
        fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
176
177
178
179
180
181
182
183
184
185
186
187
188
189

    config = Config(
        Ms=Ms,
        K=k,
        N=n,
        E=e,
        topks=TOPKs,
        dtype=dtype,
        quant_config=quant_config,
        prepare_finalize_type=combination[0],
        fused_experts_type=combination[1],
        fused_moe_chunk_size=fused_moe_chunk_size,
        world_size=world_size,
    )
190

191
192
193
194
195
196
    if not config.is_valid():
        pytest.skip(f"Tests config {config} is not valid. Skipping ...")

    if is_nyi_config(config):
        pytest.skip(f"Tests config {config} is nyi. Skipping ...")

197
198
    verbosity = pytestconfig.getoption('verbose')
    run(config, verbosity > 0)
199
200
201
202
203
204
205
206
207
208
209
210
211
212


@pytest.mark.parametrize("k", Ks)
@pytest.mark.parametrize("n", Ns)
@pytest.mark.parametrize("e", Es)
@pytest.mark.parametrize("dtype", DTYPEs)
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
@pytest.mark.parametrize(
    "combination",
    product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [1])
def test_modular_kernel_combinations_singlegpu(
        k: int, n: int, e: int, dtype: torch.dtype,
213
        quant_config: Optional[TestMoEQuantConfig],
214
215
        combination: tuple[mk.FusedMoEPrepareAndFinalize,
                           mk.FusedMoEPermuteExpertsUnpermute],
216
        fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
    config = Config(
        Ms=Ms,
        K=k,
        N=n,
        E=e,
        topks=TOPKs,
        dtype=dtype,
        quant_config=quant_config,
        prepare_finalize_type=combination[0],
        fused_experts_type=combination[1],
        fused_moe_chunk_size=fused_moe_chunk_size,
        world_size=world_size,
    )

    if not config.is_valid():
        pytest.skip(f"Tests config {config} is not valid. Skipping ...")

    if is_nyi_config(config):
        pytest.skip(f"Tests config {config} is nyi. Skipping ...")

237
238
    verbosity = pytestconfig.getoption('verbose')
    run(config, verbosity > 0)
239
240
241
242
243
244
245
246
247
248
249
250
251
252


if __name__ == '__main__':
    # Ability to test individual PrepareAndFinalize and FusedExperts combination
    from .modular_kernel_tools.cli_args import (make_config,
                                                make_config_arg_parser)
    parser = make_config_arg_parser(description=(
        "Run single prepare-finalize & fused-experts combination test"
        "Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations "  #noqa: E501
        "--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
    ))
    args = parser.parse_args()
    config = make_config(args)

253
    run(config, True)