test_pplx_cutlass_moe.py 9.47 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
5
from typing import Optional

6
7
8
import pytest
import torch

9
from tests.kernels.utils import torch_experts
10
11
12
13
14
15
16
17
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import (
    FusedMoEModularKernel)
from vllm.platforms import current_platform

bnellnm's avatar
bnellnm committed
18
from .parallel_utils import ProcessGroupInfo, parallel_launch
19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
try:
    from pplx_kernels import AllToAll
    from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
                                      nvshmem_finalize, nvshmem_get_unique_id,
                                      nvshmem_init)
    has_pplx = True
except ImportError:
    has_pplx = False

requires_pplx = pytest.mark.skipif(
    not has_pplx,
    reason="Requires PPLX kernels",
)

NUM_EXPERTS = [40, 64]
TOP_KS = [6, 8]


def rank_chunk(num, r, w):
    rem = num % w
    return (num // w) + (1 if r < rem else 0)


def chunk_by_rank(t, r, w):
    num = t.shape[0]
    chunk = rank_chunk(num, r, w)
    rem = num % w
    if rem == 0 or r < rem:
        return t[(r * chunk):(r + 1) * chunk].contiguous()
    else:
        long_chunks = (num // w + 1) * rem
        short_chunks = (r - rem) * chunk
        start = long_chunks + short_chunks
        return t[start:start + chunk].contiguous()


def pplx_cutlass_moe(
    pgi: ProcessGroupInfo,
    dp_size: int,
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w1_scale: torch.Tensor,
    w2_scale: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    a1_scale: torch.Tensor,
    out_dtype,
    per_act_token: bool,
    per_out_ch: bool,
70
    group_name: Optional[str],
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
):
    from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
        PplxPrepareAndFinalize)
    assert torch.cuda.current_device() == pgi.local_rank

    num_tokens, hidden_dim = a.shape
    num_experts = w1.shape[0]
    block_size = hidden_dim  # TODO support more cases
    device = pgi.device
    rank = pgi.rank
    world_size = pgi.world_size
    rank_num_tokens = rank_chunk(num_tokens, rank, world_size)
    max_num_tokens = rank_chunk(num_tokens, 0, world_size)
    topk = topk_ids.shape[1]

    if block_size == hidden_dim:
        scale_elems = 4  # hack to circumvent pplx data format requirements
    else:
        scale_elems = (hidden_dim + block_size - 1) // block_size

91
    args = dict(
92
93
94
95
        max_num_tokens=max_num_tokens,
        num_experts=num_experts,
        experts_per_token=topk,
        rank=rank,
bnellnm's avatar
bnellnm committed
96
        world_size=world_size,
97
98
99
100
101
102
        dp_size=dp_size,
        hidden_dim=hidden_dim,
        hidden_dim_bytes=hidden_dim,  # because a.dtype.itemsize == 1
        hidden_dim_scale_bytes=scale_elems * torch.float32.itemsize,
    )

103
104
105
106
107
108
    if group_name is None:
        ata = AllToAll.internode(**args)
    else:
        args["group_name"] = group_name
        ata = AllToAll.intranode(**args)

109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    w1 = w1.to(device)
    w2 = w2.to(device)
    w1_scale = w1_scale.to(device)
    w2_scale = w2_scale.to(device)
    a1_scale = a1_scale.to(device)

    prepare_finalize = PplxPrepareAndFinalize(
        ata,
        max_num_tokens,
        pgi.world_size,
        rank,
        dp_size,
    )

    experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size,
124
125
126
127
                                out_dtype,
                                per_act_token,
                                per_out_ch,
                                use_batched_format=True)
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181

    fused_cutlass_experts = FusedMoEModularKernel(
        prepare_finalize,
        experts,
    )

    a_chunk = chunk_by_rank(a, rank, world_size).to(device)
    chunk_topk_weight = chunk_by_rank(topk_weights, rank,
                                      world_size).to(device)
    chunk_topk_ids = chunk_by_rank(topk_ids, rank,
                                   world_size).to(torch.uint32).to(device)

    out = fused_cutlass_experts(
        a_chunk,
        chunk_by_rank(w1, rank, world_size),
        chunk_by_rank(w2, rank, world_size),
        chunk_topk_weight,
        chunk_topk_ids,
        global_num_experts=num_experts,
        expert_map=None,  #TODO
        w1_scale=chunk_by_rank(w1_scale, rank, world_size),
        w2_scale=chunk_by_rank(w2_scale, rank, world_size),
        a1_scale=chunk_by_rank(a1_scale, rank, world_size)
        if per_act_token else a1_scale[rank])

    torch.cuda.synchronize()

    ata.destroy()

    return out[:rank_num_tokens]


vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192


def _pplx_moe(
    pgi: ProcessGroupInfo,
    dp_size: int,
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w1_scale: torch.Tensor,
    w2_scale: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    a1_scale: torch.Tensor,
    out_dtype,
    a_full: torch.Tensor,
    w1_full: torch.Tensor,
    w2_full: torch.Tensor,
    per_act_token: bool,
    per_out_ch: bool,
182
    use_internode: bool,
183
):
184
185
186
187
188
189
190
191
192
    if use_internode:
        uid = nvshmem_get_unique_id(
        ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
        torch.distributed.broadcast(uid, src=0)
        nvshmem_init(uid, pgi.rank, pgi.world_size)
    else:
        group_ranks = list(range(pgi.world_size))
        cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
        group_name = cpu_group.group_name
193
194

    with set_current_vllm_config(vllm_config):
195
196
        torch_output = torch_experts(a_full, w1_full, w2_full, topk_weights,
                                     topk_ids)
197
198
199
        pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale,
                                       w2_scale, topk_weights, topk_ids,
                                       a1_scale, out_dtype, per_act_token,
200
                                       per_out_ch, group_name)
201
202
203
204
205
206
207
208
209
210

        torch_output = chunk_by_rank(torch_output, pgi.rank,
                                     pgi.world_size).to(pplx_output.device)

    # Uncomment if more debugging is needed
    # print("PPLX OUT:", pplx_output)
    # print("TORCH OUT:", torch_output)

    torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0)

211
212
    if use_internode:
        nvshmem_finalize()
213
214
215
216
217
218
219
220
221
222


@pytest.mark.parametrize("m", [2, 224])
@pytest.mark.parametrize("n", [3072])
@pytest.mark.parametrize("k", [1536])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("world_dp_size", [[2, 1]])  #, [4, 2]])
223
@pytest.mark.parametrize("use_internode", [False])
224
225
226
227
228
229
230
231
232
233
234
235
236
237
@pytest.mark.skipif(
    (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
        current_platform.get_device_capability()),
    reason="Grouped gemm is not supported on this GPU type.")
@requires_pplx
def test_cutlass_moe_pplx(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    per_act_token: bool,
    per_out_ch: bool,
    world_dp_size: tuple[int, int],
238
    use_internode: bool,
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
):
    current_platform.seed_everything(7)

    with set_current_vllm_config(vllm_config):

        dtype = torch.half

        a = torch.randn((m, k), device="cuda", dtype=dtype) / 10.0
        w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10.0
        w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10.0

        n_b_scales = 2 * n if per_out_ch else 1
        k_b_scales = k if per_out_ch else 1

        w1_q = torch.empty((e, 2 * n, k),
                           device="cuda",
                           dtype=torch.float8_e4m3fn)
        w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn)
        w1_scale = torch.empty((e, n_b_scales, 1),
                               device="cuda",
                               dtype=torch.float32)
        w2_scale = torch.empty((e, k_b_scales, 1),
                               device="cuda",
                               dtype=torch.float32)

        for expert in range(e):
            w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
                w1[expert], use_per_token_if_dynamic=per_out_ch)
            w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
                w2[expert], use_per_token_if_dynamic=per_out_ch)

        w1_d = torch.empty_like(w1)
        w2_d = torch.empty_like(w2)
        for expert in range(e):
            w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half()
            w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()

        score = torch.randn((m, e), device="cuda", dtype=dtype)
        topk_weights, topk_ids, _ = fused_topk(a,
                                               score,
                                               topk,
                                               renormalize=False)

        world_size, dp_size = world_dp_size
        a_scale1 = torch.randn(
            (m if per_act_token else 1, 1), device="cuda",
            dtype=torch.float32) / 10.0
        if not per_act_token:
            a_scale1 = a_scale1.repeat(world_size, 1)

        parallel_launch(world_size, _pplx_moe, dp_size, a, w1_q, w2_q,
                        w1_scale, w2_scale, topk_weights, topk_ids, a_scale1,
291
292
                        dtype, a, w1_d, w2_d, per_act_token, per_out_ch,
                        use_internode)