test_moe_permute_unpermute.py 11 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
"""Tests for the MOE permute/unpermute kernel

Run `pytest tests/kernels/test_moe_permute_unpermute.py`.
"""

import numpy as np
import pytest
import torch

from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
15
16
17
18
    moe_permute,
    moe_permute_unpermute_supported,
    moe_unpermute,
)
19
20
from vllm.platforms import current_platform

21
NUM_EXPERTS = [16, 64, 256]
22
TOP_KS = [2, 6, 8]
23
24
25
26
EP_SIZE = [1, 4, 16]
current_platform.seed_everything(0)


27
def torch_permute(
28
29
30
31
32
33
34
    hidden_states: torch.Tensor,
    topk_ids: torch.Tensor,
    #   token_expert_indices: torch.Tensor,
    topk: int,
    n_expert: int,
    n_local_expert: int,
    start_expert: int,
35
36
    expert_map: torch.Tensor | None = None,
    align_block_size: int | None = None,
37
38
    fill_invalid_expert: int = -1,
) -> list[torch.Tensor]:
39
40
    n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1]
    if expert_map is not None:
41
42
43
44
45
46
47
48
        is_local_expert = expert_map[topk_ids] != -1
        not_local_expert = expert_map[topk_ids] == -1
        topk_ids = is_local_expert * (topk_ids - start_expert) + not_local_expert * (
            topk_ids + n_expert
        )
    token_expert_indices = torch.arange(
        0, n_token * topk, dtype=torch.int32, device=hidden_states.device
    ).reshape((n_token, topk))
49

50
    sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(), stable=True)
51
52
    dst_row_id2src_row_id_map = token_expert_indices.flatten()[sorted_indices]

53
54
55
    expert_first_token_offset = torch.zeros(
        n_local_expert + 1, dtype=torch.int64, device="cuda"
    )
56
57
58
59
60
61
62
63
64
65
66
    idx = 0
    for i in range(0, n_local_expert):
        cnt = 0
        while idx < sorted_topk_ids.numel() and sorted_topk_ids[idx] == i:
            cnt += 1
            idx += 1
        expert_first_token_offset[i + 1] = expert_first_token_offset[i] + cnt

    _, src2dst_idx = torch.sort(dst_row_id2src_row_id_map)
    valid_row_idx = []
    if align_block_size is None:
67
        permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map // topk, ...]
68
        permuted_row_size = permuted_hidden_states.shape[0]
69
70
71
        m_indices = torch.empty(
            permuted_row_size, device="cuda", dtype=torch.int32
        ).fill_(fill_invalid_expert)
72
73
74
75
76
        for i in range(1, n_local_expert + 1):
            first_token_offset = expert_first_token_offset[i - 1]
            last_token_offset = expert_first_token_offset[i]
            m_indices[first_token_offset:last_token_offset] = i - 1
        src_row_id2dst_row_id_map = torch.arange(
77
78
            0, n_token * topk, device="cuda", dtype=torch.int32
        )[src2dst_idx].reshape((n_token, topk))
79
        valid_row_idx += [i for i in range(expert_first_token_offset[-1])]
80
        dst_row_id2src_row_id_map[expert_first_token_offset[-1] :] = n_token * topk
81
        return [
82
83
84
85
86
87
            permuted_hidden_states,
            expert_first_token_offset,
            src_row_id2dst_row_id_map,
            dst_row_id2src_row_id_map,
            m_indices,
            valid_row_idx,
88
89
        ]
    else:
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        permuted_row_size = (
            (topk * n_token + n_expert * (align_block_size - 1) + align_block_size - 1)
            // align_block_size
            * align_block_size
        )
        permuted_idx = torch.full(
            (permuted_row_size,),
            n_token * topk,
            dtype=torch.int32,
            device=hidden_states.device,
        )
        permuted_hidden_states = torch.empty(
            (permuted_row_size, n_hidden), device="cuda", dtype=hidden_states.dtype
        )
        align_src_row_id2dst_row_id = torch.empty(
            n_token * topk, device="cuda", dtype=torch.int32
        )
        align_expert_first_token_offset = torch.zeros_like(expert_first_token_offset)
        m_indices = torch.empty(
            permuted_row_size, device="cuda", dtype=torch.int32
        ).fill_(fill_invalid_expert)
