eplb_state.py 43.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
"""
Expert parallelism load balancer (EPLB) metrics and states.

# Glossary

- **Logical Expert**: An expert that is part of the model's logical structure.
  It holds a set of weights and is replicated across multiple physical
  experts.
- **Redundant Expert**: To achieve load balancing, for some popular logical
  experts, we create additional copies of the expert weights. During inference,
  each of these copies can be routed to by the same set of tokens.
- **Physical Expert**: An expert that is instantiated on a specific device.
  It is a replica of a logical expert and can be rearranged across devices.
  I.e., one logical expert may have multiple sets of weights initialized on
  different devices, and each of these sets is a physical expert.
- **Local Physical Expert**: A physical expert that is instantiated on the
  current device.

For example: DeepSeek-R1 has 256 logical experts, so each MoE layer
has 256 sets of linear layer weights in the model parameters. If we add 32
redundant experts, DeepSeek-R1 will have 256 + 32 = 288 physical experts in
total. And when deploying, we'll have 288 sets of linear layer weights for each
MoE layer. If we have 32 EP ranks, then each GPU will hold 288 / 32 = 9 local
physical experts.
"""

29
import threading
30
31
32
from collections.abc import Sequence
from dataclasses import dataclass

33
import numpy as np
34
import torch
35
from torch.distributed import ProcessGroup, all_reduce
36

37
from vllm.config import ModelConfig, ParallelConfig
38
39
40
41
42
from vllm.distributed.parallel_state import (
    get_ep_group,
    get_node_count,
    in_the_same_node_as,
)
43
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
44
from vllm.distributed.utils import StatelessProcessGroup
45
46
47
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MixtureOfExperts

48
from .async_worker import start_async_worker
Mercykid-bash's avatar
Mercykid-bash committed
49
from .policy import EPLB_POLICIES, AbstractEplbPolicy, DefaultEplbPolicy
50
51
52
53
54
from .rebalance_execute import (
    RecvMetadata,
    move_from_buffer,
    rearrange_expert_weights_inplace,
)
55
56
57
58

logger = init_logger(__name__)


59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@dataclass
class EplbStats:
    """
    Model stats used in EPLB rebalancing algorithm.
    """

    global_expert_load_window: torch.Tensor
    """
    Experts load window.
    Shape: (window_size, num_moe_layers, num_physical_experts)
    """
    num_replicas: int
    """
    Number of physical experts.
    """
    num_groups: int
    """
    Number of expert groups.
    """
    num_nodes: int
    """
    Number of nodes.
    """
    num_gpus: int
    """
    Number of GPUs.
    """


88
@dataclass
89
class EplbModelState:
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    """EPLB metrics."""

    physical_to_logical_map: torch.Tensor
    """
    Mapping from physical experts to logical experts.

    Shape: (num_moe_layers, num_physical_experts)

    # Example

    For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3
    EP ranks, the mapping could look like this:

    ```
    [[0, 1, 2, 3, 0, 1],
     [0, 2, 0, 1, 0, 3]]
    ```
    """
    logical_to_physical_map: torch.Tensor
    """
    Mapping from logical experts to physical experts.

    This is a sparse matrix, where -1 indicates no mapping.

    Shape: (num_moe_layers, num_logical_experts, num_redundant_experts + 1)

    # Example

    For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3
    EP ranks, the mapping could look like this:

    ```
    [[[0, 4, -1],
      [1, 5, -1],
      [2, -1, -1],
      [3, -1, -1]],
     [[0, 2, 4],
      [3, -1, -1],
      [1, -1, -1],
      [5, -1, -1]]]
    ```
    """
    logical_replica_count: torch.Tensor
    """
    Number of replicas for each logical expert.
    This is exactly the non-`-1` count in the `logical_to_physical_map`.

    Shape: (num_moe_layers, num_logical_experts)

    # Example
    For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3
    EP ranks, the count could look like this:

    ```
    [[2, 2, 1, 1],
     [3, 1, 1, 1]]
    """

    expert_load_pass: torch.Tensor
    """
    Expert load during this forward pass. 
    We use the token count each expert processes as the load.

153
    Shape: (num_moe_layers, num_physical_experts)
154
155
156
157
158
    """
    expert_load_window: torch.Tensor
    """
    A sliding window of expert load.

159
160
161
162
    Shape: (window_size, num_moe_layers, num_physical_experts)

    NOTE: The expert_load_view now records load for all physical experts
    rather than just local experts. This ensures consistent load statistics
163
    across different dispatch methods (naive all-to-all, DeepEP).
164
165
166
167
    The recorded load will be multiplied by dp_size when using naive all-to-all
    due to each DP rank contributing the same token set to the calculation.
    See:
    https://github.com/vllm-project/vllm/pull/22167#pullrequestreview-3086143856
168
    """
169
170
    model_name: str
    model: MixtureOfExperts
171
172
173
174
175
176
177
178
    expert_buffer: list[torch.Tensor]
    """
    The buffer to store the expert weights during transfer.
    """
    buffer_lock: threading.Lock
    """
    The lock to protect the expert buffer.
    """
179
180
181
182
183
    buffer_consumed_event: torch.cuda.Event | None
    """
    CUDA event recorded after the main thread finishes consuming the buffer.
    The async worker waits on this before writing to the buffer again.
    """
