eplb_state.py 47.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
"""
Expert parallelism load balancer (EPLB) metrics and states.

# Glossary

- **Logical Expert**: An expert that is part of the model's logical structure.
  It holds a set of weights and is replicated across multiple physical
  experts.
- **Redundant Expert**: To achieve load balancing, for some popular logical
  experts, we create additional copies of the expert weights. During inference,
  each of these copies can be routed to by the same set of tokens.
- **Physical Expert**: An expert that is instantiated on a specific device.
  It is a replica of a logical expert and can be rearranged across devices.
  I.e., one logical expert may have multiple sets of weights initialized on
  different devices, and each of these sets is a physical expert.
- **Local Physical Expert**: A physical expert that is instantiated on the
  current device.

For example: DeepSeek-R1 has 256 logical experts, so each MoE layer
has 256 sets of linear layer weights in the model parameters. If we add 32
redundant experts, DeepSeek-R1 will have 256 + 32 = 288 physical experts in
total. And when deploying, we'll have 288 sets of linear layer weights for each
MoE layer. If we have 32 EP ranks, then each GPU will hold 288 / 32 = 9 local
physical experts.
"""

29
import threading
30
31
32
from collections.abc import Sequence
from dataclasses import dataclass

33
import numpy as np
34
import torch
35
from torch.distributed import ProcessGroup, all_reduce
36

37
from vllm.config import ModelConfig, ParallelConfig
38
39
from vllm.distributed.parallel_state import (
    get_ep_group,
40
    get_eplb_group,
41
42
43
    get_node_count,
    in_the_same_node_as,
)
44
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
45
from vllm.distributed.utils import StatelessProcessGroup
46
47
48
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MixtureOfExperts

49
from .async_worker import start_async_worker
50
from .eplb_communicator import EplbCommunicator, create_eplb_communicator
Mercykid-bash's avatar
Mercykid-bash committed
51
from .policy import EPLB_POLICIES, AbstractEplbPolicy, DefaultEplbPolicy
52
53
54
55
56
from .rebalance_execute import (
    RecvMetadata,
    move_from_buffer,
    rearrange_expert_weights_inplace,
)
57
58
59
60

logger = init_logger(__name__)


61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
@dataclass
class EplbStats:
    """
    Model stats used in EPLB rebalancing algorithm.
    """

    global_expert_load_window: torch.Tensor
    """
    Experts load window.
    Shape: (window_size, num_moe_layers, num_physical_experts)
    """
    num_replicas: int
    """
    Number of physical experts.
    """
    num_groups: int
    """
    Number of expert groups.
    """
    num_nodes: int
    """
    Number of nodes.
    """
    num_gpus: int
    """
    Number of GPUs.
    """


90
@dataclass
91
class EplbModelState:
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    """EPLB metrics."""

    physical_to_logical_map: torch.Tensor
    """
    Mapping from physical experts to logical experts.

    Shape: (num_moe_layers, num_physical_experts)

    # Example

    For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3
    EP ranks, the mapping could look like this:

    ```
    [[0, 1, 2, 3, 0, 1],
     [0, 2, 0, 1, 0, 3]]
    ```
    """
    logical_to_physical_map: torch.Tensor
    """
    Mapping from logical experts to physical experts.

    This is a sparse matrix, where -1 indicates no mapping.

    Shape: (num_moe_layers, num_logical_experts, num_redundant_experts + 1)

    # Example

    For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3
    EP ranks, the mapping could look like this:

    ```
    [[[0, 4, -1],
      [1, 5, -1],
      [2, -1, -1],
      [3, -1, -1]],
     [[0, 2, 4],
      [3, -1, -1],
      [1, -1, -1],
      [5, -1, -1]]]
    ```
    """
    logical_replica_count: torch.Tensor
    """
    Number of replicas for each logical expert.
    This is exactly the non-`-1` count in the `logical_to_physical_map`.

    Shape: (num_moe_layers, num_logical_experts)

    # Example
    For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3
    EP ranks, the count could look like this:

    ```
    [[2, 2, 1, 1],
     [3, 1, 1, 1]]
    """

    expert_load_pass: torch.Tensor
    """
    Expert load during this forward pass. 
    We use the token count each expert processes as the load.

155
    Shape: (num_moe_layers, num_physical_experts)
156
157
158
159
160
    """
    expert_load_window: torch.Tensor
    """
    A sliding window of expert load.

161
162
163
164
    Shape: (window_size, num_moe_layers, num_physical_experts)

    NOTE: The expert_load_view now records load for all physical experts
    rather than just local experts. This ensures consistent load statistics
165
    across different dispatch methods (naive all-to-all, DeepEP).
166
167
168
169
    The recorded load will be multiplied by dp_size when using naive all-to-all
    due to each DP rank contributing the same token set to the calculation.
    See:
    https://github.com/vllm-project/vllm/pull/22167#pullrequestreview-3086143856
170
    """
171
172
    model_name: str
    model: MixtureOfExperts
173
174
175
176
177
178
179
180
    expert_buffer: list[torch.Tensor]
    """
    The buffer to store the expert weights during transfer.
    """
    buffer_lock: threading.Lock
    """
    The lock to protect the expert buffer.
    """
181
182
183
184
185
    buffer_consumed_event: torch.cuda.Event | None
    """
    CUDA event recorded after the main thread finishes consuming the buffer.
    The async worker waits on this before writing to the buffer again.
    """
186
187
188
189
190
    window_ready_event: torch.cuda.Event | None
    """
    CUDA event recorded after all-reduce and clone on the main thread.
    The async worker waits on this before accessing global_expert_load_window.
    """
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    ep_buffer_ready: int
    """
    The flag indicates whether the expert buffer is ready for transfer.
    0 or 1.
    """
    layer_to_transfer: int
    """
    The layer index to transfer in async mode.
    """
    rebalanced: bool
    """
    The flag indicates whether the experts rebalance have been computed.
    """
    pending_global_ready_check: bool
    """
    Whether the async EPLB needs to poll peers for buffer readiness.
    """
