test_pplx_cutlass_moe.py 10.5 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
5
from typing import Optional

6
7
8
import pytest
import torch

9
from tests.kernels.utils import torch_experts
10
11
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
12
13
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassBatchedExpertsFp8
14
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
15
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
16
from vllm.platforms import current_platform
17
from vllm.utils import cdiv
18

19
from ...utils import multi_gpu_test
bnellnm's avatar
bnellnm committed
20
from .parallel_utils import ProcessGroupInfo, parallel_launch
21

22
23
try:
    from pplx_kernels import AllToAll
24
25
26
27
28
29
30
    from pplx_kernels.nvshmem import (
        nvshmem_alloc_empty_unique_id,
        nvshmem_finalize,
        nvshmem_get_unique_id,
        nvshmem_init,
    )

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    has_pplx = True
except ImportError:
    has_pplx = False

requires_pplx = pytest.mark.skipif(
    not has_pplx,
    reason="Requires PPLX kernels",
)

NUM_EXPERTS = [40, 64]
TOP_KS = [6, 8]


def rank_chunk(num, r, w):
    rem = num % w
    return (num // w) + (1 if r < rem else 0)


def chunk_by_rank(t, r, w):
    num = t.shape[0]
    chunk = rank_chunk(num, r, w)
    rem = num % w
    if rem == 0 or r < rem:
54
        return t[(r * chunk) : (r + 1) * chunk].contiguous()
55
56
57
58
    else:
        long_chunks = (num // w + 1) * rem
        short_chunks = (r - rem) * chunk
        start = long_chunks + short_chunks
59
        return t[start : start + chunk].contiguous()
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75


def pplx_cutlass_moe(
    pgi: ProcessGroupInfo,
    dp_size: int,
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w1_scale: torch.Tensor,
    w2_scale: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    a1_scale: torch.Tensor,
    out_dtype,
    per_act_token: bool,
    per_out_ch: bool,
76
    group_name: Optional[str],
77
78
):
    from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
79
80
81
        PplxPrepareAndFinalize,
    )

82
83
84
    assert torch.cuda.current_device() == pgi.local_rank

    num_tokens, hidden_dim = a.shape
85
    intermediate_dim = w2.shape[2]
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    num_experts = w1.shape[0]
    block_size = hidden_dim  # TODO support more cases
    device = pgi.device
    rank = pgi.rank
    world_size = pgi.world_size
    rank_num_tokens = rank_chunk(num_tokens, rank, world_size)
    max_num_tokens = rank_chunk(num_tokens, 0, world_size)
    topk = topk_ids.shape[1]

    if block_size == hidden_dim:
        scale_elems = 4  # hack to circumvent pplx data format requirements
    else:
        scale_elems = (hidden_dim + block_size - 1) // block_size

100
    args = dict(
101
102
103
104
        max_num_tokens=max_num_tokens,
        num_experts=num_experts,
        experts_per_token=topk,
        rank=rank,
bnellnm's avatar
bnellnm committed
105
        world_size=world_size,
106
107
108
109
110
111
        dp_size=dp_size,
        hidden_dim=hidden_dim,
        hidden_dim_bytes=hidden_dim,  # because a.dtype.itemsize == 1
        hidden_dim_scale_bytes=scale_elems * torch.float32.itemsize,
    )

112
113
114
115
116
117
    if group_name is None:
        ata = AllToAll.internode(**args)
    else:
        args["group_name"] = group_name
        ata = AllToAll.intranode(**args)

118
119
120
121
122
123
    w1 = w1.to(device)
    w2 = w2.to(device)
    w1_scale = w1_scale.to(device)
    w2_scale = w2_scale.to(device)
    a1_scale = a1_scale.to(device)

124
125
126
127
    assert num_experts % world_size == 0
    num_local_experts = cdiv(num_experts, world_size)
    num_dispatchers = pgi.world_size // dp_size

128
129
    prepare_finalize = PplxPrepareAndFinalize(
        ata,
130
131
        max_num_tokens=max_num_tokens,
        num_local_experts=num_local_experts,
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        num_dispatchers=num_dispatchers,
    )

    ab_strides1 = torch.full(
        (num_local_experts,), hidden_dim, device="cuda", dtype=torch.int64
    )
    ab_strides2 = torch.full(
        (num_local_experts,), intermediate_dim, device="cuda", dtype=torch.int64
    )
    c_strides1 = torch.full(
        (num_local_experts,), 2 * intermediate_dim, device="cuda", dtype=torch.int64
    )
    c_strides2 = torch.full(
        (num_local_experts,), hidden_dim, device="cuda", dtype=torch.int64
    )
147

148
    experts = CutlassBatchedExpertsFp8(
149
150
151
152
153
154
155
        num_local_experts,
        num_dispatchers,
        out_dtype,
        ab_strides1,
        ab_strides2,
        c_strides1,
        c_strides2,
156
157
158
159
160
161
        fp8_w8a8_moe_quant_config(
            per_act_token_quant=per_act_token,
            per_out_ch_quant=per_out_ch,
            w1_scale=chunk_by_rank(w1_scale, rank, world_size),
            w2_scale=chunk_by_rank(w2_scale, rank, world_size),
            a1_scale=chunk_by_rank(a1_scale, rank, world_size)
162
163
164
165
            if per_act_token
            else a1_scale[rank],
        ),
    )
166
167
168
169
170
171
172

    fused_cutlass_experts = FusedMoEModularKernel(
        prepare_finalize,
        experts,
    )

    a_chunk = chunk_by_rank(a, rank, world_size).to(device)
173
174
175
176
    chunk_topk_weight = chunk_by_rank(topk_weights, rank, world_size).to(device)
    chunk_topk_ids = (
        chunk_by_rank(topk_ids, rank, world_size).to(torch.uint32).to(device)
    )
177
178
179
180
181
182
183
184

    out = fused_cutlass_experts(
        a_chunk,
        chunk_by_rank(w1, rank, world_size),
        chunk_by_rank(w2, rank, world_size),
        chunk_topk_weight,
        chunk_topk_ids,
        global_num_experts=num_experts,
185
        expert_map=None,  # TODO
186
    )
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216

    torch.cuda.synchronize()

    ata.destroy()

    return out[:rank_num_tokens]


vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192


def _pplx_moe(
    pgi: ProcessGroupInfo,
    dp_size: int,
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w1_scale: torch.Tensor,
    w2_scale: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    a1_scale: torch.Tensor,
    out_dtype,
    a_full: torch.Tensor,
    w1_full: torch.Tensor,
    w2_full: torch.Tensor,
    per_act_token: bool,
    per_out_ch: bool,
217
    use_internode: bool,
218
):
219
220
    try:
        if use_internode:
221
222
223
224
225
            uid = (
                nvshmem_get_unique_id()
                if pgi.rank == 0
                else nvshmem_alloc_empty_unique_id()
            )
226
227
228
229
            torch.distributed.broadcast(uid, src=0)
            nvshmem_init(uid, pgi.rank, pgi.world_size)
        else:
            group_ranks = list(range(pgi.world_size))
230
            cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
231
232
233
            group_name = cpu_group.group_name

        with set_current_vllm_config(vllm_config):
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
            torch_output = torch_experts(
                a_full, w1_full, w2_full, topk_weights, topk_ids
            )
            pplx_output = pplx_cutlass_moe(
                pgi,
                dp_size,
                a,
                w1,
                w2,
                w1_scale,
                w2_scale,
                topk_weights,
                topk_ids,
                a1_scale,
                out_dtype,
                per_act_token,
                per_out_ch,
                group_name,
            )

            torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(
                pplx_output.device
            )
257
258
259
260
261

        # Uncomment if more debugging is needed
        # print("PPLX OUT:", pplx_output)
        # print("TORCH OUT:", torch_output)

262
        torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0)