184
185
186
187
188
    window_ready_event: torch.cuda.Event | None
    """
    CUDA event recorded after all-reduce and clone on the main thread.
    The async worker waits on this before accessing global_expert_load_window.
    """
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    ep_buffer_ready: int
    """
    The flag indicates whether the expert buffer is ready for transfer.
    0 or 1.
    """
    layer_to_transfer: int
    """
    The layer index to transfer in async mode.
    """
    rebalanced: bool
    """
    The flag indicates whether the experts rebalance have been computed.
    """
    pending_global_ready_check: bool
    """
    Whether the async EPLB needs to poll peers for buffer readiness.
    """
206
207
208
209
    eplb_stats: EplbStats | None
    """
    EPLB stats for the model.
    """
210
    is_unchanged: np.ndarray
211
212
213
214
    """
    intermediate variable between `move_to_buffer` and `move_to_workspace`.
    The size is same as the num of physical experts in the current layer.
    """
215
    is_received_locally: np.ndarray
216
217
218
219
    """
    intermediate variable between `move_to_buffer` and `move_to_workspace`.
    The size is same as the num of physical experts in the current layer.
    """
220
    recv_metadata: RecvMetadata
221
222
223
224
225
226
227
228
229
230
231
232
    """
    intermediate variable between `move_to_buffer` and `move_to_workspace`.
    """
    cuda_device_index: int | None
    """
    CUDA device index for the async EPLB worker thread.
    """
    new_physical_to_logical_map: torch.Tensor | None = None
    """
    intermediate variable between `move_to_buffer` and `move_to_workspace`.
    the size is same as physical_to_logical_map
    """
233

234
235

class EplbState:
236
    """
237
    EplbState of each expert parallel model. Key is the model config hash.
238
239
    """

240
241
242
243
    def __init__(self, parallel_config: ParallelConfig, device: torch.device):
        self.parallel_config = parallel_config
        self.device = device
        self.model_states: dict[str, EplbModelState] = {}
Mercykid-bash's avatar
Mercykid-bash committed
244
245
246
247
248
        self.policy: type[AbstractEplbPolicy] = DefaultEplbPolicy
        """
        Selected EPLB algorithm class
        """
        self.expert_load_window_step: int = 0
249
250
        """
        Current step in the sliding window.
251

252
253
254
        Different from `expert_rearrangement_step`, 
        each EP rank may have its own `expert_load_window_step`.
        """
Mercykid-bash's avatar
Mercykid-bash committed
255
        self.expert_load_window_size: int = 0
256
257
258
259
        """
        Size of the expert load sliding window.
        This is a constant and is taken from the config.
        """
Mercykid-bash's avatar
Mercykid-bash committed
260
        self.expert_rearrangement_step: int = 0
261
262
263
264
265
266
267
268
269
        """
        Steps after last rearrangement.
        Will trigger a rearrangement if it exceeds the threshold.

        NOTE: Keep in mind that all EP ranks need to have the same
        `expert_rearrangement_step` value to ensure synchronization.
        Otherwise, the rearrangement will hang at collective
        communication calls.
        """
270
        self.expert_rearrangement_step_interval: int = 0
271
272
273
274
        """
        Interval for expert rearrangement steps.
        This is a constant and is taken from the config.
        """
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
        self.is_async: bool = False
        """
        The flag indicates whether the EPLB is running in async mode.
        """
        self.rearrange_event = threading.Event()
        """
        Event to signal when a new rearrangement is needed for the async thread.
        """
        self.async_worker: threading.Thread | None = None
        """
        Background thread handling async transfers.
        """
        self.cuda_device_index: int | None = None
        """
        CUDA device index for the async EPLB worker thread.
        """
291
292
293
294
295
296
297
298
        self.num_valid_physical_experts: int = 0
        """
        Number of valid physical experts.
        This is the number of physical experts that are
        actually mapped to logical experts. In elastic EP,
        newly started EP ranks may not have physical experts
        mapped yet.
        """
299
300
301
        if self.device.type == "cuda":
            self.cuda_device_index = self.device.index
            if self.cuda_device_index is None and torch.cuda.is_available():
302
                self.cuda_device_index = torch.accelerator.current_device_index()
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323

    @staticmethod
    def build_initial_global_physical_to_logical_map(
        num_routed_experts: int,
        num_redundant_experts: int,
    ) -> Sequence[int]:
        """
        Build an initial expert arrangement using the following structure:
        [original routed experts, redundant experts]

        Returns:
            physical_to_logical_map (Sequence[int]): A list of integers,
                where each integer is the index of the logical expert
                that the corresponding physical expert maps to.
        """
        global_physical_to_logical_map = list(range(num_routed_experts))
        global_physical_to_logical_map += [
            i % num_routed_experts for i in range(num_redundant_experts)
        ]
        return global_physical_to_logical_map

