rebalance_execute.py 28.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
"""
The actual execution of the rearrangement.

This involves the exchange of expert weights between GPUs.
"""

9
from collections.abc import Sequence
10
from dataclasses import dataclass
11

12
import numpy as np
13
import torch
14
from torch.distributed import ProcessGroup, all_gather
15

16
17
18
19
20
from vllm.distributed.eplb.eplb_communicator import EplbCommunicator
from vllm.distributed.eplb.eplb_utils import CpuGpuEvent
from vllm.logger import init_logger

logger = init_logger(__name__)
21
22


23
24
25
@dataclass
class RecvMetadata:
    """Metadata describing remote receives during EPLB rebalancing."""
26

27
28
29
30
31
32
33
34
    recv_primary_mask: np.ndarray
    """Mask of (num_local_experts,) indicating primary experts received."""
    recv_count: int
    """Number of received experts for the layer."""
    recv_expert_ids: np.ndarray
    """Expert ids (num_local_experts,) of remote primary experts."""
    recv_dst_rows: np.ndarray
    """Target expert indices (num_local_experts,) in local tensors to send."""
35
36


37
38
# Type alias for the result of move_to_buffer or transfer_layer
MoveToBufferResult = tuple[np.ndarray, np.ndarray, RecvMetadata]
39

40

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
@dataclass
class AsyncEplbLayerResult:
    """
    The result of one completed async EPLB layer transfer.
    """

    layer_idx: int
    """Index of the MoE layer that was transferred."""
    new_physical_to_logical_map: torch.Tensor
    """
    New physical→logical mapping for layers_idx, on CPU.
    Shape: (num_physical_experts)
    """
    is_unchanged: np.ndarray
    """Per-physical-expert flag: weight was not moved during transfer."""
    is_received_locally: np.ndarray
    """Per-physical-expert flag: weight was received on this rank."""
    recv_metadata: RecvMetadata
    """Metadata describing what was received during transfer_layer."""
    consumed_event: CpuGpuEvent
    """
    Event used to synchronize access to the intermediate buffer. The async worker calls
    wait() after it finishes transferring weights to the intermediate buffer. The main
    thread calls record() after it finishes transferring weights out of the intermediate
    buffer in _move_to_workspace()
    """


69
70
def get_ep_ranks_with_experts_batch(
    expert_ids: np.ndarray,
71
    num_local_experts: int,
72
73
74
    old_indices: np.ndarray,
    new_indices: np.ndarray,
) -> tuple[dict[int, list[int]], dict[int, list[int]]]:
75
76
77
78
    """
    Get the ranks of the experts that need to be exchanged.

    Args:
79
        expert_ids: 1D array of expert indices to query.
80
81
82
83
84
        num_local_experts: The number of local experts.
        old_indices: The old indices of the experts.
        new_indices: The new indices of the experts.

    Returns:
85
86
87
        A tuple of two dictionaries mapping expert_id to:
        - ranks_to_send: The ranks that have this expert and need to send.
        - ranks_to_recv: The ranks that need to receive this expert.
88
    """
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    ranks_to_send_map: dict[int, list[int]] = {}
    ranks_to_recv_map: dict[int, list[int]] = {}

    # Fast path: if no experts, return empty dicts
    if expert_ids.size == 0:
        return ranks_to_send_map, ranks_to_recv_map

    unique_experts = np.unique(expert_ids)
    num_positions = len(old_indices)
    position_indices = np.arange(num_positions, dtype=np.int32)

    # Vectorized approach: find all positions matching any query expert in one pass
    # Use np.isin to get boolean masks for all relevant positions at once
    old_relevant_mask = np.isin(old_indices, unique_experts)
    new_relevant_mask = np.isin(new_indices, unique_experts)

    # Process old_indices (send ranks)
    if np.any(old_relevant_mask):
        old_relevant_positions = position_indices[old_relevant_mask]
        old_relevant_experts = old_indices[old_relevant_mask]
        old_relevant_ranks = old_relevant_positions // num_local_experts

        # Sort by expert first, then by position (to maintain first-appearance order)
        sort_order = np.lexsort((old_relevant_positions, old_relevant_experts))
        sorted_experts = old_relevant_experts[sort_order]
        sorted_ranks = old_relevant_ranks[sort_order]

        # Find boundaries where expert changes
        expert_boundaries = np.concatenate(
            [[0], np.where(np.diff(sorted_experts) != 0)[0] + 1, [len(sorted_experts)]]
        )
