rebalance_execute.py 27.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
"""
The actual execution of the rearrangement.

This involves the exchange of expert weights between GPUs.
"""

9
from collections.abc import Sequence
10
from dataclasses import dataclass
11

12
import numpy as np
13
import torch
14
from torch.distributed import ProcessGroup, all_gather
15

16
17
18
19
20
from vllm.distributed.eplb.eplb_communicator import EplbCommunicator
from vllm.distributed.eplb.eplb_utils import CpuGpuEvent
from vllm.logger import init_logger

logger = init_logger(__name__)
21
22


23
@dataclass
24
25
class TransferMetadata:
    """Metadata describing a completed EPLB buffer transfer."""
26

27
28
29
30
    is_unchanged: np.ndarray
    """Mask of (num_local_experts,) indicating experts unchanged after rebalance."""
    is_received_locally: np.ndarray
    """Mask of (num_local_experts,) indicating experts received from local data."""
31
32
33
34
35
36
37
38
    recv_primary_mask: np.ndarray
    """Mask of (num_local_experts,) indicating primary experts received."""
    recv_count: int
    """Number of received experts for the layer."""
    recv_expert_ids: np.ndarray
    """Expert ids (num_local_experts,) of remote primary experts."""
    recv_dst_rows: np.ndarray
    """Target expert indices (num_local_experts,) in local tensors to send."""
39
40


41
42
43
44
45
46
47
48
49
50
51
52
53
@dataclass
class AsyncEplbLayerResult:
    """
    The result of one completed async EPLB layer transfer.
    """

    layer_idx: int
    """Index of the MoE layer that was transferred."""
    new_physical_to_logical_map: torch.Tensor
    """
    New physical→logical mapping for layers_idx, on CPU.
    Shape: (num_physical_experts)
    """
54
    transfer_metadata: TransferMetadata
55
56
57
58
59
60
61
62
63
64
    """Metadata describing what was received during transfer_layer."""
    consumed_event: CpuGpuEvent
    """
    Event used to synchronize access to the intermediate buffer. The async worker calls
    wait() after it finishes transferring weights to the intermediate buffer. The main
    thread calls record() after it finishes transferring weights out of the intermediate
    buffer in _move_to_workspace()
    """


65
66
def get_ep_ranks_with_experts_batch(
    expert_ids: np.ndarray,
67
    num_local_experts: int,
68
69
70
    old_indices: np.ndarray,
    new_indices: np.ndarray,
) -> tuple[dict[int, list[int]], dict[int, list[int]]]:
71
72
73
74
    """
    Get the ranks of the experts that need to be exchanged.

    Args:
75
        expert_ids: 1D array of expert indices to query.
76
77
78
79
80
        num_local_experts: The number of local experts.
        old_indices: The old indices of the experts.
        new_indices: The new indices of the experts.

    Returns:
81
82
83
        A tuple of two dictionaries mapping expert_id to:
        - ranks_to_send: The ranks that have this expert and need to send.
        - ranks_to_recv: The ranks that need to receive this expert.
84
    """
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    ranks_to_send_map: dict[int, list[int]] = {}
    ranks_to_recv_map: dict[int, list[int]] = {}

    # Fast path: if no experts, return empty dicts
    if expert_ids.size == 0:
        return ranks_to_send_map, ranks_to_recv_map

    unique_experts = np.unique(expert_ids)
    num_positions = len(old_indices)
    position_indices = np.arange(num_positions, dtype=np.int32)

    # Vectorized approach: find all positions matching any query expert in one pass
    # Use np.isin to get boolean masks for all relevant positions at once
    old_relevant_mask = np.isin(old_indices, unique_experts)
    new_relevant_mask = np.isin(new_indices, unique_experts)

    # Process old_indices (send ranks)
    if np.any(old_relevant_mask):
        old_relevant_positions = position_indices[old_relevant_mask]
        old_relevant_experts = old_indices[old_relevant_mask]
        old_relevant_ranks = old_relevant_positions // num_local_experts

        # Sort by expert first, then by position (to maintain first-appearance order)
        sort_order = np.lexsort((old_relevant_positions, old_relevant_experts))
        sorted_experts = old_relevant_experts[sort_order]
        sorted_ranks = old_relevant_ranks[sort_order]

        # Find boundaries where expert changes
        expert_boundaries = np.concatenate(
            [[0], np.where(np.diff(sorted_experts) != 0)[0] + 1, [len(sorted_experts)]]
        )