324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
    def validate_ep_configuration(self, new_model: MixtureOfExperts):
        """
        Validate that the expert parallel configuration of
        the new model is the same as the existing models.
        """
        if len(self.model_states) > 0:
            model = next(iter(self.model_states.values())).model
            if (
                model.num_routed_experts != new_model.num_routed_experts
                or model.num_redundant_experts != new_model.num_redundant_experts
                or model.num_physical_experts != new_model.num_physical_experts
                or model.num_logical_experts != new_model.num_logical_experts
                or model.num_expert_groups != new_model.num_expert_groups
            ):
                raise RuntimeError(
                    "Model: {} "
                    "with config {} "
                    "{} {} {} {} "
                    "mismatch with new model {} "
                    "with config {} "
                    "{} {} {} {}".format(
                        type(model),
                        model.num_routed_experts,
                        model.num_redundant_experts,
                        model.num_physical_experts,
                        model.num_logical_experts,
                        model.num_expert_groups,
                        type(new_model),
                        new_model.num_routed_experts,
                        new_model.num_redundant_experts,
                        new_model.num_physical_experts,
                        new_model.num_logical_experts,
                        new_model.num_expert_groups,
                    )
                )

    def add_model(
        self,
362
        model: MixtureOfExperts,
363
364
        model_config: ModelConfig,
    ):
365
366
367
        """
        Build the initial EPLB state.
        """
368
        self.validate_ep_configuration(model)
369
370
        self.is_async = self.parallel_config.eplb_config.use_async

371
372
373
374
375
        physical_to_logical_map_list = (
            EplbState.build_initial_global_physical_to_logical_map(
                model.num_routed_experts,
                model.num_redundant_experts,
            )
376
        )
377
378
        physical_to_logical_map = torch.tensor(
            physical_to_logical_map_list,
379
            device=self.device,
380
        )
381
382
383
384
385
386
        # Assuming 8 GPUs per node, this supports up to
        # (1023 + 1) / 8 = 128 nodes for now.
        # TODO(rui): make this configurable
        MAX_EXPERT_REDUNDANCY = 1023
        assert model.num_redundant_experts <= MAX_EXPERT_REDUNDANCY, (
            f"num_redundant_experts {model.num_redundant_experts} "
387
388
            f"must be less than or equal to {MAX_EXPERT_REDUNDANCY}"
        )
389
        max_slots_per_logical_expert = MAX_EXPERT_REDUNDANCY + 1
390
        logical_to_physical_map = torch.full(
391
            (model.num_logical_experts, max_slots_per_logical_expert),
392
            -1,
393
            device=self.device,
394
395
        )
        logical_replica_count = torch.zeros(
396
            (model.num_logical_experts,),
397
            device=self.device,
398
399
400
401
402
            dtype=torch.long,
        )

        for i in range(model.num_physical_experts):
            logical_idx = physical_to_logical_map[i]
403
            logical_to_physical_map[logical_idx, logical_replica_count[logical_idx]] = i
404
405
406
            logical_replica_count[logical_idx] += 1

        # Duplicate initial mapping for all layers
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
        physical_to_logical_map = (
            physical_to_logical_map.unsqueeze(0)
            .expand(
                model.num_moe_layers,
                -1,
            )
            .contiguous()
        )
        logical_to_physical_map = (
            logical_to_physical_map.unsqueeze(0)
            .expand(
                model.num_moe_layers,
                -1,
                -1,
            )
            .contiguous()
        )
        logical_replica_count = (
            logical_replica_count.unsqueeze(0)
            .expand(
                model.num_moe_layers,
                -1,
            )
            .contiguous()
        )
432
433

        expert_load_pass = torch.zeros(
434
            (model.num_moe_layers, model.num_physical_experts),
435
            dtype=torch.int32,
436
            device=self.device,
437
        )
438
        self.expert_load_window_size = self.parallel_config.eplb_config.window_size
439
        expert_load_window = torch.zeros(
440
441
442
443
444
            (
                self.expert_load_window_size,
                model.num_moe_layers,
                model.num_physical_experts,
            ),
445
            dtype=torch.int32,
446
            device=self.device,
447
448
449
        )

        # Set the initial progress of rearrangement to 3/4
450
451
452
453
454
        eplb_step_interval = self.parallel_config.eplb_config.step_interval
        self.expert_rearrangement_step = max(
            0, eplb_step_interval - eplb_step_interval // 4
        )
        self.expert_rearrangement_step_interval = eplb_step_interval
455

Mercykid-bash's avatar
Mercykid-bash committed
456
457
        policy_type = self.parallel_config.eplb_config.policy
        self.policy = EPLB_POLICIES[policy_type]
458
        logger.debug("Selected EPLB policy: %s", policy_type)
459

460
461
462
463
464
465
        model.set_eplb_state(
            expert_load_pass,
            logical_to_physical_map,
            logical_replica_count,
        )

466
467
468
469
470
471
472
473
474
475
476
477
        expert_buffer = [torch.empty_like(w) for w in model.expert_weights[0]]

        model_state = EplbModelState(
            physical_to_logical_map=physical_to_logical_map,
            logical_to_physical_map=logical_to_physical_map,
            logical_replica_count=logical_replica_count,
            expert_load_pass=expert_load_pass,
            expert_load_window=expert_load_window,
            model_name=model_config.model,
            model=model,
            expert_buffer=expert_buffer,
            buffer_lock=threading.Lock(),
478
            buffer_consumed_event=None,
479
            window_ready_event=None,
480
481
482
483
            ep_buffer_ready=0,
            layer_to_transfer=0,
            rebalanced=False,
            pending_global_ready_check=False,
484
            eplb_stats=None,
485
486
487
488
489
490
491
492
            is_unchanged=np.array([]),
            is_received_locally=np.array([]),
            recv_metadata=RecvMetadata(
                recv_primary_mask=np.array([]),
                recv_count=0,
                recv_expert_ids=np.array([]),
                recv_dst_rows=np.array([]),
            ),
493
            cuda_device_index=self.cuda_device_index,
494
            new_physical_to_logical_map=None,
495
        )
