gpu_worker.py 17.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
"""A GPU worker class."""
import gc
import os
6
from typing import TYPE_CHECKING, Optional
7
8
9

import torch
import torch.distributed
10
import torch.nn as nn
11

12
import vllm.envs as envs
13
from vllm.config import VllmConfig
14
from vllm.device_allocator.cumem import CuMemAllocator
15
16
17
from vllm.distributed import (ensure_model_parallel_initialized,
                              init_distributed_environment,
                              set_custom_all_reduce)
18
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
19
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
20
from vllm.logger import init_logger
21
from vllm.lora.request import LoRARequest
22
23
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
24
from vllm.sequence import IntermediateTensors
lizhigong's avatar
lizhigong committed
25
from vllm.two_batch_overlap.v1.gpu_model_runner import TBO_GPUModelRunner
26
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
27
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
28
from vllm.v1.outputs import ModelRunnerOutput
29
from vllm.v1.utils import report_usage_stats
30
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
31
from vllm.v1.worker.worker_base import WorkerBase
32
33
34
35

logger = init_logger(__name__)

if TYPE_CHECKING:
36
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
37
    from vllm.v1.core.sched.output import SchedulerOutput
38
39


40
class Worker(WorkerBase):
41
42
43

    def __init__(
        self,
44
        vllm_config: VllmConfig,
45
46
47
        local_rank: int,
        rank: int,
        distributed_init_method: str,
48
        is_driver_worker: bool = False,
49
    ):
50

51
52
53
54
55
        super().__init__(vllm_config=vllm_config,
                         local_rank=local_rank,
                         rank=rank,
                         distributed_init_method=distributed_init_method,
                         is_driver_worker=is_driver_worker)
56
57
58
59
60
61

        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
            init_cached_hf_modules()

62
63
64
        # Buffers saved before sleep
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        # Torch profiler. Enabled and configured through env vars:
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
        if envs.VLLM_TORCH_PROFILER_DIR:
            torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
            logger.info("Profiling enabled. Traces will be saved to: %s",
                        torch_profiler_trace_dir)
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                with_stack=True,
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
                    torch_profiler_trace_dir, use_gzip=True))
        else:
            self.profiler = None
81

82
83
    def sleep(self, level: int = 1) -> None:
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
84
85
86
87
88
89
90
91
92

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
                name: buffer.cpu().clone()
                for name, buffer in model.named_buffers()
            }

93
94
95
96
97
98
99
100
101
102
103
        allocator = CuMemAllocator.get_instance()
        allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
            "Sleep mode freed %.2f GiB memory, "
            "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
            used_bytes / GiB_bytes)

104
    def wake_up(self, tags: Optional[list[str]] = None) -> None:
105
        allocator = CuMemAllocator.get_instance()
106
        allocator.wake_up(tags)
107

108
109
110
111
112
113
114
115
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

116
117
118
119
120
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

121
    def init_device(self):
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        if self.device_config.device.type == "cuda":
            # torch.distributed.all_reduce does not free the input tensor until
            # the synchronization point. This causes the memory usage to grow
            # as the number of all_reduce calls increases. This env var disables
            # this behavior.
            # Related issue:
            # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
            os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"

            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
            self.device = torch.device(f"cuda:{self.local_rank}")
            torch.cuda.set_device(self.device)

            _check_if_gpu_supports_dtype(self.model_config.dtype)
            gc.collect()
            torch.cuda.empty_cache()
139
140
141
142
143
144

            # take current memory snapshot
            self.init_snapshot = MemorySnapshot()
            self.requested_memory = (self.init_snapshot.total_memory *
                                     self.cache_config.gpu_memory_utilization)
            if self.init_snapshot.free_memory < self.requested_memory:
145
146
                GiB = lambda b: round(b / GiB_bytes, 2)
                raise ValueError(
147
148
149
150
                    f"Free memory on device "
                    f"({GiB(self.init_snapshot.free_memory)}/"
                    f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
                    f"is less than desired GPU memory utilization "
151
                    f"({self.cache_config.gpu_memory_utilization}, "
152
                    f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
153
154
                    f"utilization or reduce GPU memory used by other processes."
                )
155
156
157
158
        else:
            raise RuntimeError(
                f"Not support device type: {self.device_config.device}")
        # Initialize the distributed environment.