208
209
210
211
    eplb_stats: EplbStats | None
    """
    EPLB stats for the model.
    """
212
    is_unchanged: np.ndarray
213
214
215
216
    """
    intermediate variable between `move_to_buffer` and `move_to_workspace`.
    The size is same as the num of physical experts in the current layer.
    """
217
    is_received_locally: np.ndarray
218
219
220
221
    """
    intermediate variable between `move_to_buffer` and `move_to_workspace`.
    The size is same as the num of physical experts in the current layer.
    """
222
    recv_metadata: RecvMetadata
223
224
225
226
227
228
229
    """
    intermediate variable between `move_to_buffer` and `move_to_workspace`.
    """
    cuda_device_index: int | None
    """
    CUDA device index for the async EPLB worker thread.
    """
230
231
232
233
    communicator: EplbCommunicator
    """
    The communicator for expert weight transfers.
    """
234
235
236
237
238
    new_physical_to_logical_map: torch.Tensor | None = None
    """
    intermediate variable between `move_to_buffer` and `move_to_workspace`.
    the size is same as physical_to_logical_map
    """
239

240
241

class EplbState:
242
    """
243
    EplbState of each expert parallel model. Key is the model config hash.
244
245
    """

246
247
248
249
    def __init__(self, parallel_config: ParallelConfig, device: torch.device):
        self.parallel_config = parallel_config
        self.device = device
        self.model_states: dict[str, EplbModelState] = {}
Mercykid-bash's avatar
Mercykid-bash committed
250
251
252
253
254
        self.policy: type[AbstractEplbPolicy] = DefaultEplbPolicy
        """
        Selected EPLB algorithm class
        """
        self.expert_load_window_step: int = 0
255
256
        """
        Current step in the sliding window.
257

258
259
260
        Different from `expert_rearrangement_step`, 
        each EP rank may have its own `expert_load_window_step`.
        """
Mercykid-bash's avatar
Mercykid-bash committed
261
        self.expert_load_window_size: int = 0
262
263
264
265
        """
        Size of the expert load sliding window.
        This is a constant and is taken from the config.
        """
Mercykid-bash's avatar
Mercykid-bash committed
266
        self.expert_rearrangement_step: int = 0
267
268
269
270
271
272
273
274
275
        """
        Steps after last rearrangement.
        Will trigger a rearrangement if it exceeds the threshold.

        NOTE: Keep in mind that all EP ranks need to have the same
        `expert_rearrangement_step` value to ensure synchronization.
        Otherwise, the rearrangement will hang at collective
        communication calls.
        """
276
        self.expert_rearrangement_step_interval: int = 0
277
278
279
280
        """
        Interval for expert rearrangement steps.
        This is a constant and is taken from the config.
        """
281
282
283
284
285
286
287
        self.should_record_tensor: torch.Tensor | None = None
        """
        Shared scalar bool tensor for all layers.  Every
        :class:`EplbLayerState` holds a reference to the **same** object so
        a single ``.fill_()`` updates all layers at once.  Allocated on the
        first call to :meth:`_init_should_record_tensor`.
        """
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
        self.is_async: bool = False
        """
        The flag indicates whether the EPLB is running in async mode.
        """
        self.rearrange_event = threading.Event()
        """
        Event to signal when a new rearrangement is needed for the async thread.
        """
        self.async_worker: threading.Thread | None = None
        """
        Background thread handling async transfers.
        """
        self.cuda_device_index: int | None = None
        """
        CUDA device index for the async EPLB worker thread.
        """
304
305
306
307
308
309
310
311
        self.num_valid_physical_experts: int = 0
        """
        Number of valid physical experts.
        This is the number of physical experts that are
        actually mapped to logical experts. In elastic EP,
        newly started EP ranks may not have physical experts
        mapped yet.
        """
312
313
314
        if self.device.type == "cuda":
            self.cuda_device_index = self.device.index
            if self.cuda_device_index is None and torch.cuda.is_available():
315
                self.cuda_device_index = torch.accelerator.current_device_index()
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336

    @staticmethod
    def build_initial_global_physical_to_logical_map(
        num_routed_experts: int,
        num_redundant_experts: int,
    ) -> Sequence[int]:
        """
        Build an initial expert arrangement using the following structure:
        [original routed experts, redundant experts]

        Returns:
            physical_to_logical_map (Sequence[int]): A list of integers,
                where each integer is the index of the logical expert
                that the corresponding physical expert maps to.
        """
        global_physical_to_logical_map = list(range(num_routed_experts))
        global_physical_to_logical_map += [
            i % num_routed_experts for i in range(num_redundant_experts)
        ]
        return global_physical_to_logical_map

337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
    def validate_ep_configuration(self, new_model: MixtureOfExperts):
        """
        Validate that the expert parallel configuration of
        the new model is the same as the existing models.
        """
        if len(self.model_states) > 0:
            model = next(iter(self.model_states.values())).model
            if (
                model.num_routed_experts != new_model.num_routed_experts
                or model.num_redundant_experts != new_model.num_redundant_experts
                or model.num_physical_experts != new_model.num_physical_experts
                or model.num_logical_experts != new_model.num_logical_experts
                or model.num_expert_groups != new_model.num_expert_groups
            ):
                raise RuntimeError(
                    "Model: {} "
                    "with config {} "
                    "{} {} {} {} "
                    "mismatch with new model {} "
                    "with config {} "
                    "{} {} {} {}".format(
                        type(model),
                        model.num_routed_experts,
                        model.num_redundant_experts,
                        model.num_physical_experts,
                        model.num_logical_experts,
                        model.num_expert_groups,
                        type(new_model),
                        new_model.num_routed_experts,
                        new_model.num_redundant_experts,
                        new_model.num_physical_experts,
                        new_model.num_logical_experts,
                        new_model.num_expert_groups,
                    )
                )

    def add_model(
        self,
375
        model: MixtureOfExperts,
376
377
        model_config: ModelConfig,
    ):