263
264
265
    finally:
        if use_internode:
            nvshmem_finalize()
266
267
268
269
270
271
272
273
274


@pytest.mark.parametrize("m", [2, 224])
@pytest.mark.parametrize("n", [3072])
@pytest.mark.parametrize("k", [1536])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
275
@pytest.mark.parametrize("world_dp_size", [[2, 1]])  # , [4, 2]])
276
@pytest.mark.parametrize("use_internode", [False])
277
@multi_gpu_test(num_gpus=2)
278
279
@pytest.mark.skipif(
    (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
280
281
282
283
        current_platform.get_device_capability()
    ),
    reason="Grouped gemm is not supported on this GPU type.",
)
284
285
286
287
288
289
290
291
292
293
@requires_pplx
def test_cutlass_moe_pplx(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    per_act_token: bool,
    per_out_ch: bool,
    world_dp_size: tuple[int, int],
294
    use_internode: bool,
295
296
297
298
299
300
301
302
303
304
305
306
307
):
    current_platform.seed_everything(7)

    with set_current_vllm_config(vllm_config):
        dtype = torch.half

        a = torch.randn((m, k), device="cuda", dtype=dtype) / 10.0
        w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10.0
        w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10.0

        n_b_scales = 2 * n if per_out_ch else 1
        k_b_scales = k if per_out_ch else 1

308
        w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn)
309
        w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn)
310
311
        w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32)
        w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32)
312
313
314

        for expert in range(e):
            w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
315
316
                w1[expert], use_per_token_if_dynamic=per_out_ch
            )
317
            w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
318
319
                w2[expert], use_per_token_if_dynamic=per_out_ch
            )
320
321
322
323
324
325
326
327

        w1_d = torch.empty_like(w1)
        w2_d = torch.empty_like(w2)
        for expert in range(e):
            w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half()
            w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()

        score = torch.randn((m, e), device="cuda", dtype=dtype)
328
        topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
329
330

        world_size, dp_size = world_dp_size
331
332
333
334
335
336
        a_scale1 = (
            torch.randn(
                (m if per_act_token else 1, 1), device="cuda", dtype=torch.float32
            )
            / 10.0
        )
337
338
339
        if not per_act_token:
            a_scale1 = a_scale1.repeat(world_size, 1)

340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
        parallel_launch(
            world_size,
            _pplx_moe,
            dp_size,
            a,
            w1_q,
            w2_q,
            w1_scale,
            w2_scale,
            topk_weights,
            topk_ids,
            a_scale1,
            dtype,
            a,
            w1_d,
            w2_d,
            per_act_token,
            per_out_ch,
            use_internode,
        )