gpu_worker.py 27.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4
import copy
5
6
import gc
import os
7
from contextlib import AbstractContextManager, nullcontext
8
from typing import TYPE_CHECKING, Any, Optional
9
10
11

import torch
import torch.distributed
12
import torch.nn as nn
13

14
import vllm.envs as envs
15
from vllm.config import VllmConfig
16
17
18
from vllm.distributed import (ensure_model_parallel_initialized,
                              init_distributed_environment,
                              set_custom_all_reduce)
19
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
20
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
21
from vllm.logger import init_logger
22
from vllm.lora.request import LoRARequest
23
from vllm.model_executor import set_random_seed
24
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
25
from vllm.platforms import current_platform
26
from vllm.sequence import IntermediateTensors
27
from vllm.tasks import SupportedTask
28
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
29
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
30
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
31
32
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
                             ModelRunnerOutput)
33
from vllm.v1.utils import report_usage_stats
34
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
35
from vllm.v1.worker.worker_base import WorkerBase
36
37
38
39

logger = init_logger(__name__)

if TYPE_CHECKING:
40
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
41
    from vllm.v1.core.sched.output import SchedulerOutput
42
43


44
class Worker(WorkerBase):
45
46
47

    def __init__(
        self,
48
        vllm_config: VllmConfig,
49
50
51
        local_rank: int,
        rank: int,
        distributed_init_method: str,
52
        is_driver_worker: bool = False,
53
    ):
54

55
56
57
58
59
        super().__init__(vllm_config=vllm_config,
                         local_rank=local_rank,
                         rank=rank,
                         distributed_init_method=distributed_init_method,
                         is_driver_worker=is_driver_worker)
60
61
62
63
64
65

        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
            init_cached_hf_modules()

66
67
68
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

69
70
71
72
73
74
        # Torch profiler. Enabled and configured through env vars:
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
        if envs.VLLM_TORCH_PROFILER_DIR:
            torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
            logger.info("Profiling enabled. Traces will be saved to: %s",
                        torch_profiler_trace_dir)
75
76
77
78
79
80
81
82
            logger.debug(
                "Profiler config: record_shapes=%s,"
                "profile_memory=%s,with_stack=%s,with_flops=%s",
                envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                envs.VLLM_TORCH_PROFILER_WITH_STACK,
                envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
            )
83
84
85
86
87
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
88
89
90
91
                record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
                with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
92
93
94
95
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
                    torch_profiler_trace_dir, use_gzip=True))
        else:
            self.profiler = None
96

97
    def sleep(self, level: int = 1) -> None:
98
99
        from vllm.device_allocator.cumem import CuMemAllocator

100
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
101
102
103
104
105
106
107
108
109

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
                name: buffer.cpu().clone()
                for name, buffer in model.named_buffers()
            }

110
111
112
113
114
115
116
117
118
119
120
        allocator = CuMemAllocator.get_instance()
        allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
            "Sleep mode freed %.2f GiB memory, "
            "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
            used_bytes / GiB_bytes)

121
    def wake_up(self, tags: Optional[list[str]] = None) -> None:
122
123
        from vllm.device_allocator.cumem import CuMemAllocator

124
        allocator = CuMemAllocator.get_instance()
125
        allocator.wake_up(tags)
126

127
128
129
130
131
132
133
134
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    def _maybe_get_memory_pool_context(self,
                                       tag: str) -> AbstractContextManager:
        if self.vllm_config.model_config.enable_sleep_mode:
            from vllm.device_allocator.cumem import CuMemAllocator

            allocator = CuMemAllocator.get_instance()
            if tag == "weights":
                assert allocator.get_current_usage() == 0, (
                    "Sleep mode can only be "
                    "used for one instance per process.")
            context = allocator.use_memory_pool(tag=tag)
        else:
            context = nullcontext()
        return context

150
151
152
153
154
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

155
    def init_device(self):
