gpu_worker.py 13.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
"""A GPU worker class."""
import gc
import os
5
from typing import TYPE_CHECKING, Optional
6
7
8

import torch
import torch.distributed
9
import torch.nn as nn
10

11
import vllm.envs as envs
12
from vllm.config import VllmConfig
13
from vllm.device_allocator.cumem import CuMemAllocator
14
15
16
from vllm.distributed import (ensure_model_parallel_initialized,
                              init_distributed_environment,
                              set_custom_all_reduce)
17
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
18
from vllm.distributed.parallel_state import get_pp_group
19
from vllm.logger import init_logger
20
from vllm.lora.request import LoRARequest
21
22
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
23
from vllm.utils import GiB_bytes
24
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
25
26
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
27
from vllm.v1.worker.worker_base import WorkerBase
28
29
30
31

logger = init_logger(__name__)

if TYPE_CHECKING:
32
    from vllm.v1.core.sched.output import SchedulerOutput
33
34


35
class Worker(WorkerBase):
36
37
38

    def __init__(
        self,
39
        vllm_config: VllmConfig,
40
41
42
        local_rank: int,
        rank: int,
        distributed_init_method: str,
43
        is_driver_worker: bool = False,
44
    ):
45

46
47
48
49
50
        super().__init__(vllm_config=vllm_config,
                         local_rank=local_rank,
                         rank=rank,
                         distributed_init_method=distributed_init_method,
                         is_driver_worker=is_driver_worker)
51
52
53
54
55
56

        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
            init_cached_hf_modules()

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        # Torch profiler. Enabled and configured through env vars:
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
        if envs.VLLM_TORCH_PROFILER_DIR:
            torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
            logger.info("Profiling enabled. Traces will be saved to: %s",
                        torch_profiler_trace_dir)
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                with_stack=True,
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
                    torch_profiler_trace_dir, use_gzip=True))
        else:
            self.profiler = None
73

74
75
76
77
78
79
80
81
82
83
84
85
86
    def sleep(self, level: int = 1) -> None:
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
        allocator = CuMemAllocator.get_instance()
        allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
            "Sleep mode freed %.2f GiB memory, "
            "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
            used_bytes / GiB_bytes)

87
    def wake_up(self, tags: Optional[list[str]] = None) -> None:
88
        allocator = CuMemAllocator.get_instance()
89
        allocator.wake_up(tags)
90

91
    def init_device(self):
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        if self.device_config.device.type == "cuda":
            # torch.distributed.all_reduce does not free the input tensor until
            # the synchronization point. This causes the memory usage to grow
            # as the number of all_reduce calls increases. This env var disables
            # this behavior.
            # Related issue:
            # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
            os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"

            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
            self.device = torch.device(f"cuda:{self.local_rank}")
            torch.cuda.set_device(self.device)

            _check_if_gpu_supports_dtype(self.model_config.dtype)
            gc.collect()
            torch.cuda.empty_cache()
            self.init_gpu_memory = torch.cuda.mem_get_info()[0]
        else:
            raise RuntimeError(
                f"Not support device type: {self.device_config.device}")
        # Initialize the distributed environment.
114
        init_worker_distributed_environment(self.vllm_config, self.rank,
115
116
117
118
119
                                            self.distributed_init_method,
                                            self.local_rank)
        # Set random seed.
        set_random_seed(self.model_config.seed)

120
        # Construct the model runner
121
122
        self.model_runner: GPUModelRunner = GPUModelRunner(
            self.vllm_config, self.device)
123

124
125
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
126
    def load_model(self) -> None:
127
128
129
130
131
132
133
134
135
136
137
        if self.vllm_config.model_config.enable_sleep_mode:
            allocator = CuMemAllocator.get_instance()
            assert allocator.get_current_usage() == 0, (
                "Sleep mode can only be "
                "used for one instance per process.")
            context = allocator.use_memory_pool(tag="weights")
        else:
            from contextlib import nullcontext
            context = nullcontext()
        with context:
            self.model_runner.load_model()
138
139

    @torch.inference_mode()
140
141
142
    def determine_available_memory(self) -> int:
        """Profiles the peak memory usage of the model to determine how much 
        memory can be used for KV cache without OOMs.
143
144

        The engine will first conduct a profiling of the existing memory usage.
145
146
        Then, it calculate the free memory that can be used for KV cache in
        bytes.
147
148
149
150
151
152

        .. tip::
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
        """
        torch.cuda.empty_cache()
153
        torch.cuda.reset_peak_memory_stats()
154

155
        _, total_gpu_memory = torch.cuda.mem_get_info()
156
157
158
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
        self.model_runner.profile_run()
159
160

        free_gpu_memory, _ = torch.cuda.mem_get_info()
