eplb_state.py 22.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
"""
Expert parallelism load balancer (EPLB) metrics and states.

# Glossary

- **Logical Expert**: An expert that is part of the model's logical structure.
  It holds a set of weights and is replicated across multiple physical
  experts.
- **Redundant Expert**: To achieve load balancing, for some popular logical
  experts, we create additional copies of the expert weights. During inference,
  each of these copies can be routed to by the same set of tokens.
- **Physical Expert**: An expert that is instantiated on a specific device.
  It is a replica of a logical expert and can be rearranged across devices.
  I.e., one logical expert may have multiple sets of weights initialized on
  different devices, and each of these sets is a physical expert.
- **Local Physical Expert**: A physical expert that is instantiated on the
  current device.

For example: DeepSeek-R1 has 256 logical experts, so each MoE layer
has 256 sets of linear layer weights in the model parameters. If we add 32
redundant experts, DeepSeek-R1 will have 256 + 32 = 288 physical experts in
total. And when deploying, we'll have 288 sets of linear layer weights for each
MoE layer. If we have 32 EP ranks, then each GPU will hold 288 / 32 = 9 local
physical experts.
"""

import time
from collections.abc import Sequence
from dataclasses import dataclass

import torch
34
from torch.distributed import ProcessGroup, all_reduce
35
36

from vllm.config import ParallelConfig
37
38
39
40
41
from vllm.distributed.parallel_state import (
    get_ep_group,
    get_node_count,
    in_the_same_node_as,
)
42
from vllm.distributed.utils import StatelessProcessGroup
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MixtureOfExperts

from .rebalance_algo import rebalance_experts
from .rebalance_execute import rearrange_expert_weights_inplace

logger = init_logger(__name__)


@dataclass
class EplbState:
    """EPLB metrics."""

    physical_to_logical_map: torch.Tensor
    """
    Mapping from physical experts to logical experts.

    Shape: (num_moe_layers, num_physical_experts)

    # Example

    For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3
    EP ranks, the mapping could look like this:

    ```
    [[0, 1, 2, 3, 0, 1],
     [0, 2, 0, 1, 0, 3]]
    ```
    """
    logical_to_physical_map: torch.Tensor
    """
    Mapping from logical experts to physical experts.

    This is a sparse matrix, where -1 indicates no mapping.

    Shape: (num_moe_layers, num_logical_experts, num_redundant_experts + 1)

    # Example

    For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3
    EP ranks, the mapping could look like this:

    ```
    [[[0, 4, -1],
      [1, 5, -1],
      [2, -1, -1],
      [3, -1, -1]],
     [[0, 2, 4],
      [3, -1, -1],
      [1, -1, -1],
      [5, -1, -1]]]
    ```
    """
    logical_replica_count: torch.Tensor
    """
    Number of replicas for each logical expert.
    This is exactly the non-`-1` count in the `logical_to_physical_map`.

    Shape: (num_moe_layers, num_logical_experts)

    # Example
    For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3
    EP ranks, the count could look like this:

    ```
    [[2, 2, 1, 1],
     [3, 1, 1, 1]]
    """

    expert_load_pass: torch.Tensor
    """
    Expert load during this forward pass. 
    We use the token count each expert processes as the load.

117
    Shape: (num_moe_layers, num_physical_experts)
118
119
120
121
122
    """
    expert_load_window: torch.Tensor
    """
    A sliding window of expert load.

123
124
125
126
127
128
129
130
131
    Shape: (window_size, num_moe_layers, num_physical_experts)

    NOTE: The expert_load_view now records load for all physical experts
    rather than just local experts. This ensures consistent load statistics
    across different dispatch methods (naive all-to-all, DeepEP, pplx-kernels).
    The recorded load will be multiplied by dp_size when using naive all-to-all
    due to each DP rank contributing the same token set to the calculation.
    See:
    https://github.com/vllm-project/vllm/pull/22167#pullrequestreview-3086143856
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    """
    expert_load_window_step: int = 0
    """
    Current step in the sliding window.

    Different from `expert_rearrangement_step`, each EP rank may have its own
    `expert_load_window_step`.
    """
    expert_load_window_size: int = 0
    """
    Size of the expert load sliding window.
    This is a constant and is taken from the config.
    """

    expert_rearrangement_step: int = 0
    """
    Steps after last rearrangement.
    Will trigger a rearrangement if it exceeds the threshold.

    NOTE: Keep in mind that all EP ranks need to have the same
    `expert_rearrangement_step` value to ensure synchronization.
    Otherwise, the rearrangement will hang at collective
    communication calls.
    """
    expert_rearrangement_step_interval: int = 0
    """
    Interval for expert rearrangement steps.
    This is a constant and is taken from the config.
    """

    @staticmethod
    def build_initial_global_physical_to_logical_map(
        num_routed_experts: int,
        num_redundant_experts: int,
    ) -> Sequence[int]:
        """
        Build an initial expert arrangement using the following structure:
        [original routed experts, redundant experts]

        Returns:
            physical_to_logical_map (Sequence[int]): A list of integers,
                where each integer is the index of the logical expert
                that the corresponding physical expert maps to.
        """
        global_physical_to_logical_map = list(range(num_routed_experts))
        global_physical_to_logical_map += [
            i % num_routed_experts for i in range(num_redundant_experts)
        ]
        return global_physical_to_logical_map

    @classmethod
    def build(
        cls,
        model: MixtureOfExperts,
        device: torch.device,
        parallel_config: ParallelConfig,
188
189
190
        global_expert_load: torch.Tensor | None = None,
        old_global_expert_indices: torch.Tensor | None = None,
        rank_mapping: dict[int, int] | None = None,
191
192
193
194
    ) -> "EplbState":
        """
        Build the initial EPLB state.
        """