156
157
158
159
160
161
162
163
164
165
166
167
        if self.device_config.device.type == "cuda":
            # torch.distributed.all_reduce does not free the input tensor until
            # the synchronization point. This causes the memory usage to grow
            # as the number of all_reduce calls increases. This env var disables
            # this behavior.
            # Related issue:
            # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
            os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"

            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
            self.device = torch.device(f"cuda:{self.local_rank}")
168
            current_platform.set_device(self.device)
169
170
171
172

            _check_if_gpu_supports_dtype(self.model_config.dtype)
            gc.collect()
            torch.cuda.empty_cache()
173
174
175
176
177
178

            # take current memory snapshot
            self.init_snapshot = MemorySnapshot()
            self.requested_memory = (self.init_snapshot.total_memory *
                                     self.cache_config.gpu_memory_utilization)
            if self.init_snapshot.free_memory < self.requested_memory:
179
180
                GiB = lambda b: round(b / GiB_bytes, 2)
                raise ValueError(
181
182
183
184
                    f"Free memory on device "
                    f"({GiB(self.init_snapshot.free_memory)}/"
                    f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
                    f"is less than desired GPU memory utilization "
185
                    f"({self.cache_config.gpu_memory_utilization}, "
186
                    f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
187
188
                    f"utilization or reduce GPU memory used by other processes."
                )
189
190
191
192
        else:
            raise RuntimeError(
                f"Not support device type: {self.device_config.device}")
        # Initialize the distributed environment.
193
        init_worker_distributed_environment(self.vllm_config, self.rank,
194
                                            self.distributed_init_method,
195
196
                                            self.local_rank,
                                            current_platform.dist_backend)
197
198
199
        # Set random seed.
        set_random_seed(self.model_config.seed)

200
        # Construct the model runner
201
202
        self.model_runner: GPUModelRunner = GPUModelRunner(
            self.vllm_config, self.device)
203

204
205
206
207
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

208
209
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
210
    def load_model(self) -> None:
211
        eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
212
        with self._maybe_get_memory_pool_context(tag="weights"):
213
            self.model_runner.load_model(eep_scale_up=eep_scale_up)
214

215
216
217
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

218
219
220
221
    def reload_weights(self) -> None:
        with self._maybe_get_memory_pool_context(tag="weights"):
            self.model_runner.reload_weights()

222
    @torch.inference_mode()
223
    def determine_available_memory(self) -> int:
224
        """Profiles the peak memory usage of the model to determine how much
225
        memory can be used for KV cache without OOMs.
226
227

        The engine will first conduct a profiling of the existing memory usage.
228
229
        Then, it calculate the free memory that can be used for KV cache in
        bytes.
230

231
232
233
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
234
235
        """
        torch.cuda.empty_cache()
236
        torch.cuda.reset_peak_memory_stats()
237
        GiB = lambda b: b / GiB_bytes
238
239
240

        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
241
242
243
244
245
        with memory_profiling(
                self.init_snapshot,
                weights_memory=int(
                    self.model_runner.model_memory_usage)) as profile_result:
            self.model_runner.profile_run()
246

247
        free_gpu_memory = profile_result.after_profile.free_memory
248
249
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
250
        assert self.init_snapshot.free_memory > free_gpu_memory, (
251
            "Error in memory profiling. "
252
253
254
255
256
257
258
259
            f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {GiB(free_gpu_memory)} GiB. "
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
            "isolate vLLM in its own container.")
        available_kv_cache_memory = self.requested_memory \
            - profile_result.non_kv_cache_memory
260

261
262
        unrequested_memory = self.init_snapshot.free_memory \
            - self.requested_memory
263
        logger.debug(
264
265
266
267
268
269
270
271
272
273
274
275
            "Initial free memory: %.2f GiB; "
            "Requested memory: %.2f (util), %.2f GiB",
            GiB(self.init_snapshot.free_memory),
            self.cache_config.gpu_memory_utilization,
            GiB(self.requested_memory),
        )
        logger.debug(
            "Free memory after profiling: %.2f GiB (total), "
            "%.2f GiB (within requested)",
            GiB(free_gpu_memory),
            GiB(free_gpu_memory - unrequested_memory),
        )
