default.py 15.9 KB
Newer Older
Mercykid-bash's avatar
Mercykid-bash committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Expert parallelism load balancer (EPLB) for vLLM.

This module implements the core rearrangement algorithm.

The rearrangement algorithm is adapted from
[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).

Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example
on how the EPLB algorithm works.
"""

import numpy as np
import torch

from .abstract import AbstractEplbPolicy


class DefaultEplbPolicy(AbstractEplbPolicy):
    @classmethod
    def balanced_packing(
        cls, weight: torch.Tensor, num_packs: int
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Pack n weighted objects to m packs, such that each bin contains exactly
        n/m objects and the weights of all packs are as balanced as possible.

        Parameters:
            weight: [X, n], the weight of each item
            num_packs: number of packs

        Returns:
            pack_index: [X, n], the pack index of each item
            rank_in_pack: [X, n], the rank of the item in the pack
        """
        num_layers, num_groups = weight.shape
        assert num_groups % num_packs == 0
        groups_per_pack = num_groups // num_packs

        device = weight.device

        if groups_per_pack == 1:
            pack_index = torch.arange(
                weight.size(-1), dtype=torch.int64, device=device
            ).expand(weight.shape)
            rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device)
            return pack_index, rank_in_pack

        weight_np = weight.cpu().numpy()

        # Sort and get indices in decending order
        indices_np = np.argsort(-weight_np, axis=-1)

        pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
        rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64)

        # Run the packing algorithm
        for i in range(num_layers):
            pack_weights = [0.0] * num_packs
            pack_items = [0] * num_packs

            for group in indices_np[i]:
                # Find a pack with capacity that has the lowest weight
                pack = min(
                    (j for j in range(num_packs) if pack_items[j] < groups_per_pack),
                    key=pack_weights.__getitem__,
                )

                assert pack_items[pack] < groups_per_pack
                pack_index_np[i, group] = pack
                rank_in_pack_np[i, group] = pack_items[pack]
                pack_weights[pack] += weight_np[i, group]
                pack_items[pack] += 1

        pack_index = torch.from_numpy(pack_index_np).to(device)
        rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device)

        return pack_index, rank_in_pack

    @classmethod
    def replicate_experts(
        cls, weight: torch.Tensor, num_phy: int
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Replicate `num_log` experts to `num_phy` replicas, such that the maximum
        load of all replicas is minimized.

        Parameters:
            weight: [X, num_log]
            num_phy: total number of experts after replication

        Returns:
            phy2log: [X, num_phy], logical expert id of each physical expert
96
            replica_idx: [X, num_phy], the index of the replica for each logical expert
Mercykid-bash's avatar
Mercykid-bash committed
97
98
99
100
101
102
103
            logcnt: [X, num_log], number of replicas for each logical expert
        """
        n, num_log = weight.shape
        num_redundant = num_phy - num_log
        assert num_redundant >= 0
        device = weight.device
        phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
104
        replica_idx = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
Mercykid-bash's avatar
Mercykid-bash committed
105
106
107
108
109
        logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
        arangen = torch.arange(n, dtype=torch.int64, device=device)
        for i in range(num_log, num_phy):
            redundant_indices = (weight / logcnt).max(dim=-1).indices
            phy2log[:, i] = redundant_indices
110
            replica_idx[:, i] = logcnt[arangen, redundant_indices]
Mercykid-bash's avatar
Mercykid-bash committed
111
            logcnt[arangen, redundant_indices] += 1
112
        return phy2log, replica_idx, logcnt
Mercykid-bash's avatar
Mercykid-bash committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134

    @classmethod
    def rebalance_experts_hierarchical(
        cls,
        weight: torch.Tensor,
        num_physical_experts: int,
        num_groups: int,
        num_nodes: int,
        num_gpus: int,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Parameters:
            weight: [num_moe_layers, num_logical_experts]
            num_physical_experts: number of physical experts after replication
            num_groups: number of expert groups
            num_nodes: number of server nodes, where the intra-node network
                (e.g, NVLink) is faster
            num_gpus: number of GPUs, must be a multiple of `num_nodes`

        Returns:
            phy2log: [layers, num_replicas], the expert
                index of each replica
135
            pphy_replicas_idx: [layers, num_logical_experts, X],
Mercykid-bash's avatar
Mercykid-bash committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
                the replica indices for each expert
            logcnt: [layers, num_logical_experts], number of
                physical replicas for each logical expert
        """
        num_layers, num_logical_experts = weight.shape
        assert num_logical_experts % num_groups == 0
        group_size = num_logical_experts // num_groups
        assert num_groups % num_nodes == 0
        groups_per_node = num_groups // num_nodes
        assert num_gpus % num_nodes == 0
        assert num_physical_experts % num_gpus == 0
        phy_experts_per_gpu = num_physical_experts // num_gpus

        def inverse(perm: torch.Tensor) -> torch.Tensor:
            inv = torch.empty_like(perm)
            inv.scatter_(
                1,
                perm,
                torch.arange(
                    perm.size(1), dtype=torch.int64, device=perm.device
                ).expand(perm.shape),
            )
            return inv

        # Step 1: pack groups to nodes
        tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
        group_pack_index, group_rank_in_pack = cls.balanced_packing(
            tokens_per_group, num_nodes
        )
        log2mlog = (
            (
                (group_pack_index * groups_per_node + group_rank_in_pack) * group_size
            ).unsqueeze(-1)
            + torch.arange(
                group_size, dtype=torch.int64, device=group_pack_index.device
            )
        ).flatten(-2)
        mlog2log = inverse(log2mlog)

        # Step 2: construct redundant experts within nodes
        # [num_layers * num_nodes, num_logical_experts // num_nodes]
        tokens_per_mlog = weight.gather(-1, mlog2log).view(
            -1, num_logical_experts // num_nodes
        )
180
        phy2mlog, replicas_idx, mlogcnt = cls.replicate_experts(
Mercykid-bash's avatar
Mercykid-bash committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
            tokens_per_mlog, num_physical_experts // num_nodes
        )

        # Step 3: pack physical_experts to GPUs
        # [num_layers * num_nodes, num_physical_experts // num_nodes]
        tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
        pack_index, rank_in_pack = cls.balanced_packing(
            tokens_per_phy, num_gpus // num_nodes
        )
        phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
        pphy2phy = inverse(phy2pphy)

        pphy2mlog = phy2mlog.gather(
            -1, pphy2phy
        )  # [num_layers * num_nodes, num_log_per_nodes]
        pphy2mlog = (
            pphy2mlog.view(num_layers, num_nodes, -1)
            + torch.arange(
                0,
                num_logical_experts,
                num_logical_experts // num_nodes,
                device=group_pack_index.device,
            ).view(1, -1, 1)
        ).flatten(-2)
        pphy2log = mlog2log.gather(-1, pphy2mlog)
206
        pphy_replicas_idx = replicas_idx.gather(-1, pphy2phy).view(num_layers, -1)
Mercykid-bash's avatar
Mercykid-bash committed
207
        logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
        return pphy2log, pphy_replicas_idx, logcnt

    @classmethod
    def preserve_intragpu_slots(
        cls,
        phy2log: torch.Tensor,
        phy_replicas_idx: torch.Tensor,
        num_ranks: int,
        old_global_expert_indices: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Reorder the new mapping per GPU so that experts that remain on the same GPU
        keep their previous slot positions when possible. Incoming experts to that GPU
        fill any remaining available slots. This is applied only when the number of GPUs
        is unchanged and the slots per GPU remain the same between
        the old and new mappings.
        """
        device = phy2log.device
        num_phy_experts = phy2log.shape[1]
        if num_ranks <= 0 or num_phy_experts % num_ranks != 0:
            return phy2log, phy_replicas_idx

        # Move to CPU and convert to NumPy for processing
        new_phy2log_np = phy2log.cpu().numpy()
        replicas_idx_np = phy_replicas_idx.cpu().numpy()
        old_phy2log_np = old_global_expert_indices.cpu().numpy()

        slots_per_gpu = num_phy_experts // num_ranks
        num_layers = new_phy2log_np.shape[0]

        post_phy2log_np = new_phy2log_np.copy()
        post_phy_replicas_idx_np = replicas_idx_np.copy()

        for gpu_idx in range(num_ranks):
            start = gpu_idx * slots_per_gpu
            end = start + slots_per_gpu
            # Experts across all layers for this GPU
            old_local = old_phy2log_np[:, start:end]  # [layers, slots]
            new_local = new_phy2log_np[:, start:end]  # [layers, slots]
            new_ridx = replicas_idx_np[:, start:end]  # [layers, slots]

            used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool)
            preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool)

            # First pass: preserve same-logical experts in their previous slots
            for slot_idx in range(slots_per_gpu):
                # matches: [layers, slots], True where new local experts have
                # the same logical value as the old from 'slot_idx' and not checked yet
                matches = (new_local == old_local[:, slot_idx][:, None]) & (
                    ~used_new_indices
                )
                has_any = matches.any(axis=1)
                if np.any(has_any):
                    first_idx = np.argmax(matches, axis=1)
                    layer_indices = np.nonzero(has_any)[0]
                    matched_new_positions = first_idx[layer_indices]
                    post_phy2log_np[layer_indices, start + slot_idx] = new_local[
                        layer_indices, matched_new_positions
                    ]
                    post_phy_replicas_idx_np[layer_indices, start + slot_idx] = (
                        new_ridx[layer_indices, matched_new_positions]
                    )
                    used_new_indices[layer_indices, matched_new_positions] = True
                    preserved_positions[layer_indices, slot_idx] = True

            # Second pass: fill remaining slots with remaining new experts
            remaining_mask = ~used_new_indices  # [layers, slots]
            fill_mask = ~preserved_positions  # [layers, slots]
            if remaining_mask.any() and fill_mask.any():
                idx_base = np.tile(np.arange(slots_per_gpu), (num_layers, 1))
                # Sentinel value for unavailable positions.
                large = slots_per_gpu + 1
                # Priorities: keep original index for available spots, set sentinel
                # for unavailable; lower is earlier.
                remaining_priority = np.where(remaining_mask, idx_base, large)
                fill_priority = np.where(fill_mask, idx_base, large)
                # Sort to get ordered indices of available src/dst positions per layer.
                remaining_indices = np.argsort(remaining_priority, axis=1)
                fill_indices = np.argsort(fill_priority, axis=1)
                # Fill count per layer (cannot exceed either side).
                remaining_counts = remaining_mask.sum(axis=1)
                fill_counts = fill_mask.sum(axis=1)
                take_counts = np.minimum(remaining_counts, fill_counts)
                # Assign remaining new experts to remaining slots per layer.
                for layer_idx in range(num_layers):
                    k = int(take_counts[layer_idx])
                    if k <= 0:
                        continue
                    src_pos = remaining_indices[layer_idx, :k]
                    dst_pos = fill_indices[layer_idx, :k]
                    post_phy2log_np[layer_idx, start + dst_pos] = new_local[
                        layer_idx, src_pos
                    ]
                    post_phy_replicas_idx_np[layer_idx, start + dst_pos] = new_ridx[
                        layer_idx, src_pos
                    ]

        # Convert back to torch and move to original device
        post_phy2log = torch.from_numpy(post_phy2log_np).to(device)
        post_phy_replicas_idx = torch.from_numpy(post_phy_replicas_idx_np).to(device)
        return post_phy2log, post_phy_replicas_idx
Mercykid-bash's avatar
Mercykid-bash committed
309
310
311
312
313
314
315
316
317

    @classmethod
    def rebalance_experts(
        cls,
        weight: torch.Tensor,
        num_replicas: int,
        num_groups: int,
        num_nodes: int,
        num_ranks: int,
318
        old_global_expert_indices: torch.Tensor | None = None,
Mercykid-bash's avatar
Mercykid-bash committed
319
320
321
322
323
324
325
326
327
328
329
330
331
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Entry point for expert-parallelism load balancer.

        Parameters:
            weight: [layers, num_logical_experts], the load statistics for all
                logical experts
            num_replicas: number of physical experts, must be a multiple of
                `num_gpus`
            num_groups: number of expert groups
            num_nodes: number of server nodes, where the intra-node network
                (e.g, NVLink) is faster
            num_ranks: number of ranks, must be a multiple of `num_nodes`
332
333
334
            old_global_expert_indices: [layers, num_logical_experts], the old global
                expert indices. Used to avoid unnecessary weight copying
                for experts moving within one rank.
Mercykid-bash's avatar
Mercykid-bash committed
335
336
337
338
339
340
341
342
343
344
345
346
        Returns:
            phy2log: [layers, num_replicas], the expert
                index of each replica
            log2phy: [layers, num_logical_experts, X],
                the replica indices for each expert
            logcnt: [layers, num_logical_experts], number of
                physical replicas for each logical expert
        """
        num_layers, num_logical_experts = weight.shape
        weight = weight.float()
        if num_groups % num_nodes == 0:
            # use hierarchical load-balance policy
347
            phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical(
Mercykid-bash's avatar
Mercykid-bash committed
348
349
350
351
                weight, num_replicas, num_groups, num_nodes, num_ranks
            )
        else:
            # use global load-balance policy
352
            phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical(
Mercykid-bash's avatar
Mercykid-bash committed
353
354
                weight, num_replicas, 1, 1, num_ranks
            )
355
356
357
358
359
360
361
362
363
        # Optional postprocessing to preserve slots for experts moving
        # within the same GPU
        # Only apply when the number of GPUs and slots per GPU remain unchanged.
        # Helps to avoid unnecessary weight copying when experts move
        # within the same GPU.
        if old_global_expert_indices is not None:
            phy2log, phy_replicas_idx = cls.preserve_intragpu_slots(
                phy2log, phy_replicas_idx, num_ranks, old_global_expert_indices
            )
Mercykid-bash's avatar
Mercykid-bash committed
364
365
366
367
368
369
370
371
372
373
        num_redundant_experts = num_replicas - num_logical_experts
        maxlogcnt = num_redundant_experts + 1
        log2phy: torch.Tensor = torch.full(
            (num_layers, num_logical_experts, maxlogcnt),
            -1,
            dtype=torch.int64,
            device=logcnt.device,
        )
        log2phy.view(num_layers, -1).scatter_(
            -1,
374
            phy2log * maxlogcnt + phy_replicas_idx,
Mercykid-bash's avatar
Mercykid-bash committed
375
376
377
378
379
            torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(
                num_layers, -1
            ),
        )
        return phy2log, log2phy, logcnt