195
196
197
198
        physical_to_logical_map_list = cls.build_initial_global_physical_to_logical_map(
            model.num_routed_experts,
            model.num_redundant_experts,
        )
199
200
201
202
        physical_to_logical_map = torch.tensor(
            physical_to_logical_map_list,
            device=device,
        )
203
204
205
206
207
208
        # Assuming 8 GPUs per node, this supports up to
        # (1023 + 1) / 8 = 128 nodes for now.
        # TODO(rui): make this configurable
        MAX_EXPERT_REDUNDANCY = 1023
        assert model.num_redundant_experts <= MAX_EXPERT_REDUNDANCY, (
            f"num_redundant_experts {model.num_redundant_experts} "
209
210
            f"must be less than or equal to {MAX_EXPERT_REDUNDANCY}"
        )
211
        max_slots_per_logical_expert = MAX_EXPERT_REDUNDANCY + 1
212
        logical_to_physical_map = torch.full(
213
            (model.num_logical_experts, max_slots_per_logical_expert),
214
215
216
217
            -1,
            device=device,
        )
        logical_replica_count = torch.zeros(
218
            (model.num_logical_experts,),
219
220
221
222
223
224
            device=device,
            dtype=torch.long,
        )

        for i in range(model.num_physical_experts):
            logical_idx = physical_to_logical_map[i]
225
            logical_to_physical_map[logical_idx, logical_replica_count[logical_idx]] = i
226
227
228
            logical_replica_count[logical_idx] += 1

        # Duplicate initial mapping for all layers
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
        physical_to_logical_map = (
            physical_to_logical_map.unsqueeze(0)
            .expand(
                model.num_moe_layers,
                -1,
            )
            .contiguous()
        )
        logical_to_physical_map = (
            logical_to_physical_map.unsqueeze(0)
            .expand(
                model.num_moe_layers,
                -1,
                -1,
            )
            .contiguous()
        )
        logical_replica_count = (
            logical_replica_count.unsqueeze(0)
            .expand(
                model.num_moe_layers,
                -1,
            )
            .contiguous()
        )
254
255

        expert_load_pass = torch.zeros(
256
            (model.num_moe_layers, model.num_physical_experts),
257
258
259
            dtype=torch.int32,
            device=device,
        )
260
        expert_load_window_size = parallel_config.eplb_config.window_size
261
        expert_load_window = torch.zeros(
262
            (expert_load_window_size, model.num_moe_layers, model.num_physical_experts),
263
264
265
266
267
            dtype=torch.int32,
            device=device,
        )

        # Set the initial progress of rearrangement to 3/4
268
        eplb_step_interval = parallel_config.eplb_config.step_interval
