gpu_worker.py 25.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4
import copy
5
6
import gc
import os
7
from typing import TYPE_CHECKING, Any, Optional
8
9
10

import torch
import torch.distributed
11
import torch.nn as nn
12

13
import vllm.envs as envs
14
from vllm.config import VllmConfig
15
16
17
from vllm.distributed import (ensure_model_parallel_initialized,
                              init_distributed_environment,
                              set_custom_all_reduce)
18
19
20
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
                                          get_kv_transfer_group,
                                          has_kv_transfer_group)
21
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
22
from vllm.logger import init_logger
23
from vllm.lora.request import LoRARequest
24
25
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
26
from vllm.pooling_params import PoolingTask
27
from vllm.sequence import IntermediateTensors
28
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
29
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
30
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
31
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
32
from vllm.v1.utils import report_usage_stats
33
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
34
from vllm.v1.worker.worker_base import WorkerBase
35
36
37
38

logger = init_logger(__name__)

if TYPE_CHECKING:
39
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
40
    from vllm.v1.core.sched.output import SchedulerOutput
41
42


43
class Worker(WorkerBase):
44
45
46

    def __init__(
        self,
47
        vllm_config: VllmConfig,
48
49
50
        local_rank: int,
        rank: int,
        distributed_init_method: str,
51
        is_driver_worker: bool = False,
52
    ):
53

54
55
56
57
58
        super().__init__(vllm_config=vllm_config,
                         local_rank=local_rank,
                         rank=rank,
                         distributed_init_method=distributed_init_method,
                         is_driver_worker=is_driver_worker)
59
60
61
62
63
64

        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
            init_cached_hf_modules()

65
66
67
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        # Torch profiler. Enabled and configured through env vars:
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
        if envs.VLLM_TORCH_PROFILER_DIR:
            torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
            logger.info("Profiling enabled. Traces will be saved to: %s",
                        torch_profiler_trace_dir)
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                with_stack=True,
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
                    torch_profiler_trace_dir, use_gzip=True))
        else:
            self.profiler = None
84

85
    def sleep(self, level: int = 1) -> None:
86
87
        from vllm.device_allocator.cumem import CuMemAllocator

88
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
89
90
91
92
93
94
95
96
97

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
                name: buffer.cpu().clone()
                for name, buffer in model.named_buffers()
            }

98
99
100
101
102
103
104
105
106
107
108
        allocator = CuMemAllocator.get_instance()
        allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
            "Sleep mode freed %.2f GiB memory, "
            "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
            used_bytes / GiB_bytes)

109
    def wake_up(self, tags: Optional[list[str]] = None) -> None:
110
111
        from vllm.device_allocator.cumem import CuMemAllocator

112
        allocator = CuMemAllocator.get_instance()
113
        allocator.wake_up(tags)
114

115
116
117
118
119
120
121
122
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

123
124
125
126
127
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

128
    def init_device(self):
129
130
131
132
133
134
135
136
137
138
139
140
        if self.device_config.device.type == "cuda":
            # torch.distributed.all_reduce does not free the input tensor until
            # the synchronization point. This causes the memory usage to grow
            # as the number of all_reduce calls increases. This env var disables
            # this behavior.
            # Related issue:
            # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
            os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"

            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
            self.device = torch.device(f"cuda:{self.local_rank}")
141
            current_platform.set_device(self.device)
142
143
144
145

            _check_if_gpu_supports_dtype(self.model_config.dtype)
            gc.collect()
            torch.cuda.empty_cache()
146
147
148
149
150
151

            # take current memory snapshot
            self.init_snapshot = MemorySnapshot()
            self.requested_memory = (self.init_snapshot.total_memory *
                                     self.cache_config.gpu_memory_utilization)
            if self.init_snapshot.free_memory < self.requested_memory:
152
153
                GiB = lambda b: round(b / GiB_bytes, 2)
                raise ValueError(
154
155
156
157
                    f"Free memory on device "
                    f"({GiB(self.init_snapshot.free_memory)}/"
                    f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
                    f"is less than desired GPU memory utilization "
158
                    f"({self.cache_config.gpu_memory_utilization}, "
159
                    f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
160
161
                    f"utilization or reduce GPU memory used by other processes."
                )
162
163
164
165
        else:
            raise RuntimeError(
                f"Not support device type: {self.device_config.device}")
        # Initialize the distributed environment.
166
        init_worker_distributed_environment(self.vllm_config, self.rank,
167
                                            self.distributed_init_method,
168
169
                                            self.local_rank,
                                            current_platform.dist_backend)
170
171
172
        # Set random seed.
        set_random_seed(self.model_config.seed)

173
        # Construct the model runner
174
175
        self.model_runner: GPUModelRunner = GPUModelRunner(
            self.vllm_config, self.device)
176

177
178
179
180
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

