worker.py 21.6 KB
Newer Older
1
"""A GPU worker class."""
2
import gc
3
import os
4
from typing import Dict, List, Optional, Set, Tuple, Type, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6

import torch
7
import torch.distributed
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
import vllm.envs as envs
10
11
12
from vllm.config import VllmConfig
from vllm.distributed import (ensure_kv_transfer_initialized,
                              ensure_model_parallel_initialized,
13
14
                              init_distributed_environment,
                              set_custom_all_reduce)
15
from vllm.logger import init_logger
16
from vllm.lora.request import LoRARequest
17
from vllm.model_executor import set_random_seed
18
from vllm.model_executor.layers.sampler import SamplerOutput
19
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
20
from vllm.platforms import current_platform
21
from vllm.prompt_adapter.request import PromptAdapterRequest
22
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
23
                           SequenceGroupMetadata, SequenceGroupMetadataDelta)
24
from vllm.utils import GiB_bytes, bind_kv_cache, memory_profiling
Woosuk Kwon's avatar
Woosuk Kwon committed
25
from vllm.worker.cache_engine import CacheEngine
26
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
27
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
28
from vllm.worker.pooling_model_runner import PoolingModelRunner
29
30
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
                                     WorkerInput)
Woosuk Kwon's avatar
Woosuk Kwon committed
31

32
33
logger = init_logger(__name__)

34