161
162
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
163
        assert self.init_gpu_memory > free_gpu_memory, (
164
165
166
167
168
            "Error in memory profiling. "
            f"Initial free memory {self.init_gpu_memory}, current free memory"
            f" {free_gpu_memory}. This happens when the GPU memory was "
            "not properly cleaned up before initializing the vLLM instance.")

169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        # Get the peak memory allocation recorded by torch
        peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"]

        # Check for any memory left around that may have been allocated on the
        # gpu outside of `torch`. NCCL operations, for example, can use a few
        # GB during a forward pass
        torch.cuda.empty_cache()
        torch_allocated_bytes = torch.cuda.memory_stats(
        )["allocated_bytes.all.current"]
        total_allocated_bytes = torch.cuda.mem_get_info(
        )[1] - torch.cuda.mem_get_info()[0]
        non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
        if non_torch_allocations > 0:
            peak_memory += non_torch_allocations
        available_kv_cache_memory = (
            total_gpu_memory * self.cache_config.gpu_memory_utilization -
            peak_memory)

187
188
        return int(available_kv_cache_memory)

189
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
190
191
        return self.model_runner.get_kv_cache_spec()

192
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
193
        """Allocate GPU KV cache with the specified kv_cache_config."""
194
195
196
197
198
199
200
201
        if self.vllm_config.model_config.enable_sleep_mode:
            allocator = CuMemAllocator.get_instance()
            context = allocator.use_memory_pool(tag="kv_cache")
        else:
            from contextlib import nullcontext
            context = nullcontext()
        with context:
            self.model_runner.initialize_kv_cache(kv_cache_config)
202
203

    def compile_or_warm_up_model(self) -> None:
204
205
206
207
208
209
210
211
212
213
214
215
        # warm up sizes that are not in cudagraph capture sizes,
        # but users still want to compile for better performance,
        # e.g. for the max-num-batched token size in chunked prefill.
        warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
        if not self.model_config.enforce_eager:
            warmup_sizes = [
                x for x in warmup_sizes if x not in
                self.vllm_config.compilation_config.cudagraph_capture_sizes
            ]
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
            self.model_runner._dummy_run(size)
216
217
        if not self.model_config.enforce_eager:
            self.model_runner.capture_model()
218
219
220
221
222
223

        # Warm up sampler and preallocate memory buffer for logits and other
        # sampling related tensors of max possible shape to avoid memory
        # fragmentation issue.
        # NOTE: This is called after `capture_model` on purpose to prevent
        # memory buffers from being cleared by `torch.cuda.empty_cache`.
224
        if get_pp_group().is_last_rank:
225
226
227
228
229
            max_num_reqs = min(self.scheduler_config.max_num_seqs,
                               self.scheduler_config.max_num_batched_tokens)
            self.model_runner._dummy_sampler_run(
                hidden_states=self.model_runner._dummy_run(
                    num_tokens=max_num_reqs))
230

231
232
233
234
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

235
236
237
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

238
239
240
241
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
242
    ) -> Optional[ModelRunnerOutput]:
243
        output = self.model_runner.execute_model(scheduler_output)
244
        return output if self.is_driver_worker else None
245

246
    def profile(self, is_start: bool = True):
247
248
249
250
251
252
253
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()

254
255
256
    def execute_dummy_batch(self) -> None:
        self.model_runner._dummy_run(1)

257
258
259
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

260
261
262
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

263
    def list_loras(self) -> set[int]:
264
265
266
267
268
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

269
270
271
272
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

273
274
275
276
277
278
279
280
281
282
283
284
285
286
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        from vllm.model_executor.model_loader.loader import ShardedStateLoader
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

287
288

def init_worker_distributed_environment(
289
    vllm_config: VllmConfig,
290
291
292
293
294
    rank: int,
    distributed_init_method: Optional[str] = None,
    local_rank: int = -1,
) -> None:
    """Initialize the distributed environment."""
295
    parallel_config = vllm_config.parallel_config
296
297
298
299
300
301
302
303
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

    init_distributed_environment(parallel_config.world_size, rank,
                                 distributed_init_method, local_rank)

    ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
                                      parallel_config.pipeline_parallel_size)

304
305
    ensure_kv_transfer_initialized(vllm_config)

306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322

def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
    # Check if the GPU supports the dtype.
    if torch_dtype == torch.bfloat16:  # noqa: SIM102
        if not current_platform.has_device_capability(80):
            capability = current_platform.get_device_capability()
            gpu_name = current_platform.get_device_name()

            if capability is None:
                compute_str = "does not have a compute capability"
            else:
                version_str = capability.as_version_str()
                compute_str = f"has compute capability {version_str}"

            raise ValueError(
                "Bfloat16 is only supported on GPUs with compute capability "
                f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
323
                "You can use float16 instead by explicitly setting the "
324
                "`dtype` flag in CLI, for example: --dtype=half.")