116

117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        # For each expert, extract unique ranks in order of first appearance
        for i in range(len(expert_boundaries) - 1):
            start, end = expert_boundaries[i], expert_boundaries[i + 1]
            expert = int(sorted_experts[start])
            expert_ranks = sorted_ranks[start:end]

            # Get unique ranks preserving order
            _, unique_idx = np.unique(expert_ranks, return_index=True)
            unique_ranks = expert_ranks[np.sort(unique_idx)]
            ranks_to_send_map[expert] = unique_ranks.tolist()

    # Process new_indices (recv ranks)
    if np.any(new_relevant_mask):
        new_relevant_positions = position_indices[new_relevant_mask]
        new_relevant_experts = new_indices[new_relevant_mask]
        new_relevant_ranks = new_relevant_positions // num_local_experts

        # Sort by expert first, then by position
        sort_order = np.lexsort((new_relevant_positions, new_relevant_experts))
        sorted_experts = new_relevant_experts[sort_order]
        sorted_ranks = new_relevant_ranks[sort_order]

        # Find boundaries where expert changes
        expert_boundaries = np.concatenate(
            [[0], np.where(np.diff(sorted_experts) != 0)[0] + 1, [len(sorted_experts)]]
        )
143

144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        # For each expert, extract unique ranks and exclude local copies
        for i in range(len(expert_boundaries) - 1):
            start, end = expert_boundaries[i], expert_boundaries[i + 1]
            expert = int(sorted_experts[start])
            expert_ranks = sorted_ranks[start:end]

            # Get unique ranks preserving order
            _, unique_idx = np.unique(expert_ranks, return_index=True)
            unique_ranks = expert_ranks[np.sort(unique_idx)]

            # Remove ranks that have local copies (in send map)
            send_ranks_set = set(ranks_to_send_map.get(expert, []))
            recv_ranks_actual = [
                int(r) for r in unique_ranks if r not in send_ranks_set
            ]
            ranks_to_recv_map[expert] = recv_ranks_actual

    # Handle experts that only appear in old (send only) or new (recv only)
    for expert in unique_experts:
        expert = int(expert)
        if expert not in ranks_to_send_map:
            ranks_to_send_map[expert] = []
        if expert not in ranks_to_recv_map:
            ranks_to_recv_map[expert] = []

    return ranks_to_send_map, ranks_to_recv_map
170
171


172
def move_to_buffer(
173
    num_local_experts: int,
174
175
    old_indices: np.ndarray,
    new_indices: np.ndarray,
176
    expert_weights: Sequence[torch.Tensor],
177
    expert_weights_buffers: Sequence[torch.Tensor],
178
    cuda_stream: torch.cuda.Stream | None,
179
180
    ep_rank: int,
    communicator: EplbCommunicator,
181
) -> TransferMetadata:
182
    """
183
184
185
186
187
188
189
190
191
192
193
    Rearranges expert weights during EPLB rebalancing.

    Args:
        num_local_experts: Number of local experts.
        old_indices: (num_experts_total,) ndarray of current (old)
            global-to-local expert assignments.
        new_indices: (num_experts_total,) ndarray of desired (new)
            global-to-local assignments after rebalance.
        expert_weights: Original expert weights for the layer.
        expert_weights_buffers: Intermediate buffers (one per tensor).
        cuda_stream: CUDA stream for async copies (can be None for sync mode).
194
195
        ep_rank: Rank of this process in expert parallel group.
        communicator: EplbCommunicator instance for P2P communication.
196
197

    Returns:
198
        TransferMetadata: Metadata needed for completing remote weight transfers.
199
    """
200
201
202
203
204
205
    assert old_indices.shape == new_indices.shape
    recv_primary_mask = np.zeros((num_local_experts,), dtype=np.bool_)
    send_expert_ids = np.full((num_local_experts,), -1, dtype=np.int64)
    send_src_rows = np.full((num_local_experts,), -1, dtype=np.int32)
    recv_expert_ids = np.full((num_local_experts,), -1, dtype=np.int64)
    recv_dst_rows = np.full((num_local_experts,), -1, dtype=np.int32)
