test_pplx_cutlass_moe.py 10.8 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4

5
6
7
import pytest
import torch

8
from tests.kernels.utils import torch_experts
9
10
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
11
from vllm.model_executor.layers.fused_moe import fused_topk
12
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
13
14
15
16
17
18
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEConfig,
    FusedMoEParallelConfig,
    RoutingMethodType,
    fp8_w8a8_moe_quant_config,
)
19
20
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassBatchedExpertsFp8
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
21
from vllm.platforms import current_platform
22
from vllm.utils.math_utils import cdiv
23
from vllm.utils.torch_utils import set_random_seed
24
from vllm.v1.worker.workspace import init_workspace_manager
25

26
from ...utils import multi_gpu_test
bnellnm's avatar
bnellnm committed
27
from .parallel_utils import ProcessGroupInfo, parallel_launch
28

29
30
try:
    from pplx_kernels import AllToAll
31
32
33
34
35
36
37
    from pplx_kernels.nvshmem import (
        nvshmem_alloc_empty_unique_id,
        nvshmem_finalize,
        nvshmem_get_unique_id,
        nvshmem_init,
    )

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    has_pplx = True
except ImportError:
    has_pplx = False

requires_pplx = pytest.mark.skipif(
    not has_pplx,
    reason="Requires PPLX kernels",
)

NUM_EXPERTS = [40, 64]
TOP_KS = [6, 8]


def rank_chunk(num, r, w):
    rem = num % w
    return (num // w) + (1 if r < rem else 0)


def chunk_by_rank(t, r, w):
    num = t.shape[0]
    chunk = rank_chunk(num, r, w)
    rem = num % w
    if rem == 0 or r < rem:
61
        return t[(r * chunk) : (r + 1) * chunk].contiguous()
62
63
64
65
    else:
        long_chunks = (num // w + 1) * rem
        short_chunks = (r - rem) * chunk
        start = long_chunks + short_chunks
66
        return t[start : start + chunk].contiguous()
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82


def pplx_cutlass_moe(
    pgi: ProcessGroupInfo,
    dp_size: int,
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w1_scale: torch.Tensor,
    w2_scale: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    a1_scale: torch.Tensor,
    out_dtype,
    per_act_token: bool,
    per_out_ch: bool,
83
    group_name: str | None,
84
85
):
    from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
86
87
88
        PplxPrepareAndFinalize,
    )

89
90
    init_workspace_manager(torch.cuda.current_device())

91
92
93
    assert torch.cuda.current_device() == pgi.local_rank

    num_tokens, hidden_dim = a.shape
94
    intermediate_dim = w2.shape[2]
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    num_experts = w1.shape[0]
    block_size = hidden_dim  # TODO support more cases
    device = pgi.device
    rank = pgi.rank
    world_size = pgi.world_size
    rank_num_tokens = rank_chunk(num_tokens, rank, world_size)
    max_num_tokens = rank_chunk(num_tokens, 0, world_size)
    topk = topk_ids.shape[1]

    if block_size == hidden_dim:
        scale_elems = 4  # hack to circumvent pplx data format requirements
    else:
        scale_elems = (hidden_dim + block_size - 1) // block_size

109
    args = dict(
110
111
112
113
        max_num_tokens=max_num_tokens,
        num_experts=num_experts,
        experts_per_token=topk,
        rank=rank,
bnellnm's avatar
bnellnm committed
114
        world_size=world_size,
115
116
117
118
119
120
        dp_size=dp_size,
        hidden_dim=hidden_dim,
        hidden_dim_bytes=hidden_dim,  # because a.dtype.itemsize == 1
        hidden_dim_scale_bytes=scale_elems * torch.float32.itemsize,
    )

121
122
123
124
125
126
    if group_name is None:
        ata = AllToAll.internode(**args)
    else:
        args["group_name"] = group_name
        ata = AllToAll.intranode(**args)

127
128
129
130
131
132
    w1 = w1.to(device)
    w2 = w2.to(device)
    w1_scale = w1_scale.to(device)
    w2_scale = w2_scale.to(device)
    a1_scale = a1_scale.to(device)

133
134
135
136
    assert num_experts % world_size == 0
    num_local_experts = cdiv(num_experts, world_size)
    num_dispatchers = pgi.world_size // dp_size

137
138
    prepare_finalize = PplxPrepareAndFinalize(
        ata,
139
140
        max_num_tokens=max_num_tokens,
        num_local_experts=num_local_experts,
141
142
143
        num_dispatchers=num_dispatchers,
    )

144
145
146
147
148
149
150
    def make_moe_config() -> FusedMoEConfig:
        return FusedMoEConfig(
            num_experts=num_experts,
            experts_per_token=topk,
            hidden_dim=hidden_dim,
            intermediate_size_per_partition=intermediate_dim,
            num_local_experts=num_local_experts,
151
            num_logical_experts=num_experts,
152
            moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
153
            activation=MoEActivation.SILU,
154
155
156
157
            in_dtype=torch.bfloat16,
            device="cuda",
            routing_method=RoutingMethodType.Llama4,
        )
158

159
    experts = CutlassBatchedExpertsFp8(
160
161
        moe_config=make_moe_config(),
        quant_config=fp8_w8a8_moe_quant_config(
162
163
164
165
166
            per_act_token_quant=per_act_token,
            per_out_ch_quant=per_out_ch,
            w1_scale=chunk_by_rank(w1_scale, rank, world_size),
            w2_scale=chunk_by_rank(w2_scale, rank, world_size),
            a1_scale=chunk_by_rank(a1_scale, rank, world_size)
167
168
169
            if per_act_token
            else a1_scale[rank],
        ),
170
171
        max_num_tokens=max_num_tokens,
        num_dispatchers=num_dispatchers,
172
    )
173
174
175
176

    fused_cutlass_experts = FusedMoEModularKernel(
        prepare_finalize,
        experts,
177
        inplace=False,
178
179
180
    )

    a_chunk = chunk_by_rank(a, rank, world_size).to(device)
181
182
183
184
    chunk_topk_weight = chunk_by_rank(topk_weights, rank, world_size).to(device)
    chunk_topk_ids = (
        chunk_by_rank(topk_ids, rank, world_size).to(torch.uint32).to(device)
    )
185
186
187
188
189
190
191
192

    out = fused_cutlass_experts(
        a_chunk,
        chunk_by_rank(w1, rank, world_size),
        chunk_by_rank(w2, rank, world_size),
        chunk_topk_weight,
        chunk_topk_ids,
        global_num_experts=num_experts,
193
        expert_map=None,  # TODO
194
    )
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222

    torch.cuda.synchronize()

    ata.destroy()

    return out[:rank_num_tokens]


vllm_config = VllmConfig()


def _pplx_moe(
    pgi: ProcessGroupInfo,
    dp_size: int,
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w1_scale: torch.Tensor,
    w2_scale: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    a1_scale: torch.Tensor,
    out_dtype,
    a_full: torch.Tensor,
    w1_full: torch.Tensor,
    w2_full: torch.Tensor,
    per_act_token: bool,
    per_out_ch: bool,
223
    use_internode: bool,
224
):
225
226
    try:
        if use_internode:
227
228
229
230
231
            uid = (
                nvshmem_get_unique_id()
                if pgi.rank == 0
                else nvshmem_alloc_empty_unique_id()
            )
232
233
234
235
            torch.distributed.broadcast(uid, src=0)
            nvshmem_init(uid, pgi.rank, pgi.world_size)
        else:
            group_ranks = list(range(pgi.world_size))
236
            cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
237
238
239
            group_name = cpu_group.group_name

        with set_current_vllm_config(vllm_config):
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
            torch_output = torch_experts(
                a_full, w1_full, w2_full, topk_weights, topk_ids
            )
            pplx_output = pplx_cutlass_moe(
                pgi,
                dp_size,
                a,
                w1,
                w2,
                w1_scale,
                w2_scale,
                topk_weights,
                topk_ids,
                a1_scale,
                out_dtype,
                per_act_token,
                per_out_ch,
                group_name,
            )

            torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(
                pplx_output.device
            )
263
264
265
266
267

        # Uncomment if more debugging is needed
        # print("PPLX OUT:", pplx_output)
        # print("TORCH OUT:", torch_output)

268
        torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0)
