test_pplx_cutlass_moe.py 10.7 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4

5
6
7
import pytest
import torch

8
from tests.kernels.utils import torch_experts
9
10
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
11
from vllm.model_executor.layers.fused_moe import fused_topk
12
13
14
15
16
17
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEConfig,
    FusedMoEParallelConfig,
    RoutingMethodType,
    fp8_w8a8_moe_quant_config,
)
18
19
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassBatchedExpertsFp8
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
20
from vllm.platforms import current_platform
21
from vllm.utils.math_utils import cdiv
22
from vllm.utils.torch_utils import set_random_seed
23
from vllm.v1.worker.workspace import init_workspace_manager
24

25
from ...utils import multi_gpu_test
bnellnm's avatar
bnellnm committed
26
from .parallel_utils import ProcessGroupInfo, parallel_launch
27

28
29
try:
    from pplx_kernels import AllToAll
30
31
32
33
34
35
36
    from pplx_kernels.nvshmem import (
        nvshmem_alloc_empty_unique_id,
        nvshmem_finalize,
        nvshmem_get_unique_id,
        nvshmem_init,
    )

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    has_pplx = True
except ImportError:
    has_pplx = False

requires_pplx = pytest.mark.skipif(
    not has_pplx,
    reason="Requires PPLX kernels",
)

NUM_EXPERTS = [40, 64]
TOP_KS = [6, 8]


def rank_chunk(num, r, w):
    rem = num % w
    return (num // w) + (1 if r < rem else 0)


def chunk_by_rank(t, r, w):
    num = t.shape[0]
    chunk = rank_chunk(num, r, w)
    rem = num % w
    if rem == 0 or r < rem:
60
        return t[(r * chunk) : (r + 1) * chunk].contiguous()
61
62
63
64
    else:
        long_chunks = (num // w + 1) * rem
        short_chunks = (r - rem) * chunk
        start = long_chunks + short_chunks
65
        return t[start : start + chunk].contiguous()
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81


def pplx_cutlass_moe(
    pgi: ProcessGroupInfo,
    dp_size: int,
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w1_scale: torch.Tensor,
    w2_scale: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    a1_scale: torch.Tensor,
    out_dtype,
    per_act_token: bool,
    per_out_ch: bool,
82
    group_name: str | None,
83
84
):
    from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
85
86
87
        PplxPrepareAndFinalize,
    )

88
89
    init_workspace_manager(torch.cuda.current_device())

90
91
92
    assert torch.cuda.current_device() == pgi.local_rank

    num_tokens, hidden_dim = a.shape
93
    intermediate_dim = w2.shape[2]
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    num_experts = w1.shape[0]
    block_size = hidden_dim  # TODO support more cases
    device = pgi.device
    rank = pgi.rank
    world_size = pgi.world_size
    rank_num_tokens = rank_chunk(num_tokens, rank, world_size)
    max_num_tokens = rank_chunk(num_tokens, 0, world_size)
    topk = topk_ids.shape[1]

    if block_size == hidden_dim:
        scale_elems = 4  # hack to circumvent pplx data format requirements
    else:
        scale_elems = (hidden_dim + block_size - 1) // block_size

108
    args = dict(
109
110
111
112
        max_num_tokens=max_num_tokens,
        num_experts=num_experts,
        experts_per_token=topk,
        rank=rank,
bnellnm's avatar
bnellnm committed
113
        world_size=world_size,
114
115
116
117
118
119
        dp_size=dp_size,
        hidden_dim=hidden_dim,
        hidden_dim_bytes=hidden_dim,  # because a.dtype.itemsize == 1
        hidden_dim_scale_bytes=scale_elems * torch.float32.itemsize,
    )

120
121
122
123
124
125
    if group_name is None:
        ata = AllToAll.internode(**args)
    else:
        args["group_name"] = group_name
        ata = AllToAll.intranode(**args)

126
127
128
129
130
131
    w1 = w1.to(device)
    w2 = w2.to(device)
    w1_scale = w1_scale.to(device)
    w2_scale = w2_scale.to(device)
    a1_scale = a1_scale.to(device)

132
133
134
135
    assert num_experts % world_size == 0
    num_local_experts = cdiv(num_experts, world_size)
    num_dispatchers = pgi.world_size // dp_size

136
137
    prepare_finalize = PplxPrepareAndFinalize(
        ata,
138
139
        max_num_tokens=max_num_tokens,
        num_local_experts=num_local_experts,
140
141
142
        num_dispatchers=num_dispatchers,
    )

143
144
145
146
147
148
149
150
151
152
153
154
155
    def make_moe_config() -> FusedMoEConfig:
        return FusedMoEConfig(
            num_experts=num_experts,
            experts_per_token=topk,
            hidden_dim=hidden_dim,
            intermediate_size_per_partition=intermediate_dim,
            num_local_experts=num_local_experts,
            moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
            activation="silu",
            in_dtype=torch.bfloat16,
            device="cuda",
            routing_method=RoutingMethodType.Llama4,
        )
156

157
    experts = CutlassBatchedExpertsFp8(
158
159
        moe_config=make_moe_config(),
        quant_config=fp8_w8a8_moe_quant_config(
160
161
162
163
164
            per_act_token_quant=per_act_token,
            per_out_ch_quant=per_out_ch,
            w1_scale=chunk_by_rank(w1_scale, rank, world_size),
            w2_scale=chunk_by_rank(w2_scale, rank, world_size),
            a1_scale=chunk_by_rank(a1_scale, rank, world_size)
165
166
167
            if per_act_token
            else a1_scale[rank],
        ),
168
169
        max_num_tokens=max_num_tokens,
        num_dispatchers=num_dispatchers,
170
    )
171
172
173
174

    fused_cutlass_experts = FusedMoEModularKernel(
        prepare_finalize,
        experts,
175
        inplace=False,
176
177
178
    )

    a_chunk = chunk_by_rank(a, rank, world_size).to(device)
179
180
181
182
    chunk_topk_weight = chunk_by_rank(topk_weights, rank, world_size).to(device)
    chunk_topk_ids = (
        chunk_by_rank(topk_ids, rank, world_size).to(torch.uint32).to(device)
    )
183
184
185
186
187
188
189
190

    out = fused_cutlass_experts(
        a_chunk,
        chunk_by_rank(w1, rank, world_size),
        chunk_by_rank(w2, rank, world_size),
        chunk_topk_weight,
        chunk_topk_ids,
        global_num_experts=num_experts,
191
        expert_map=None,  # TODO
192
    )
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220

    torch.cuda.synchronize()

    ata.destroy()

    return out[:rank_num_tokens]


vllm_config = VllmConfig()


def _pplx_moe(
    pgi: ProcessGroupInfo,
    dp_size: int,
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w1_scale: torch.Tensor,
    w2_scale: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    a1_scale: torch.Tensor,
    out_dtype,
    a_full: torch.Tensor,
    w1_full: torch.Tensor,
    w2_full: torch.Tensor,
    per_act_token: bool,
    per_out_ch: bool,
221
    use_internode: bool,
222
):
223
224
    try:
        if use_internode:
225
226
227
228
229
            uid = (
                nvshmem_get_unique_id()
                if pgi.rank == 0
                else nvshmem_alloc_empty_unique_id()
            )
230
231
232
233
            torch.distributed.broadcast(uid, src=0)
            nvshmem_init(uid, pgi.rank, pgi.world_size)
        else:
            group_ranks = list(range(pgi.world_size))
234
            cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
235
236
237
            group_name = cpu_group.group_name

        with set_current_vllm_config(vllm_config):
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
            torch_output = torch_experts(
                a_full, w1_full, w2_full, topk_weights, topk_ids
            )
            pplx_output = pplx_cutlass_moe(
                pgi,
                dp_size,
                a,
                w1,
                w2,
                w1_scale,
                w2_scale,
                topk_weights,
                topk_ids,
                a1_scale,
                out_dtype,
                per_act_token,
                per_out_ch,
                group_name,
            )

            torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(
                pplx_output.device
            )
261
262
263
264
265

        # Uncomment if more debugging is needed
        # print("PPLX OUT:", pplx_output)
        # print("TORCH OUT:", torch_output)

266
        torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0)