181
182
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
183
    def load_model(self) -> None:
184
        if self.vllm_config.model_config.enable_sleep_mode:
185
186
            from vllm.device_allocator.cumem import CuMemAllocator

187
188
189
190
191
192
193
194
            allocator = CuMemAllocator.get_instance()
            assert allocator.get_current_usage() == 0, (
                "Sleep mode can only be "
                "used for one instance per process.")
            context = allocator.use_memory_pool(tag="weights")
        else:
            from contextlib import nullcontext
            context = nullcontext()
195
        eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
196
        with context:
197
            self.model_runner.load_model(eep_scale_up=eep_scale_up)
198

199
200
201
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

202
    @torch.inference_mode()
203
204
205
    def determine_available_memory(self) -> int:
        """Profiles the peak memory usage of the model to determine how much 
        memory can be used for KV cache without OOMs.
206
207

        The engine will first conduct a profiling of the existing memory usage.
208
209
        Then, it calculate the free memory that can be used for KV cache in
        bytes.
210

211
212
213
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
214
215
        """
        torch.cuda.empty_cache()
216
        torch.cuda.reset_peak_memory_stats()
217
        GiB = lambda b: b / GiB_bytes
218
219
220

        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
221
222
223
224
225
        with memory_profiling(
                self.init_snapshot,
                weights_memory=int(
                    self.model_runner.model_memory_usage)) as profile_result:
            self.model_runner.profile_run()
226

227
        free_gpu_memory = profile_result.after_profile.free_memory
228
229
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
230
        assert self.init_snapshot.free_memory > free_gpu_memory, (
231
            "Error in memory profiling. "
232
233
234
235
236
237
238
239
            f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {GiB(free_gpu_memory)} GiB. "
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
            "isolate vLLM in its own container.")
        available_kv_cache_memory = self.requested_memory \
            - profile_result.non_kv_cache_memory
240
241
242

        logger.debug(
            "Initial free memory: %.2f GiB, free memory: %.2f GiB, "
243
244
245
246
247
248
249
            "requested GPU memory: %.2f GiB",
            GiB(self.init_snapshot.free_memory), GiB(free_gpu_memory),
            GiB(self.requested_memory))
        logger.debug(profile_result)
        logger.info("Available KV cache memory: %.2f GiB",
                    GiB(available_kv_cache_memory))
        gc.collect()
250

251
252
        return int(available_kv_cache_memory)

253
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
254
255
        return self.model_runner.get_kv_cache_spec()

256
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
257
        """Allocate GPU KV cache with the specified kv_cache_config."""
258

259
        if self.vllm_config.model_config.enable_sleep_mode:
260
261
            from vllm.device_allocator.cumem import CuMemAllocator

262
263
264
265
266
267
268
            allocator = CuMemAllocator.get_instance()
            context = allocator.use_memory_pool(tag="kv_cache")
        else:
            from contextlib import nullcontext
            context = nullcontext()
        with context:
            self.model_runner.initialize_kv_cache(kv_cache_config)
269
270

    def compile_or_warm_up_model(self) -> None:
271
272
273
274
275
276
277
278
279
        # warm up sizes that are not in cudagraph capture sizes,
        # but users still want to compile for better performance,
        # e.g. for the max-num-batched token size in chunked prefill.
        warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
        if not self.model_config.enforce_eager:
            warmup_sizes = [
                x for x in warmup_sizes if x not in
                self.vllm_config.compilation_config.cudagraph_capture_sizes
            ]
280
        # We skip EPLB here since we don't want to record dummy metrics
281
282
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
283
            self.model_runner._dummy_run(size, skip_eplb=True)
284
285
        if not self.model_config.enforce_eager:
            self.model_runner.capture_model()
286
287
288
289
290
291

        # Warm up sampler and preallocate memory buffer for logits and other
        # sampling related tensors of max possible shape to avoid memory
        # fragmentation issue.
        # NOTE: This is called after `capture_model` on purpose to prevent
        # memory buffers from being cleared by `torch.cuda.empty_cache`.
292
        if get_pp_group().is_last_rank:
293
294
            max_num_reqs = min(self.scheduler_config.max_num_seqs,
                               self.scheduler_config.max_num_batched_tokens)
295

296
            # We skip EPLB here since we don't want to record dummy metrics
297
            hidden_states, last_hidden_states = \
298
299
300
301
                self.model_runner._dummy_run(
                    num_tokens=max_num_reqs,
                    skip_eplb=True,
                )
302
303
304
305
306
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
                self.model_runner._dummy_sampler_run(
                    hidden_states=last_hidden_states)
307

308
309
310
311
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

312
313
314
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

315
316
317
    def get_supported_pooling_tasks(self) -> list[PoolingTask]:
        return self.model_runner.get_supported_pooling_tasks()

