default.py 16 KB
Newer Older
Mercykid-bash's avatar
Mercykid-bash committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Expert parallelism load balancer (EPLB) for vLLM.

This module implements the core rearrangement algorithm.

The rearrangement algorithm is adapted from
[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).

Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example
on how the EPLB algorithm works.
"""

import numpy as np
import torch

from .abstract import AbstractEplbPolicy


class DefaultEplbPolicy(AbstractEplbPolicy):
    @classmethod
    def balanced_packing(
24
25
        cls, weight: np.ndarray, num_packs: int
    ) -> tuple[np.ndarray, np.ndarray]:
Mercykid-bash's avatar
Mercykid-bash committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
        """
        Pack n weighted objects to m packs, such that each bin contains exactly
        n/m objects and the weights of all packs are as balanced as possible.

        Parameters:
            weight: [X, n], the weight of each item
            num_packs: number of packs

        Returns:
            pack_index: [X, n], the pack index of each item
            rank_in_pack: [X, n], the rank of the item in the pack
        """
        num_layers, num_groups = weight.shape
        assert num_groups % num_packs == 0
        groups_per_pack = num_groups // num_packs

        if groups_per_pack == 1:
43
44
            pack_index = np.tile(np.arange(num_groups, dtype=np.int64), (num_layers, 1))
            rank_in_pack = np.zeros_like(pack_index, dtype=np.int64)
Mercykid-bash's avatar
Mercykid-bash committed
45
46
            return pack_index, rank_in_pack

Jiayi Yan's avatar
Jiayi Yan committed
47
        # Sort and get indices in descending order
48
        indices = np.argsort(-weight, axis=-1)
Mercykid-bash's avatar
Mercykid-bash committed
49

50
51
        pack_index = np.full((num_layers, num_groups), -1, dtype=np.int64)
        rank_in_pack = np.full((num_layers, num_groups), -1, dtype=np.int64)
Mercykid-bash's avatar
Mercykid-bash committed
52

53
54
        pack_weights = np.zeros((num_layers, num_packs), dtype=np.float64)
        pack_items = np.zeros((num_layers, num_packs), dtype=np.int64)
Mercykid-bash's avatar
Mercykid-bash committed
55

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        # Run the packing algorithm
        for layer_idx in range(num_layers):
            weights_row = pack_weights[layer_idx]
            items_row = pack_items[layer_idx]

            for group in indices[layer_idx]:
                # Pick the lightest pack; full packs are masked out by inf.
                pack = int(np.argmin(weights_row))

                pack_index[layer_idx, group] = pack
                rank_in_pack[layer_idx, group] = items_row[pack]
                weights_row[pack] += weight[layer_idx, group]
                items_row[pack] += 1
                if items_row[pack] == groups_per_pack:
                    # Mark as unavailable for future selections.
                    weights_row[pack] = np.inf
Mercykid-bash's avatar
Mercykid-bash committed
72
73
74
75
76

        return pack_index, rank_in_pack

    @classmethod
    def replicate_experts(
77
78
        cls, weight: np.ndarray, num_phy: int
    ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
Mercykid-bash's avatar
Mercykid-bash committed
79
80
81
82
83
84
85
86
87
88
        """
        Replicate `num_log` experts to `num_phy` replicas, such that the maximum
        load of all replicas is minimized.

        Parameters:
            weight: [X, num_log]
            num_phy: total number of experts after replication

        Returns:
            phy2log: [X, num_phy], logical expert id of each physical expert
89
            replica_idx: [X, num_phy], the index of the replica for each logical expert
Mercykid-bash's avatar
Mercykid-bash committed
90
91
92
93
94
            logcnt: [X, num_log], number of replicas for each logical expert
        """
        n, num_log = weight.shape
        num_redundant = num_phy - num_log
        assert num_redundant >= 0
95
96
97
98
        phy2log = np.tile(np.arange(num_phy, dtype=np.int64), (n, 1))
        replica_idx = np.zeros((n, num_phy), dtype=np.int64)
        logcnt = np.ones((n, num_log), dtype=np.int64)
        arangen = np.arange(n, dtype=np.int64)
Mercykid-bash's avatar
Mercykid-bash committed
99
        for i in range(num_log, num_phy):
100
            redundant_indices = np.argmax(weight / logcnt, axis=-1)
Mercykid-bash's avatar
Mercykid-bash committed
101
            phy2log[:, i] = redundant_indices
102
            replica_idx[:, i] = logcnt[arangen, redundant_indices]
Mercykid-bash's avatar
Mercykid-bash committed
103
            logcnt[arangen, redundant_indices] += 1
104
        return phy2log, replica_idx, logcnt
Mercykid-bash's avatar
Mercykid-bash committed
105
106
107
108

    @classmethod
    def rebalance_experts_hierarchical(
        cls,
109
        weight: np.ndarray,
Mercykid-bash's avatar
Mercykid-bash committed
110
111
112
113
        num_physical_experts: int,
        num_groups: int,
        num_nodes: int,
        num_gpus: int,
114
    ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
Mercykid-bash's avatar
Mercykid-bash committed
115
116
117
118
119
120
121
122
123
124
125
126
        """
        Parameters:
            weight: [num_moe_layers, num_logical_experts]
            num_physical_experts: number of physical experts after replication
            num_groups: number of expert groups
            num_nodes: number of server nodes, where the intra-node network
                (e.g, NVLink) is faster
            num_gpus: number of GPUs, must be a multiple of `num_nodes`

        Returns:
            phy2log: [layers, num_replicas], the expert
                index of each replica