269
        expert_rearrangement_step = max(0, eplb_step_interval - eplb_step_interval // 4)
270

271
272
        if global_expert_load is not None:
            ep_group = get_ep_group().device_group
273
274
275
276
            assert global_expert_load.shape == (
                model.num_moe_layers,
                model.num_logical_experts,
            )
277
278
279
280
281
282
283
284
285
286
287
288
            assert global_expert_load.dtype == torch.int64

            num_replicas = model.num_physical_experts
            num_groups = model.num_expert_groups
            num_nodes = get_node_count()
            num_gpus = ep_group.size()

            if num_gpus % num_nodes != 0:
                num_nodes = 1
                logger.warning_once(
                    f"num_gpus % num_nodes != 0, "
                    "not using hierarchical rearrangement algorithm.\n"
289
290
                    f"{num_gpus=}, {num_nodes=}"
                )
291
292
293
294
295
296

            # Get new expert mappings
            (
                new_physical_to_logical_map,
                new_logical_to_physical_map,
                new_logical_replica_count,
297
            ) = rebalance_experts(
298
299
300
301
302
                global_expert_load,
                num_replicas,
                num_groups,
                num_nodes,
                num_gpus,
303
            )
304
305
306
307
308
309
310
311
312
313
314
315

            max_physical_slots = new_logical_to_physical_map.shape[-1]
            assert max_physical_slots <= logical_to_physical_map.shape[-1]
            new_logical_to_physical_map = torch.nn.functional.pad(
                new_logical_to_physical_map,
                (0, logical_to_physical_map.shape[-1] - max_physical_slots),
                value=-1,
            )
            physical_to_logical_map = new_physical_to_logical_map.to(device)
            logical_to_physical_map.copy_(new_logical_to_physical_map)
            logical_replica_count.copy_(new_logical_replica_count)

316
317
318
319
320
        model.set_eplb_state(
            expert_load_pass,
            logical_to_physical_map,
            logical_replica_count,
        )
321
322
323
324
325
326
327
328
329
330
        if global_expert_load is not None:
            rearrange_expert_weights_inplace(
                old_global_expert_indices,
                new_physical_to_logical_map,
                model.expert_weights,
                ep_group,
                False,
                rank_mapping,
            )
            expert_rearrangement_step = 0
331
332
333
334
335
336
337
338
339
340
341
342

        return cls(
            physical_to_logical_map,
            logical_to_physical_map,
            logical_replica_count,
            expert_load_pass,
            expert_load_window,
            expert_load_window_size=expert_load_window_size,
            expert_rearrangement_step=expert_rearrangement_step,
            expert_rearrangement_step_interval=eplb_step_interval,
        )

343
344
345
346
347
348
349
    def step(
        self,
        model: MixtureOfExperts,
        is_dummy: bool = False,
        is_profile: bool = False,
        log_stats: bool = False,
    ) -> None:
350
351
352
353
354
355
        """
        Step the EPLB state.

        Args:
            model (MixtureOfExperts): The MoE model.
            is_dummy (bool): If `True`, this is a dummy step and the load
356
357
                metrics recorded in this forward pass will not count.
                Defaults to `False`.
358
            is_profile (bool): If `True`, perform a dummy rearrangement
359
360
361
                with maximum communication cost. This is used in
                `profile_run` to reserve enough memory
                for the communication buffer.
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
            log_stats (bool): If `True`, log the expert load metrics.

        # Stats
            The metrics are all summed up across layers.
            - `avg_tokens`: The average load across ranks.
            - `max_tokens`: The maximum load across ranks.
            - `balancedness`: The ratio of average load to maximum load.
        """

        if is_profile:
            self.rearrange(model, is_profile=True)
            return

        if is_dummy:
            # Do not record load metrics for dummy steps
            self.expert_load_pass.zero_()

        if log_stats:
380
381
            # total_expert_load_pass: (num_moe_layers, num_physical_experts)
            total_expert_load_pass = self.expert_load_pass.clone()
382
383
384

            # Collect load metrics from all ranks
            ep_group = get_ep_group().device_group
385
386
387
            all_reduce(total_expert_load_pass, group=ep_group)

            # num_tokens_per_rank: (num_moe_layers, num_ranks)
388
389
390
391
392
393
394
            num_tokens_per_rank = (
                total_expert_load_pass.reshape(
                    total_expert_load_pass.shape[0], ep_group.size(), -1
                )
                .sum(dim=-1)
                .float()
            )
395
396
397
398
399

            # Compute balancedness ratio:
            # for each layer:
            #   (mean load across ranks) / (max load across ranks)
            avg_tokens_tensor = num_tokens_per_rank.mean(dim=0).sum(dim=0)
400
            max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum(dim=0)
401
402
403

            # Just to make type checker happy
            tokens_tensors: list[float] = torch.stack(
404
405
                [avg_tokens_tensor, max_tokens_tensor]
            ).tolist()
406
407
408
409
410
            avg_tokens, max_tokens = tokens_tensors
            balancedness = avg_tokens / max_tokens if max_tokens > 0 else 0.0

            if ep_group.rank() == 0:
                logger.info(
411
412
413
414
415
                    "EPLB step: avg_tokens=%.2f, max_tokens=%d, balancedness=%.4f",
                    avg_tokens,
                    max_tokens,
                    balancedness,
                )
416
417
418
419

        # Update the expert load sliding window
        if not is_dummy:
            self.expert_load_window[self.expert_load_window_step] = (
420
421
                self.expert_load_pass.clone()
            )
422
423
424
425
426
427
428
429
430
431
            self.expert_load_window_step += 1
            if self.expert_load_window_step >= self.expert_load_window_size:
                self.expert_load_window_step = 0
            self.expert_load_pass.zero_()

        # Step the expert rearrangement step
        # Note that even if this is a dummy step, we still increment the
        # rearrangement step and perform rearrangement to ensure all ranks are
        # performing collective communication.
        self.expert_rearrangement_step += 1
432
        if self.expert_rearrangement_step >= self.expert_rearrangement_step_interval:
433
434
435
            self.expert_rearrangement_step = 0
            self.rearrange(model)

436
437
438
439
440
    def rearrange(
        self,
        model: MixtureOfExperts,
        is_profile: bool = False,
        execute_shuffle: bool = True,
441
442
443
        global_expert_load: torch.Tensor | None = None,
        rank_mapping: dict[int, int] | None = None,
    ) -> torch.Tensor | None:
444
445
446
447
448
449
450
451
452
453
454
455
        """
        Rearrange the experts according to the current load.
        """

        ep_group = get_ep_group().device_group
        ep_rank = ep_group.rank()

        time_start = None
        is_main_rank = ep_rank == 0
        if is_main_rank:
            torch.cuda.synchronize()
            time_start = time.perf_counter()
456
            logger.info("Rearranging experts %s...", "(profile)" if is_profile else "")
457

458
        if global_expert_load is None:
459
            # Map the physical expert load to global logical experts
460
461
462
463
464
465
466
467
468
            logical_expert_load_window = torch.zeros(
                self.expert_load_window_size,
                model.num_moe_layers,
                model.num_logical_experts,
                dtype=self.expert_load_window.dtype,
                device=self.expert_load_window.device,
            )
            logical_expert_load_window.scatter_add_(
                dim=-1,
469
470
471
                index=self.physical_to_logical_map.unsqueeze(0)
                .expand_as(self.expert_load_window)
                .long(),
472
473
                src=self.expert_load_window,
            )
474

475
476
477
            if not execute_shuffle:
                metadata = torch.tensor(
                    [
478
479
480
                        model.num_moe_layers,
                        model.num_logical_experts,
                        self.physical_to_logical_map.shape[1],
481
482
483
484
                    ],
                    dtype=torch.int32,
                    device="cpu",
                )
485
486
487
                torch.distributed.broadcast(
                    metadata, group=get_ep_group().cpu_group, group_src=0
                )
488
489
490
491
492
493
494
495

            # Perform all-reduce to get the expert load across all ranks
            global_expert_load_window = logical_expert_load_window.sum(dim=0)
            all_reduce(global_expert_load_window, group=ep_group)

            if not execute_shuffle:
                # (num_moe_layers, old_num_physical_experts)
                old_global_expert_indices = self.physical_to_logical_map
496
497
498
                torch.distributed.broadcast(
                    old_global_expert_indices, group=ep_group, group_src=0
                )
499
500
501
502
                return global_expert_load_window
        else:
            assert execute_shuffle
            global_expert_load_window = global_expert_load
503
504
505
506

        # TODO(bowen): Treat differently for prefill and decode nodes
        num_replicas = model.num_physical_experts
        num_groups = model.num_expert_groups
507
508
509
510
511
512
        if rank_mapping is not None and len(rank_mapping) == ep_group.size():
            # NOTE(yongji): scale down, we need to rebalance the experts on
            # remaining GPUs, transfer the experts while we haven't shutdown
            # the GPUs to be released.
            cpu_group = get_ep_group().cpu_group
            num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping)