206

207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    base = ep_rank * num_local_experts
    local_rows = np.arange(num_local_experts, dtype=np.int32)
    local_global = base + local_rows

    old_local_expert_ids = old_indices[local_global]
    new_local_expert_ids = new_indices[local_global]

    # Unchanged mask
    is_unchanged = old_local_expert_ids == new_local_expert_ids

    # Local receive eligibility
    new_valid = new_local_expert_ids != -1
    can_recv_local = np.isin(
        new_local_expert_ids, old_local_expert_ids, assume_unique=False
    )
    is_received_locally = np.logical_or(
        is_unchanged, np.logical_and(new_valid, can_recv_local)
    )

    # Send map: first src row per unique expert present locally in old mapping
    send_count = 0
    valid_old = old_local_expert_ids != -1
    if np.any(valid_old):
        uniq_experts, first_idx = np.unique(
            old_local_expert_ids[valid_old], return_index=True
        )
        filtered_rows = local_rows[valid_old]
        src_rows = filtered_rows[first_idx]
        send_count = int(uniq_experts.shape[0])
        send_expert_ids[:send_count] = uniq_experts
        send_src_rows[:send_count] = src_rows

    # Recv map: primary dst per unique expert needed remotely
    recv_count = 0
    need_recv_mask = np.logical_and(~is_received_locally, new_valid)
    if np.any(need_recv_mask):
        desired_experts = new_local_expert_ids[need_recv_mask]
        desired_dsts = local_rows[need_recv_mask]
        uniq_recv_experts, uniq_indices = np.unique(desired_experts, return_index=True)
        dst_rows = desired_dsts[uniq_indices]
        recv_count = int(uniq_recv_experts.shape[0])
        recv_expert_ids[:recv_count] = uniq_recv_experts
        recv_dst_rows[:recv_count] = dst_rows
        recv_primary_mask[dst_rows] = True

    eligible_local_buffer_mask = np.logical_and(~is_unchanged, is_received_locally)

    # 1. Local moves into tmp buffers
    if bool(eligible_local_buffer_mask.any()) and send_count > 0:
        dest_indices = np.nonzero(eligible_local_buffer_mask)[0].tolist()
        expert_to_src_map = dict(
            zip(send_expert_ids[:send_count], send_src_rows[:send_count])
        )
        for dst in dest_indices:
            expert = new_local_expert_ids[dst]
            src_local = expert_to_src_map.get(expert, -1)
            if src_local != -1:
264
265
266
                with torch.cuda.stream(cuda_stream):
                    for w, b in zip(expert_weights, expert_weights_buffers):
                        b[dst].copy_(w[src_local], non_blocking=True)
267
268
269
270
271
272
273
274
275
276
277

    # 2. Post sends
    if send_count > 0:
        experts = send_expert_ids[:send_count]
        srcs = send_src_rows[:send_count]
        order = np.argsort(experts, kind="stable")
        experts = experts[order]
        srcs = srcs[order]

        send_map, recv_map = get_ep_ranks_with_experts_batch(
            experts,
278
279
280
281
282
            num_local_experts,
            old_indices,
            new_indices,
        )

283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
        for expert, src in zip(experts.tolist(), srcs.tolist()):
            ranks_to_send = send_map[expert]
            ranks_to_recv = recv_map[expert]
            if not ranks_to_send or not ranks_to_recv:
                continue
            num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
            sender_pos = ranks_to_send.index(ep_rank)
            recv_begin = sender_pos * num_dst_per_sender
            recv_end = recv_begin + num_dst_per_sender
            recv_ranks = ranks_to_recv[recv_begin:recv_end]
            remainder_start = len(ranks_to_send) * num_dst_per_sender
            recver_pos = remainder_start + sender_pos
            if recver_pos < len(ranks_to_recv):
                recv_ranks.append(ranks_to_recv[recver_pos])
            for dst in recv_ranks:
298
299
                for w in expert_weights:
                    communicator.add_send(w[src], dst)