127
            pphy_replicas_idx: [layers, num_logical_experts, X],
Mercykid-bash's avatar
Mercykid-bash committed
128
129
130
131
132
133
134
135
136
137
138
139
140
                the replica indices for each expert
            logcnt: [layers, num_logical_experts], number of
                physical replicas for each logical expert
        """
        num_layers, num_logical_experts = weight.shape
        assert num_logical_experts % num_groups == 0
        group_size = num_logical_experts // num_groups
        assert num_groups % num_nodes == 0
        groups_per_node = num_groups // num_nodes
        assert num_gpus % num_nodes == 0
        assert num_physical_experts % num_gpus == 0
        phy_experts_per_gpu = num_physical_experts // num_gpus

141
142
143
144
145
        def inverse(perm: np.ndarray) -> np.ndarray:
            inv = np.empty_like(perm)
            row_idx = np.arange(perm.shape[0])[:, None]
            col_idx = np.arange(perm.shape[1], dtype=np.int64)
            inv[row_idx, perm] = col_idx
Mercykid-bash's avatar
Mercykid-bash committed
146
147
148
            return inv

        # Step 1: pack groups to nodes
149
150
151
        tokens_per_group = weight.reshape(num_layers, num_groups, group_size).sum(
            axis=-1
        )
Mercykid-bash's avatar
Mercykid-bash committed
152
153
154
        group_pack_index, group_rank_in_pack = cls.balanced_packing(
            tokens_per_group, num_nodes
        )
155
        # Map each logical expert into a node-local ordering based on packed groups.
Mercykid-bash's avatar
Mercykid-bash committed
156
157
        log2mlog = (
            (
158
159
                (group_pack_index * groups_per_node + group_rank_in_pack)[..., None]
                * group_size
Mercykid-bash's avatar
Mercykid-bash committed
160
            )
161
162
            + np.arange(group_size, dtype=np.int64)
        ).reshape(num_layers, num_logical_experts)
Mercykid-bash's avatar
Mercykid-bash committed
163
164
165
        mlog2log = inverse(log2mlog)

        # Step 2: construct redundant experts within nodes
166
167
        # Reorder weights into the node-local layout so replication is done per node.
        tokens_per_mlog = np.take_along_axis(weight, mlog2log, axis=1).reshape(
Mercykid-bash's avatar
Mercykid-bash committed
168
169
            -1, num_logical_experts // num_nodes
        )
170
        phy2mlog, replicas_idx, mlogcnt = cls.replicate_experts(
Mercykid-bash's avatar
Mercykid-bash committed
171
172
173
174
            tokens_per_mlog, num_physical_experts // num_nodes
        )

        # Step 3: pack physical_experts to GPUs
175
176
        # Effective per-physical load = logical load divided by replica count.
        tokens_per_phy = np.take_along_axis(tokens_per_mlog / mlogcnt, phy2mlog, axis=1)
Mercykid-bash's avatar
Mercykid-bash committed
177
178
179
180
181
182
        pack_index, rank_in_pack = cls.balanced_packing(
            tokens_per_phy, num_gpus // num_nodes
        )
        phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
        pphy2phy = inverse(phy2pphy)

183
184
        # Reorder node-local logical indices into the post-packing physical order.
        pphy2mlog = np.take_along_axis(phy2mlog, pphy2phy, axis=1)
Mercykid-bash's avatar
Mercykid-bash committed
185
        pphy2mlog = (
186
187
            pphy2mlog.reshape(num_layers, num_nodes, -1)
            + np.arange(
Mercykid-bash's avatar
Mercykid-bash committed
188
189
190
                0,
                num_logical_experts,
                num_logical_experts // num_nodes,
191
192
193
194
195
196
197
198
199
200
201
                dtype=np.int64,
            )[None, :, None]
        ).reshape(num_layers, -1)
        # Map node-local logical indices back to global logical expert ids.
        pphy2log = np.take_along_axis(mlog2log, pphy2mlog, axis=1)
        # Reorder replica ranks to the post-packing physical ordering.
        pphy_replicas_idx = np.take_along_axis(replicas_idx, pphy2phy, axis=1).reshape(
            num_layers, -1
        )
        # Convert replica counts back to the original logical ordering.
        logcnt = np.take_along_axis(mlogcnt.reshape(num_layers, -1), log2mlog, axis=1)
202
203
204
205
206
        return pphy2log, pphy_replicas_idx, logcnt

    @classmethod
    def preserve_intragpu_slots(
        cls,
207
208
        phy2log: np.ndarray,
        phy_replicas_idx: np.ndarray,
209
        num_ranks: int,
210
211
        old_phy2log: np.ndarray,
    ) -> tuple[np.ndarray, np.ndarray]:
212
213
214
215
216
217
218
219
220
221
222
223
224
        """
        Reorder the new mapping per GPU so that experts that remain on the same GPU
        keep their previous slot positions when possible. Incoming experts to that GPU
        fill any remaining available slots. This is applied only when the number of GPUs
        is unchanged and the slots per GPU remain the same between
        the old and new mappings.
        """
        num_phy_experts = phy2log.shape[1]
        if num_ranks <= 0 or num_phy_experts % num_ranks != 0:
            return phy2log, phy_replicas_idx

        # Move to CPU and convert to NumPy for processing
        slots_per_gpu = num_phy_experts // num_ranks
225
        num_layers = phy2log.shape[0]
226

227
228
        post_phy2log = phy2log.copy()
        post_phy_replicas_idx = phy_replicas_idx.copy()
229
230
231
232
233

        for gpu_idx in range(num_ranks):
            start = gpu_idx * slots_per_gpu
            end = start + slots_per_gpu
            # Experts across all layers for this GPU
234
235
236
            old_local = old_phy2log[:, start:end]  # [layers, slots]
            new_local = phy2log[:, start:end]  # [layers, slots]
            new_ridx = phy_replicas_idx[:, start:end]  # [layers, slots]
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252

            used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool)
            preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool)

            # First pass: preserve same-logical experts in their previous slots
            for slot_idx in range(slots_per_gpu):
                # matches: [layers, slots], True where new local experts have
                # the same logical value as the old from 'slot_idx' and not checked yet
                matches = (new_local == old_local[:, slot_idx][:, None]) & (
                    ~used_new_indices
                )
                has_any = matches.any(axis=1)
                if np.any(has_any):
                    first_idx = np.argmax(matches, axis=1)
                    layer_indices = np.nonzero(has_any)[0]
                    matched_new_positions = first_idx[layer_indices]
253
254
255
256
                    post_phy2log[layer_indices, start + slot_idx] = new_local[
                        layer_indices, matched_new_positions
                    ]
                    post_phy_replicas_idx[layer_indices, start + slot_idx] = new_ridx[
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
                        layer_indices, matched_new_positions
                    ]
                    used_new_indices[layer_indices, matched_new_positions] = True
                    preserved_positions[layer_indices, slot_idx] = True

            # Second pass: fill remaining slots with remaining new experts
            remaining_mask = ~used_new_indices  # [layers, slots]
            fill_mask = ~preserved_positions  # [layers, slots]
            if remaining_mask.any() and fill_mask.any():
                idx_base = np.tile(np.arange(slots_per_gpu), (num_layers, 1))
                # Sentinel value for unavailable positions.
                large = slots_per_gpu + 1
                # Priorities: keep original index for available spots, set sentinel
                # for unavailable; lower is earlier.
                remaining_priority = np.where(remaining_mask, idx_base, large)
                fill_priority = np.where(fill_mask, idx_base, large)
                # Sort to get ordered indices of available src/dst positions per layer.
                remaining_indices = np.argsort(remaining_priority, axis=1)
                fill_indices = np.argsort(fill_priority, axis=1)
                # Fill count per layer (cannot exceed either side).
                remaining_counts = remaining_mask.sum(axis=1)
                fill_counts = fill_mask.sum(axis=1)
                take_counts = np.minimum(remaining_counts, fill_counts)
                # Assign remaining new experts to remaining slots per layer.
                for layer_idx in range(num_layers):
                    k = int(take_counts[layer_idx])
                    if k <= 0:
                        continue
                    src_pos = remaining_indices[layer_idx, :k]
                    dst_pos = fill_indices[layer_idx, :k]
287
                    post_phy2log[layer_idx, start + dst_pos] = new_local[
288
289
                        layer_idx, src_pos
                    ]
290
                    post_phy_replicas_idx[layer_idx, start + dst_pos] = new_ridx[
291
292
293
294
                        layer_idx, src_pos
                    ]

        return post_phy2log, post_phy_replicas_idx
Mercykid-bash's avatar
Mercykid-bash committed
295
296
297
298
299
300
301
302
303

    @classmethod
    def rebalance_experts(
        cls,
        weight: torch.Tensor,
        num_replicas: int,
        num_groups: int,
        num_nodes: int,
        num_ranks: int,
304
        old_global_expert_indices: torch.Tensor | None = None,
Mercykid-bash's avatar
Mercykid-bash committed
305
306
307
308
309
310
311
312
313
314
315
316
317
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Entry point for expert-parallelism load balancer.

        Parameters:
            weight: [layers, num_logical_experts], the load statistics for all
                logical experts
            num_replicas: number of physical experts, must be a multiple of
                `num_gpus`
            num_groups: number of expert groups
            num_nodes: number of server nodes, where the intra-node network
                (e.g, NVLink) is faster
            num_ranks: number of ranks, must be a multiple of `num_nodes`
318
319
320
            old_global_expert_indices: [layers, num_logical_experts], the old global
                expert indices. Used to avoid unnecessary weight copying
                for experts moving within one rank.
Mercykid-bash's avatar
Mercykid-bash committed
321
322
323
324
325
326
327
328
        Returns:
            phy2log: [layers, num_replicas], the expert
                index of each replica
            log2phy: [layers, num_logical_experts, X],
                the replica indices for each expert
            logcnt: [layers, num_logical_experts], number of
                physical replicas for each logical expert
        """
