gpu_worker.py 11.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
"""A GPU worker class."""
import gc
import os
5
from typing import TYPE_CHECKING, List, Optional
6
7
8

import torch
import torch.distributed
9
import torch.nn as nn
10

11
import vllm.envs as envs
12
from vllm.config import ParallelConfig, VllmConfig
13
from vllm.device_allocator.cumem import CuMemAllocator
14
15
16
17
18
19
from vllm.distributed import (ensure_model_parallel_initialized,
                              init_distributed_environment,
                              set_custom_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
20
from vllm.utils import GiB_bytes
21
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
22
23
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
24
from vllm.v1.worker.worker_base import WorkerBase
25
26
27
28

logger = init_logger(__name__)

if TYPE_CHECKING:
29
    from vllm.v1.core.scheduler_output import SchedulerOutput
30
31


32
class Worker(WorkerBase):
33
34
35

    def __init__(
        self,
36
        vllm_config: VllmConfig,
37
38
39
        local_rank: int,
        rank: int,
        distributed_init_method: str,
40
        is_driver_worker: bool = False,
41
    ):
42

43
44
45
46
47
        super().__init__(vllm_config=vllm_config,
                         local_rank=local_rank,
                         rank=rank,
                         distributed_init_method=distributed_init_method,
                         is_driver_worker=is_driver_worker)
48
49
50
51
52
53

        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
            init_cached_hf_modules()

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
        # Torch profiler. Enabled and configured through env vars:
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
        if envs.VLLM_TORCH_PROFILER_DIR:
            torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
            logger.info("Profiling enabled. Traces will be saved to: %s",
                        torch_profiler_trace_dir)
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                with_stack=True,
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
                    torch_profiler_trace_dir, use_gzip=True))
        else:
            self.profiler = None
70

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    def sleep(self, level: int = 1) -> None:
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
        allocator = CuMemAllocator.get_instance()
        allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
            "Sleep mode freed %.2f GiB memory, "
            "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
            used_bytes / GiB_bytes)

    def wake_up(self) -> None:
        allocator = CuMemAllocator.get_instance()
        allocator.wake_up()

88
    def init_device(self):
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        if self.device_config.device.type == "cuda":
            # torch.distributed.all_reduce does not free the input tensor until
            # the synchronization point. This causes the memory usage to grow
            # as the number of all_reduce calls increases. This env var disables
            # this behavior.
            # Related issue:
            # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
            os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"

            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
            self.device = torch.device(f"cuda:{self.local_rank}")
            torch.cuda.set_device(self.device)

            _check_if_gpu_supports_dtype(self.model_config.dtype)
            gc.collect()
            torch.cuda.empty_cache()
            self.init_gpu_memory = torch.cuda.mem_get_info()[0]
        else:
            raise RuntimeError(
                f"Not support device type: {self.device_config.device}")
        # Initialize the distributed environment.
        init_worker_distributed_environment(self.parallel_config, self.rank,
                                            self.distributed_init_method,
                                            self.local_rank)
        # Set random seed.
        set_random_seed(self.model_config.seed)

117
        # Construct the model runner
118
119
        self.model_runner: GPUModelRunner = GPUModelRunner(
            self.vllm_config, self.device)
120

121
    def load_model(self) -> None:
122
123
124
125
126
127
128
129
130
131
132
        if self.vllm_config.model_config.enable_sleep_mode:
            allocator = CuMemAllocator.get_instance()
            assert allocator.get_current_usage() == 0, (
                "Sleep mode can only be "
                "used for one instance per process.")
            context = allocator.use_memory_pool(tag="weights")
        else:
            from contextlib import nullcontext
            context = nullcontext()
        with context:
            self.model_runner.load_model()
133
134

    @torch.inference_mode()
135
136
137
    def determine_available_memory(self) -> int:
        """Profiles the peak memory usage of the model to determine how much 
        memory can be used for KV cache without OOMs.
138
139

        The engine will first conduct a profiling of the existing memory usage.
140
141
        Then, it calculate the free memory that can be used for KV cache in
        bytes.
142
143
144
145
146
147

        .. tip::
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
        """
        torch.cuda.empty_cache()