35
class Worker(LocalOrDistributedWorkerBase):
36
37
38
39
40
41
    """A worker class that executes (a partition of) the model on a GPU.

    Each worker is associated with a single GPU. The worker is responsible for
    maintaining the KV cache and executing the model on the GPU. In case of
    distributed inference, each worker is assigned a partition of the model.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
42
43
44

    def __init__(
        self,
45
        vllm_config: VllmConfig,
46
47
48
49
        local_rank: int,
        rank: int,
        distributed_init_method: str,
        is_driver_worker: bool = False,
50
        model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
51
    ) -> None:
52
        WorkerBase.__init__(self, vllm_config)
53
        self.parallel_config.rank = rank
54
        self.local_rank = local_rank
55
56
        self.rank = rank
        self.distributed_init_method = distributed_init_method
57
        self.is_driver_worker = is_driver_worker
58
59
60
61
        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
            init_cached_hf_modules()
62

63
64
        # Return hidden states from target model if the draft model is an
        # mlp_speculator
65
66
        speculative_config = self.speculative_config
        model_config = self.model_config
67
68
69
        speculative_args = {} if speculative_config is None \
            or (speculative_config.draft_model_config.model ==
                model_config.model) \
70
            or (speculative_config.draft_model_config.hf_config.model_type
71
                not in ["medusa", "mlp_speculator", "eagle"]) \
72
                    else {"return_hidden_states": True}
73

74
        ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
75
        if model_config.runner_type == "pooling":
76
            ModelRunnerClass = PoolingModelRunner
77
        elif self.model_config.is_encoder_decoder:
78
            ModelRunnerClass = EncoderDecoderModelRunner
79
        self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
80
            vllm_config=self.vllm_config,
81
            kv_cache_dtype=self.cache_config.cache_dtype,
82
            is_driver_worker=is_driver_worker,
83
            **speculative_args,
84
        )
85
86
87
        if model_runner_cls is not None:
            self.model_runner = model_runner_cls(self.model_runner)

88
        # Uninitialized cache engine. Will be initialized by
89
        # initialize_cache.
90
        self.cache_engine: List[CacheEngine]
91
        # Initialize gpu_cache as pooling models don't initialize kv_caches
92
        self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
93
        self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}
94

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        # Torch profiler. Enabled and configured through env vars:
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
        if envs.VLLM_TORCH_PROFILER_DIR:
            torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
            logger.info("Profiling enabled. Traces will be saved to: %s",
                        torch_profiler_trace_dir)
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                with_stack=True,
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
                    torch_profiler_trace_dir, use_gzip=True))
        else:
            self.profiler = None

    def start_profile(self):
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.start()

    def stop_profile(self):
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.stop()

122
    def init_device(self) -> None:
123
124
125
126
127
128
129
130
        if self.device_config.device.type == "cuda":
            # torch.distributed.all_reduce does not free the input tensor until
            # the synchronization point. This causes the memory usage to grow
            # as the number of all_reduce calls increases. This env var disables
            # this behavior.
            # Related issue:
            # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
            os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
131

132
133
134
135
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
            self.device = torch.device(f"cuda:{self.local_rank}")
            torch.cuda.set_device(self.device)
136

137
            _check_if_gpu_supports_dtype(self.model_config.dtype)
138
            gc.collect()
139
140
            torch.cuda.empty_cache()
            self.init_gpu_memory = torch.cuda.mem_get_info()[0]
141
142
143
        else:
            raise RuntimeError(
                f"Not support device type: {self.device_config.device}")
144
        # Initialize the distributed environment.
145
        init_worker_distributed_environment(self.vllm_config, self.rank,
146
147
                                            self.distributed_init_method,
                                            self.local_rank)
148
        # Set random seed.
149
        set_random_seed(self.model_config.seed)
150
151

    def load_model(self):
152
        self.model_runner.load_model()
153

154
155
156
157
158
159
160
161
162
163
164
165
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        self.model_runner.save_sharded_state(
            path,
            pattern=pattern,
            max_size=max_size,
        )

166
167
168
169
170
171
172
    def save_tensorized_model(
        self,
        tensorizer_config: TensorizerConfig,
    ) -> None:
        self.model_runner.save_tensorized_model(
            tensorizer_config=tensorizer_config, )

173
    @torch.inference_mode()
174
175
176
177
178
179
180
181
182
183
184
    def determine_num_available_blocks(self) -> Tuple[int, int]:
        """Profiles the peak memory usage of the model to determine how many
        KV blocks may be allocated without OOMs.

        The engine will first conduct a profiling of the existing memory usage.
        Then, it calculate the maximum possible number of GPU and CPU blocks
        that can be allocated with the remaining free memory.

        .. tip::
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
185
        """
186
187
188
        # Profile the memory usage of the model and get the maximum number of
        # cache blocks that can be allocated with the remaining free memory.
        torch.cuda.empty_cache()
189
190
191
        torch.cuda.reset_peak_memory_stats()

        free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
192

193
194
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
195
196
197
198
199
        with memory_profiling(baseline_memory_in_bytes=total_gpu_memory -
                              self.init_gpu_memory,
                              weights_memory_in_bytes=self.model_runner.
                              model_memory_usage) as result:
            self.model_runner.profile_run()
200
201
202

        self._assert_memory_footprint_increased_during_profiling()

203
204
205
206
        memory_for_current_instance = total_gpu_memory * \
            self.cache_config.gpu_memory_utilization
        available_kv_cache_memory = (memory_for_current_instance -
                                     result.non_kv_cache_memory_in_bytes)
207
208
209

        # Calculate the number of blocks that can be allocated with the
        # profiled peak memory.
210
        cache_block_size = self.get_cache_block_size_bytes()
211
212
213
214
        if cache_block_size == 0:
            num_gpu_blocks = 0
            num_cpu_blocks = 0
        else:
215
            num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
216
217
            num_cpu_blocks = int(self.cache_config.swap_space_bytes //
                                 cache_block_size)
218
219
        num_gpu_blocks = max(num_gpu_blocks, 0)
        num_cpu_blocks = max(num_cpu_blocks, 0)
220

221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n"
               "the current vLLM instance can use "
               "total_gpu_memory "
               f"({(total_gpu_memory / GiB_bytes):.2f}GiB)"
               " x gpu_memory_utilization "
               f"({self.cache_config.gpu_memory_utilization:.2f})"
               f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n"
               "model weights take "
               f"{(result.weights_memory_in_bytes / GiB_bytes):.2f}GiB;"
               " non_torch_memory takes "
               f"{(result.non_torch_increase_in_bytes / GiB_bytes):.2f}GiB;"
               " PyTorch activation peak memory takes "
               f"{(result.torch_peak_increase_in_bytes / GiB_bytes):.2f}GiB;"
               " the rest of the memory reserved for KV Cache is "
               f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.")

        logger.info(msg)
238
239

        # Final cleanup
240
241
242
        if self.model_runner.lora_manager:
            self.model_runner.remove_all_loras()
        gc.collect()
243

244
245
        return num_gpu_blocks, num_cpu_blocks

246
247
248
249
250
251
252
253
254
255
    def _assert_memory_footprint_increased_during_profiling(self):
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
        free_gpu_memory, _ = torch.cuda.mem_get_info()
        assert self.init_gpu_memory - free_gpu_memory > 0, (
            "Error in memory profiling. "
            f"Initial free memory {self.init_gpu_memory}, current free memory"
            f" {free_gpu_memory}. This happens when the GPU memory was "
            "not properly cleaned up before initializing the vLLM instance.")

256
257
258
259
260
261
262
263
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Allocate GPU and CPU KV cache with the specified number of blocks.

        This also warms up the model, which may record CUDA graphs.
        """
        raise_if_cache_size_invalid(num_gpu_blocks,
                                    self.cache_config.block_size,
264
                                    self.cache_config.is_attention_free,
265
266
267
268
269
270
271
272
273
274
                                    self.model_config.max_model_len)

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self._init_cache_engine()
        self._warm_up_model()

    def _init_cache_engine(self):
        assert self.cache_config.num_gpu_blocks is not None