513
514
515
516
            num_gpus = sum(new_rank != -1 for new_rank in rank_mapping.values())
            num_replicas = (
                num_replicas // ep_group.size() * num_gpus
            )  # handle num replicas change
517
518
519
        else:
            num_nodes = get_node_count()
            num_gpus = ep_group.size()
520
521

        if num_gpus % num_nodes != 0:
522
            self.num_nodes = 1
523
524
525
            logger.warning_once(
                f"num_gpus % num_nodes != 0, "
                "not using hierarchical rearrangement algorithm.\n"
526
527
                f"{num_gpus=}, {num_nodes=}"
            )
528
529
530
531
532
533

        # Get new expert mappings
        (
            new_physical_to_logical_map,
            new_logical_to_physical_map,
            new_logical_replica_count,
534
        ) = rebalance_experts(
535
536
537
538
539
            global_expert_load_window,
            num_replicas,
            num_groups,
            num_nodes,
            num_gpus,
540
        )
541
542
543
544
545
546
547
548

        # Update expert weights
        rearrange_expert_weights_inplace(
            self.physical_to_logical_map,
            new_physical_to_logical_map,
            model.expert_weights,
            ep_group,
            is_profile,
549
            rank_mapping,
550
551
552
        )

        if not is_profile:
553
554
555
556
            if (
                self.physical_to_logical_map.shape[1]
                != new_physical_to_logical_map.shape[1]
            ):