378
379
380
        """
        Build the initial EPLB state.
        """
381
        self.validate_ep_configuration(model)
382
383
        self.is_async = self.parallel_config.eplb_config.use_async

384
385
386
387
388
        physical_to_logical_map_list = (
            EplbState.build_initial_global_physical_to_logical_map(
                model.num_routed_experts,
                model.num_redundant_experts,
            )
389
        )
390
391
        physical_to_logical_map = torch.tensor(
            physical_to_logical_map_list,
392
            device=self.device,
393
        )
394
395
396
397
398
399
        # Assuming 8 GPUs per node, this supports up to
        # (1023 + 1) / 8 = 128 nodes for now.
        # TODO(rui): make this configurable
        MAX_EXPERT_REDUNDANCY = 1023
        assert model.num_redundant_experts <= MAX_EXPERT_REDUNDANCY, (
            f"num_redundant_experts {model.num_redundant_experts} "
400
401
            f"must be less than or equal to {MAX_EXPERT_REDUNDANCY}"
        )
402
        max_slots_per_logical_expert = MAX_EXPERT_REDUNDANCY + 1
403
        logical_to_physical_map = torch.full(
404
            (model.num_logical_experts, max_slots_per_logical_expert),
405
            -1,
406
            device=self.device,
407
408
        )
        logical_replica_count = torch.zeros(
409
            (model.num_logical_experts,),
410
            device=self.device,
411
412
413
414
415
            dtype=torch.long,
        )

        for i in range(model.num_physical_experts):
            logical_idx = physical_to_logical_map[i]
416
            logical_to_physical_map[logical_idx, logical_replica_count[logical_idx]] = i
417
418
419
            logical_replica_count[logical_idx] += 1

        # Duplicate initial mapping for all layers
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
        physical_to_logical_map = (
            physical_to_logical_map.unsqueeze(0)
            .expand(
                model.num_moe_layers,
                -1,
            )
            .contiguous()
        )
        logical_to_physical_map = (
            logical_to_physical_map.unsqueeze(0)
            .expand(
                model.num_moe_layers,
                -1,
                -1,
            )
            .contiguous()
        )
        logical_replica_count = (
            logical_replica_count.unsqueeze(0)
            .expand(
                model.num_moe_layers,
                -1,
            )
            .contiguous()
        )
445
446

        expert_load_pass = torch.zeros(
447
            (model.num_moe_layers, model.num_physical_experts),
448
            dtype=torch.int32,
449
            device=self.device,
450
        )
451
        self.expert_load_window_size = self.parallel_config.eplb_config.window_size
452
        expert_load_window = torch.zeros(
453
454
455
456
457
            (
                self.expert_load_window_size,
                model.num_moe_layers,
                model.num_physical_experts,
            ),
458
            dtype=torch.int32,
459
            device=self.device,
460
461
462
        )

        # Set the initial progress of rearrangement to 3/4
463
464
465
466
467
        eplb_step_interval = self.parallel_config.eplb_config.step_interval
        self.expert_rearrangement_step = max(
            0, eplb_step_interval - eplb_step_interval // 4
        )
        self.expert_rearrangement_step_interval = eplb_step_interval
468

Mercykid-bash's avatar
Mercykid-bash committed
469
470
        policy_type = self.parallel_config.eplb_config.policy
        self.policy = EPLB_POLICIES[policy_type]
471
        logger.debug("Selected EPLB policy: %s", policy_type)
472

473
474
475
476
477
        model.set_eplb_state(
            expert_load_pass,
            logical_to_physical_map,
            logical_replica_count,
        )
478
        self._init_should_record_tensor(model)
479
480
        expert_buffer = [torch.empty_like(w) for w in model.expert_weights[0]]

481
482
483
484
485
486
        communicator = create_eplb_communicator(
            group_coordinator=get_eplb_group(),
            backend=self.parallel_config.eplb_config.communicator,
            expert_weights=model.expert_weights[0],
        )

487
488
489
490
491
492
493
494
495
496
        model_state = EplbModelState(
            physical_to_logical_map=physical_to_logical_map,
            logical_to_physical_map=logical_to_physical_map,
            logical_replica_count=logical_replica_count,
            expert_load_pass=expert_load_pass,
            expert_load_window=expert_load_window,
            model_name=model_config.model,
            model=model,
            expert_buffer=expert_buffer,
            buffer_lock=threading.Lock(),
497
            buffer_consumed_event=None,
498
            window_ready_event=None,
499
500
501
502
            ep_buffer_ready=0,
            layer_to_transfer=0,
            rebalanced=False,
            pending_global_ready_check=False,
503
            eplb_stats=None,
504
505
506
507
508
509
510
511
            is_unchanged=np.array([]),
            is_received_locally=np.array([]),
            recv_metadata=RecvMetadata(
                recv_primary_mask=np.array([]),
                recv_count=0,
                recv_expert_ids=np.array([]),
                recv_dst_rows=np.array([]),
            ),
512
            cuda_device_index=self.cuda_device_index,
513
            communicator=communicator,
514
            new_physical_to_logical_map=None,
515
        )
516
        self.model_states[model_config.compute_hash()] = model_state
517
        self.num_valid_physical_experts = model.num_physical_experts
518