318
319
320
321
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
322
    ) -> Optional[ModelRunnerOutput]:
323
324
325
326
327
328
329
330
        intermediate_tensors = None
        if not get_pp_group().is_first_rank:
            intermediate_tensors = IntermediateTensors(
                get_pp_group().recv_tensor_dict(
                    all_gather_group=get_tp_group()))

        output = self.model_runner.execute_model(scheduler_output,
                                                 intermediate_tensors)
331

332
333
334
        parallel_config = self.vllm_config.parallel_config
        if parallel_config.distributed_executor_backend != "external_launcher" \
            and not get_pp_group().is_last_rank:
335
336
337
            assert isinstance(output, IntermediateTensors)
            get_pp_group().send_tensor_dict(output.tensors,
                                            all_gather_group=get_tp_group())
338
339
            output = EMPTY_MODEL_RUNNER_OUTPUT

340
        assert isinstance(output, ModelRunnerOutput)
341
342
343
344
345
346
347
348
349
        if has_kv_transfer_group():
            finished_sending, finished_recving = (
                get_kv_transfer_group().get_finished(
                    scheduler_output.finished_req_ids))
            if finished_sending or finished_recving:
                if output is EMPTY_MODEL_RUNNER_OUTPUT:
                    output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
                output.finished_sending = finished_sending
                output.finished_recving = finished_recving
350
351
352
353

            # Clear KVConnector state for this step.
            get_kv_transfer_group().clear_connector_metadata()

354
355
356
357
            # with a connector, the scheduler expects output from all workers
            return output

        # return output only from the driver worker
358
        return output if self.is_driver_worker else None
359

360
    def profile(self, is_start: bool = True):
361
362
363
364
365
366
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()
367
368
            print(self.profiler.key_averages().table(
                sort_by="self_cuda_time_total"))
369

370
371
372
    def execute_dummy_batch(self) -> None:
        self.model_runner._dummy_run(1)

373
374
375
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

376
377
378
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

379
    def list_loras(self) -> set[int]:
380
381
382
383
384
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