300
301
302
303
304
305
306
307
308
309
310

    # 3. Post recvs
    if recv_count > 0:
        experts = recv_expert_ids[:recv_count]
        dsts = recv_dst_rows[:recv_count]
        order = np.argsort(experts, kind="stable")
        experts = experts[order]
        dsts = dsts[order]

        send_map, recv_map = get_ep_ranks_with_experts_batch(
            experts,
311
312
313
314
315
            num_local_experts,
            old_indices,
            new_indices,
        )

316
317
318
319
320
321
322
323
324
325
326
327
        for expert, dst in zip(experts.tolist(), dsts.tolist()):
            ranks_to_send = send_map[expert]
            ranks_to_recv = recv_map[expert]
            if not ranks_to_send or not ranks_to_recv:
                continue
            num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
            recver_pos = ranks_to_recv.index(ep_rank)
            remainder_start = len(ranks_to_send) * num_dst_per_sender
            if recver_pos < remainder_start:
                src = ranks_to_send[recver_pos // num_dst_per_sender]
            else:
                src = ranks_to_send[recver_pos - remainder_start]
328
329
            for b in expert_weights_buffers:
                communicator.add_recv(b[dst], src)
330
331

    # 4. Execute the P2P operations. The real communication happens here.
332
    communicator.execute()
333
    # wait for the communication to finish
334
335
336
337
338
339
340
    return TransferMetadata(
        is_unchanged=is_unchanged,
        is_received_locally=is_received_locally,
        recv_primary_mask=recv_primary_mask,
        recv_count=recv_count,
        recv_expert_ids=recv_expert_ids,
        recv_dst_rows=recv_dst_rows,
341
    )
342
343
344


def move_from_buffer(
345
    expert_weights: Sequence[torch.Tensor],
346
    expert_weights_buffers: list[torch.Tensor],
347
    transfer_metadata: TransferMetadata,
348
349
    new_indices: np.ndarray,
    ep_rank: int,
350
) -> None:
351
352
353
354
355
356
357
358
    """
    Copies expert weights from communication buffers back to the target weight tensors
    after EPLB rebalancing.

    Args:
        expert_weights: List of the actual MoE layer weights used in the execution.
        expert_weights_buffers: Intermediate buffers containing the experts weights
            after the transfer is completed.
359
        transfer_metadata: TransferMetadata containing transfer metadata.
360
361
362
363
        new_indices: (num_experts_total,) mapping from local rows to desired
            (possibly global) expert id, after rebalance.
        ep_rank: Rank of the process in the expert parallel group.
    """
364
365
366
367
368
369
    is_unchanged = transfer_metadata.is_unchanged
    is_received_locally = transfer_metadata.is_received_locally
    recv_primary_mask = transfer_metadata.recv_primary_mask
    recv_count = transfer_metadata.recv_count
    recv_expert_ids = transfer_metadata.recv_expert_ids
    recv_dst_rows = transfer_metadata.recv_dst_rows
370
371
372
373
374
375
376
377
378
379
380
381
382
383
    num_local_experts = is_unchanged.shape[0]

    # Mask for rows to copy back from buffers:
    # copy if locally received OR remote primary recv
    copy_mask = np.logical_or(is_received_locally, recv_primary_mask)
    dest_mask_np = np.logical_and(~is_unchanged, copy_mask)
    if bool(dest_mask_np.any()):
        dest_indices = np.nonzero(dest_mask_np)[0].tolist()
        for dst in dest_indices:
            for w, b in zip(expert_weights, expert_weights_buffers):
                w[dst].copy_(b[dst], non_blocking=True)

    if recv_count == 0:
        return
384

385
386
387
388
389
390
    # Duplicate remote received rows to non-primary duplicate dsts
    base = ep_rank * num_local_experts
    local_experts = new_indices[base + np.arange(num_local_experts, dtype=np.int32)]
    duplicate_mask = np.logical_and(
        np.logical_and(~is_unchanged, ~is_received_locally),
        np.logical_and(~recv_primary_mask, local_experts != -1),
391
    )
392
393
394
    # All received experts are unique in the destination, so no need to copy duplicates
    if not bool(duplicate_mask.any()):
        return
395

396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
    dup_dst_rows = np.nonzero(duplicate_mask)[0]
    dup_experts = local_experts[dup_dst_rows]

    prim_experts = recv_expert_ids[:recv_count]
    prim_dsts = recv_dst_rows[:recv_count]
    order = np.argsort(prim_experts, kind="stable")
    prim_experts_sorted = prim_experts[order]
    prim_dsts_sorted = prim_dsts[order]
    pos = np.searchsorted(prim_experts_sorted, dup_experts)
    valid = np.logical_and(
        pos < prim_experts_sorted.shape[0],
        prim_experts_sorted[np.minimum(pos, prim_experts_sorted.shape[0] - 1)]
        == dup_experts,
    )
    if not bool(valid.any()):
        return

    matched_dst_rows = dup_dst_rows[valid]
    matched_src_rows = prim_dsts_sorted[pos[valid]]

    for dst, src in zip(matched_dst_rows.tolist(), matched_src_rows.tolist()):
        for w in expert_weights:
            w[dst].copy_(w[src], non_blocking=True)
419
420


421
def transfer_layer(
422
423
424
    old_layer_indices: torch.Tensor,
    new_layer_indices: torch.Tensor,
    expert_weights: Sequence[torch.Tensor],
425
426
    expert_weights_buffer: Sequence[torch.Tensor],
    ep_group: ProcessGroup,
427
    communicator: EplbCommunicator,
428
429
430
    is_profile: bool = False,
    cuda_stream: torch.cuda.Stream | None = None,
    rank_mapping: dict[int, int] | None = None,
431
) -> TransferMetadata:
432
433
434
435
436
437
438
    """
    Rearranges the expert weights in place according to the new expert indices.

    The value of the indices arguments are logical indices of the experts,
    while keys are physical.

    Args:
439
440
441
442
443
444
        old_layer_indices: Shape (num_physical_experts,).
        new_layer_indices: Shape (num_physical_experts,).
        expert_weights: Iterable of weight tensors for this layer, each with shape
            (num_local_physical_experts, hidden_size_i).
            For example, a linear layer may have up and down projection.
        expert_weights_buffer: Intermediate buffers (one per weight tensor).
445
        ep_group: The device process group for expert parallelism.
446
        communicator: EplbCommunicator instance for P2P communication.
447
448
449
        is_profile (bool): If `True`, do not perform any actual weight copy.
            This is used during profile run, where we only perform dummy
            communications to reserve enough memory for the buffers.
450
451
        cuda_stream: CUDA stream for async copies (can be None for sync mode).
        rank_mapping: Optional rank mapping for elastic expert parallelism.
452
453

    Returns:
454
455
        TransferMetadata: Metadata needed for completing remote weight transfers,
            including is_unchanged and is_received_locally masks.
456
457
458
    """
    ep_size = ep_group.size()
    if rank_mapping is not None:
459
460
461
462
        # Add a layer dimension for compatibility with mapping functions
        old_layer_indices_2d = old_layer_indices.unsqueeze(0)
        new_layer_indices_2d = new_layer_indices.unsqueeze(0)

463
464
        if len(rank_mapping) == ep_group.size():
            # scale down
465
466
            new_layer_indices_2d = _map_new_expert_indices_with_rank_mapping(
                new_layer_indices_2d,
467
468
469
470
                rank_mapping,
            )
        else:
            # scale up
471
472
            old_layer_indices_2d = _map_old_expert_indices_with_rank_mapping(
                old_layer_indices_2d,
473
474
475
476
                rank_mapping,
                ep_group.size(),
            )

477
478
479
480
481
482
        # Remove the layer dimension
        old_layer_indices = old_layer_indices_2d.squeeze(0)
        new_layer_indices = new_layer_indices_2d.squeeze(0)

    assert old_layer_indices.shape == new_layer_indices.shape
    num_physical_experts = old_layer_indices.shape[0]
483
    assert len(expert_weights[0]) >= 1
484
    num_local_physical_experts = expert_weights[0].shape[0]
485
486
    assert num_physical_experts == ep_size * num_local_physical_experts

487
488
    old_layer_indices_np = old_layer_indices.cpu().numpy()
    new_layer_indices_np = new_layer_indices.cpu().numpy()
489

490
    return move_to_buffer(
491
        num_local_experts=num_local_physical_experts,
492
493
494
        old_indices=old_layer_indices_np,
        new_indices=new_layer_indices_np,
        expert_weights=expert_weights,
495
        expert_weights_buffers=expert_weights_buffer,
496
        cuda_stream=cuda_stream,
497
498
        ep_rank=ep_group.rank(),
        communicator=communicator,
499
    )
500
501
502
503
504


def rearrange_expert_weights_inplace(
    old_global_expert_indices: torch.Tensor,
    new_global_expert_indices: torch.Tensor,
505
    expert_weights: Sequence[Sequence[torch.Tensor]],
506
    ep_group: ProcessGroup,
507
    communicator: EplbCommunicator,
508
    is_profile: bool = False,
509
    rank_mapping: dict[int, int] | None = None,
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
) -> None:
    """
    Rearranges the expert weights in place according to the new expert indices.

    The value of the indices arguments are logical indices of the experts,
    while keys are physical.

    Args:
        old_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
        new_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
        expert_weights: A sequence of shape (num_moe_layers)(weight_count)
            of tensors of shape (num_local_physical_experts, hidden_size_i).
            For example, a linear layer may have up and down projection,
            so weight_count = 2. Each weight's hidden size can be different.
        ep_group: The device process group for expert parallelism.
525
        communicator: EplbCommunicator instance for P2P communication.
526
527
528
        is_profile (bool): If `True`, do not perform any actual weight copy.
            This is used during profile run, where we only perform dummy
            communications to reserve enough memory for the buffers.
529
        rank_mapping: A dictionary mapping old rank to new rank.
530
    """
531
532
533
    if rank_mapping is not None:
        if len(rank_mapping) == ep_group.size():
            # scale down
534
            new_global_expert_indices = _map_new_expert_indices_with_rank_mapping(
535
536
537
538
539
                new_global_expert_indices,
                rank_mapping,
            )
        else:
            # scale up
540
            old_global_expert_indices = _map_old_expert_indices_with_rank_mapping(
541
542
543
544
545
                old_global_expert_indices,
                rank_mapping,
                ep_group.size(),
            )

546
    assert old_global_expert_indices.shape[1] == new_global_expert_indices.shape[1]
547

548
549
    num_moe_layers, num_physical_experts = old_global_expert_indices.shape
    assert len(expert_weights) == num_moe_layers
550
    assert len(expert_weights[0]) >= 1
551

552
    num_local_physical_experts = expert_weights[0][0].shape[0]
553
    assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
554
555

    ep_size = ep_group.size()
556
    ep_rank = ep_group.rank()
557
558
    assert num_physical_experts == ep_size * num_local_physical_experts

559
    first_layer_weights = list(expert_weights[0])
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578

    if is_profile:
        if communicator.needs_profile_buffer_reservation:
            # Reserve NCCL communication buffers via a dummy all_gather.
            # Backends that pre-allocate their own transfer buffers
            # skip this to avoid the extra memory spike during profiling.
            weights_buffer: list[torch.Tensor] = [
                torch.empty_like(w) for w in first_layer_weights
            ]
            for weight, buffer in zip(expert_weights[0], weights_buffer):
                dummy_recv_buffer = [buffer for _ in range(ep_size)]
                torch.distributed.barrier()
                all_gather(
                    dummy_recv_buffer,
                    weight,
                    group=ep_group,
                )
        return

579
    # Buffers to hold the expert weights during the exchange.
580
581
    # NOTE: Currently we assume the same weights across different layers
    # have the same shape.
582
    weights_buffer = [torch.empty_like(w) for w in first_layer_weights]
583

584
585
    # NOTE(bowen): We need this synchronize to run, but I don't know why.
    # If you figure out the reason, please let me know -- thank you!
586
    torch.accelerator.synchronize()
587

588
589
590
591
    old_global_expert_indices_cpu = old_global_expert_indices.cpu().numpy()
    new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy()

    for layer_idx in range(num_moe_layers):
592
        transfer_metadata = move_to_buffer(
593
            num_local_experts=num_local_physical_experts,
594
595
596
597
            old_indices=old_global_expert_indices_cpu[layer_idx],
            new_indices=new_global_expert_indices_cpu[layer_idx],
            expert_weights=expert_weights[layer_idx],
            expert_weights_buffers=weights_buffer,
598
            cuda_stream=None,
599
600
            ep_rank=ep_rank,
            communicator=communicator,
601
602
603
        )

        move_from_buffer(
604
605
            expert_weights=expert_weights[layer_idx],
            expert_weights_buffers=weights_buffer,
606
            transfer_metadata=transfer_metadata,
607
            new_indices=new_global_expert_indices_cpu[layer_idx],
608
            ep_rank=ep_rank,
609
610
611
        )


612
613
614
615
616
617
618
def _map_old_expert_indices_with_rank_mapping(
    old_global_expert_indices: torch.Tensor,
    rank_mapping: dict[int, int],
    new_ep_size: int,
) -> torch.Tensor:
    """
    Map the old global expert indices to the new global expert indices.
619

620
621
622
623
624
    Args:
        old_global_expert_indices:
            Shape (num_layers, old_ep_size * num_local_physical_experts).
        rank_mapping: Mapping from old rank to new rank.
        new_ep_size: New expert parallelism size.
625

626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
    Returns:
        Mapped expert indices with shape
        (num_layers, new_ep_size * num_local_physical_experts).
    """
    num_layers, old_num_physical_experts = old_global_expert_indices.shape
    assert rank_mapping, "Rank mapping is required"

    # Get sizes from parameters and rank_mapping
    old_ep_size = len(rank_mapping)
    num_local_physical_experts = old_num_physical_experts // old_ep_size
    new_num_physical_experts = new_ep_size * num_local_physical_experts

    # Create mapped tensor with new shape, initialized to -1
    mapped_expert_indices = torch.full(
        (num_layers, new_num_physical_experts),
        fill_value=-1,
        dtype=old_global_expert_indices.dtype,
        device=old_global_expert_indices.device,
    )

    # Handle rank mapping (scale up/down with rank changes)
    for old_rank in range(old_ep_size):
        new_rank = rank_mapping.get(old_rank)
        if new_rank is not None and new_rank >= 0 and new_rank < new_ep_size:
            # This old rank exists in the new configuration
            old_start_idx = old_rank * num_local_physical_experts
            old_end_idx = (old_rank + 1) * num_local_physical_experts
            new_start_idx = new_rank * num_local_physical_experts
            new_end_idx = (new_rank + 1) * num_local_physical_experts

656
            mapped_expert_indices[:, new_start_idx:new_end_idx] = (
657
                old_global_expert_indices[:, old_start_idx:old_end_idx]
658
            )
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
        # If new_rank is None or >= new_ep_size, the experts remain -1
        # (scale down case)

    return mapped_expert_indices


def _map_new_expert_indices_with_rank_mapping(
    new_global_expert_indices: torch.Tensor,
    rank_mapping: dict[int, int],
) -> torch.Tensor:
    num_layers, new_num_physical_experts = new_global_expert_indices.shape
    assert rank_mapping, "Rank mapping is required"

    # Get sizes from parameters and rank_mapping
    old_ep_size = len(rank_mapping)
    new_ep_size = sum(new_rank != -1 for new_rank in rank_mapping.values())
    num_local_physical_experts = new_num_physical_experts // new_ep_size
    old_num_physical_experts = old_ep_size * num_local_physical_experts

    mapped_expert_indices = torch.full(
        (num_layers, old_num_physical_experts),
        fill_value=-1,
        dtype=new_global_expert_indices.dtype,
        device=new_global_expert_indices.device,
    )

    for old_rank in range(old_ep_size):
        new_rank = rank_mapping[old_rank]
        if new_rank >= 0 and new_rank < new_ep_size:
            old_start_idx = old_rank * num_local_physical_experts
            old_end_idx = (old_rank + 1) * num_local_physical_experts
            new_start_idx = new_rank * num_local_physical_experts
            new_end_idx = (new_rank + 1) * num_local_physical_experts

693
            mapped_expert_indices[:, old_start_idx:old_end_idx] = (
694
                new_global_expert_indices[:, new_start_idx:new_end_idx]
695
            )
696
697
698
699

    return mapped_expert_indices


700
__all__ = ["transfer_layer", "move_from_buffer", "TransferMetadata"]