519
520
521
522
523
524
    def step(
        self,
        is_dummy: bool = False,
        is_profile: bool = False,
        log_stats: bool = False,
    ) -> None:
525
526
527
528
529
        """
        Step the EPLB state.

        Args:
            is_dummy (bool): If `True`, this is a dummy step and the load
530
531
                metrics recorded in this forward pass will not count.
                Defaults to `False`.
532
            is_profile (bool): If `True`, perform a dummy rearrangement
533
534
535
                with maximum communication cost. This is used in
                `profile_run` to reserve enough memory
                for the communication buffer.
536
537
538
539
540
541
542
543
            log_stats (bool): If `True`, log the expert load metrics.

        # Stats
            The metrics are all summed up across layers.
            - `avg_tokens`: The average load across ranks.
            - `max_tokens`: The maximum load across ranks.
            - `balancedness`: The ratio of average load to maximum load.
        """
544
        ep_group = get_ep_group().device_group
545
        if is_profile:
546
            self.rearrange(is_profile=True)
547
548
549
550
            return

        if is_dummy:
            # Do not record load metrics for dummy steps
551
552
            for eplb_model_state in self.model_states.values():
                eplb_model_state.expert_load_pass.zero_()
553

554
555
556
557
558
559
        if (
            log_stats
            and self.expert_rearrangement_step
            % self.parallel_config.eplb_config.log_balancedness_interval
            == 0
        ):
560
561
562
            # Sync the expert load pass for each model (main and drafter).
            # expert_load_pass: (num_moe_layers, num_physical_experts)
            expert_load_pass_list = self._sync_load_pass()
563
            ep_group = get_ep_group().device_group
564
565
566
567
568
569
570
571
572
573
            for expert_load_pass, eplb_model_state in zip(
                expert_load_pass_list, self.model_states.values()
            ):
                # num_tokens_per_rank: (num_moe_layers, num_ranks)
                num_tokens_per_rank = (
                    expert_load_pass.reshape(
                        expert_load_pass.shape[0], ep_group.size(), -1
                    )
                    .sum(dim=-1)
                    .float()
574
                )
575

576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
                # Compute balancedness ratio:
                # for each layer:
                #   (mean load across ranks) / (max load across ranks)
                avg_tokens_tensor = num_tokens_per_rank.mean(dim=0).sum(dim=0)
                max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum(dim=0)

                # Just to make type checker happy
                tokens_tensors: list[float] = torch.stack(
                    [avg_tokens_tensor, max_tokens_tensor]
                ).tolist()
                avg_tokens, max_tokens = tokens_tensors
                balancedness = avg_tokens / max_tokens if max_tokens > 0 else 0.0

                if ep_group.rank() == 0:
                    logger.info(
                        "EPLB step: %d for model %s: avg_tokens=%.2f, "
592
593
                        "max_tokens=%d, balancedness=%.4f, "
                        "steps until the next rearrangement: %d",
594
595
596
597
598
                        self.expert_rearrangement_step,
                        eplb_model_state.model_name,
                        avg_tokens,
                        max_tokens,
                        balancedness,
599
600
                        self.expert_rearrangement_step_interval
                        - self.expert_rearrangement_step,
601
                    )
602
603
604

        # Update the expert load sliding window
        if not is_dummy:
605
            should_record = self._should_record_current_step(log_stats=log_stats)
606
            for eplb_model_state in self.model_states.values():
607
608
609
610
611
                if should_record:
                    eplb_model_state.expert_load_window[
                        self.expert_load_window_step
                    ].copy_(eplb_model_state.expert_load_pass)
                    eplb_model_state.expert_load_pass.zero_()
612

613
614
615
616
            if should_record:
                self.expert_load_window_step += 1
                if self.expert_load_window_step >= self.expert_load_window_size:
                    self.expert_load_window_step = 0
617
618
619
620
621
622

        # Step the expert rearrangement step
        # Note that even if this is a dummy step, we still increment the
        # rearrangement step and perform rearrangement to ensure all ranks are
        # performing collective communication.
        self.expert_rearrangement_step += 1
623
624
625
626
627
628
629
630

        if self.is_async:
            for eplb_model_state in self.model_states.values():
                all_ranks_buffer_ready = False
                if eplb_model_state.pending_global_ready_check:
                    all_ranks_buffer_ready = self._all_ranks_buffer_ready(
                        eplb_model_state
                    )
631
                if eplb_model_state.ep_buffer_ready and all_ranks_buffer_ready:
632
633
634
635
636
637
                    self.move_to_workspace(
                        model_state=eplb_model_state,
                        ep_group=ep_group,
                        is_profile=is_profile,
                    )

638
        if self.expert_rearrangement_step >= self.expert_rearrangement_step_interval:
639
640
            if self.is_async and any(
                eplb_model_state.rebalanced
641
642
                for eplb_model_state in self.model_states.values()
            ):
643
644
645
646
                # Still performing asynchronous rearrangement; update
                # should_record (step > step_interval, so always True) and
                # bail out before the step counter is reset.
                self._update_layer_should_record(log_stats=log_stats)
647
                return
648
            self.expert_rearrangement_step = 0
649
            self.rearrange()
650