275
276
277
278
279
280
281
282
283
        self.cache_engine = [
            CacheEngine(self.cache_config, self.model_config,
                        self.parallel_config, self.device_config)
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]
        self.gpu_cache = [
            self.cache_engine[ve].gpu_cache
            for ve in range(self.parallel_config.pipeline_parallel_size)
        ]
284
285
        bind_kv_cache(self.compilation_config.static_forward_context,
                      self.gpu_cache)
Woosuk Kwon's avatar
Woosuk Kwon committed
286

287
    def _warm_up_model(self) -> None:
288
289
290
291
292
293
        if not self.model_config.enforce_eager:
            self.model_runner.capture_model(self.gpu_cache)
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

294
295
296
297
298
    @property
    def do_metadata_broadcast(self) -> bool:
        return self.parallel_config.tensor_parallel_size > 1

    @property
299
    def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
300
        return self.gpu_cache
301
302

    @torch.inference_mode()
303
304
    def prepare_worker_input(
            self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
305
        virtual_engine = execute_model_req.virtual_engine
306
        num_steps = execute_model_req.num_steps
307
        num_seq_groups = len(execute_model_req.seq_group_metadata_list)
308
309
310
311
312
313
314
        # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
        # they contain parameters to launch cudamemcpyasync.
        blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
                                         device="cpu",
                                         dtype=torch.int64).view(-1, 2)
        blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
                                          device="cpu",
315
                                          dtype=torch.int64).view(-1, 2)
316
317
318
319
320
321
        # `blocks_to_copy` is a gpu tensor. The src and tgt of
        # blocks to copy are in the same device, and `blocks_to_copy`
        # can be used directly within cuda kernels.
        blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
                                      device=self.device,
                                      dtype=torch.int64).view(-1, 2)
Woosuk Kwon's avatar
Woosuk Kwon committed
322