159
        init_worker_distributed_environment(self.vllm_config, self.rank,
160
161
162
163
164
                                            self.distributed_init_method,
                                            self.local_rank)
        # Set random seed.
        set_random_seed(self.model_config.seed)

165
        # Construct the model runner
lizhigong's avatar
lizhigong committed
166
167
168
169
170
171
        if envs.VLLM_ENABLE_TBO:
            self.model_runner: TBO_GPUModelRunner = TBO_GPUModelRunner(
                self.vllm_config, self.device)
        else:
            self.model_runner: GPUModelRunner = GPUModelRunner(
                self.vllm_config, self.device)
172

173
174
175
176
        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

177
178
    # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
    # to hijack tensor allocation.
179
    def load_model(self) -> None:
180
181
182
183
184
185
186
187
188
189
190
        if self.vllm_config.model_config.enable_sleep_mode:
            allocator = CuMemAllocator.get_instance()
            assert allocator.get_current_usage() == 0, (
                "Sleep mode can only be "
                "used for one instance per process.")
            context = allocator.use_memory_pool(tag="weights")
        else:
            from contextlib import nullcontext
            context = nullcontext()
        with context:
            self.model_runner.load_model()
191
192

    @torch.inference_mode()
193
194
195
    def determine_available_memory(self) -> int:
        """Profiles the peak memory usage of the model to determine how much 
        memory can be used for KV cache without OOMs.
196
197

        The engine will first conduct a profiling of the existing memory usage.
198
199
        Then, it calculate the free memory that can be used for KV cache in
        bytes.
200

201
202
203
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
204
205
        """
        torch.cuda.empty_cache()
206
        torch.cuda.reset_peak_memory_stats()
207
        GiB = lambda b: b / GiB_bytes
208
209
210

        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
211
212
213
214
215
        with memory_profiling(
                self.init_snapshot,
                weights_memory=int(
                    self.model_runner.model_memory_usage)) as profile_result:
            self.model_runner.profile_run()
216

217
        free_gpu_memory = profile_result.after_profile.free_memory
218
219
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
220
        assert self.init_snapshot.free_memory > free_gpu_memory, (
221
            "Error in memory profiling. "
222
223
224
225
226
227
228
229
            f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
            f"current free memory {GiB(free_gpu_memory)} GiB. "
            "This happens when other processes sharing the same container "
            "release GPU memory while vLLM is profiling during initialization. "
            "To fix this, ensure consistent GPU memory allocation or "
            "isolate vLLM in its own container.")
        available_kv_cache_memory = self.requested_memory \
            - profile_result.non_kv_cache_memory
230
231
232

        logger.debug(
            "Initial free memory: %.2f GiB, free memory: %.2f GiB, "
233
234
235
236
237
238
239
            "requested GPU memory: %.2f GiB",
            GiB(self.init_snapshot.free_memory), GiB(free_gpu_memory),
            GiB(self.requested_memory))
        logger.debug(profile_result)
        logger.info("Available KV cache memory: %.2f GiB",
                    GiB(available_kv_cache_memory))
        gc.collect()
240

241
242
        return int(available_kv_cache_memory)

243
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
244
245
        return self.model_runner.get_kv_cache_spec()

246
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
247
        """Allocate GPU KV cache with the specified kv_cache_config."""
248
249
250
251
252
253
254
255
        if self.vllm_config.model_config.enable_sleep_mode:
            allocator = CuMemAllocator.get_instance()
            context = allocator.use_memory_pool(tag="kv_cache")
        else:
            from contextlib import nullcontext
            context = nullcontext()
        with context:
            self.model_runner.initialize_kv_cache(kv_cache_config)
256
257

    def compile_or_warm_up_model(self) -> None:
258
259
260
261
262
263
264
265
266
        # warm up sizes that are not in cudagraph capture sizes,
        # but users still want to compile for better performance,
        # e.g. for the max-num-batched token size in chunked prefill.
        warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
        if not self.model_config.enforce_eager:
            warmup_sizes = [
                x for x in warmup_sizes if x not in
                self.vllm_config.compilation_config.cudagraph_capture_sizes
            ]
267
        # We skip EPLB here since we don't want to record dummy metrics