651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
        self._update_layer_should_record(log_stats=log_stats)

    def _should_record_current_step(self, log_stats: bool = False) -> bool:
        """Return whether expert-load recording should be enabled this step.

        Recording is enabled when we are close to either:
        1) The next rearrangement step, so the sliding window is ready.
        2) The next balancedness logging step, when log_stats is enabled.
        """
        steps_remaining = (
            self.expert_rearrangement_step_interval - self.expert_rearrangement_step
        )
        should_record_for_rearrange = steps_remaining <= self.expert_load_window_size

        if not log_stats:
            return should_record_for_rearrange

        log_interval = self.parallel_config.eplb_config.log_balancedness_interval
        steps_until_next_log = (
            log_interval - (self.expert_rearrangement_step % log_interval)
        ) % log_interval
        should_record_for_log = steps_until_next_log <= self.expert_load_window_size
        return should_record_for_rearrange or should_record_for_log

    def _update_layer_should_record(self, log_stats: bool = False) -> None:
        """Update the shared ``should_record_tensor`` for all layers."""
        if self.should_record_tensor is not None:
            self.should_record_tensor.fill_(
                self._should_record_current_step(log_stats=log_stats)
            )

    def _init_should_record_tensor(self, model: "MixtureOfExperts") -> None:  # type: ignore[name-defined]
        """Allocate (once) and propagate the shared ``should_record_tensor``.

        Must be called after :meth:`model.set_eplb_state` so that each
        layer's ``eplb_state`` is already populated with the tensor views.
        """
        layer_states = [
            layer.eplb_state
            for layer in model.moe_layers
            if hasattr(layer, "eplb_state")
            and isinstance(layer.eplb_state, EplbLayerState)
        ]

        if self.should_record_tensor is None and layer_states:
            self.should_record_tensor = torch.ones(
                (), dtype=torch.bool, device=self.device
            )

        for ls in layer_states:
            ls.should_record_tensor = self.should_record_tensor

703
704
705
    def rearrange(
        self,
        is_profile: bool = False,
706
707
        rank_mapping: dict[int, int] | None = None,
    ) -> torch.Tensor | None:
708
709
        """
        Rearrange the experts according to the current load.
710
711
712
713
714
715
716

        Args:
            is_profile (bool): If `True`, perform a dummy rearrangement.
                This is used in `profile_run` to reserve enough memory,
                no memory movement will be performed. Default is False.
            rank_mapping (dict[int, int] | None): The rank mapping
                when scaling is done in EEP.
717
718
719
720
721
        """

        ep_group = get_ep_group().device_group
        ep_rank = ep_group.rank()

722
723
        start_event = None
        end_event = None
724
725
        is_main_rank = ep_rank == 0
        if is_main_rank:
726
727
728
729
            if not self.is_async or is_profile:
                start_event = torch.cuda.Event(enable_timing=True)
                end_event = torch.cuda.Event(enable_timing=True)
                start_event.record()
730
731
732
733
734
            logger.info(
                "Rearranging experts %s %s...",
                "(async mode)" if self.is_async else "sync mode",
                "(profile)" if is_profile else "",
            )
735

736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
        # Map the physical expert load to global logical experts
        global_expert_load_windows = []
        for eplb_model_state in self.model_states.values():
            expert_load_window = eplb_model_state.expert_load_window[
                :, :, : self.num_valid_physical_experts
            ]
            logical_expert_load_window = torch.zeros(
                self.expert_load_window_size,
                eplb_model_state.model.num_moe_layers,
                eplb_model_state.model.num_logical_experts,
                dtype=eplb_model_state.expert_load_window.dtype,
                device=eplb_model_state.expert_load_window.device,
            )
            logical_expert_load_window.scatter_add_(
                dim=-1,
                index=eplb_model_state.physical_to_logical_map[
                    :, : self.num_valid_physical_experts
                ]
                .unsqueeze(0)
                .expand_as(expert_load_window)
                .long(),
                src=expert_load_window,
758
            )
759
760
761
762
763

            global_expert_load_window = logical_expert_load_window.sum(dim=0)
            global_expert_load_windows.append(global_expert_load_window)
        # Perform all-reduce to get the expert load across all ranks for each model
        global_expert_load_windows = self._allreduce_list(global_expert_load_windows)
764
765

        # TODO(bowen): Treat differently for prefill and decode nodes
766
767
        eplb_model_state = next(iter(self.model_states.values()))
        model = eplb_model_state.model
768
769
        num_replicas = model.num_physical_experts
        num_groups = model.num_expert_groups
770

771
772
773
774
        if rank_mapping is not None and len(rank_mapping) == ep_group.size():
            # NOTE(yongji): scale down, we need to rebalance the experts on
            # remaining GPUs, transfer the experts while we haven't shutdown
            # the GPUs to be released.
775
776
777
778
            coordinator = get_ep_group()
            assert isinstance(coordinator, StatelessGroupCoordinator)
            tcp_store_group = coordinator.tcp_store_group
            num_nodes = _node_count_with_rank_mapping(tcp_store_group, rank_mapping)
779
780
781
782
            num_gpus = sum(new_rank != -1 for new_rank in rank_mapping.values())
            num_replicas = (
                num_replicas // ep_group.size() * num_gpus
            )  # handle num replicas change
783
784
785
        else:
            num_nodes = get_node_count()
            num_gpus = ep_group.size()
786
787

        if num_gpus % num_nodes != 0:
788
            num_nodes = 1
789
790
791
            logger.warning_once(
                f"num_gpus % num_nodes != 0, "
                "not using hierarchical rearrangement algorithm.\n"
792
793
                f"{num_gpus=}, {num_nodes=}"
            )
794

Mercykid-bash's avatar
Mercykid-bash committed
795
        # Get new expert mappings
796
797
798
        for eplb_model_state, global_expert_load_window in zip(
            self.model_states.values(), global_expert_load_windows
        ):
799
            if not self.is_async or is_profile:
800
                # Get new expert mappings for the model
801
802
                new_physical_to_logical_map = self.policy.rebalance_experts(
                    global_expert_load_window.cpu(),
803
804
805
806
                    num_replicas,
                    num_groups,
                    num_nodes,
                    num_gpus,
807
808
809
                    eplb_model_state.physical_to_logical_map.cpu(),
                )