120

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        # For each expert, extract unique ranks in order of first appearance
        for i in range(len(expert_boundaries) - 1):
            start, end = expert_boundaries[i], expert_boundaries[i + 1]
            expert = int(sorted_experts[start])
            expert_ranks = sorted_ranks[start:end]

            # Get unique ranks preserving order
            _, unique_idx = np.unique(expert_ranks, return_index=True)
            unique_ranks = expert_ranks[np.sort(unique_idx)]
            ranks_to_send_map[expert] = unique_ranks.tolist()

    # Process new_indices (recv ranks)
    if np.any(new_relevant_mask):
        new_relevant_positions = position_indices[new_relevant_mask]
        new_relevant_experts = new_indices[new_relevant_mask]
        new_relevant_ranks = new_relevant_positions // num_local_experts

        # Sort by expert first, then by position
        sort_order = np.lexsort((new_relevant_positions, new_relevant_experts))
        sorted_experts = new_relevant_experts[sort_order]
        sorted_ranks = new_relevant_ranks[sort_order]

        # Find boundaries where expert changes
        expert_boundaries = np.concatenate(
            [[0], np.where(np.diff(sorted_experts) != 0)[0] + 1, [len(sorted_experts)]]
        )
147

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
        # For each expert, extract unique ranks and exclude local copies
        for i in range(len(expert_boundaries) - 1):
            start, end = expert_boundaries[i], expert_boundaries[i + 1]
            expert = int(sorted_experts[start])
            expert_ranks = sorted_ranks[start:end]

            # Get unique ranks preserving order
            _, unique_idx = np.unique(expert_ranks, return_index=True)
            unique_ranks = expert_ranks[np.sort(unique_idx)]

            # Remove ranks that have local copies (in send map)
            send_ranks_set = set(ranks_to_send_map.get(expert, []))
            recv_ranks_actual = [
                int(r) for r in unique_ranks if r not in send_ranks_set
            ]
            ranks_to_recv_map[expert] = recv_ranks_actual

    # Handle experts that only appear in old (send only) or new (recv only)
    for expert in unique_experts:
        expert = int(expert)
        if expert not in ranks_to_send_map:
            ranks_to_send_map[expert] = []
        if expert not in ranks_to_recv_map:
            ranks_to_recv_map[expert] = []

    return ranks_to_send_map, ranks_to_recv_map
174
175


176
def move_to_buffer(
177
    num_local_experts: int,
178
179
    old_indices: np.ndarray,
    new_indices: np.ndarray,
180
    expert_weights: Sequence[torch.Tensor],
181
    expert_weights_buffers: Sequence[torch.Tensor],
182
    cuda_stream: torch.cuda.Stream | None,
183
184
    ep_rank: int,
    communicator: EplbCommunicator,
185
) -> MoveToBufferResult:
186
    """
187
188
189
190
191
192
193
194
195
196
197
    Rearranges expert weights during EPLB rebalancing.

    Args:
        num_local_experts: Number of local experts.
        old_indices: (num_experts_total,) ndarray of current (old)
            global-to-local expert assignments.
        new_indices: (num_experts_total,) ndarray of desired (new)
            global-to-local assignments after rebalance.
        expert_weights: Original expert weights for the layer.
        expert_weights_buffers: Intermediate buffers (one per tensor).
        cuda_stream: CUDA stream for async copies (can be None for sync mode).
198
199
        ep_rank: Rank of this process in expert parallel group.
        communicator: EplbCommunicator instance for P2P communication.
200
201
202
203
204
205
206

    Returns:
        is_unchanged (np.ndarray): (num_local_experts,), True where an expert row
            is unchanged after rebalance.
        is_received_locally (np.ndarray): (num_local_experts,), True where a row
            can be updated from local data.
        RecvMetadata: Metadata needed for completing remote weight transfers.
207
    """
208
209
210
211
212
213
    assert old_indices.shape == new_indices.shape
    recv_primary_mask = np.zeros((num_local_experts,), dtype=np.bool_)
    send_expert_ids = np.full((num_local_experts,), -1, dtype=np.int64)
    send_src_rows = np.full((num_local_experts,), -1, dtype=np.int32)
    recv_expert_ids = np.full((num_local_experts,), -1, dtype=np.int64)
    recv_dst_rows = np.full((num_local_experts,), -1, dtype=np.int32)