276
277
278
279
        logger.debug(profile_result)
        logger.info("Available KV cache memory: %.2f GiB",
                    GiB(available_kv_cache_memory))
        gc.collect()
280

281
282
        return int(available_kv_cache_memory)

283
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
284
285
        return self.model_runner.get_kv_cache_spec()

286
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
287
        """Allocate GPU KV cache with the specified kv_cache_config."""
288

289
        if self.vllm_config.model_config.enable_sleep_mode:
290
291
            from vllm.device_allocator.cumem import CuMemAllocator

292
293
294
295
296
297
298
            allocator = CuMemAllocator.get_instance()
            context = allocator.use_memory_pool(tag="kv_cache")
        else:
            from contextlib import nullcontext
            context = nullcontext()
        with context:
            self.model_runner.initialize_kv_cache(kv_cache_config)
299
300

    def compile_or_warm_up_model(self) -> None:
301
302
303
304
305
306
307
308
309
        # warm up sizes that are not in cudagraph capture sizes,
        # but users still want to compile for better performance,
        # e.g. for the max-num-batched token size in chunked prefill.
        warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
        if not self.model_config.enforce_eager:
            warmup_sizes = [
                x for x in warmup_sizes if x not in
                self.vllm_config.compilation_config.cudagraph_capture_sizes
            ]
310
        # We skip EPLB here since we don't want to record dummy metrics
311
312
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
313
            self.model_runner._dummy_run(size, skip_eplb=True)
314

315
316
        if not self.model_config.enforce_eager:
            self.model_runner.capture_model()
317
318
319
320
321
322

        # Warm up sampler and preallocate memory buffer for logits and other
        # sampling related tensors of max possible shape to avoid memory
        # fragmentation issue.
        # NOTE: This is called after `capture_model` on purpose to prevent
        # memory buffers from being cleared by `torch.cuda.empty_cache`.
323
        if get_pp_group().is_last_rank:
324
325
            max_num_reqs = min(self.scheduler_config.max_num_seqs,
                               self.scheduler_config.max_num_batched_tokens)
326

327
            # We skip EPLB here since we don't want to record dummy metrics
328
            hidden_states, last_hidden_states = \
329
330
331
332
                self.model_runner._dummy_run(
                    num_tokens=max_num_reqs,
                    skip_eplb=True,
                )
333
334
335
336
337
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
                self.model_runner._dummy_sampler_run(
                    hidden_states=last_hidden_states)
338

339
        # Warmup kernels used during model execution
340
        kernel_warmup(self)
341

342
343
344
345
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

346
347
348
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

349
350
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
351

352
353
354
355
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
356
    ) -> Optional[ModelRunnerOutput]:
357
358
359
360
361
362
363
364
        intermediate_tensors = None
        if not get_pp_group().is_first_rank:
            intermediate_tensors = IntermediateTensors(
                get_pp_group().recv_tensor_dict(
                    all_gather_group=get_tp_group()))

        output = self.model_runner.execute_model(scheduler_output,
                                                 intermediate_tensors)
365

366
367
368
        parallel_config = self.vllm_config.parallel_config
        if parallel_config.distributed_executor_backend != "external_launcher" \
            and not get_pp_group().is_last_rank:
369
370
371
            assert isinstance(output, IntermediateTensors)
            get_pp_group().send_tensor_dict(output.tensors,
                                            all_gather_group=get_tp_group())
372
373
374

            kv_connector_output = output.kv_connector_output
            if not kv_connector_output:
375
                return None
376

377
            # In case of PP with kv transfer, we need to pass through the
378
379
380
381
382
383
384
385
            # kv_connector_output
            if (not kv_connector_output.finished_sending
                    and not kv_connector_output.finished_recving):
                return EMPTY_MODEL_RUNNER_OUTPUT

            output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
            output.kv_connector_output = kv_connector_output
            return output
386