810
811
812
813
814
815
                # Update expert weights
                rearrange_expert_weights_inplace(
                    eplb_model_state.physical_to_logical_map,
                    new_physical_to_logical_map,
                    eplb_model_state.model.expert_weights,
                    ep_group,
816
                    eplb_model_state.communicator,
817
818
819
                    is_profile,
                    rank_mapping,
                )
820

821
                if not is_profile:
822
823
824
                    _commit_eplb_maps(
                        eplb_model_state,
                        new_physical_to_logical_map=new_physical_to_logical_map,
825
                    )
826

827
                if is_main_rank:
828
829
830
831
832
                    assert start_event is not None
                    assert end_event is not None
                    end_event.record()
                    end_event.synchronize()
                    gpu_elapsed = start_event.elapsed_time(end_event) / 1000.0
833
                    logger.info(
834
                        "Rearranged experts %s in %.2f s.",
835
                        " (profile) " if is_profile else " ",
836
                        gpu_elapsed,
837
838
                    )
            else:
839
840
841
842
843
844
845
846
847
                eplb_model_state.eplb_stats = EplbStats(
                    # We copy the tensor to snapshot the global_expert_load_window
                    # on the main thread so that async worker can access it safely
                    # while the main thread is running.
                    global_expert_load_window=global_expert_load_window.clone(),
                    num_replicas=num_replicas,
                    num_groups=num_groups,
                    num_nodes=num_nodes,
                    num_gpus=num_gpus,
848
                )
849
850
851
852
853
                # Record event after clone to signal async worker
                # that load stats data is ready
                sync_event = torch.cuda.Event()
                sync_event.record()
                eplb_model_state.window_ready_event = sync_event
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904

                eplb_model_state.rebalanced = True
                eplb_model_state.layer_to_transfer = 0
                eplb_model_state.pending_global_ready_check = True
        # Signal async thread to start transferring layers
        if self.is_async and (not is_profile):
            self.rearrange_event.set()
        return None

    def start_async_loop(
        self,
        rank_mapping: dict[int, int] | None = None,
        is_profile: bool = False,
    ):
        if not self.is_async:
            return
        if self.async_worker is None:
            self.async_worker = start_async_worker(
                self,
                is_profile=is_profile,
            )

    def _all_ranks_buffer_ready(self, model_state: EplbModelState) -> bool:
        parallel_state = get_ep_group()
        cpu_group = getattr(parallel_state, "cpu_group", None)
        if cpu_group is not None and cpu_group.size() > 1:
            flag = torch.tensor(
                (int(model_state.ep_buffer_ready),), dtype=torch.int32, device="cpu"
            )
            all_reduce(flag, group=cpu_group)
            return int(flag.item()) == cpu_group.size()

        device_group = parallel_state.device_group
        if device_group.size() <= 1:
            return bool(model_state.ep_buffer_ready)

        device = getattr(
            parallel_state, "device", model_state.physical_to_logical_map.device
        )
        flag = torch.tensor(
            (int(model_state.ep_buffer_ready),), dtype=torch.int32, device=device
        )
        all_reduce(flag, group=device_group)
        return int(flag.item()) == device_group.size()

    def move_to_workspace(
        self,
        model_state: EplbModelState,
        ep_group: ProcessGroup,
        is_profile: bool = False,
    ):
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
        # We call move_to_workspace only when ep_buffer_ready is 1.
        # It means we only need to wait for the lock for a short time.
        max_retries = 6  # 1 minute max
        retries = 0
        while not model_state.buffer_lock.acquire(blocking=True, timeout=10.0):
            retries += 1
            if retries >= max_retries:
                raise RuntimeError(
                    f"Rank {ep_group.rank()}: buffer_lock timeout after "
                    "{max_retries * 10}s"
                )
            logger.warning(
                "Rank %d: EPLB buffer_lock acquire failed, retrying (%d/%d)",
                ep_group.rank(),
                retries,
                max_retries,
            )
922
923
        try:
            assert model_state.new_physical_to_logical_map is not None
924
925
926
927
            expert_weights = model_state.model.expert_weights[
                model_state.layer_to_transfer
            ]
            expert_weights_buffer = model_state.expert_buffer
928
929
930
            new_indices = model_state.new_physical_to_logical_map[
                model_state.layer_to_transfer
            ].numpy()
931
            move_from_buffer(
932
933
                expert_weights=expert_weights,
                expert_weights_buffers=expert_weights_buffer,
934
935
                is_unchanged=model_state.is_unchanged,
                is_received_locally=model_state.is_received_locally,
936
937
938
                recv_metadata=model_state.recv_metadata,
                new_indices=new_indices,
                ep_rank=ep_group.rank(),
939
            )
940
941

            transferred_layer = model_state.layer_to_transfer
942

943
            transferred_layer = model_state.layer_to_transfer
944
945
946
947
948
949
            assert model_state.new_physical_to_logical_map is not None
            _commit_eplb_maps_for_layer(
                model_state,
                new_physical_to_logical_map=model_state.new_physical_to_logical_map,
                layer=transferred_layer,
            )
950
951
952
953
954
955
956

            # Record event after consuming buffer to signal async thread
            # that it's safe to overwrite the intermediate buffer
            consumed_event = torch.cuda.Event()
            consumed_event.record()
            model_state.buffer_consumed_event = consumed_event

957
958
959
            # After the main thread consumes, advance layer_to_transfer
            model_state.layer_to_transfer += 1
            model_state.ep_buffer_ready = 0
960
            logger.debug(
961
962
963
                "model %s successfully move_to_workspace layer %d",
                model_state.model_name,
                transferred_layer,
964
            )
965
            if model_state.layer_to_transfer >= model_state.model.num_moe_layers:
966
                self.post_eplb(model_state)