214

215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
    base = ep_rank * num_local_experts
    local_rows = np.arange(num_local_experts, dtype=np.int32)
    local_global = base + local_rows

    old_local_expert_ids = old_indices[local_global]
    new_local_expert_ids = new_indices[local_global]

    # Unchanged mask
    is_unchanged = old_local_expert_ids == new_local_expert_ids

    # Local receive eligibility
    new_valid = new_local_expert_ids != -1
    can_recv_local = np.isin(
        new_local_expert_ids, old_local_expert_ids, assume_unique=False
    )
    is_received_locally = np.logical_or(
        is_unchanged, np.logical_and(new_valid, can_recv_local)
    )

    # Send map: first src row per unique expert present locally in old mapping
    send_count = 0
    valid_old = old_local_expert_ids != -1
    if np.any(valid_old):
        uniq_experts, first_idx = np.unique(
            old_local_expert_ids[valid_old], return_index=True
        )
        filtered_rows = local_rows[valid_old]
        src_rows = filtered_rows[first_idx]
        send_count = int(uniq_experts.shape[0])
        send_expert_ids[:send_count] = uniq_experts
        send_src_rows[:send_count] = src_rows

    # Recv map: primary dst per unique expert needed remotely
    recv_count = 0
    need_recv_mask = np.logical_and(~is_received_locally, new_valid)
    if np.any(need_recv_mask):
        desired_experts = new_local_expert_ids[need_recv_mask]
        desired_dsts = local_rows[need_recv_mask]
        uniq_recv_experts, uniq_indices = np.unique(desired_experts, return_index=True)
        dst_rows = desired_dsts[uniq_indices]
        recv_count = int(uniq_recv_experts.shape[0])
        recv_expert_ids[:recv_count] = uniq_recv_experts
        recv_dst_rows[:recv_count] = dst_rows
        recv_primary_mask[dst_rows] = True

    eligible_local_buffer_mask = np.logical_and(~is_unchanged, is_received_locally)

    # 1. Local moves into tmp buffers
    if bool(eligible_local_buffer_mask.any()) and send_count > 0:
        dest_indices = np.nonzero(eligible_local_buffer_mask)[0].tolist()
        expert_to_src_map = dict(
            zip(send_expert_ids[:send_count], send_src_rows[:send_count])
        )
        for dst in dest_indices:
            expert = new_local_expert_ids[dst]
            src_local = expert_to_src_map.get(expert, -1)
            if src_local != -1:
272
273
274
                with torch.cuda.stream(cuda_stream):
                    for w, b in zip(expert_weights, expert_weights_buffers):
                        b[dst].copy_(w[src_local], non_blocking=True)
275
276
277
278
279
280
281
282
283
284
285

    # 2. Post sends
    if send_count > 0:
        experts = send_expert_ids[:send_count]
        srcs = send_src_rows[:send_count]
        order = np.argsort(experts, kind="stable")
        experts = experts[order]
        srcs = srcs[order]

        send_map, recv_map = get_ep_ranks_with_experts_batch(
            experts,
286
287
288
289
290
            num_local_experts,
            old_indices,
            new_indices,
        )

291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
        for expert, src in zip(experts.tolist(), srcs.tolist()):
            ranks_to_send = send_map[expert]
            ranks_to_recv = recv_map[expert]
            if not ranks_to_send or not ranks_to_recv:
                continue
            num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
            sender_pos = ranks_to_send.index(ep_rank)
            recv_begin = sender_pos * num_dst_per_sender
            recv_end = recv_begin + num_dst_per_sender
            recv_ranks = ranks_to_recv[recv_begin:recv_end]
            remainder_start = len(ranks_to_send) * num_dst_per_sender
            recver_pos = remainder_start + sender_pos
            if recver_pos < len(ranks_to_recv):
                recv_ranks.append(ranks_to_recv[recver_pos])
            for dst in recv_ranks:
306
307
                for w in expert_weights:
                    communicator.add_send(w[src], dst)