387
        assert isinstance(output, ModelRunnerOutput)
388
        return output
389

390
391
392
    def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
        return self.model_runner.take_draft_token_ids()

393
    def profile(self, is_start: bool = True):
394
395
396
397
398
399
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()
400
401
            print(self.profiler.key_averages().table(
                sort_by="self_cuda_time_total"))
402

403
404
405
    def execute_dummy_batch(self) -> None:
        self.model_runner._dummy_run(1)

406
407
408
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

409
410
411
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

412
    def list_loras(self) -> set[int]:
413
414
415
416
417
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

418
419
420
421
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
    def _eplb_before_scale_down(self, old_ep_size: int,
                                new_ep_size: int) -> None:
        from vllm.distributed.parallel_state import get_ep_group
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Starting expert resharding "
                        "before scaling down...")
        rank_mapping = {
            old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(self.model_runner.model,
                                               execute_shuffle=True,
                                               global_expert_load=None,
                                               rank_mapping=rank_mapping)
        torch.cuda.synchronize()
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _eplb_after_scale_up(
            self, old_ep_size: int, new_ep_size: int,
            global_expert_load: Optional[torch.Tensor]) -> None:
        from vllm.distributed.parallel_state import get_ep_group
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Starting expert resharding "
                        "after scaling up...")
        rank_mapping = {
            old_ep_rank: old_ep_rank
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(
            self.model_runner.model,
            execute_shuffle=True,
            global_expert_load=global_expert_load,
            rank_mapping=rank_mapping)
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _reconfigure_parallel_config(
            self, reconfig_request: ReconfigureDistributedRequest) -> None:
        """
        Update parallel config with provided reconfig_request
        """
        parallel_config = self.vllm_config.parallel_config
        parallel_config.data_parallel_size = \
            reconfig_request.new_data_parallel_size
        if reconfig_request.new_data_parallel_rank != \
        ReconfigureRankType.KEEP_CURRENT_RANK:
            parallel_config.data_parallel_rank = \
                reconfig_request.new_data_parallel_rank
        if reconfig_request.new_data_parallel_rank_local != \
        ReconfigureRankType.KEEP_CURRENT_RANK:
            parallel_config.data_parallel_rank_local = \
                reconfig_request.new_data_parallel_rank_local
        parallel_config.data_parallel_master_ip = \
            reconfig_request.new_data_parallel_master_ip
        parallel_config.data_parallel_master_port = \
            reconfig_request.new_data_parallel_master_port

    def _reconfigure_moe(self, old_ep_size: int,
                         new_ep_size: int) -> Optional[torch.Tensor]:
        """
        Reconfigure MoE modules with provided reconfig_request

        Return the global expert load if new_ep_size > old_ep_size,
        otherwise None
        """
        from vllm.distributed.parallel_state import (
            get_dp_group, get_ep_group, prepare_communication_buffer_for_model)
        from vllm.model_executor.layers.fused_moe.layer import (
            FusedMoEParallelConfig)

        parallel_config = self.vllm_config.parallel_config
        moe_modules = [
            module for module in self.model_runner.model.modules()
            if module.__class__.__name__ == "FusedMoE"
        ]
        num_local_experts = moe_modules[0].moe_config.num_local_experts
        assert all(module.moe_config.num_local_experts == num_local_experts
                   for module in moe_modules), (
                       "All MoE modules must have the same number of experts")
        for module in moe_modules:
            module.moe_config.num_experts = num_local_experts * new_ep_size
            module.global_num_experts = module.moe_config.num_experts
            module.moe_parallel_config = FusedMoEParallelConfig.make(
                tp_size_=get_tp_group().world_size,
                dp_size_=get_dp_group().world_size,
                vllm_parallel_config=parallel_config,
            )
            module.moe_config.moe_parallel_config = module.moe_parallel_config
        if new_ep_size < old_ep_size:
            num_local_physical_experts = num_local_experts
            assert self.model_runner.eplb_state is not None
            new_physical_experts = \
                self.model_runner.eplb_state.physical_to_logical_map.shape[1]
518
            parallel_config.eplb_config.num_redundant_experts = (
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
                new_physical_experts -
                self.model_runner.eplb_state.logical_replica_count.shape[1])
            global_expert_load = None
        else:
            num_local_physical_experts = torch.tensor([num_local_experts],
                                                      dtype=torch.int32,
                                                      device="cpu")
            torch.distributed.broadcast(num_local_physical_experts,
                                        group=get_ep_group().cpu_group,
                                        group_src=0)
            num_local_physical_experts = num_local_physical_experts.item()
            new_physical_experts = num_local_physical_experts * new_ep_size
            assert self.model_runner.eplb_state is not None
            global_expert_load = self.model_runner.eplb_state.rearrange(
                self.model_runner.model, execute_shuffle=False)
534
            parallel_config.eplb_config.num_redundant_experts = (
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
                new_physical_experts - global_expert_load.shape[1])
        prepare_communication_buffer_for_model(self.model_runner.model)
        self.model_runner.model.update_physical_experts_metadata(
            num_physical_experts=new_physical_experts,
            num_local_physical_experts=num_local_physical_experts)
        return global_expert_load

    def reinitialize_distributed(
            self, reconfig_request: ReconfigureDistributedRequest) -> None:
        from vllm.config import set_current_vllm_config
        from vllm.distributed.parallel_state import (
            cleanup_dist_env_and_memory, get_ep_group)

        old_ep_size = get_ep_group().world_size
        old_ep_rank = get_ep_group().rank
        new_ep_size = reconfig_request.new_data_parallel_size * get_tp_group(
        ).world_size * get_pp_group().world_size
        if new_ep_size < old_ep_size:
            self._eplb_before_scale_down(old_ep_size, new_ep_size)

        cleanup_dist_env_and_memory()

        if reconfig_request.new_data_parallel_rank == \
        ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
            assert old_ep_rank >= new_ep_size
            # shutdown
            return

        self._reconfigure_parallel_config(reconfig_request)

        with set_current_vllm_config(self.vllm_config):
            init_worker_distributed_environment(self.vllm_config, self.rank,
                                                self.distributed_init_method,
                                                self.local_rank)

        global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size)

        if new_ep_size > old_ep_size:
            assert global_expert_load is not None
            self._eplb_after_scale_up(old_ep_size, new_ep_size,
                                      global_expert_load)