269
270
271
    finally:
        if use_internode:
            nvshmem_finalize()
272
273
274
275
276
277
278
279
280


@pytest.mark.parametrize("m", [2, 224])
@pytest.mark.parametrize("n", [3072])
@pytest.mark.parametrize("k", [1536])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
281
@pytest.mark.parametrize("world_dp_size", [[2, 1]])  # , [4, 2]])
282
@pytest.mark.parametrize("use_internode", [False])
283
@multi_gpu_test(num_gpus=2)
284
285
@pytest.mark.skipif(
    (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
286
287
288
289
        current_platform.get_device_capability()
    ),
    reason="Grouped gemm is not supported on this GPU type.",
)
290
291
292
293
294
295
296
297
298
299
@requires_pplx
def test_cutlass_moe_pplx(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    per_act_token: bool,
    per_out_ch: bool,
    world_dp_size: tuple[int, int],
300
    use_internode: bool,
301
):
302
    set_random_seed(7)
303
304
305
306
307
308
309
310
311
312
313

    with set_current_vllm_config(vllm_config):
        dtype = torch.half

        a = torch.randn((m, k), device="cuda", dtype=dtype) / 10.0
        w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10.0
        w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10.0

        n_b_scales = 2 * n if per_out_ch else 1
        k_b_scales = k if per_out_ch else 1

314
        w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn)
315
        w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn)
316
317
        w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32)
        w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32)
318
319
320

        for expert in range(e):
            w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
321
322
                w1[expert], use_per_token_if_dynamic=per_out_ch
            )
323
            w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
324
325
                w2[expert], use_per_token_if_dynamic=per_out_ch
            )
326
327
328
329
330
331
332
333

        w1_d = torch.empty_like(w1)
        w2_d = torch.empty_like(w2)
        for expert in range(e):
            w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half()
            w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()

        score = torch.randn((m, e), device="cuda", dtype=dtype)
334
        topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
335
336

        world_size, dp_size = world_dp_size
337
338
339
340
341
342
        a_scale1 = (
            torch.randn(
                (m if per_act_token else 1, 1), device="cuda", dtype=torch.float32
            )
            / 10.0
        )
343
344
345
        if not per_act_token:
            a_scale1 = a_scale1.repeat(world_size, 1)

346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
        parallel_launch(
            world_size,
            _pplx_moe,
            dp_size,
            a,
            w1_q,
            w2_q,
            w1_scale,
            w2_scale,
            topk_weights,
            topk_ids,
            a_scale1,
            dtype,
            a,
            w1_d,
            w2_d,
            per_act_token,
            per_out_ch,
            use_internode,
        )