308
309
310
311
312
313
314
315
316
317
318

    # 3. Post recvs
    if recv_count > 0:
        experts = recv_expert_ids[:recv_count]
        dsts = recv_dst_rows[:recv_count]
        order = np.argsort(experts, kind="stable")
        experts = experts[order]
        dsts = dsts[order]

        send_map, recv_map = get_ep_ranks_with_experts_batch(
            experts,
319
320
321
322
323
            num_local_experts,
            old_indices,
            new_indices,
        )

324
325
326
327
328
329
330
331
332
333
334
335
        for expert, dst in zip(experts.tolist(), dsts.tolist()):
            ranks_to_send = send_map[expert]
            ranks_to_recv = recv_map[expert]
            if not ranks_to_send or not ranks_to_recv:
                continue
            num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
            recver_pos = ranks_to_recv.index(ep_rank)
            remainder_start = len(ranks_to_send) * num_dst_per_sender
            if recver_pos < remainder_start:
                src = ranks_to_send[recver_pos // num_dst_per_sender]
            else:
                src = ranks_to_send[recver_pos - remainder_start]
336
337
            for b in expert_weights_buffers:
                communicator.add_recv(b[dst], src)
338
339

    # 4. Execute the P2P operations. The real communication happens here.
340
    communicator.execute()
341
    # wait for the communication to finish
342
343
344
345
346
347
348
349
350
351
    return (
        is_unchanged,
        is_received_locally,
        RecvMetadata(
            recv_primary_mask=recv_primary_mask,
            recv_count=recv_count,
            recv_expert_ids=recv_expert_ids,
            recv_dst_rows=recv_dst_rows,
        ),
    )
352
353
354


def move_from_buffer(
355
    expert_weights: Sequence[torch.Tensor],
356
357
358
359
360
361
    expert_weights_buffers: list[torch.Tensor],
    is_unchanged: np.ndarray,
    is_received_locally: np.ndarray,
    recv_metadata: RecvMetadata,
    new_indices: np.ndarray,
    ep_rank: int,
362
) -> None:
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    """
    Copies expert weights from communication buffers back to the target weight tensors
    after EPLB rebalancing.

    Args:
        expert_weights: List of the actual MoE layer weights used in the execution.
        expert_weights_buffers: Intermediate buffers containing the experts weights
            after the transfer is completed.
        is_unchanged: (num_local_experts,), True where an expert row is unchanged.
        is_received_locally: (num_local_experts,), True where a row is updated locally.
        recv_metadata: RecvMetadata containing remote receive metadata.
        new_indices: (num_experts_total,) mapping from local rows to desired
            (possibly global) expert id, after rebalance.
        ep_rank: Rank of the process in the expert parallel group.
    """
    recv_primary_mask = recv_metadata.recv_primary_mask
    recv_count = recv_metadata.recv_count
    recv_expert_ids = recv_metadata.recv_expert_ids
    recv_dst_rows = recv_metadata.recv_dst_rows
    num_local_experts = is_unchanged.shape[0]

    # Mask for rows to copy back from buffers:
    # copy if locally received OR remote primary recv
    copy_mask = np.logical_or(is_received_locally, recv_primary_mask)
    dest_mask_np = np.logical_and(~is_unchanged, copy_mask)
    if bool(dest_mask_np.any()):
        dest_indices = np.nonzero(dest_mask_np)[0].tolist()
        for dst in dest_indices:
            for w, b in zip(expert_weights, expert_weights_buffers):
                w[dst].copy_(b[dst], non_blocking=True)

    if recv_count == 0:
        return
396

397
398
399
400
401
402
    # Duplicate remote received rows to non-primary duplicate dsts
    base = ep_rank * num_local_experts
    local_experts = new_indices[base + np.arange(num_local_experts, dtype=np.int32)]
    duplicate_mask = np.logical_and(
        np.logical_and(~is_unchanged, ~is_received_locally),
        np.logical_and(~recv_primary_mask, local_experts != -1),
403
    )
404
405
406
    # All received experts are unique in the destination, so no need to copy duplicates
    if not bool(duplicate_mask.any()):
        return
407

408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
    dup_dst_rows = np.nonzero(duplicate_mask)[0]
    dup_experts = local_experts[dup_dst_rows]

    prim_experts = recv_expert_ids[:recv_count]
    prim_dsts = recv_dst_rows[:recv_count]
    order = np.argsort(prim_experts, kind="stable")
    prim_experts_sorted = prim_experts[order]
    prim_dsts_sorted = prim_dsts[order]
    pos = np.searchsorted(prim_experts_sorted, dup_experts)
    valid = np.logical_and(
        pos < prim_experts_sorted.shape[0],
        prim_experts_sorted[np.minimum(pos, prim_experts_sorted.shape[0] - 1)]
        == dup_experts,
    )
    if not bool(valid.any()):
        return

    matched_dst_rows = dup_dst_rows[valid]
    matched_src_rows = prim_dsts_sorted[pos[valid]]

    for dst, src in zip(matched_dst_rows.tolist(), matched_src_rows.tolist()):
        for w in expert_weights:
            w[dst].copy_(w[src], non_blocking=True)
431
432
433


async def transfer_layer(
434
435
436
    old_layer_indices: torch.Tensor,
    new_layer_indices: torch.Tensor,
    expert_weights: Sequence[torch.Tensor],
437
438
    expert_weights_buffer: Sequence[torch.Tensor],
    ep_group: ProcessGroup,
439
    communicator: EplbCommunicator,
440
441
442
    is_profile: bool = False,
    cuda_stream: torch.cuda.Stream | None = None,
    rank_mapping: dict[int, int] | None = None,
443
) -> MoveToBufferResult:
444
445
446
447
448
449
450
    """
    Rearranges the expert weights in place according to the new expert indices.

    The value of the indices arguments are logical indices of the experts,
    while keys are physical.

    Args:
451
452
453
454
455
456
        old_layer_indices: Shape (num_physical_experts,).
        new_layer_indices: Shape (num_physical_experts,).
        expert_weights: Iterable of weight tensors for this layer, each with shape
            (num_local_physical_experts, hidden_size_i).
            For example, a linear layer may have up and down projection.
        expert_weights_buffer: Intermediate buffers (one per weight tensor).
457
        ep_group: The device process group for expert parallelism.
458
        communicator: EplbCommunicator instance for P2P communication.
459
460
461
        is_profile (bool): If `True`, do not perform any actual weight copy.
            This is used during profile run, where we only perform dummy
            communications to reserve enough memory for the buffers.
462
463
        cuda_stream: CUDA stream for async copies (can be None for sync mode).
        rank_mapping: Optional rank mapping for elastic expert parallelism.
464
465

    Returns:
466
        is_unchanged (np.ndarray): (num_local_experts,), True where expert
467
            is left unchanged.
468
        is_received_locally (np.ndarray): (num_local_experts,), True where expert
469
470
            can be received locally.
        RecvMetadata: Metadata needed for completing remote weight transfers.
471
472
473
    """
    ep_size = ep_group.size()
    if rank_mapping is not None:
474
475
476
477
        # Add a layer dimension for compatibility with mapping functions
        old_layer_indices_2d = old_layer_indices.unsqueeze(0)
        new_layer_indices_2d = new_layer_indices.unsqueeze(0)

478
479
        if len(rank_mapping) == ep_group.size():
            # scale down
480
481
            new_layer_indices_2d = _map_new_expert_indices_with_rank_mapping(
                new_layer_indices_2d,
482
483
484
485
                rank_mapping,
            )
        else:
            # scale up
486
487
            old_layer_indices_2d = _map_old_expert_indices_with_rank_mapping(
                old_layer_indices_2d,
488
489
490
491
                rank_mapping,
                ep_group.size(),
            )

492
493
494
495
496
497
        # Remove the layer dimension
        old_layer_indices = old_layer_indices_2d.squeeze(0)
        new_layer_indices = new_layer_indices_2d.squeeze(0)

    assert old_layer_indices.shape == new_layer_indices.shape
    num_physical_experts = old_layer_indices.shape[0]
498
    assert len(expert_weights[0]) >= 1
499
    num_local_physical_experts = expert_weights[0].shape[0]
500
501
    assert num_physical_experts == ep_size * num_local_physical_experts

502
503
    old_layer_indices_np = old_layer_indices.cpu().numpy()
    new_layer_indices_np = new_layer_indices.cpu().numpy()
504
505

    is_unchanged, is_received_locally, recv_metadata = move_to_buffer(
506
        num_local_experts=num_local_physical_experts,
507
508
509
        old_indices=old_layer_indices_np,
        new_indices=new_layer_indices_np,
        expert_weights=expert_weights,
510
        expert_weights_buffers=expert_weights_buffer,
511
        cuda_stream=cuda_stream,
512
513
        ep_rank=ep_group.rank(),
        communicator=communicator,
514
    )
515
    return is_unchanged, is_received_locally, recv_metadata
516
517
518
519
520


def rearrange_expert_weights_inplace(
    old_global_expert_indices: torch.Tensor,
    new_global_expert_indices: torch.Tensor,
521
    expert_weights: Sequence[Sequence[torch.Tensor]],
522
    ep_group: ProcessGroup,
523
    communicator: EplbCommunicator,
524
    is_profile: bool = False,
525
    rank_mapping: dict[int, int] | None = None,
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
) -> None:
    """
    Rearranges the expert weights in place according to the new expert indices.

    The value of the indices arguments are logical indices of the experts,
    while keys are physical.

    Args:
        old_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
        new_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
        expert_weights: A sequence of shape (num_moe_layers)(weight_count)
            of tensors of shape (num_local_physical_experts, hidden_size_i).
            For example, a linear layer may have up and down projection,
            so weight_count = 2. Each weight's hidden size can be different.
        ep_group: The device process group for expert parallelism.
541
        communicator: EplbCommunicator instance for P2P communication.
542
543
544
        is_profile (bool): If `True`, do not perform any actual weight copy.
            This is used during profile run, where we only perform dummy
            communications to reserve enough memory for the buffers.
545
        rank_mapping: A dictionary mapping old rank to new rank.
546
    """
547
548
549
    if rank_mapping is not None:
        if len(rank_mapping) == ep_group.size():
            # scale down
550
            new_global_expert_indices = _map_new_expert_indices_with_rank_mapping(
551
552
553
554
555
                new_global_expert_indices,
                rank_mapping,
            )
        else:
            # scale up
556
            old_global_expert_indices = _map_old_expert_indices_with_rank_mapping(
557
558
559
560
561
                old_global_expert_indices,
                rank_mapping,
                ep_group.size(),
            )

562
    assert old_global_expert_indices.shape[1] == new_global_expert_indices.shape[1]
563

564
565
    num_moe_layers, num_physical_experts = old_global_expert_indices.shape
    assert len(expert_weights) == num_moe_layers
566
    assert len(expert_weights[0]) >= 1
567

568
    num_local_physical_experts = expert_weights[0][0].shape[0]
569
    assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
570
571

    ep_size = ep_group.size()
572
    ep_rank = ep_group.rank()
573
574
    assert num_physical_experts == ep_size * num_local_physical_experts

575
    first_layer_weights = list(expert_weights[0])
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594

    if is_profile:
        if communicator.needs_profile_buffer_reservation:
            # Reserve NCCL communication buffers via a dummy all_gather.
            # Backends that pre-allocate their own transfer buffers
            # skip this to avoid the extra memory spike during profiling.
            weights_buffer: list[torch.Tensor] = [
                torch.empty_like(w) for w in first_layer_weights
            ]
            for weight, buffer in zip(expert_weights[0], weights_buffer):
                dummy_recv_buffer = [buffer for _ in range(ep_size)]
                torch.distributed.barrier()
                all_gather(
                    dummy_recv_buffer,
                    weight,
                    group=ep_group,
                )
        return

595
    # Buffers to hold the expert weights during the exchange.
596
597
    # NOTE: Currently we assume the same weights across different layers
    # have the same shape.
598
    weights_buffer = [torch.empty_like(w) for w in first_layer_weights]
599

600
601
    # NOTE(bowen): We need this synchronize to run, but I don't know why.
    # If you figure out the reason, please let me know -- thank you!
602
    torch.accelerator.synchronize()
603

604
605
606
607
608
    old_global_expert_indices_cpu = old_global_expert_indices.cpu().numpy()
    new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy()

    for layer_idx in range(num_moe_layers):
        is_unchanged, is_received_locally, recv_metadata = move_to_buffer(
609
            num_local_experts=num_local_physical_experts,
610
611
612
613
            old_indices=old_global_expert_indices_cpu[layer_idx],
            new_indices=new_global_expert_indices_cpu[layer_idx],
            expert_weights=expert_weights[layer_idx],
            expert_weights_buffers=weights_buffer,
614
            cuda_stream=None,
615
616
            ep_rank=ep_rank,
            communicator=communicator,
617
618
619
        )

        move_from_buffer(
620
621
            expert_weights=expert_weights[layer_idx],
            expert_weights_buffers=weights_buffer,
622
623
            is_unchanged=is_unchanged,
            is_received_locally=is_received_locally,
624
625
            recv_metadata=recv_metadata,
            new_indices=new_global_expert_indices_cpu[layer_idx],
626
            ep_rank=ep_rank,
627
628
629
        )


630
631
632
633
634
635
636
def _map_old_expert_indices_with_rank_mapping(
    old_global_expert_indices: torch.Tensor,
    rank_mapping: dict[int, int],
    new_ep_size: int,
) -> torch.Tensor:
    """
    Map the old global expert indices to the new global expert indices.
637

638
639
640
641
642
    Args:
        old_global_expert_indices:
            Shape (num_layers, old_ep_size * num_local_physical_experts).
        rank_mapping: Mapping from old rank to new rank.
        new_ep_size: New expert parallelism size.
643

644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
    Returns:
        Mapped expert indices with shape
        (num_layers, new_ep_size * num_local_physical_experts).
    """
    num_layers, old_num_physical_experts = old_global_expert_indices.shape
    assert rank_mapping, "Rank mapping is required"

    # Get sizes from parameters and rank_mapping
    old_ep_size = len(rank_mapping)
    num_local_physical_experts = old_num_physical_experts // old_ep_size
    new_num_physical_experts = new_ep_size * num_local_physical_experts

    # Create mapped tensor with new shape, initialized to -1
    mapped_expert_indices = torch.full(
        (num_layers, new_num_physical_experts),
        fill_value=-1,
        dtype=old_global_expert_indices.dtype,
        device=old_global_expert_indices.device,
    )

    # Handle rank mapping (scale up/down with rank changes)
    for old_rank in range(old_ep_size):
        new_rank = rank_mapping.get(old_rank)
        if new_rank is not None and new_rank >= 0 and new_rank < new_ep_size:
            # This old rank exists in the new configuration
            old_start_idx = old_rank * num_local_physical_experts
            old_end_idx = (old_rank + 1) * num_local_physical_experts
            new_start_idx = new_rank * num_local_physical_experts
            new_end_idx = (new_rank + 1) * num_local_physical_experts

674
            mapped_expert_indices[:, new_start_idx:new_end_idx] = (
675
                old_global_expert_indices[:, old_start_idx:old_end_idx]
676
            )
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
        # If new_rank is None or >= new_ep_size, the experts remain -1
        # (scale down case)

    return mapped_expert_indices


def _map_new_expert_indices_with_rank_mapping(
    new_global_expert_indices: torch.Tensor,
    rank_mapping: dict[int, int],
) -> torch.Tensor:
    num_layers, new_num_physical_experts = new_global_expert_indices.shape
    assert rank_mapping, "Rank mapping is required"

    # Get sizes from parameters and rank_mapping
    old_ep_size = len(rank_mapping)
    new_ep_size = sum(new_rank != -1 for new_rank in rank_mapping.values())
    num_local_physical_experts = new_num_physical_experts // new_ep_size
    old_num_physical_experts = old_ep_size * num_local_physical_experts

    mapped_expert_indices = torch.full(
        (num_layers, old_num_physical_experts),
        fill_value=-1,
        dtype=new_global_expert_indices.dtype,
        device=new_global_expert_indices.device,
    )

    for old_rank in range(old_ep_size):
        new_rank = rank_mapping[old_rank]
        if new_rank >= 0 and new_rank < new_ep_size:
            old_start_idx = old_rank * num_local_physical_experts
            old_end_idx = (old_rank + 1) * num_local_physical_experts
            new_start_idx = new_rank * num_local_physical_experts
            new_end_idx = (new_rank + 1) * num_local_physical_experts

711
            mapped_expert_indices[:, old_start_idx:old_end_idx] = (
712
                new_global_expert_indices[:, new_start_idx:new_end_idx]
713
            )
714
715
716
717

    return mapped_expert_indices


718
__all__ = ["transfer_layer", "move_from_buffer", "RecvMetadata"]