967
968
969
970
971
972
973
974
975
976
                model_state.rebalanced = False
                model_state.layer_to_transfer = 0
                model_state.pending_global_ready_check = False
                logger.info(
                    "finish async transfer for model %s rank %d layer %d",
                    model_state.model_name,
                    ep_group.rank(),
                    model_state.model.num_moe_layers,
                )

977
978
979
980
981
982
983
984
985
986
        finally:
            try:
                model_state.buffer_lock.release()
            except Exception as e:
                logger.error(
                    "Rank %d: buffer_lock release failed in move_to_workspace: %s",
                    ep_group.rank(),
                    str(e),
                )

987
    def post_eplb(self, model_state: EplbModelState) -> None:
988
989
        assert model_state.new_physical_to_logical_map is not None
        model_state.new_physical_to_logical_map = None
990

991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
    def _allreduce_list(self, tensor_list: list[torch.Tensor]) -> list[torch.Tensor]:
        """
        All-reduce a list of tensors.
        """
        if len(tensor_list) == 1:
            all_reduce(tensor_list[0], group=get_ep_group().device_group)
            return tensor_list
        assert all(t.dim() == 2 for t in tensor_list), "All tensors must be 2D."
        assert all(t.shape[1] == tensor_list[0].shape[1] for t in tensor_list), (
            "All tensors must have the same shape[1]."
        )
        # Concatenate, all_reduce, then unpack to original shapes.
        # We assume all tensors are 2D and shape[1] (num_physical_experts)
        # is the same across all models.
        shapes = [t.shape for t in tensor_list]
        concat_tensor = torch.cat(tensor_list, dim=0)

        ep_group = get_ep_group().device_group
        all_reduce(concat_tensor, group=ep_group)

        all_reduce_list = []
        offset = 0
        for shape in shapes:
            all_reduce_list.append(concat_tensor[offset : offset + shape[0], :])
            offset += shape[0]
        return all_reduce_list

    def _sync_load_pass(self) -> list[torch.Tensor]:
        """
        Sync the expert load pass across all ranks for log stats.
        Doesn't update the expert load pass in eplb_model_state.
        """
        load_pass_list = []
        for eplb_model_state in self.model_states.values():
            load_pass_list.append(eplb_model_state.expert_load_pass.clone())
        return self._allreduce_list(load_pass_list)
1027

1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
    @classmethod
    def from_mapping(
        cls,
        model: MixtureOfExperts,
        model_config: ModelConfig,
        device: torch.device,
        parallel_config: ParallelConfig,
        expanded_physical_to_logical: torch.Tensor,
        num_valid_physical_experts: int,
    ) -> "EplbState":
        eplb_state = cls(
            parallel_config=parallel_config,
            device=device,
        )
        eplb_state.add_model(
            model=model,
            model_config=model_config,
        )
        eplb_state.num_valid_physical_experts = num_valid_physical_experts
        eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
        eplb_model_state.physical_to_logical_map.copy_(expanded_physical_to_logical)

1050
1051
1052
1053
1054
1055
1056
1057
        (logical_to_physical_map_cpu, logical_replica_count_cpu) = compute_logical_maps(
            expanded_physical_to_logical.cpu(), model.num_logical_experts
        )

        max_num_replicas = eplb_model_state.logical_to_physical_map.shape[-1]
        num_replicas = logical_to_physical_map_cpu.shape[-1]
        logical_to_physical_map = torch.nn.functional.pad(
            logical_to_physical_map_cpu,
1058
            (
1059
1060
                0,
                max_num_replicas - num_replicas,
1061
            ),
1062
1063
1064
            value=-1,
        ).to(device)
        logical_replica_count = logical_replica_count_cpu.to(device)
1065
1066
1067

        eplb_model_state.logical_to_physical_map.copy_(logical_to_physical_map)
        eplb_model_state.logical_replica_count.copy_(logical_replica_count)
1068

1069
1070
        return eplb_state

1071

1072
1073
1074
1075
1076
1077
1078
@dataclass
class EplbLayerState:
    """Runtime EPLB data stored in the MoE layer."""

    expert_load_view: torch.Tensor | None = None
    logical_to_physical_map: torch.Tensor | None = None
    logical_replica_count: torch.Tensor | None = None
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
    should_record_tensor: torch.Tensor | None = None
    """
    Shared scalar bool tensor controlling whether to accumulate expert load
    metrics during this forward pass.  All layers reference the **same**
    tensor object, which is owned and updated by :class:`EplbState`.

    Set to ``False`` for the first ``step_interval - window_size`` steps of
    each rearrangement period: those steps would be overwritten in the
    sliding window before the next rearrangement, so recording them wastes
    GPU work.
    """
1090
1091


1092
def _node_count_with_rank_mapping(
1093
    pg: ProcessGroup | StatelessProcessGroup,
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
    rank_mapping: dict[int, int],
) -> int:
    if isinstance(pg, ProcessGroup):
        world_size = torch.distributed.get_world_size(group=pg)
    else:
        world_size = pg.world_size

    if world_size == 1:
        return 1

    # Build node assignment map
    node_assignment = [0] * world_size  # rank -> node_id
    next_node_id = 0

    for current_rank in range(world_size):
        if node_assignment[current_rank] != 0:
            continue  # Already assigned to a node

        assert current_rank in rank_mapping
        if rank_mapping[current_rank] == -1:
            continue  # Pending shutdown

        # Assign current rank to a new node
        next_node_id += 1
        node_assignment[current_rank] = next_node_id

        # Find all ranks on the same node as current_rank
        same_node_flags = in_the_same_node_as(pg, current_rank)
        for other_rank, is_same_node in enumerate(same_node_flags):
            if is_same_node and node_assignment[other_rank] == 0:
                node_assignment[other_rank] = next_node_id

