default.py 13.4 KB
Newer Older
Mercykid-bash's avatar
Mercykid-bash committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Expert parallelism load balancer (EPLB) for vLLM.

This module implements the core rearrangement algorithm.

The rearrangement algorithm is adapted from
[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).

Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example
on how the EPLB algorithm works.
"""

import numpy as np
import torch

from .abstract import AbstractEplbPolicy


class DefaultEplbPolicy(AbstractEplbPolicy):
    @classmethod
    def balanced_packing(
24
25
        cls, weight: np.ndarray, num_packs: int
    ) -> tuple[np.ndarray, np.ndarray]:
Mercykid-bash's avatar
Mercykid-bash committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
        """
        Pack n weighted objects to m packs, such that each bin contains exactly
        n/m objects and the weights of all packs are as balanced as possible.

        Parameters:
            weight: [X, n], the weight of each item
            num_packs: number of packs

        Returns:
            pack_index: [X, n], the pack index of each item
            rank_in_pack: [X, n], the rank of the item in the pack
        """
        num_layers, num_groups = weight.shape
        assert num_groups % num_packs == 0
        groups_per_pack = num_groups // num_packs

        if groups_per_pack == 1:
43
44
            pack_index = np.tile(np.arange(num_groups, dtype=np.int64), (num_layers, 1))
            rank_in_pack = np.zeros_like(pack_index, dtype=np.int64)
Mercykid-bash's avatar
Mercykid-bash committed
45
46
            return pack_index, rank_in_pack

Jiayi Yan's avatar
Jiayi Yan committed
47
        # Sort and get indices in descending order
48
        indices = np.argsort(-weight, axis=-1)
Mercykid-bash's avatar
Mercykid-bash committed
49

50
51
        pack_index = np.full((num_layers, num_groups), -1, dtype=np.int64)
        rank_in_pack = np.full((num_layers, num_groups), -1, dtype=np.int64)
Mercykid-bash's avatar
Mercykid-bash committed
52

53
54
        pack_weights = np.zeros((num_layers, num_packs), dtype=np.float64)
        pack_items = np.zeros((num_layers, num_packs), dtype=np.int64)
Mercykid-bash's avatar
Mercykid-bash committed
55

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        # Run the packing algorithm
        for layer_idx in range(num_layers):
            weights_row = pack_weights[layer_idx]
            items_row = pack_items[layer_idx]

            for group in indices[layer_idx]:
                # Pick the lightest pack; full packs are masked out by inf.
                pack = int(np.argmin(weights_row))

                pack_index[layer_idx, group] = pack
                rank_in_pack[layer_idx, group] = items_row[pack]
                weights_row[pack] += weight[layer_idx, group]
                items_row[pack] += 1
                if items_row[pack] == groups_per_pack:
                    # Mark as unavailable for future selections.
                    weights_row[pack] = np.inf
Mercykid-bash's avatar
Mercykid-bash committed
72
73
74
75
76

        return pack_index, rank_in_pack

    @classmethod
    def replicate_experts(
77
        cls, weight: np.ndarray, num_phy: int
78
    ) -> tuple[np.ndarray, np.ndarray]:
Mercykid-bash's avatar
Mercykid-bash committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        """
        Replicate `num_log` experts to `num_phy` replicas, such that the maximum
        load of all replicas is minimized.

        Parameters:
            weight: [X, num_log]
            num_phy: total number of experts after replication

        Returns:
            phy2log: [X, num_phy], logical expert id of each physical expert
            logcnt: [X, num_log], number of replicas for each logical expert
        """
        n, num_log = weight.shape
        num_redundant = num_phy - num_log
        assert num_redundant >= 0
94
95
96
        phy2log = np.tile(np.arange(num_phy, dtype=np.int64), (n, 1))
        logcnt = np.ones((n, num_log), dtype=np.int64)
        arangen = np.arange(n, dtype=np.int64)
Mercykid-bash's avatar
Mercykid-bash committed
97
        for i in range(num_log, num_phy):
98
            redundant_indices = np.argmax(weight / logcnt, axis=-1)
Mercykid-bash's avatar
Mercykid-bash committed
99
100
            phy2log[:, i] = redundant_indices
            logcnt[arangen, redundant_indices] += 1
101
        return phy2log, logcnt
Mercykid-bash's avatar
Mercykid-bash committed
102
103
104
105

    @classmethod
    def rebalance_experts_hierarchical(
        cls,
106
        weight: np.ndarray,
Mercykid-bash's avatar
Mercykid-bash committed
107
108
109
110
        num_physical_experts: int,
        num_groups: int,
        num_nodes: int,
        num_gpus: int,
111
    ) -> np.ndarray:
Mercykid-bash's avatar
Mercykid-bash committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
        """
        Parameters:
            weight: [num_moe_layers, num_logical_experts]
            num_physical_experts: number of physical experts after replication
            num_groups: number of expert groups
            num_nodes: number of server nodes, where the intra-node network
                (e.g, NVLink) is faster
            num_gpus: number of GPUs, must be a multiple of `num_nodes`

        Returns:
            phy2log: [layers, num_replicas], the expert
                index of each replica
        """
        num_layers, num_logical_experts = weight.shape
        assert num_logical_experts % num_groups == 0
        group_size = num_logical_experts // num_groups
        assert num_groups % num_nodes == 0
        groups_per_node = num_groups // num_nodes
        assert num_gpus % num_nodes == 0
        assert num_physical_experts % num_gpus == 0
        phy_experts_per_gpu = num_physical_experts // num_gpus

134
135
136
137
138
        def inverse(perm: np.ndarray) -> np.ndarray:
            inv = np.empty_like(perm)
            row_idx = np.arange(perm.shape[0])[:, None]
            col_idx = np.arange(perm.shape[1], dtype=np.int64)
            inv[row_idx, perm] = col_idx
Mercykid-bash's avatar
Mercykid-bash committed
139
140
141
            return inv

        # Step 1: pack groups to nodes
142
143
144
        tokens_per_group = weight.reshape(num_layers, num_groups, group_size).sum(
            axis=-1
        )
Mercykid-bash's avatar
Mercykid-bash committed
145
146
147
        group_pack_index, group_rank_in_pack = cls.balanced_packing(
            tokens_per_group, num_nodes
        )
148
        # Map each logical expert into a node-local ordering based on packed groups.
Mercykid-bash's avatar
Mercykid-bash committed
149
150
        log2mlog = (
            (
151
152
                (group_pack_index * groups_per_node + group_rank_in_pack)[..., None]
                * group_size
Mercykid-bash's avatar
Mercykid-bash committed
153
            )
154
155
            + np.arange(group_size, dtype=np.int64)
        ).reshape(num_layers, num_logical_experts)
Mercykid-bash's avatar
Mercykid-bash committed
156
157
158
        mlog2log = inverse(log2mlog)

        # Step 2: construct redundant experts within nodes
159
160
        # Reorder weights into the node-local layout so replication is done per node.
        tokens_per_mlog = np.take_along_axis(weight, mlog2log, axis=1).reshape(
Mercykid-bash's avatar
Mercykid-bash committed
161
162
            -1, num_logical_experts // num_nodes
        )
163
        phy2mlog, mlogcnt = cls.replicate_experts(
Mercykid-bash's avatar
Mercykid-bash committed
164
165
166
167
            tokens_per_mlog, num_physical_experts // num_nodes
        )

        # Step 3: pack physical_experts to GPUs
168
169
        # Effective per-physical load = logical load divided by replica count.
        tokens_per_phy = np.take_along_axis(tokens_per_mlog / mlogcnt, phy2mlog, axis=1)
Mercykid-bash's avatar
Mercykid-bash committed
170
171
172
173
174
175
        pack_index, rank_in_pack = cls.balanced_packing(
            tokens_per_phy, num_gpus // num_nodes
        )
        phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
        pphy2phy = inverse(phy2pphy)

176
177
        # Reorder node-local logical indices into the post-packing physical order.
        pphy2mlog = np.take_along_axis(phy2mlog, pphy2phy, axis=1)
Mercykid-bash's avatar
Mercykid-bash committed
178
        pphy2mlog = (
179
180
            pphy2mlog.reshape(num_layers, num_nodes, -1)
            + np.arange(
Mercykid-bash's avatar
Mercykid-bash committed
181
182
183
                0,
                num_logical_experts,
                num_logical_experts // num_nodes,
184
185
186
187
188
                dtype=np.int64,
            )[None, :, None]
        ).reshape(num_layers, -1)
        # Map node-local logical indices back to global logical expert ids.
        pphy2log = np.take_along_axis(mlog2log, pphy2mlog, axis=1)
189
        return pphy2log
190
191
192
193

    @classmethod
    def preserve_intragpu_slots(
        cls,
194
        phy2log: np.ndarray,
195
        num_ranks: int,
196
        old_phy2log: np.ndarray,
197
    ) -> np.ndarray:
198
199
200
201
202
203
204
205
206
        """
        Reorder the new mapping per GPU so that experts that remain on the same GPU
        keep their previous slot positions when possible. Incoming experts to that GPU
        fill any remaining available slots. This is applied only when the number of GPUs
        is unchanged and the slots per GPU remain the same between
        the old and new mappings.
        """
        num_phy_experts = phy2log.shape[1]
        if num_ranks <= 0 or num_phy_experts % num_ranks != 0:
207
            return phy2log
208
209
210

        # Move to CPU and convert to NumPy for processing
        slots_per_gpu = num_phy_experts // num_ranks
211
        num_layers = phy2log.shape[0]
212

213
        post_phy2log = phy2log.copy()
214
215
216
217
218

        for gpu_idx in range(num_ranks):
            start = gpu_idx * slots_per_gpu
            end = start + slots_per_gpu
            # Experts across all layers for this GPU
219
220
            old_local = old_phy2log[:, start:end]  # [layers, slots]
            new_local = phy2log[:, start:end]  # [layers, slots]
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236

            used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool)
            preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool)

            # First pass: preserve same-logical experts in their previous slots
            for slot_idx in range(slots_per_gpu):
                # matches: [layers, slots], True where new local experts have
                # the same logical value as the old from 'slot_idx' and not checked yet
                matches = (new_local == old_local[:, slot_idx][:, None]) & (
                    ~used_new_indices
                )
                has_any = matches.any(axis=1)
                if np.any(has_any):
                    first_idx = np.argmax(matches, axis=1)
                    layer_indices = np.nonzero(has_any)[0]
                    matched_new_positions = first_idx[layer_indices]
237
238
239
                    post_phy2log[layer_indices, start + slot_idx] = new_local[
                        layer_indices, matched_new_positions
                    ]
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
                    used_new_indices[layer_indices, matched_new_positions] = True
                    preserved_positions[layer_indices, slot_idx] = True

            # Second pass: fill remaining slots with remaining new experts
            remaining_mask = ~used_new_indices  # [layers, slots]
            fill_mask = ~preserved_positions  # [layers, slots]
            if remaining_mask.any() and fill_mask.any():
                idx_base = np.tile(np.arange(slots_per_gpu), (num_layers, 1))
                # Sentinel value for unavailable positions.
                large = slots_per_gpu + 1
                # Priorities: keep original index for available spots, set sentinel
                # for unavailable; lower is earlier.
                remaining_priority = np.where(remaining_mask, idx_base, large)
                fill_priority = np.where(fill_mask, idx_base, large)
                # Sort to get ordered indices of available src/dst positions per layer.
                remaining_indices = np.argsort(remaining_priority, axis=1)
                fill_indices = np.argsort(fill_priority, axis=1)
                # Fill count per layer (cannot exceed either side).
                remaining_counts = remaining_mask.sum(axis=1)
                fill_counts = fill_mask.sum(axis=1)
                take_counts = np.minimum(remaining_counts, fill_counts)
                # Assign remaining new experts to remaining slots per layer.
                for layer_idx in range(num_layers):
                    k = int(take_counts[layer_idx])
                    if k <= 0:
                        continue
                    src_pos = remaining_indices[layer_idx, :k]
                    dst_pos = fill_indices[layer_idx, :k]
268
                    post_phy2log[layer_idx, start + dst_pos] = new_local[
269
270
271
                        layer_idx, src_pos
                    ]

272
        return post_phy2log
Mercykid-bash's avatar
Mercykid-bash committed
273
274
275
276
277
278
279
280
281

    @classmethod
    def rebalance_experts(
        cls,
        weight: torch.Tensor,
        num_replicas: int,
        num_groups: int,
        num_nodes: int,
        num_ranks: int,
282
        old_global_expert_indices: torch.Tensor | None = None,
283
    ) -> torch.Tensor:
Mercykid-bash's avatar
Mercykid-bash committed
284
285
286
287
288
289
290
291
292
293
294
295
        """
        Entry point for expert-parallelism load balancer.

        Parameters:
            weight: [layers, num_logical_experts], the load statistics for all
                logical experts
            num_replicas: number of physical experts, must be a multiple of
                `num_gpus`
            num_groups: number of expert groups
            num_nodes: number of server nodes, where the intra-node network
                (e.g, NVLink) is faster
            num_ranks: number of ranks, must be a multiple of `num_nodes`