385
386
387
388
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
    def _eplb_before_scale_down(self, old_ep_size: int,
                                new_ep_size: int) -> None:
        from vllm.distributed.parallel_state import get_ep_group
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Starting expert resharding "
                        "before scaling down...")
        rank_mapping = {
            old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(self.model_runner.model,
                                               execute_shuffle=True,
                                               global_expert_load=None,
                                               rank_mapping=rank_mapping)
        torch.cuda.synchronize()
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _eplb_after_scale_up(
            self, old_ep_size: int, new_ep_size: int,
            global_expert_load: Optional[torch.Tensor]) -> None:
        from vllm.distributed.parallel_state import get_ep_group
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Starting expert resharding "
                        "after scaling up...")
        rank_mapping = {
            old_ep_rank: old_ep_rank
            for old_ep_rank in range(old_ep_size)
        }
        assert self.model_runner.eplb_state is not None
        self.model_runner.eplb_state.rearrange(
            self.model_runner.model,
            execute_shuffle=True,
            global_expert_load=global_expert_load,
            rank_mapping=rank_mapping)
        if get_ep_group().rank == 0:
            logger.info("[Elastic EP] Expert resharding completed!")

    def _reconfigure_parallel_config(
            self, reconfig_request: ReconfigureDistributedRequest) -> None:
        """
        Update parallel config with provided reconfig_request
        """
        parallel_config = self.vllm_config.parallel_config
        parallel_config.data_parallel_size = \
            reconfig_request.new_data_parallel_size
        if reconfig_request.new_data_parallel_rank != \
        ReconfigureRankType.KEEP_CURRENT_RANK:
            parallel_config.data_parallel_rank = \
                reconfig_request.new_data_parallel_rank
        if reconfig_request.new_data_parallel_rank_local != \
        ReconfigureRankType.KEEP_CURRENT_RANK:
            parallel_config.data_parallel_rank_local = \
                reconfig_request.new_data_parallel_rank_local
        parallel_config.data_parallel_master_ip = \
            reconfig_request.new_data_parallel_master_ip
        parallel_config.data_parallel_master_port = \
            reconfig_request.new_data_parallel_master_port

    def _reconfigure_moe(self, old_ep_size: int,
                         new_ep_size: int) -> Optional[torch.Tensor]:
        """
        Reconfigure MoE modules with provided reconfig_request

        Return the global expert load if new_ep_size > old_ep_size,
        otherwise None
        """
        from vllm.distributed.parallel_state import (
            get_dp_group, get_ep_group, prepare_communication_buffer_for_model)
        from vllm.model_executor.layers.fused_moe.layer import (
            FusedMoEParallelConfig)

        parallel_config = self.vllm_config.parallel_config
        moe_modules = [
            module for module in self.model_runner.model.modules()
            if module.__class__.__name__ == "FusedMoE"
        ]
        num_local_experts = moe_modules[0].moe_config.num_local_experts
        assert all(module.moe_config.num_local_experts == num_local_experts
                   for module in moe_modules), (
                       "All MoE modules must have the same number of experts")
        for module in moe_modules:
            module.moe_config.num_experts = num_local_experts * new_ep_size
            module.global_num_experts = module.moe_config.num_experts
            module.moe_parallel_config = FusedMoEParallelConfig.make(
                tp_size_=get_tp_group().world_size,
                dp_size_=get_dp_group().world_size,
                vllm_parallel_config=parallel_config,
            )
            module.moe_config.moe_parallel_config = module.moe_parallel_config
        if new_ep_size < old_ep_size:
            num_local_physical_experts = num_local_experts
            assert self.model_runner.eplb_state is not None
            new_physical_experts = \
                self.model_runner.eplb_state.physical_to_logical_map.shape[1]
            parallel_config.num_redundant_experts = (
                new_physical_experts -
                self.model_runner.eplb_state.logical_replica_count.shape[1])
            global_expert_load = None
        else:
            num_local_physical_experts = torch.tensor([num_local_experts],
                                                      dtype=torch.int32,
                                                      device="cpu")
            torch.distributed.broadcast(num_local_physical_experts,
                                        group=get_ep_group().cpu_group,
                                        group_src=0)
            num_local_physical_experts = num_local_physical_experts.item()
            new_physical_experts = num_local_physical_experts * new_ep_size
            assert self.model_runner.eplb_state is not None
            global_expert_load = self.model_runner.eplb_state.rearrange(
                self.model_runner.model, execute_shuffle=False)
            parallel_config.num_redundant_experts = (
                new_physical_experts - global_expert_load.shape[1])
        prepare_communication_buffer_for_model(self.model_runner.model)
        self.model_runner.model.update_physical_experts_metadata(
            num_physical_experts=new_physical_experts,
            num_local_physical_experts=num_local_physical_experts)
        return global_expert_load

    def reinitialize_distributed(
            self, reconfig_request: ReconfigureDistributedRequest) -> None:
        from vllm.config import set_current_vllm_config
        from vllm.distributed.parallel_state import (
            cleanup_dist_env_and_memory, get_ep_group)

        old_ep_size = get_ep_group().world_size
        old_ep_rank = get_ep_group().rank
        new_ep_size = reconfig_request.new_data_parallel_size * get_tp_group(
        ).world_size * get_pp_group().world_size
        if new_ep_size < old_ep_size:
            self._eplb_before_scale_down(old_ep_size, new_ep_size)

        cleanup_dist_env_and_memory()

        if reconfig_request.new_data_parallel_rank == \
        ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
            assert old_ep_rank >= new_ep_size
            # shutdown
            return

        self._reconfigure_parallel_config(reconfig_request)

        with set_current_vllm_config(self.vllm_config):
            init_worker_distributed_environment(self.vllm_config, self.rank,
                                                self.distributed_init_method,
                                                self.local_rank)

        global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size)

        if new_ep_size > old_ep_size:
            assert global_expert_load is not None
            self._eplb_after_scale_up(old_ep_size, new_ep_size,
                                      global_expert_load)

544
545
546
547
548
549
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
550
        from vllm.model_executor.model_loader import ShardedStateLoader
551
552
553
554
555
556
557
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

558
559
560
561
562
563
564
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        self.model_runner.save_tensorized_model(
            tensorizer_config=tensorizer_config, )

565
566

def init_worker_distributed_environment(
567
    vllm_config: VllmConfig,
568
569
570
    rank: int,
    distributed_init_method: Optional[str] = None,
    local_rank: int = -1,
571
    backend: str = "nccl",
572
573
) -> None:
    """Initialize the distributed environment."""
574
    parallel_config = vllm_config.parallel_config
575
576
577
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

    init_distributed_environment(parallel_config.world_size, rank,
578
                                 distributed_init_method, local_rank, backend)
579
580

    ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
581
                                      parallel_config.pipeline_parallel_size)
582

583
584
    ensure_kv_transfer_initialized(vllm_config)

585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601

def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
    # Check if the GPU supports the dtype.
    if torch_dtype == torch.bfloat16:  # noqa: SIM102
        if not current_platform.has_device_capability(80):
            capability = current_platform.get_device_capability()
            gpu_name = current_platform.get_device_name()

            if capability is None:
                compute_str = "does not have a compute capability"
            else:
                version_str = capability.as_version_str()
                compute_str = f"has compute capability {version_str}"

            raise ValueError(
                "Bfloat16 is only supported on GPUs with compute capability "
                f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
602
                "You can use float16 instead by explicitly setting the "
603
                "`dtype` flag in CLI, for example: --dtype=half.")