329
        device = weight.device
Mercykid-bash's avatar
Mercykid-bash committed
330
        num_layers, num_logical_experts = weight.shape
331
332
333
334
335
336
337
        weight_np = weight.float().cpu().numpy()
        old_phy2log_np = (
            old_global_expert_indices.cpu().numpy()
            if old_global_expert_indices is not None
            else None
        )

Mercykid-bash's avatar
Mercykid-bash committed
338
339
        if num_groups % num_nodes == 0:
            # use hierarchical load-balance policy
340
341
342
343
            phy2log_np, phy_replicas_idx_np, logcnt_np = (
                cls.rebalance_experts_hierarchical(
                    weight_np, num_replicas, num_groups, num_nodes, num_ranks
                )
Mercykid-bash's avatar
Mercykid-bash committed
344
345
346
            )
        else:
            # use global load-balance policy
347
348
349
350
            phy2log_np, phy_replicas_idx_np, logcnt_np = (
                cls.rebalance_experts_hierarchical(
                    weight_np, num_replicas, 1, 1, num_ranks
                )
Mercykid-bash's avatar
Mercykid-bash committed
351
            )
352

353
354
355
356
357
358
        # Optional postprocessing to preserve slots for experts moving
        # within the same GPU
        # Only apply when the number of GPUs and slots per GPU remain unchanged.
        # Helps to avoid unnecessary weight copying when experts move
        # within the same GPU.
        if old_global_expert_indices is not None:
359
360
            phy2log_np, phy_replicas_idx_np = cls.preserve_intragpu_slots(
                phy2log_np, phy_replicas_idx_np, num_ranks, old_phy2log_np
361
            )
Mercykid-bash's avatar
Mercykid-bash committed
362
363
        num_redundant_experts = num_replicas - num_logical_experts
        maxlogcnt = num_redundant_experts + 1
364
365
        log2phy_np = np.full(
            (num_layers, num_logical_experts, maxlogcnt), -1, dtype=np.int64
Mercykid-bash's avatar
Mercykid-bash committed
366
        )
367
368
369
        layer_indices = np.arange(num_layers)[:, None]
        replica_indices = np.tile(
            np.arange(num_replicas, dtype=np.int64), (num_layers, 1)
Mercykid-bash's avatar
Mercykid-bash committed
370
        )
371
372
373
374
375
        log2phy_np[layer_indices, phy2log_np, phy_replicas_idx_np] = replica_indices

        phy2log = torch.from_numpy(phy2log_np).to(device)
        log2phy = torch.from_numpy(log2phy_np).to(device)
        logcnt = torch.from_numpy(logcnt_np).to(device)
Mercykid-bash's avatar
Mercykid-bash committed
376
        return phy2log, log2phy, logcnt