577
578
579
580
581
582
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
583
        from vllm.model_executor.model_loader import ShardedStateLoader
584
585
586
587
588
589
590
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

591
592
593
594
595
596
597
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        self.model_runner.save_tensorized_model(
            tensorizer_config=tensorizer_config, )

598
599

def init_worker_distributed_environment(
600
    vllm_config: VllmConfig,
601
602
603
    rank: int,
    distributed_init_method: Optional[str] = None,
    local_rank: int = -1,
604
    backend: str = "nccl",
605
606
) -> None:
    """Initialize the distributed environment."""
607
    parallel_config = vllm_config.parallel_config
608
609
610
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

    init_distributed_environment(parallel_config.world_size, rank,
611
                                 distributed_init_method, local_rank, backend)
612
613

    ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
614
                                      parallel_config.pipeline_parallel_size)
615

616
617
    ensure_kv_transfer_initialized(vllm_config)

618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634

def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
    # Check if the GPU supports the dtype.
    if torch_dtype == torch.bfloat16:  # noqa: SIM102
        if not current_platform.has_device_capability(80):
            capability = current_platform.get_device_capability()
            gpu_name = current_platform.get_device_name()

            if capability is None:
                compute_str = "does not have a compute capability"
            else:
                version_str = capability.as_version_str()
                compute_str = f"has compute capability {version_str}"

            raise ValueError(
                "Bfloat16 is only supported on GPUs with compute capability "
                f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
635
                "You can use float16 instead by explicitly setting the "
636
                "`dtype` flag in CLI, for example: --dtype=half.")