elastic_execute.py 22.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import gc
import weakref
from collections.abc import Iterable, Sequence

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed import P2POp

from vllm.compilation.counter import compilation_counter
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.wrapper import reset_compile_wrapper
from vllm.config import (
    CompilationMode,
    set_current_vllm_config,
)
from vllm.distributed import (
    get_dp_group,
    get_ep_group,
    get_pcp_group,
    get_tp_group,
)
from vllm.distributed.elastic_ep.standby_state import (
    create_standby_groups,
    get_standby_dp_group,
    get_standby_ep_group,
    pop_standby_groups,
)
from vllm.distributed.parallel_state import (
    _replace_active_groups,
    prepare_communication_buffer_for_model,
)
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
from vllm.v1.worker.workspace import lock_workspace, unlock_workspace

logger = init_logger(__name__)


def batch_transfer_weights(
    model: nn.Module,
    is_sender: bool,
    peer_rank: int,
    dp_group: StatelessGroupCoordinator,
    expert_weights: Sequence[Iterable[torch.Tensor]],
) -> None:
    device_comm = dp_group.device_communicator
    if device_comm is None:
        raise ValueError("No device communicator found")

    expert_weights_set = set()
    for weight_group in expert_weights:
        for weight in weight_group:
            expert_weights_set.add(weight.data_ptr())

    state_dict = model.state_dict()
    all_params = []

    for name, param in state_dict.items():
        if name.endswith("expert_map"):
            continue
        if param.data_ptr() not in expert_weights_set:
            all_params.append(param.data)

    assert len(all_params) > 0
    p2p_ops = []
    for param in all_params:
        op = object.__new__(P2POp)
        if is_sender:
            op.op = torch.distributed.isend
            op.tensor = param
        else:
            op.op = torch.distributed.irecv
            op.tensor = param
        op.group_peer = peer_rank
        p2p_ops.append(op)
    device_comm.batch_isend_irecv(p2p_ops)


def broadcast_expert_mapping(
    physical_to_logical: torch.Tensor | None,
    num_local_physical_experts: int | None,
    num_logical_experts: int | None,
    dp_group: StatelessGroupCoordinator,
    device: torch.device,
    src_rank: int = 0,
) -> tuple[torch.Tensor, int, int]:
    if dp_group.rank_in_group == src_rank:
        assert physical_to_logical is not None
        assert num_local_physical_experts is not None
        assert num_logical_experts is not None
        assert physical_to_logical.dtype == torch.int64
        shape_tensor = torch.tensor(
            list(physical_to_logical.shape), dtype=torch.int64, device="cpu"
        )
        metadata_tensor = torch.tensor(
            [num_local_physical_experts, num_logical_experts],
            dtype=torch.int64,
            device="cpu",
        )
    else:
        shape_tensor = torch.empty(2, dtype=torch.int64, device="cpu")
        metadata_tensor = torch.empty(2, dtype=torch.int64, device="cpu")

    shape_tensor = dp_group.tcp_store_group.broadcast(shape_tensor, src_rank)
    metadata_tensor = dp_group.tcp_store_group.broadcast(metadata_tensor, src_rank)

    if dp_group.rank_in_group != src_rank:
        assert device is not None
        physical_to_logical = torch.empty(
            tuple(shape_tensor.tolist()),
            dtype=torch.int64,
            device=device,
        )

    assert physical_to_logical is not None
    physical_to_logical = dp_group.broadcast(physical_to_logical, src_rank)
    num_local_physical_experts = int(metadata_tensor[0].item())
    num_logical_experts = int(metadata_tensor[1].item())

    return physical_to_logical, num_local_physical_experts, num_logical_experts