496
        self.model_states[model_config.compute_hash()] = model_state
497
        self.num_valid_physical_experts = model.num_physical_experts
498

499
500
501
502
503
504
    def step(
        self,
        is_dummy: bool = False,
        is_profile: bool = False,
        log_stats: bool = False,
    ) -> None:
505
506
507
508
509
        """
        Step the EPLB state.

        Args:
            is_dummy (bool): If `True`, this is a dummy step and the load
510
511
                metrics recorded in this forward pass will not count.
                Defaults to `False`.
512
            is_profile (bool): If `True`, perform a dummy rearrangement
513
514
515
                with maximum communication cost. This is used in
                `profile_run` to reserve enough memory
                for the communication buffer.
516
517
518
519
520
521
522
523
            log_stats (bool): If `True`, log the expert load metrics.

        # Stats
            The metrics are all summed up across layers.
            - `avg_tokens`: The average load across ranks.
            - `max_tokens`: The maximum load across ranks.
            - `balancedness`: The ratio of average load to maximum load.
        """
524
        ep_group = get_ep_group().device_group
525
        if is_profile:
526
            self.rearrange(is_profile=True)
527
528
529
530
            return

        if is_dummy:
            # Do not record load metrics for dummy steps
531
532
            for eplb_model_state in self.model_states.values():
                eplb_model_state.expert_load_pass.zero_()
533

534
535
536
537
538
539
        if (
            log_stats
            and self.expert_rearrangement_step
            % self.parallel_config.eplb_config.log_balancedness_interval
            == 0
        ):
540
541
542
            # Sync the expert load pass for each model (main and drafter).
            # expert_load_pass: (num_moe_layers, num_physical_experts)
            expert_load_pass_list = self._sync_load_pass()
543
            ep_group = get_ep_group().device_group
544
545
546
547
548
549
550
551
552
553
            for expert_load_pass, eplb_model_state in zip(
                expert_load_pass_list, self.model_states.values()
            ):
                # num_tokens_per_rank: (num_moe_layers, num_ranks)
                num_tokens_per_rank = (
                    expert_load_pass.reshape(
                        expert_load_pass.shape[0], ep_group.size(), -1
                    )
                    .sum(dim=-1)
                    .float()
554
                )
555

556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
                # Compute balancedness ratio:
                # for each layer:
                #   (mean load across ranks) / (max load across ranks)
                avg_tokens_tensor = num_tokens_per_rank.mean(dim=0).sum(dim=0)
                max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum(dim=0)

                # Just to make type checker happy
                tokens_tensors: list[float] = torch.stack(
                    [avg_tokens_tensor, max_tokens_tensor]
                ).tolist()
                avg_tokens, max_tokens = tokens_tensors
                balancedness = avg_tokens / max_tokens if max_tokens > 0 else 0.0

                if ep_group.rank() == 0:
                    logger.info(
                        "EPLB step: %d for model %s: avg_tokens=%.2f, "
572
573
                        "max_tokens=%d, balancedness=%.4f, "
                        "steps until the next rearrangement: %d",
574
575
576
577
578
                        self.expert_rearrangement_step,
                        eplb_model_state.model_name,
                        avg_tokens,
                        max_tokens,
                        balancedness,
579
580
                        self.expert_rearrangement_step_interval
                        - self.expert_rearrangement_step,
581
                    )
582
583
584

        # Update the expert load sliding window
        if not is_dummy:
585
586
587
588
589
590
            for eplb_model_state in self.model_states.values():
                eplb_model_state.expert_load_window[self.expert_load_window_step] = (
                    eplb_model_state.expert_load_pass.clone()
                )
                eplb_model_state.expert_load_pass.zero_()

591
592
593
594
595
596
597
598
599
            self.expert_load_window_step += 1
            if self.expert_load_window_step >= self.expert_load_window_size:
                self.expert_load_window_step = 0

        # Step the expert rearrangement step
        # Note that even if this is a dummy step, we still increment the
        # rearrangement step and perform rearrangement to ensure all ranks are
        # performing collective communication.
        self.expert_rearrangement_step += 1
600
601
602
603
604
605
606
607

        if self.is_async:
            for eplb_model_state in self.model_states.values():
                all_ranks_buffer_ready = False
                if eplb_model_state.pending_global_ready_check:
                    all_ranks_buffer_ready = self._all_ranks_buffer_ready(
                        eplb_model_state
                    )
608
                if eplb_model_state.ep_buffer_ready and all_ranks_buffer_ready:
609
610
611
612
613
614
                    self.move_to_workspace(
                        model_state=eplb_model_state,
                        ep_group=ep_group,
                        is_profile=is_profile,
                    )

615
        if self.expert_rearrangement_step >= self.expert_rearrangement_step_interval:
616
617
            if self.is_async and any(
                eplb_model_state.rebalanced
618
619
620
621
                for eplb_model_state in self.model_states.values()
            ):
                # Still performing asynchronous rearrangement
                return