148
        torch.cuda.reset_peak_memory_stats()
149

150
        _, total_gpu_memory = torch.cuda.mem_get_info()
151
152
153
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
        self.model_runner.profile_run()
154
155

        free_gpu_memory, _ = torch.cuda.mem_get_info()
156
157
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
158
        assert self.init_gpu_memory > free_gpu_memory, (
159
160
161
162
163
            "Error in memory profiling. "
            f"Initial free memory {self.init_gpu_memory}, current free memory"
            f" {free_gpu_memory}. This happens when the GPU memory was "
            "not properly cleaned up before initializing the vLLM instance.")

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        # Get the peak memory allocation recorded by torch
        peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"]

        # Check for any memory left around that may have been allocated on the
        # gpu outside of `torch`. NCCL operations, for example, can use a few
        # GB during a forward pass
        torch.cuda.empty_cache()
        torch_allocated_bytes = torch.cuda.memory_stats(
        )["allocated_bytes.all.current"]
        total_allocated_bytes = torch.cuda.mem_get_info(
        )[1] - torch.cuda.mem_get_info()[0]
        non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
        if non_torch_allocations > 0:
            peak_memory += non_torch_allocations
        available_kv_cache_memory = (
            total_gpu_memory * self.cache_config.gpu_memory_utilization -
            peak_memory)

182
183
184
185
186
        return int(available_kv_cache_memory)

    def get_kv_cache_spec(self) -> KVCacheSpec:
        return self.model_runner.get_kv_cache_spec()

187
    def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None:
188
        """Allocate GPU KV cache with the specified kv_cache_config."""
189
        kv_cache_config = kv_cache_configs[self.rank]
190
191
192
193
194
195
196
197
        if self.vllm_config.model_config.enable_sleep_mode:
            allocator = CuMemAllocator.get_instance()
            context = allocator.use_memory_pool(tag="kv_cache")
        else:
            from contextlib import nullcontext
            context = nullcontext()
        with context:
            self.model_runner.initialize_kv_cache(kv_cache_config)
198
199

    def compile_or_warm_up_model(self) -> None:
200
201
202
203
204
205
206
207
208
209
210
211
        # warm up sizes that are not in cudagraph capture sizes,
        # but users still want to compile for better performance,
        # e.g. for the max-num-batched token size in chunked prefill.
        warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
        if not self.model_config.enforce_eager:
            warmup_sizes = [
                x for x in warmup_sizes if x not in
                self.vllm_config.compilation_config.cudagraph_capture_sizes
            ]
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
            self.model_runner._dummy_run(size)
212
213
214
215
216
217
        if not self.model_config.enforce_eager:
            self.model_runner.capture_model()
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

218
219
220
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

221
222
223
224
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
225
    ) -> Optional[ModelRunnerOutput]:
226
        output = self.model_runner.execute_model(scheduler_output)
227
        return output if self.rank == 0 else None
228

229
    def profile(self, is_start: bool = True):
230
231
232
233
234
235
236
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()

237
238
239
240
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275

def init_worker_distributed_environment(
    parallel_config: ParallelConfig,
    rank: int,
    distributed_init_method: Optional[str] = None,
    local_rank: int = -1,
) -> None:
    """Initialize the distributed environment."""
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

    init_distributed_environment(parallel_config.world_size, rank,
                                 distributed_init_method, local_rank)

    ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
                                      parallel_config.pipeline_parallel_size)


def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
    # Check if the GPU supports the dtype.
    if torch_dtype == torch.bfloat16:  # noqa: SIM102
        if not current_platform.has_device_capability(80):
            capability = current_platform.get_device_capability()
            gpu_name = current_platform.get_device_name()

            if capability is None:
                compute_str = "does not have a compute capability"
            else:
                version_str = capability.as_version_str()
                compute_str = f"has compute capability {version_str}"

            raise ValueError(
                "Bfloat16 is only supported on GPUs with compute capability "
                f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
                "You can use float16 instead by explicitly setting the"
                "`dtype` flag in CLI, for example: --dtype=half.")