323
324
325
326
327
        return WorkerInput(
            num_seq_groups=num_seq_groups,
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy,
328
            virtual_engine=virtual_engine,
329
            num_steps=num_steps,
330
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
331

332
    @torch.inference_mode()
333
    def execute_worker(self, worker_input: WorkerInput) -> None:
334
        virtual_engine = worker_input.virtual_engine
335
336
337
        # Issue cache operations.
        if (worker_input.blocks_to_swap_in is not None
                and worker_input.blocks_to_swap_in.numel() > 0):
338
339
            self.cache_engine[virtual_engine].swap_in(
                worker_input.blocks_to_swap_in)
340
341
        if (worker_input.blocks_to_swap_out is not None
                and worker_input.blocks_to_swap_out.numel() > 0):
342
343
            self.cache_engine[virtual_engine].swap_out(
                worker_input.blocks_to_swap_out)
344
345
        if (worker_input.blocks_to_copy is not None
                and worker_input.blocks_to_copy.numel() > 0):
346
            self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
347

348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
    def _get_cached_seq_group_metadata(
            self,
            seq_group_metadata_list: List[Union[SequenceGroupMetadata,
                                                SequenceGroupMetadataDelta]],
            finished_request_ids: List[str]) -> List[SequenceGroupMetadata]:
        """Return a list of cached Sequence Group Metadata after updating its
        state.

        It is used because scheduler only sends delta to workers to reduce
        the data payload size. The function also cleans up cache based on
        a given `finished_request_ids`.
        """
        new_seq_group_metadata_list = []
        for metadata_or_delta in seq_group_metadata_list:
            request_id = metadata_or_delta.request_id
            if request_id not in self._seq_group_metadata_cache:
                # The first prefill.
                assert isinstance(metadata_or_delta, SequenceGroupMetadata)
                self._seq_group_metadata_cache[request_id] = metadata_or_delta
            else:
                # The first prefill is already cached.
                if isinstance(metadata_or_delta, SequenceGroupMetadataDelta):
                    self._seq_group_metadata_cache[request_id].apply_delta(
                        metadata_or_delta)
                else:
                    # If metadata snapshot is sent again, it is
                    # preempted. Reset the cache because we need to start
                    # from scratch.
                    assert isinstance(metadata_or_delta, SequenceGroupMetadata)
                    self._seq_group_metadata_cache[
                        request_id] = metadata_or_delta

            new_seq_group_metadata_list.append(
                self._seq_group_metadata_cache[request_id])

        # Clean up finished ids
        for finished_id in finished_request_ids:
            del self._seq_group_metadata_cache[finished_id]

        return new_seq_group_metadata_list

    def _execute_model_spmd(
        self,
        execute_model_req: ExecuteModelRequest,
        intermediate_tensors: Optional[IntermediateTensors] = None,
    ) -> Optional[List[SamplerOutput]]:
        if execute_model_req is not None:
            new_seq_group_metadata_list = self._get_cached_seq_group_metadata(
                execute_model_req.seq_group_metadata_list,
                execute_model_req.finished_requests_ids)

            execute_model_req.seq_group_metadata_list = (
                new_seq_group_metadata_list)
        output = super()._execute_model_spmd(execute_model_req,
                                             intermediate_tensors)
        return output

405
406
407
408
409
410
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

411
412
413
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

414
415
416
    def list_loras(self) -> Set[int]:
        return self.model_runner.list_loras()

417
418
419
420
421
422
423
424
425
426
427
428
429
    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
        return self.model_runner.add_prompt_adapter(prompt_adapter_request)

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        return self.model_runner.remove_lora(prompt_adapter_id)

    def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        return self.model_runner.pin_prompt_adapter(prompt_adapter_id)

    def list_prompt_adapters(self) -> Set[int]:
        return self.model_runner.list_prompt_adapters()

430
431
432
433
434
435
436
437
    @property
    def max_model_len(self) -> int:
        return self.model_config.max_model_len

    @property
    def vocab_size(self) -> int:
        return self.model_runner.vocab_size

438
    def get_cache_block_size_bytes(self) -> int:
439
440
        """Get the size of the KV cache block size in bytes.
        """
441
        return CacheEngine.get_cache_block_size(self.cache_config,
442
443
444
                                                self.model_config,
                                                self.parallel_config)

Woosuk Kwon's avatar
Woosuk Kwon committed
445

446
def init_worker_distributed_environment(
447
    vllm_config: VllmConfig,
448
    rank: int,
449
    distributed_init_method: Optional[str] = None,
450
    local_rank: int = -1,
451
452
) -> None:
    """Initialize the distributed environment."""
453
    parallel_config = vllm_config.parallel_config
454
455
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

456
457
    init_distributed_environment(parallel_config.world_size, rank,
                                 distributed_init_method, local_rank)
458
459
460
    ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
                                      parallel_config.pipeline_parallel_size)

461
462
    ensure_kv_transfer_initialized(vllm_config)

463

464
465
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
    # Check if the GPU supports the dtype.
466
467
468
    if torch_dtype == torch.bfloat16:  # noqa: SIM102
        if not current_platform.has_device_capability(80):
            capability = current_platform.get_device_capability()
469
            gpu_name = current_platform.get_device_name()
470
471
472
473
474
475
476

            if capability is None:
                compute_str = "does not have a compute capability"
            else:
                version_str = capability.as_version_str()
                compute_str = f"has compute capability {version_str}"

477
478
            raise ValueError(
                "Bfloat16 is only supported on GPUs with compute capability "
479
                f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
Woosuk Kwon's avatar
Woosuk Kwon committed
480
481
                "You can use float16 instead by explicitly setting the"
                "`dtype` flag in CLI, for example: --dtype=half.")
482
483


484
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free,
485
                                max_model_len) -> None:
486
487
488
489
490
    if is_attention_free and num_gpu_blocks != 0:
        raise ValueError("No memory should be allocated for the cache blocks "
                         f"for an attention-free model, but {num_gpu_blocks}"
                         "blocks are allocated.")
    if not is_attention_free and num_gpu_blocks <= 0:
491
492
493
494
        raise ValueError("No available memory for the cache blocks. "
                         "Try increasing `gpu_memory_utilization` when "
                         "initializing the engine.")
    max_seq_len = block_size * num_gpu_blocks
495
    if not is_attention_free and max_model_len > max_seq_len:
496
497
498
499
500
501
        raise ValueError(
            f"The model's max seq len ({max_model_len}) "
            "is larger than the maximum number of tokens that can be "
            f"stored in KV cache ({max_seq_len}). Try increasing "
            "`gpu_memory_utilization` or decreasing `max_model_len` when "
            "initializing the engine.")