622
            self.expert_rearrangement_step = 0
623
            self.rearrange()
624

625
626
627
    def rearrange(
        self,
        is_profile: bool = False,
628
629
        rank_mapping: dict[int, int] | None = None,
    ) -> torch.Tensor | None:
630
631
        """
        Rearrange the experts according to the current load.
632
633
634
635
636
637
638

        Args:
            is_profile (bool): If `True`, perform a dummy rearrangement.
                This is used in `profile_run` to reserve enough memory,
                no memory movement will be performed. Default is False.
            rank_mapping (dict[int, int] | None): The rank mapping
                when scaling is done in EEP.
639
640
641
642
643
        """

        ep_group = get_ep_group().device_group
        ep_rank = ep_group.rank()

644
645
        start_event = None
        end_event = None
646
647
        is_main_rank = ep_rank == 0
        if is_main_rank:
648
649
650
651
            if not self.is_async or is_profile:
                start_event = torch.cuda.Event(enable_timing=True)
                end_event = torch.cuda.Event(enable_timing=True)
                start_event.record()
652
653
654
655
656
            logger.info(
                "Rearranging experts %s %s...",
                "(async mode)" if self.is_async else "sync mode",
                "(profile)" if is_profile else "",
            )
657

658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
        # Map the physical expert load to global logical experts
        global_expert_load_windows = []
        for eplb_model_state in self.model_states.values():
            expert_load_window = eplb_model_state.expert_load_window[
                :, :, : self.num_valid_physical_experts
            ]
            logical_expert_load_window = torch.zeros(
                self.expert_load_window_size,
                eplb_model_state.model.num_moe_layers,
                eplb_model_state.model.num_logical_experts,
                dtype=eplb_model_state.expert_load_window.dtype,
                device=eplb_model_state.expert_load_window.device,
            )
            logical_expert_load_window.scatter_add_(
                dim=-1,
                index=eplb_model_state.physical_to_logical_map[
                    :, : self.num_valid_physical_experts
                ]
                .unsqueeze(0)
                .expand_as(expert_load_window)
                .long(),
                src=expert_load_window,
680
            )
681
682
683
684
685

            global_expert_load_window = logical_expert_load_window.sum(dim=0)
            global_expert_load_windows.append(global_expert_load_window)
        # Perform all-reduce to get the expert load across all ranks for each model
        global_expert_load_windows = self._allreduce_list(global_expert_load_windows)
686
687

        # TODO(bowen): Treat differently for prefill and decode nodes
688
689
        eplb_model_state = next(iter(self.model_states.values()))
        model = eplb_model_state.model
690
691
        num_replicas = model.num_physical_experts
        num_groups = model.num_expert_groups
692

693
694
695
696
        if rank_mapping is not None and len(rank_mapping) == ep_group.size():
            # NOTE(yongji): scale down, we need to rebalance the experts on
            # remaining GPUs, transfer the experts while we haven't shutdown
            # the GPUs to be released.
697
698
699
700
            coordinator = get_ep_group()
            assert isinstance(coordinator, StatelessGroupCoordinator)
            tcp_store_group = coordinator.tcp_store_group
            num_nodes = _node_count_with_rank_mapping(tcp_store_group, rank_mapping)
701
702
703
704
            num_gpus = sum(new_rank != -1 for new_rank in rank_mapping.values())
            num_replicas = (
                num_replicas // ep_group.size() * num_gpus
            )  # handle num replicas change
705
706
707
        else:
            num_nodes = get_node_count()
            num_gpus = ep_group.size()
708
709

        if num_gpus % num_nodes != 0:
710
            num_nodes = 1
711
712
713
            logger.warning_once(
                f"num_gpus % num_nodes != 0, "
                "not using hierarchical rearrangement algorithm.\n"
714
715
                f"{num_gpus=}, {num_nodes=}"
            )
716

Mercykid-bash's avatar
Mercykid-bash committed
717
        # Get new expert mappings
718
719
720
        for eplb_model_state, global_expert_load_window in zip(
            self.model_states.values(), global_expert_load_windows
        ):
721
            if not self.is_async or is_profile:
722
                # Get new expert mappings for the model
723
724
                new_physical_to_logical_map = self.policy.rebalance_experts(
                    global_expert_load_window.cpu(),
725
726
727
728
                    num_replicas,
                    num_groups,
                    num_nodes,
                    num_gpus,
729
730
731
732
733
734
735
736
                    eplb_model_state.physical_to_logical_map.cpu(),
                )

                num_logical_experts = global_expert_load_window.shape[-1]
                (new_logical_to_physical_map, new_logical_replica_count) = (
                    compute_logical_maps(
                        new_physical_to_logical_map, num_logical_experts
                    )
737
738
                )

739
740
741
742
743
744
745
746
747
                # Update expert weights
                rearrange_expert_weights_inplace(
                    eplb_model_state.physical_to_logical_map,
                    new_physical_to_logical_map,
                    eplb_model_state.model.expert_weights,
                    ep_group,
                    is_profile,
                    rank_mapping,
                )
748

749
750
751
752
753
754
755
756
757
                if not is_profile:
                    if (
                        eplb_model_state.physical_to_logical_map.shape[1]
                        != new_physical_to_logical_map.shape[1]
                    ):
                        eplb_model_state.physical_to_logical_map = (
                            new_physical_to_logical_map.to(
                                eplb_model_state.physical_to_logical_map.device
                            )
758
                        )