1126
    return next_node_id
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205


def compute_logical_maps(
    physical_to_logical_map: torch.Tensor,
    num_logical_experts: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Derive logical_to_physical_map and logical_replica_count from
    physical_to_logical_map.

    Args:
        physical_to_logical_map: [num_layers, num_physical_experts], logical
            expert index for each physical expert slot
        num_logical_experts: total number of logical experts

    Returns:
        logical_to_physical_map: [num_layers, num_logical_experts, max_replicas],
            physical slots per logical expert; -1 where unused
        logical_replica_count: [num_layers, num_logical_experts], number of
            physical replicas per logical expert
    """
    device = physical_to_logical_map.device
    assert physical_to_logical_map.device.type == "cpu"

    dtype = physical_to_logical_map.dtype

    # If computing maps for a single layer, unsqueeze a single element layer dimension
    per_layer = physical_to_logical_map.dim() == 1
    physical_to_logical_map_view = physical_to_logical_map
    if per_layer:
        physical_to_logical_map_view = physical_to_logical_map.unsqueeze(0)
    assert len(physical_to_logical_map_view.shape) == 2
    num_layers, num_physical = physical_to_logical_map_view.shape

    valid_mask = physical_to_logical_map_view >= 0
    logical_replica_count = torch.zeros(
        num_layers,
        num_logical_experts,
        dtype=dtype,
        device=device,
    )
    logical_replica_count.scatter_add_(
        1,
        physical_to_logical_map_view.clamp(min=0),
        valid_mask.to(dtype),
    )

    max_replicas = int(logical_replica_count.max().item())
    logical_to_physical_map_out = torch.full(
        (num_layers, num_logical_experts, max_replicas),
        -1,
        dtype=dtype,
        device=device,
    )

    running_count = torch.zeros_like(logical_replica_count)
    layer_indices = torch.arange(num_layers, device=device)
    for phys_idx in range(num_physical):
        # Logical expert at physical slot phys_idx for each layer
        logical_expert_ids = physical_to_logical_map_view[:, phys_idx]  # [num_layers]

        # Scale up will set the logical expert ids to -1 for all new physical experts.
        # Only consider "valid" experts when setting up the logical_to_physical map.
        valid_expert_mask = logical_expert_ids >= 0
        if not valid_expert_mask.any():
            continue
        valid_layers = layer_indices[valid_expert_mask]
        valid_experts = logical_expert_ids[valid_expert_mask]

        # Use the current running count as the replica index, then increment it.
        replica_idx = running_count[valid_layers, valid_experts]
        logical_to_physical_map_out[valid_layers, valid_experts, replica_idx] = phys_idx
        running_count[valid_layers, valid_experts] += 1

    # If computing maps for a single layer, squeeze out the extra layer dimension
    # before returning
    if per_layer:
        return logical_to_physical_map_out.squeeze(0), logical_replica_count.squeeze(0)
    return logical_to_physical_map_out, logical_replica_count
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286


def _pad_out_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
    src_padding = dst.shape[-1] - src.shape[-1]
    assert src_padding >= 0
    new_src = torch.nn.functional.pad(src, (0, src_padding), value=-1)
    dst.copy_(new_src)


def _commit_eplb_maps_for_layer(
    model_state: EplbModelState,
    new_physical_to_logical_map: torch.Tensor,
    layer: int,
) -> None:
    """
    Per-layer version of _commit_eplb_maps that's used by the sync portion of EPLB
    when running async EPLB. Copies all of the new_* maps into model_state. After this
    function completes, the new mappings will become the current mappings and will be
    visible to the model.
    """

    # Commit physical_to_logical_map
    src = new_physical_to_logical_map[layer]
    dst = model_state.physical_to_logical_map[layer]
    assert src.shape == dst.shape, (
        "The number of physical experts must stay the same while running Async EPLB. "
        f"Current number of physical experts: {dst.shape[0]}. New number of physical "
        f"experts {src.shape[0]}."
    )
    dst.copy_(src, non_blocking=True)

    num_logical_experts = model_state.logical_to_physical_map.shape[1]
    new_logical, new_replica_count = compute_logical_maps(src, num_logical_experts)
    # Commit logical_to_physical_map
    _pad_out_tensor(
        src=new_logical,
        dst=model_state.logical_to_physical_map[layer],
    )

    # Commit logical_replica_count
    src = new_replica_count
    dst = model_state.logical_replica_count[layer]
    assert src.shape == dst.shape
    dst.copy_(src, non_blocking=True)


def _commit_eplb_maps(
    model_state: EplbModelState,
    new_physical_to_logical_map: torch.Tensor,
) -> None:
    """
    Copies all of the new_* maps into model_state. After this function completes,
    the new mappings will become the current mappings and will be visible to the
    model.
    """

    # Commit physical_to_logical_map
    src = new_physical_to_logical_map
    dst = model_state.physical_to_logical_map

    # Rare Case: When the number of physical experts has changed, discard the old
    # physical to logical expert map and use the new one. This only happens when the
    # number of GPUs available to vLLM changes while vLLM is running. Otherwise copy the
    # new map into the old one.
    if src.shape[1] != dst.shape[1]:
        model_state.physical_to_logical_map = src.to(dst.device)
    else:
        dst.copy_(src, non_blocking=True)

    num_logical_experts = model_state.logical_to_physical_map.shape[1]
    new_logical, new_replica_count = compute_logical_maps(src, num_logical_experts)
    # Commit logical_to_physical_map
    _pad_out_tensor(
        src=new_logical,
        dst=model_state.logical_to_physical_map,
    )

    # Commit logical_replica_count
    src = new_replica_count
    dst = model_state.logical_replica_count
    dst.copy_(src, non_blocking=True)