557
                self.physical_to_logical_map = new_physical_to_logical_map.to(
558
559
                    self.physical_to_logical_map.device
                )
560
561
562
563
564
565
            else:
                self.physical_to_logical_map.copy_(new_physical_to_logical_map)
            max_physical_slots = new_logical_to_physical_map.shape[-1]
            assert max_physical_slots <= self.logical_to_physical_map.shape[-1]
            new_logical_to_physical_map = torch.nn.functional.pad(
                new_logical_to_physical_map,
566
                (0, self.logical_to_physical_map.shape[-1] - max_physical_slots),
567
568
                value=-1,
            )
569
570
571
572
573
574
575
576
577
578
579
580
            self.logical_to_physical_map.copy_(new_logical_to_physical_map)
            self.logical_replica_count.copy_(new_logical_replica_count)

        if is_main_rank:
            assert time_start is not None
            torch.cuda.synchronize()
            time_end = time.perf_counter()
            logger.info(
                "Rearranged experts%sin %.2f seconds.",
                " (profile) " if is_profile else " ",
                time_end - time_start,
            )
581
        return None
582
583
584
585
586
587
588
589

    @staticmethod
    def recv_state() -> tuple[torch.Tensor, torch.Tensor]:
        """
        Receive the expert load and old placement from the master rank.
        """
        ep_group = get_ep_group()
        metadata = torch.empty(3, dtype=torch.int32, device="cpu")
590
        torch.distributed.broadcast(metadata, group=ep_group.cpu_group, group_src=0)
591
        num_moe_layers, num_logical_experts, num_old_physical_experts = (
592
593
            metadata.tolist()
        )
594
595
596
597
598
599
600
601
602
603
604
        global_expert_load = torch.zeros(
            (num_moe_layers, num_logical_experts),
            dtype=torch.int64,
            device=ep_group.device,
        )
        all_reduce(global_expert_load, group=ep_group.device_group)
        old_global_expert_indices = torch.empty(
            (num_moe_layers, num_old_physical_experts),
            dtype=torch.int64,
            device=ep_group.device,
        )
605
606
607
        torch.distributed.broadcast(
            old_global_expert_indices, group=ep_group.device_group, group_src=0
        )
608
609
610
611
612

        return global_expert_load, old_global_expert_indices


def _node_count_with_rank_mapping(
613
    pg: ProcessGroup | StatelessProcessGroup,
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
    rank_mapping: dict[int, int],
) -> int:
    if isinstance(pg, ProcessGroup):
        world_size = torch.distributed.get_world_size(group=pg)
    else:
        world_size = pg.world_size

    if world_size == 1:
        return 1

    # Build node assignment map
    node_assignment = [0] * world_size  # rank -> node_id
    next_node_id = 0

    for current_rank in range(world_size):
        if node_assignment[current_rank] != 0:
            continue  # Already assigned to a node

        assert current_rank in rank_mapping
        if rank_mapping[current_rank] == -1:
            continue  # Pending shutdown

        # Assign current rank to a new node
        next_node_id += 1
        node_assignment[current_rank] = next_node_id

        # Find all ranks on the same node as current_rank
        same_node_flags = in_the_same_node_as(pg, current_rank)
        for other_rank, is_same_node in enumerate(same_node_flags):
            if is_same_node and node_assignment[other_rank] == 0:
                node_assignment[other_rank] = next_node_id

646
    return next_node_id