759
760
761
762
763
764
765
766
                    else:
                        eplb_model_state.physical_to_logical_map.copy_(
                            new_physical_to_logical_map
                        )
                    max_physical_slots = new_logical_to_physical_map.shape[-1]
                    assert (
                        max_physical_slots
                        <= eplb_model_state.logical_to_physical_map.shape[-1]
767
                    )
768
769
770
771
772
773
774
775
                    new_logical_to_physical_map = torch.nn.functional.pad(
                        new_logical_to_physical_map,
                        (
                            0,
                            eplb_model_state.logical_to_physical_map.shape[-1]
                            - max_physical_slots,
                        ),
                        value=-1,
776
                    )
777
778
779
780
781
782
783
                    eplb_model_state.logical_to_physical_map.copy_(
                        new_logical_to_physical_map
                    )
                    eplb_model_state.logical_replica_count.copy_(
                        new_logical_replica_count
                    )
                if is_main_rank:
784
785
786
787
788
                    assert start_event is not None
                    assert end_event is not None
                    end_event.record()
                    end_event.synchronize()
                    gpu_elapsed = start_event.elapsed_time(end_event) / 1000.0
789
                    logger.info(
790
                        "Rearranged experts %s in %.2f s.",
791
                        " (profile) " if is_profile else " ",
792
                        gpu_elapsed,
793
794
                    )
            else:
795
796
797
798
799
800
801
802
803
                eplb_model_state.eplb_stats = EplbStats(
                    # We copy the tensor to snapshot the global_expert_load_window
                    # on the main thread so that async worker can access it safely
                    # while the main thread is running.
                    global_expert_load_window=global_expert_load_window.clone(),
                    num_replicas=num_replicas,
                    num_groups=num_groups,
                    num_nodes=num_nodes,
                    num_gpus=num_gpus,
804
                )
805
806
807
808
809
                # Record event after clone to signal async worker
                # that load stats data is ready
                sync_event = torch.cuda.Event()
                sync_event.record()
                eplb_model_state.window_ready_event = sync_event
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834

                eplb_model_state.rebalanced = True
                eplb_model_state.layer_to_transfer = 0
                eplb_model_state.pending_global_ready_check = True
        # Signal async thread to start transferring layers
        if self.is_async and (not is_profile):
            self.rearrange_event.set()
        return None

    def start_async_loop(
        self,
        rank_mapping: dict[int, int] | None = None,
        is_profile: bool = False,
    ):
        if not self.is_async:
            return
        if self.async_worker is None:
            self.async_worker = start_async_worker(
                self,
                is_profile=is_profile,
            )

    def _update_layer_mapping_from_new(
        self, model_state: EplbModelState, layer: int
    ) -> None:
835
        if model_state.new_physical_to_logical_map is None:
836
837
838
839
            return

        target_device = model_state.physical_to_logical_map.device
        new_physical = model_state.new_physical_to_logical_map
840
841
        # If the number of physical experts has changed, then the new map needs to
        # be copied synchronously to avoid a race condition with the async worker
842
843
844
845
        if model_state.physical_to_logical_map.shape[1] != new_physical.shape[1]:
            model_state.physical_to_logical_map = new_physical.to(target_device)
        else:
            model_state.physical_to_logical_map[layer].copy_(
846
                new_physical[layer].to(target_device, non_blocking=True)
847
848
            )

849
850
851
852
853
        num_logical_experts = model_state.logical_to_physical_map.shape[1]
        new_logical, new_replica_count = compute_logical_maps(
            new_physical[layer], num_logical_experts
        )

854
855
856
857
858
859
860
        logical_device = model_state.logical_to_physical_map.device
        max_slots = model_state.logical_to_physical_map.shape[-1]
        slot_delta = max_slots - new_logical.shape[-1]
        if slot_delta > 0:
            new_logical = torch.nn.functional.pad(
                new_logical, (0, slot_delta), value=-1
            )
861
        model_state.logical_to_physical_map[layer].copy_(new_logical.to(logical_device))
862
863
864

        replica_device = model_state.logical_replica_count.device
        model_state.logical_replica_count[layer].copy_(
865
            new_replica_count.to(replica_device)
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
        )

    def _all_ranks_buffer_ready(self, model_state: EplbModelState) -> bool:
        parallel_state = get_ep_group()
        cpu_group = getattr(parallel_state, "cpu_group", None)
        if cpu_group is not None and cpu_group.size() > 1:
            flag = torch.tensor(
                (int(model_state.ep_buffer_ready),), dtype=torch.int32, device="cpu"
            )
            all_reduce(flag, group=cpu_group)
            return int(flag.item()) == cpu_group.size()

        device_group = parallel_state.device_group
        if device_group.size() <= 1:
            return bool(model_state.ep_buffer_ready)

        device = getattr(
            parallel_state, "device", model_state.physical_to_logical_map.device
        )
        flag = torch.tensor(
            (int(model_state.ep_buffer_ready),), dtype=torch.int32, device=device
        )
        all_reduce(flag, group=device_group)
        return int(flag.item()) == device_group.size()

    def move_to_workspace(
        self,
        model_state: EplbModelState,
        ep_group: ProcessGroup,
        is_profile: bool = False,
    ):
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
        # We call move_to_workspace only when ep_buffer_ready is 1.
        # It means we only need to wait for the lock for a short time.
        max_retries = 6  # 1 minute max
        retries = 0
        while not model_state.buffer_lock.acquire(blocking=True, timeout=10.0):
            retries += 1
            if retries >= max_retries:
                raise RuntimeError(
                    f"Rank {ep_group.rank()}: buffer_lock timeout after "
                    "{max_retries * 10}s"
                )
            logger.warning(
                "Rank %d: EPLB buffer_lock acquire failed, retrying (%d/%d)",
                ep_group.rank(),
                retries,
                max_retries,
            )