267
268
269
    finally:
        if use_internode:
            nvshmem_finalize()
270
271
272
273
274
275
276
277
278


@pytest.mark.parametrize("m", [2, 224])
@pytest.mark.parametrize("n", [3072])
@pytest.mark.parametrize("k", [1536])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
279
@pytest.mark.parametrize("world_dp_size", [[2, 1]])  # , [4, 2]])
280
@pytest.mark.parametrize("use_internode", [False])
281
@multi_gpu_test(num_gpus=2)
282
283
@pytest.mark.skipif(
    (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
284
285
286
287
        current_platform.get_device_capability()
    ),
    reason="Grouped gemm is not supported on this GPU type.",
)
288
289
290
291
292
293
294
295
296
297
@requires_pplx
def test_cutlass_moe_pplx(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    per_act_token: bool,
    per_out_ch: bool,
    world_dp_size: tuple[int, int],
298
    use_internode: bool,
299
):
300
    set_random_seed(7)
301
302
303
304
305
306
307
308
309
310
311

    with set_current_vllm_config(vllm_config):
        dtype = torch.half

        a = torch.randn((m, k), device="cuda", dtype=dtype) / 10.0
        w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10.0
        w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10.0

        n_b_scales = 2 * n if per_out_ch else 1
        k_b_scales = k if per_out_ch else 1

312
        w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn)
313
        w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn)
314
315
        w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32)
        w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32)
316
317
318

        for expert in range(e):
            w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
319
320
                w1[expert], use_per_token_if_dynamic=per_out_ch
            )
321
            w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
322
323
                w2[expert], use_per_token_if_dynamic=per_out_ch
            )
324
325
326
327
328
329
330
331

        w1_d = torch.empty_like(w1)
        w2_d = torch.empty_like(w2)
        for expert in range(e):
            w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half()
            w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()

        score = torch.randn((m, e), device="cuda", dtype=dtype)
332
        topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
333
334

        world_size, dp_size = world_dp_size
335
336
337
338
339
340
        a_scale1 = (
            torch.randn(
                (m if per_act_token else 1, 1), device="cuda", dtype=torch.float32
            )
            / 10.0
        )
341
342
343
        if not per_act_token:
            a_scale1 = a_scale1.repeat(world_size, 1)

344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
        parallel_launch(
            world_size,
            _pplx_moe,
            dp_size,
            a,
            w1_q,
            w2_q,
            w1_scale,
            w2_scale,
            topk_weights,
            topk_ids,
            a_scale1,
            dtype,
            a,
            w1_d,
            w2_d,
            per_act_token,
            per_out_ch,
            use_internode,
        )