111
112
113
114
115
116
        # get align_permuted_hidden_states,
        # valid row_idx and align_expert_first_token_offset
        for i in range(1, n_local_expert + 1):
            first_token_offset = expert_first_token_offset[i - 1]
            last_token_offset = expert_first_token_offset[i]
            n_token_in_expert = last_token_offset - first_token_offset
117
118
119
120
121
122
            align_expert_first_token_offset[i] = (
                align_expert_first_token_offset[i - 1]
                + (n_token_in_expert + align_block_size - 1)
                // align_block_size
                * align_block_size
            )
123
124
125
            align_first_token_offset = align_expert_first_token_offset[i - 1]
            align_last_token_offset = align_expert_first_token_offset[i]
            dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[
126
127
                first_token_offset : first_token_offset + n_token_in_expert
            ]
128
            # store token in current expert with align_first_token_offset
129
130
131
132
133
134
135
            permuted_hidden_states[
                align_first_token_offset : align_first_token_offset + n_token_in_expert,
                ...,
            ] = hidden_states[dst_row_id2src_row_id_in_expert // topk, ...]
            permuted_idx[
                align_first_token_offset : align_first_token_offset + n_token_in_expert
            ] = dst_row_id2src_row_id_in_expert
136
137
138
            # set current expert m_indices
            m_indices[align_first_token_offset:align_last_token_offset] = i - 1
            valid_row_idx += [
139
140
141
142
143
                i
                for i in range(
                    align_first_token_offset,
                    align_first_token_offset + n_token_in_expert,
                )
144
145
146
147
            ]
        # get align_src_row_id2dst_row_id
        for i in range(n_token * topk):
            eid = sorted_topk_ids[i]
148
            if eid >= n_local_expert:
149
                # check token not in local expert
150
                align_src_row_id2dst_row_id[i] = align_expert_first_token_offset[-1]
151
152
153
154
                continue
            first_token_offset = expert_first_token_offset[eid]
            align_first_token_offset = align_expert_first_token_offset[eid]
            token_offset = i - first_token_offset
155
156
157
158
            align_src_row_id2dst_row_id[i] = align_first_token_offset + token_offset
        align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[src2dst_idx].reshape(
            (n_token, topk)
        )
159
        return [
160
161
162
163
164
165
            permuted_hidden_states,
            align_expert_first_token_offset,
            align_src_row_id2dst_row_id,
            permuted_idx,
            m_indices,
            valid_row_idx,
166
167
168
        ]


169
170
171
172
173
174
175
176
177
178
def torch_unpermute(
    permuted_hidden_states: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    token_expert_indices: torch.Tensor,
    src_row_id2dst_row_id_map: torch.Tensor,
    valid_row_idx: torch.Tensor,
    topk: int,
    n_expert: int,
) -> torch.Tensor:
179
    # ignore invalid row
180
    n_hidden = permuted_hidden_states.shape[1]
181
    mask = torch.zeros(permuted_hidden_states.shape[0], dtype=bool, device="cuda")
182
183
    mask[valid_row_idx] = True
    permuted_hidden_states[~mask] = 0
184
185

    permuted_hidden_states = permuted_hidden_states[
186
187
        src_row_id2dst_row_id_map.flatten(), ...
    ]
188
    permuted_hidden_states = permuted_hidden_states.view(-1, topk, n_hidden)
189
190
191
192
193
    output = (
        (permuted_hidden_states * topk_weights.unsqueeze(2))
        .sum(1)
        .to(permuted_hidden_states.dtype)
    )
194
195
196
    return output


197
198
@pytest.mark.parametrize("n_token", [1, 33, 1024, 5000])
@pytest.mark.parametrize("n_hidden", [2048, 7168])
199
200
@pytest.mark.parametrize("n_expert", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
201
@pytest.mark.parametrize("dtype", [torch.bfloat16])
202
203
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("align_block_size", [None, 128])
204
205
206
207
208
209
210
def test_moe_permute_unpermute(
    n_token: int,
    n_hidden: int,
    topk: int,
    n_expert: int,
    ep_size: int,
    dtype: torch.dtype,
211
    align_block_size: int | None,
212
):
213
214
    if not moe_permute_unpermute_supported():
        pytest.skip("moe_permute_unpermute is not supported on this platform.")
215
216
217
218
    fill_invalid_expert = 0
    ep_rank = np.random.randint(0, ep_size)
    expert_map = None
    n_local_expert = n_expert
219
220
    if ep_size != 1:
        n_local_expert, expert_map = determine_expert_map(ep_size, ep_rank, n_expert)
221
222
223
224
225
226
        expert_map = expert_map.cuda()
    start_expert = n_local_expert * ep_rank
    current_platform.seed_everything(0)
    hidden_states = torch.randn((n_token, n_hidden), device="cuda").to(dtype)
    gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype)
    topk_weights, topk_ids, token_expert_indices = fused_topk(
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        hidden_states, gating_output, topk, False
    )
    (
        gold_permuted_hidden_states,
        gold_expert_first_token_offset,
        gold_inv_permuted_idx,
        gold_permuted_idx,
        gold_m_indices,
        valid_row_idx,
    ) = torch_permute(
        hidden_states,
        topk_ids,
        # token_expert_indices,
        topk,
        n_expert,
        n_local_expert,
        start_expert,
        expert_map=expert_map,
        align_block_size=align_block_size,
        fill_invalid_expert=fill_invalid_expert,
    )
248

249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
    (
        permuted_hidden_states,
        _,
        expert_first_token_offset,
        inv_permuted_idx,
        m_indices,
    ) = moe_permute(
        hidden_states=hidden_states,
        a1q_scale=None,
        topk_ids=topk_ids,
        n_expert=n_expert,
        n_local_expert=n_local_expert,
        expert_map=expert_map,
        align_block_size=align_block_size,
        fill_invalid_expert=fill_invalid_expert,
    )
265
266

    # check expert_first_token_offset
267
268
269
    torch.testing.assert_close(
        gold_expert_first_token_offset, expert_first_token_offset, atol=0, rtol=0
    )
270
    # check src_row_id2dst_row_id_map
271
272
273
    torch.testing.assert_close(
        gold_inv_permuted_idx.flatten(), inv_permuted_idx, atol=0, rtol=0
    )
274
    # check mindice
275
276
277
278
279
    # current kernel usage assumes deepgemm requires align_block_size
    # when it's not provided then we don't compute m_indices (for cutlass)
    if align_block_size is not None:
        torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)

280
    # check permuted_hidden_states, only valid token
281
282
283
284
285
286
    torch.testing.assert_close(
        gold_permuted_hidden_states[valid_row_idx],
        permuted_hidden_states[valid_row_idx],
        atol=0,
        rtol=0,
    )
287
    # add a random tensor to simulate group gemm
288
    result0 = 0.5 * permuted_hidden_states + torch.randn_like(permuted_hidden_states)
289
    result4 = torch.empty_like(hidden_states)
290
291
292
    moe_unpermute(
        result4, result0, topk_weights, inv_permuted_idx, expert_first_token_offset
    )
293

294
295
296
297
298
299
300
301
302
303
    gold4 = torch_unpermute(
        result0,
        topk_weights,
        topk_ids,
        token_expert_indices,
        inv_permuted_idx,
        valid_row_idx,
        topk,
        n_local_expert,
    )
304
305
    # check unpermuted hidden
    torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0)