914
915
        try:
            assert model_state.new_physical_to_logical_map is not None
916
917
918
919
            expert_weights = model_state.model.expert_weights[
                model_state.layer_to_transfer
            ]
            expert_weights_buffer = model_state.expert_buffer
920
921
922
            new_indices = model_state.new_physical_to_logical_map[
                model_state.layer_to_transfer
            ].numpy()
923
            move_from_buffer(
924
925
                expert_weights=expert_weights,
                expert_weights_buffers=expert_weights_buffer,
926
927
                is_unchanged=model_state.is_unchanged,
                is_received_locally=model_state.is_received_locally,
928
929
930
                recv_metadata=model_state.recv_metadata,
                new_indices=new_indices,
                ep_rank=ep_group.rank(),
931
            )
932
            # Record event after consuming buffer to signal async thread
933
            # that it's safe to overwrite the intermediate buffer
934
935
936
937
            consumed_event = torch.cuda.Event()
            consumed_event.record()
            model_state.buffer_consumed_event = consumed_event

938
939
940
941
942
            transferred_layer = model_state.layer_to_transfer
            self._update_layer_mapping_from_new(model_state, transferred_layer)
            # After the main thread consumes, advance layer_to_transfer
            model_state.layer_to_transfer += 1
            model_state.ep_buffer_ready = 0
943
            logger.debug(
944
945
946
                "model %s successfully move_to_workspace layer %d",
                model_state.model_name,
                transferred_layer,
947
            )
948
            if model_state.layer_to_transfer >= model_state.model.num_moe_layers:
949
                self.post_eplb(model_state)
950
951
952
953
954
955
956
957
958
959
                model_state.rebalanced = False
                model_state.layer_to_transfer = 0
                model_state.pending_global_ready_check = False
                logger.info(
                    "finish async transfer for model %s rank %d layer %d",
                    model_state.model_name,
                    ep_group.rank(),
                    model_state.model.num_moe_layers,
                )

960
961
962
963
964
965
966
967
968
969
        finally:
            try:
                model_state.buffer_lock.release()
            except Exception as e:
                logger.error(
                    "Rank %d: buffer_lock release failed in move_to_workspace: %s",
                    ep_group.rank(),
                    str(e),
                )

970
    def post_eplb(self, model_state: EplbModelState) -> None:
971
972
        assert model_state.new_physical_to_logical_map is not None
        model_state.new_physical_to_logical_map = None
973

974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
    def _allreduce_list(self, tensor_list: list[torch.Tensor]) -> list[torch.Tensor]:
        """
        All-reduce a list of tensors.
        """
        if len(tensor_list) == 1:
            all_reduce(tensor_list[0], group=get_ep_group().device_group)
            return tensor_list
        assert all(t.dim() == 2 for t in tensor_list), "All tensors must be 2D."
        assert all(t.shape[1] == tensor_list[0].shape[1] for t in tensor_list), (
            "All tensors must have the same shape[1]."
        )
        # Concatenate, all_reduce, then unpack to original shapes.
        # We assume all tensors are 2D and shape[1] (num_physical_experts)
        # is the same across all models.
        shapes = [t.shape for t in tensor_list]
        concat_tensor = torch.cat(tensor_list, dim=0)

        ep_group = get_ep_group().device_group
        all_reduce(concat_tensor, group=ep_group)

        all_reduce_list = []
        offset = 0
        for shape in shapes:
            all_reduce_list.append(concat_tensor[offset : offset + shape[0], :])
            offset += shape[0]
        return all_reduce_list

    def _sync_load_pass(self) -> list[torch.Tensor]:
        """
        Sync the expert load pass across all ranks for log stats.
        Doesn't update the expert load pass in eplb_model_state.
        """
        load_pass_list = []
        for eplb_model_state in self.model_states.values():
            load_pass_list.append(eplb_model_state.expert_load_pass.clone())
        return self._allreduce_list(load_pass_list)
1010

1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
    @classmethod
    def from_mapping(
        cls,
        model: MixtureOfExperts,
        model_config: ModelConfig,
        device: torch.device,
        parallel_config: ParallelConfig,
        expanded_physical_to_logical: torch.Tensor,
        num_valid_physical_experts: int,
    ) -> "EplbState":
        eplb_state = cls(
            parallel_config=parallel_config,
            device=device,
        )
        eplb_state.add_model(
            model=model,
            model_config=model_config,
        )
        eplb_state.num_valid_physical_experts = num_valid_physical_experts
        eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
        eplb_model_state.physical_to_logical_map.copy_(expanded_physical_to_logical)