class ElasticEPScalingExecutor:
    def __init__(self, worker):
        self.worker_ref = weakref.ref(worker)
        self.reconfig_request = None

    @property
    def worker(self):
        worker = self.worker_ref()
        if worker is None:
            raise RuntimeError("Worker has been garbage collected")
        return worker

    def execute(self, execute_method: str, *args, **kwargs):
        method = getattr(self, execute_method, None)
        if method is None:
            raise ValueError(f"Unknown execute method: {execute_method}")
        return method(*args, **kwargs)

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    def _set_eplb_suppressed(self, suppressed: bool) -> None:
        self.worker.model_runner.eep_eplb_suppressed = suppressed
        ep_group = get_standby_ep_group() or get_ep_group()
        if ep_group.rank == 0:
            logger.info(
                "[Elastic EP] EPLB %s elastic scaling transition",
                "disabled during" if suppressed else "re-enabled after",
            )

    def load_model(self) -> None:
        (
            expanded_physical_to_logical,
            num_logical_experts,
            old_num_physical_experts,
        ) = self.receive_expert_mapping()
        num_physical_experts = expanded_physical_to_logical.shape[1]
        self.worker.parallel_config.eplb_config.num_redundant_experts = (
            num_physical_experts - num_logical_experts
        )
        self.worker.load_model(load_dummy_weights=True)
        self.worker.model_runner.setup_eplb_from_mapping(
            expanded_physical_to_logical, old_num_physical_experts
        )
        self._set_eplb_suppressed(True)