296
297
298
            old_global_expert_indices: [layers, num_logical_experts], the old global
                expert indices. Used to avoid unnecessary weight copying
                for experts moving within one rank.
Mercykid-bash's avatar
Mercykid-bash committed
299
300
301
302
        Returns:
            phy2log: [layers, num_replicas], the expert
                index of each replica
        """
303
304
305
306
307
308
309
        weight_np = weight.float().cpu().numpy()
        old_phy2log_np = (
            old_global_expert_indices.cpu().numpy()
            if old_global_expert_indices is not None
            else None
        )

Mercykid-bash's avatar
Mercykid-bash committed
310
311
        if num_groups % num_nodes == 0:
            # use hierarchical load-balance policy
312
313
            phy2log_np = cls.rebalance_experts_hierarchical(
                weight_np, num_replicas, num_groups, num_nodes, num_ranks
Mercykid-bash's avatar
Mercykid-bash committed
314
315
316
            )
        else:
            # use global load-balance policy
317
318
            phy2log_np = cls.rebalance_experts_hierarchical(
                weight_np, num_replicas, 1, 1, num_ranks
Mercykid-bash's avatar
Mercykid-bash committed
319
            )
320

321
322
323
324
325
        # Optional postprocessing to preserve slots for experts moving
        # within the same GPU
        # Only apply when the number of GPUs and slots per GPU remain unchanged.
        # Helps to avoid unnecessary weight copying when experts move
        # within the same GPU.
326
327
328
        if old_phy2log_np is not None:
            phy2log_np = cls.preserve_intragpu_slots(
                phy2log_np, num_ranks, old_phy2log_np
329
            )
330

331
332
        phy2log = torch.from_numpy(phy2log_np)
        return phy2log