268
269
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
270
            self.model_runner._dummy_run(size, skip_eplb=True)
271
272
        if not self.model_config.enforce_eager:
            self.model_runner.capture_model()
273
274
275
276
277
278

        # Warm up sampler and preallocate memory buffer for logits and other
        # sampling related tensors of max possible shape to avoid memory
        # fragmentation issue.
        # NOTE: This is called after `capture_model` on purpose to prevent
        # memory buffers from being cleared by `torch.cuda.empty_cache`.
279
        if get_pp_group().is_last_rank:
280
281
            max_num_reqs = min(self.scheduler_config.max_num_seqs,
                               self.scheduler_config.max_num_batched_tokens)
282

283
            # We skip EPLB here since we don't want to record dummy metrics
284
            hidden_states, last_hidden_states = \
285
286
287
288
                self.model_runner._dummy_run(
                    num_tokens=max_num_reqs,
                    skip_eplb=True,
                )
289
290
291
292
293
            if self.model_runner.is_pooling_model:
                self.model_runner._dummy_pooler_run(hidden_states)
            else:
                self.model_runner._dummy_sampler_run(
                    hidden_states=last_hidden_states)
294

295
296
297
298
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

299
300
301
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

302
303
304
305
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
306
    ) -> Optional[ModelRunnerOutput]:
307
308
309
310
311
312
313
314
        intermediate_tensors = None
        if not get_pp_group().is_first_rank:
            intermediate_tensors = IntermediateTensors(
                get_pp_group().recv_tensor_dict(
                    all_gather_group=get_tp_group()))

        output = self.model_runner.execute_model(scheduler_output,
                                                 intermediate_tensors)
315
316
317
        parallel_config = self.vllm_config.parallel_config
        if parallel_config.distributed_executor_backend != "external_launcher" \
            and not get_pp_group().is_last_rank:
318
319
320
321
322
            assert isinstance(output, IntermediateTensors)
            get_pp_group().send_tensor_dict(output.tensors,
                                            all_gather_group=get_tp_group())
            return None
        assert isinstance(output, ModelRunnerOutput)
323
        return output if self.is_driver_worker else None
324

325
    def profile(self, is_start: bool = True):
326
327
328
329
330
331
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        if is_start:
            self.profiler.start()
        else:
            self.profiler.stop()
332
333
            print(self.profiler.key_averages().table(
                sort_by="self_cuda_time_total"))
334

335
336
337
    def execute_dummy_batch(self) -> None:
        self.model_runner._dummy_run(1)

338
339
340
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

341
342
343
    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

344
    def list_loras(self) -> set[int]:
345
346
347
348
349
        return self.model_runner.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

350
351
352
353
    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

354
355
356
357
358
359
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
360
        from vllm.model_executor.model_loader import ShardedStateLoader
361
362
363
364
365
366
367
        ShardedStateLoader.save_model(
            self.model_runner.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

368
369
370
371
372
373
374
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        self.model_runner.save_tensorized_model(
            tensorizer_config=tensorizer_config, )

375
376

def init_worker_distributed_environment(
377
    vllm_config: VllmConfig,
378
379
380
    rank: int,
    distributed_init_method: Optional[str] = None,
    local_rank: int = -1,
381
    backend: str = "nccl",
382
383
) -> None:
    """Initialize the distributed environment."""
384
    parallel_config = vllm_config.parallel_config
385
386
387
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

    init_distributed_environment(parallel_config.world_size, rank,
388
                                 distributed_init_method, local_rank, backend)
389
390

    ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
391
                                      parallel_config.pipeline_parallel_size)
392

393
394
    ensure_kv_transfer_initialized(vllm_config)

395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411

def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
    # Check if the GPU supports the dtype.
    if torch_dtype == torch.bfloat16:  # noqa: SIM102
        if not current_platform.has_device_capability(80):
            capability = current_platform.get_device_capability()
            gpu_name = current_platform.get_device_name()

            if capability is None:
                compute_str = "does not have a compute capability"
            else:
                version_str = capability.as_version_str()
                compute_str = f"has compute capability {version_str}"

            raise ValueError(
                "Bfloat16 is only supported on GPUs with compute capability "
                f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
412
                "You can use float16 instead by explicitly setting the "
413
                "`dtype` flag in CLI, for example: --dtype=half.")