173
174
175
176
177
    def create_standby_groups(
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
        self.reconfig_request = reconfig_request
        new_dp_size = reconfig_request.new_data_parallel_size
178
        old_dp_size = get_dp_group().world_size
179
180
181
182
183
184
185
186
187
188
189
190
        world_size = self.worker.vllm_config.parallel_config.world_size
        new_world_size_across_dp = world_size * new_dp_size
        updated_config = copy.copy(self.worker.vllm_config)
        updated_config.parallel_config = copy.deepcopy(
            self.worker.vllm_config.parallel_config
        )
        updated_config.parallel_config.data_parallel_size = new_dp_size
        with set_current_vllm_config(updated_config):
            create_standby_groups(
                new_dp_size=new_dp_size,
                new_world_size_across_dp=new_world_size_across_dp,
                master_ip=reconfig_request.new_data_parallel_master_ip,
191
192
                coord_store_port=reconfig_request.coord_store_port,
                enable_eplb=updated_config.parallel_config.enable_eplb,
193
            )
194
195
        if new_dp_size > old_dp_size:
            self._set_eplb_suppressed(True)
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240

    def transfer_weights(self, old_dp_size: int, new_dp_size: int) -> None:
        standby_dp_group = get_standby_dp_group()
        assert standby_dp_group is not None
        # Broadcast old_dp_size to all workers in standby group
        if standby_dp_group.rank_in_group < old_dp_size:
            old_dp_size_tensor = torch.tensor(
                [old_dp_size], dtype=torch.int64, device="cpu"
            )
        else:
            old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu")
        old_dp_size_tensor = standby_dp_group.tcp_store_group.broadcast(
            old_dp_size_tensor, 0
        )

        num_new_workers = new_dp_size - old_dp_size
        dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank

        # Sender-receiver pairing: the first new_workers % old_dp_size
        # senders get (k+1) contiguous receivers, the rest get k
        # receivers.
        num_dst_per_sender = num_new_workers // old_dp_size
        remainder = num_new_workers % old_dp_size

        if dp_rank < remainder:
            recv_begin = dp_rank * (num_dst_per_sender + 1)
            recv_end = recv_begin + num_dst_per_sender + 1
        else:
            recv_begin = (
                remainder * (num_dst_per_sender + 1)
                + (dp_rank - remainder) * num_dst_per_sender
            )
            recv_end = recv_begin + num_dst_per_sender

        ranks_to_send = list(range(old_dp_size + recv_begin, old_dp_size + recv_end))

        model = self.worker.model_runner.get_model()
        for new_worker_rank in sorted(ranks_to_send):
            batch_transfer_weights(
                model=model,
                is_sender=True,
                peer_rank=new_worker_rank,
                dp_group=standby_dp_group,
                expert_weights=model.expert_weights,
            )
241
        torch.accelerator.synchronize()
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262

    def broadcast_expert_mapping(self) -> None:
        standby_dp_group = get_standby_dp_group()
        assert standby_dp_group is not None
        model_config = self.worker.model_runner.model_config
        eplb_state = self.worker.model_runner.eplb_state
        assert eplb_state is not None
        eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
        physical_to_logical = eplb_model_state.physical_to_logical_map
        num_physical_experts = physical_to_logical.shape[1]
        num_local_physical_experts = num_physical_experts // get_ep_group().world_size
        num_logical_experts = eplb_model_state.logical_replica_count.shape[1]
        broadcast_expert_mapping(
            physical_to_logical=physical_to_logical,
            num_local_physical_experts=num_local_physical_experts,
            num_logical_experts=num_logical_experts,
            dp_group=standby_dp_group,
            src_rank=0,
            device=self.worker.device,
        )

263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    def _release_cuda_graphs(self) -> None:
        if isinstance(self.worker.model_runner.model, CUDAGraphWrapper):
            wrapper = self.worker.model_runner.model
            wrapper.concrete_cudagraph_entries = {}

        elif isinstance(self.worker.model_runner.model, UBatchWrapper):
            raise RuntimeError("DBO is not yet supported in elastic EP")

        torch.compiler.reset()
        with set_current_vllm_config(self.worker.vllm_config):
            reset_compile_wrapper(self.worker.model_runner.get_model())

        gc.collect()
        torch.accelerator.synchronize()
        torch.accelerator.empty_cache()

279
    def switch_and_remove(self) -> None:
280
        self._release_cuda_graphs()
281
282
283
284
285
286
        _replace_active_groups(world=None, dp=None, ep=None, eplb=None, node_count=None)

    def switch_and_prepare(self) -> None:
        old_dp_size = get_dp_group().world_size
        old_ep_size = get_ep_group().world_size

287
        self._release_cuda_graphs()
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
        _replace_active_groups(**pop_standby_groups())

        parallel_config = self.worker.vllm_config.parallel_config
        reconfig_request = self.reconfig_request
        assert reconfig_request is not None
        new_dp_size = reconfig_request.new_data_parallel_size
        new_ep_size = get_ep_group().world_size

        parallel_config.data_parallel_size = new_dp_size
        if (
            reconfig_request.new_data_parallel_rank
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
        if (
            reconfig_request.new_data_parallel_rank_local
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank_local = (
                reconfig_request.new_data_parallel_rank_local
            )
        parallel_config.data_parallel_master_ip = (
            reconfig_request.new_data_parallel_master_ip
        )
        parallel_config.data_parallel_master_port = (
            reconfig_request.new_data_parallel_master_port
        )

        # Reconfigure MoE modules with new EP size
        moe_modules = [
            module
            for module in self.worker.model_runner.model.modules()
            if (
                module.__class__.__name__ == "FusedMoE"
                or module.__class__.__name__ == "SharedFusedMoE"
            )
        ]
        num_local_experts = moe_modules[0].moe_config.num_local_experts
        assert all(
            module.moe_config.num_local_experts == num_local_experts
            for module in moe_modules
        ), "All MoE modules must have the same number of experts"
        for module in moe_modules:
            module.moe_config.num_experts = num_local_experts * new_ep_size
            module.global_num_experts = module.moe_config.num_experts
            tp_size = get_tp_group().world_size
            is_sequence_parallel = parallel_config.use_sequence_parallel_moe
            sp_size = tp_size if is_sequence_parallel else 1
            module.moe_parallel_config = FusedMoEParallelConfig.make(
                tp_size_=tp_size,
                pcp_size_=get_pcp_group().world_size,
                dp_size_=get_dp_group().world_size,
                sp_size_=sp_size,
                vllm_parallel_config=parallel_config,
            )
            module.moe_config.moe_parallel_config = module.moe_parallel_config

        # Update EPLB state
        eplb_state = self.worker.model_runner.eplb_state
        assert eplb_state is not None
        model_config = self.worker.model_runner.model_config
        eplb_model_state = eplb_state.model_states[model_config.compute_hash()]

        num_physical_experts = num_local_experts * new_ep_size
        num_logical_experts = eplb_model_state.logical_replica_count.shape[1]
        parallel_config.eplb_config.num_redundant_experts = (
            num_physical_experts - num_logical_experts
        )
        old_physical_to_logical = eplb_model_state.physical_to_logical_map
        num_moe_layers = old_physical_to_logical.shape[0]
        num_local_experts = eplb_model_state.expert_load_pass.shape[1] // old_ep_size
        if new_dp_size > old_dp_size:
            expanded_physical_to_logical = torch.full(
                (num_moe_layers, num_local_experts * new_ep_size),
                -1,
                dtype=old_physical_to_logical.dtype,
                device=old_physical_to_logical.device,
            )
            expanded_physical_to_logical[:, : num_local_experts * old_ep_size] = (
                old_physical_to_logical
            )
            eplb_model_state.physical_to_logical_map = expanded_physical_to_logical

        old_num_physical_experts = eplb_model_state.expert_load_pass.shape[1]
        pad_size = num_physical_experts - old_num_physical_experts
        if new_dp_size > old_dp_size:
            assert pad_size > 0
            expanded_expert_load_pass = F.pad(
                eplb_model_state.expert_load_pass, (0, pad_size), value=0
            )
            expanded_expert_load_window = F.pad(
                eplb_model_state.expert_load_window, (0, pad_size), value=0
            )
            eplb_model_state.expert_load_pass = expanded_expert_load_pass
            eplb_model_state.expert_load_window = expanded_expert_load_window
            eplb_state.num_valid_physical_experts = old_num_physical_experts
        else:
            assert pad_size < 0
            eplb_model_state.expert_load_pass = eplb_model_state.expert_load_pass[
                :, :num_physical_experts
            ]
            eplb_model_state.expert_load_window = eplb_model_state.expert_load_window[
                :, :, :num_physical_experts
            ]
            eplb_state.num_valid_physical_experts = num_physical_experts

        model = self.worker.model_runner.get_model()
        model.expert_weights = []
        with set_current_vllm_config(self.worker.vllm_config):
            model.set_eplb_state(
                eplb_model_state.expert_load_pass,
                eplb_model_state.logical_to_physical_map,
                eplb_model_state.logical_replica_count,
            )
            model.update_physical_experts_metadata(
                num_physical_experts=num_physical_experts,
                num_local_physical_experts=num_local_experts,
            )
            # Force re-creation of the modular kernel (and all2all manager)
            # for the new EP size by resetting quant_method to base
            for module in moe_modules:
                if hasattr(module.quant_method, "old_quant_method"):
                    module.quant_method = module.quant_method.old_quant_method
                    module.runner = module._init_runner()
            prepare_communication_buffer_for_model(self.worker.model_runner.model)
        if (
            self.worker.vllm_config.compilation_config.mode
            == CompilationMode.STOCK_TORCH_COMPILE
        ):
            # NOTE(yongji): when using stock torch.compile,
            # torch.compile is triggered during GPUModelRunner's load_model()
            # TODO(yongji):check do we need to re-trigger torch.compile here?
            # any changes to the tensor shapes in execution should already
            # be handled internally by torch.compile.
            backend = self.worker.vllm_config.compilation_config.init_backend(
                self.worker.vllm_config
            )
            compilation_counter.stock_torch_compile_count += 1
            self.worker.model_runner.model.compile(fullgraph=True, backend=backend)

        multi_block_table = self.worker.model_runner.input_batch.block_table
        saved_block_tables: list[tuple[torch.Tensor, torch.Tensor]] = []
        for bt in multi_block_table.block_tables:
            saved_block_tables.append(
                (bt.block_table.gpu.clone(), bt.block_table.cpu.clone())
            )
        multi_block_table.clear()

        unlock_workspace()
        self.worker.compile_or_warm_up_model()
        lock_workspace()

        for bt, (saved_gpu, saved_cpu) in zip(
            multi_block_table.block_tables, saved_block_tables
        ):
            bt.block_table.gpu.copy_(saved_gpu)
            bt.block_table.cpu.copy_(saved_cpu)
445
446
        if new_dp_size < old_dp_size:
            self._set_eplb_suppressed(False)
447

448
449
450
    def _perform_eplb_reshuffle(
        self, rank_mapping: dict[int, int] | None = None
    ) -> None:
451
452
453
454
455
456
457
458
459
460
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Starting expert resharding...")

        eplb_state = self.worker.model_runner.eplb_state
        assert eplb_state is not None

        model_config = self.worker.model_runner.model_config
        eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
        is_async_enabled = eplb_state.is_async
        eplb_state.is_async = False
461
        if rank_mapping is None:
462
463
464
465
            eplb_state.rearrange()
        else:
            eplb_state.rearrange(rank_mapping=rank_mapping)
        # NOTE(yongji): check whether we need to synchronize here
466
        torch.accelerator.synchronize()
467
468
469
470
471
472
473
474
475
        # reset expert_rearrangement_step to ensure all ranks are synchronized
        eplb_state.expert_rearrangement_step = 0
        eplb_state.num_valid_physical_experts = (
            eplb_model_state.physical_to_logical_map.shape[1]
        )
        eplb_state.is_async = is_async_enabled
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed")

476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
    def perform_eplb_reshuffle(self) -> None:
        self._perform_eplb_reshuffle()
        self._set_eplb_suppressed(False)

    def perform_scale_down_eplb_reshuffle(self, new_dp_size: int) -> None:
        self._set_eplb_suppressed(True)
        parallel_config = self.worker.vllm_config.parallel_config
        tp_size = parallel_config.tensor_parallel_size
        old_ep_size = parallel_config.data_parallel_size * tp_size
        new_ep_size = new_dp_size * tp_size
        rank_mapping = {
            old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
            for old_ep_rank in range(old_ep_size)
        }
        self._perform_eplb_reshuffle(rank_mapping=rank_mapping)

492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
    def receive_weights(self) -> None:
        dp_group = get_dp_group()
        assert isinstance(dp_group, StatelessGroupCoordinator)
        new_dp_size = dp_group.world_size
        dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank

        # Receive old_dp_size broadcasted during transfer_weights
        old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu")
        old_dp_size_tensor = dp_group.tcp_store_group.broadcast(old_dp_size_tensor, 0)
        old_dp_size = int(old_dp_size_tensor[0].item())

        # Calculate which existing worker will send to this new worker
        num_new_workers = new_dp_size - old_dp_size
        new_worker_idx = dp_rank - old_dp_size
        num_dst_per_sender = num_new_workers // old_dp_size
        remainder = num_new_workers % old_dp_size

        if new_worker_idx < remainder * (num_dst_per_sender + 1):
            sender_rank = new_worker_idx // (num_dst_per_sender + 1)
        else:
            sender_rank = (
                remainder
                + (new_worker_idx - remainder * (num_dst_per_sender + 1))
                // num_dst_per_sender
            )

        model = self.worker.model_runner.get_model()
        batch_transfer_weights(
            model=model,
            is_sender=False,
            peer_rank=sender_rank,
            dp_group=dp_group,
            expert_weights=model.expert_weights,
        )
526
        torch.accelerator.synchronize()
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561

    def receive_expert_mapping(self) -> tuple[torch.Tensor, int, int]:
        dp_group = get_dp_group()
        assert isinstance(dp_group, StatelessGroupCoordinator)
        physical_to_logical, num_local_physical_experts, num_logical_experts = (
            broadcast_expert_mapping(
                physical_to_logical=None,
                num_local_physical_experts=None,
                num_logical_experts=None,
                dp_group=dp_group,
                src_rank=0,
                device=self.worker.device,
            )
        )
        num_moe_layers = physical_to_logical.shape[0]
        new_dp_size = get_dp_group().world_size
        tp_size = self.worker.vllm_config.parallel_config.tensor_parallel_size
        new_ep_size = new_dp_size * tp_size
        expanded_physical_to_logical = torch.full(
            (num_moe_layers, num_local_physical_experts * new_ep_size),
            -1,
            dtype=physical_to_logical.dtype,
            device=physical_to_logical.device,
        )
        old_num_physical_experts = physical_to_logical.shape[1]
        expanded_physical_to_logical[:, :old_num_physical_experts] = physical_to_logical
        return (
            expanded_physical_to_logical,
            num_logical_experts,
            old_num_physical_experts,
        )

    def prepare_new_worker(self) -> None:
        with set_current_vllm_config(self.worker.vllm_config):
            prepare_communication_buffer_for_model(self.worker.model_runner.get_model())