1033
1034
1035
1036
1037
1038
1039
1040
        (logical_to_physical_map_cpu, logical_replica_count_cpu) = compute_logical_maps(
            expanded_physical_to_logical.cpu(), model.num_logical_experts
        )

        max_num_replicas = eplb_model_state.logical_to_physical_map.shape[-1]
        num_replicas = logical_to_physical_map_cpu.shape[-1]
        logical_to_physical_map = torch.nn.functional.pad(
            logical_to_physical_map_cpu,
1041
            (
1042
1043
                0,
                max_num_replicas - num_replicas,
1044
            ),
1045
1046
1047
            value=-1,
        ).to(device)
        logical_replica_count = logical_replica_count_cpu.to(device)
1048
1049
1050

        eplb_model_state.logical_to_physical_map.copy_(logical_to_physical_map)
        eplb_model_state.logical_replica_count.copy_(logical_replica_count)
1051

1052
1053
        return eplb_state

1054

1055
1056
1057
1058
1059
1060
1061
1062
1063
@dataclass
class EplbLayerState:
    """Runtime EPLB data stored in the MoE layer."""

    expert_load_view: torch.Tensor | None = None
    logical_to_physical_map: torch.Tensor | None = None
    logical_replica_count: torch.Tensor | None = None


1064
def _node_count_with_rank_mapping(
1065
    pg: ProcessGroup | StatelessProcessGroup,
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
    rank_mapping: dict[int, int],
) -> int:
    if isinstance(pg, ProcessGroup):
        world_size = torch.distributed.get_world_size(group=pg)
    else:
        world_size = pg.world_size

    if world_size == 1:
        return 1

    # Build node assignment map
    node_assignment = [0] * world_size  # rank -> node_id
    next_node_id = 0

    for current_rank in range(world_size):
        if node_assignment[current_rank] != 0:
            continue  # Already assigned to a node

        assert current_rank in rank_mapping
        if rank_mapping[current_rank] == -1:
            continue  # Pending shutdown

        # Assign current rank to a new node
        next_node_id += 1
        node_assignment[current_rank] = next_node_id

        # Find all ranks on the same node as current_rank
        same_node_flags = in_the_same_node_as(pg, current_rank)
        for other_rank, is_same_node in enumerate(same_node_flags):
            if is_same_node and node_assignment[other_rank] == 0:
                node_assignment[other_rank] = next_node_id

1098
    return next_node_id
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177


def compute_logical_maps(
    physical_to_logical_map: torch.Tensor,
    num_logical_experts: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Derive logical_to_physical_map and logical_replica_count from
    physical_to_logical_map.

    Args:
        physical_to_logical_map: [num_layers, num_physical_experts], logical
            expert index for each physical expert slot
        num_logical_experts: total number of logical experts

    Returns:
        logical_to_physical_map: [num_layers, num_logical_experts, max_replicas],
            physical slots per logical expert; -1 where unused
        logical_replica_count: [num_layers, num_logical_experts], number of
            physical replicas per logical expert
    """
    device = physical_to_logical_map.device
    assert physical_to_logical_map.device.type == "cpu"

    dtype = physical_to_logical_map.dtype

    # If computing maps for a single layer, unsqueeze a single element layer dimension
    per_layer = physical_to_logical_map.dim() == 1
    physical_to_logical_map_view = physical_to_logical_map
    if per_layer:
        physical_to_logical_map_view = physical_to_logical_map.unsqueeze(0)
    assert len(physical_to_logical_map_view.shape) == 2
    num_layers, num_physical = physical_to_logical_map_view.shape

    valid_mask = physical_to_logical_map_view >= 0
    logical_replica_count = torch.zeros(
        num_layers,
        num_logical_experts,
        dtype=dtype,
        device=device,
    )
    logical_replica_count.scatter_add_(
        1,
        physical_to_logical_map_view.clamp(min=0),
        valid_mask.to(dtype),
    )

    max_replicas = int(logical_replica_count.max().item())
    logical_to_physical_map_out = torch.full(
        (num_layers, num_logical_experts, max_replicas),
        -1,
        dtype=dtype,
        device=device,
    )

    running_count = torch.zeros_like(logical_replica_count)
    layer_indices = torch.arange(num_layers, device=device)
    for phys_idx in range(num_physical):
        # Logical expert at physical slot phys_idx for each layer
        logical_expert_ids = physical_to_logical_map_view[:, phys_idx]  # [num_layers]

        # Scale up will set the logical expert ids to -1 for all new physical experts.
        # Only consider "valid" experts when setting up the logical_to_physical map.
        valid_expert_mask = logical_expert_ids >= 0
        if not valid_expert_mask.any():
            continue
        valid_layers = layer_indices[valid_expert_mask]
        valid_experts = logical_expert_ids[valid_expert_mask]

        # Use the current running count as the replica index, then increment it.
        replica_idx = running_count[valid_layers, valid_experts]
        logical_to_physical_map_out[valid_layers, valid_experts, replica_idx] = phys_idx
        running_count[valid_layers, valid_experts] += 1

    # If computing maps for a single layer, squeeze out the extra layer dimension
    # before returning
    if per_layer:
        return logical_to_physical_map_out.squeeze(0), logical_replica_count.squeeze(0)
    return logical_to_physical_map_out, logical_replica_count