worker.py 21.7 KB
Newer Older
1
"""A GPU worker class."""
2
import gc
3
import os
4
from typing import Dict, List, Optional, Set, Tuple, Type, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6

import torch
7
import torch.distributed
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
import vllm.envs as envs
10
from vllm.config import ParallelConfig, VllmConfig
11
from vllm.distributed import (ensure_model_parallel_initialized,
12
13
                              init_distributed_environment,
                              set_custom_all_reduce)
14
from vllm.logger import init_logger
15
from vllm.lora.request import LoRARequest
16
from vllm.model_executor import set_random_seed
17
from vllm.model_executor.layers.sampler import SamplerOutput
18
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
19
from vllm.platforms import current_platform
20
from vllm.prompt_adapter.request import PromptAdapterRequest
21
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
22
                           SequenceGroupMetadata, SequenceGroupMetadataDelta)
Woosuk Kwon's avatar
Woosuk Kwon committed
23
from vllm.worker.cache_engine import CacheEngine
24
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
25
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
26
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
27
28
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
                                     WorkerInput)
Woosuk Kwon's avatar
Woosuk Kwon committed
29

30
31
logger = init_logger(__name__)

32

33
class Worker(LocalOrDistributedWorkerBase):
34
35
36
37
38
39
    """A worker class that executes (a partition of) the model on a GPU.

    Each worker is associated with a single GPU. The worker is responsible for
    maintaining the KV cache and executing the model on the GPU. In case of
    distributed inference, each worker is assigned a partition of the model.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
40
41
42

    def __init__(
        self,
43
        vllm_config: VllmConfig,
44
45
46
47
        local_rank: int,
        rank: int,
        distributed_init_method: str,
        is_driver_worker: bool = False,
48
        model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
49
    ) -> None:
50
        WorkerBase.__init__(self, vllm_config)
51
        self.parallel_config.rank = rank
52
        self.local_rank = local_rank
53
54
        self.rank = rank
        self.distributed_init_method = distributed_init_method
55
        self.is_driver_worker = is_driver_worker
56
57
        if is_driver_worker:
            assert rank % self.parallel_config.tensor_parallel_size == 0, \
58
                   "Driver worker should be rank 0 of tensor parallel group."
59
60
61
62
        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
            init_cached_hf_modules()
63

64
65
        # Return hidden states from target model if the draft model is an
        # mlp_speculator
66
67
        speculative_config = self.speculative_config
        model_config = self.model_config
68
69
70
        speculative_args = {} if speculative_config is None \
            or (speculative_config.draft_model_config.model ==
                model_config.model) \
71
            or (speculative_config.draft_model_config.hf_config.model_type
72
                not in ["medusa", "mlp_speculator", "eagle"]) \
73
                    else {"return_hidden_states": True}
74

75
        ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
76
77
        if model_runner_cls is not None:
            ModelRunnerClass = model_runner_cls
78
        elif model_config.task == "embedding":
79
            ModelRunnerClass = EmbeddingModelRunner
80
81
        elif self._is_encoder_decoder_model():
            ModelRunnerClass = EncoderDecoderModelRunner
82
        self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
83
            vllm_config=self.vllm_config,
84
            kv_cache_dtype=self.cache_config.cache_dtype,
85
            is_driver_worker=is_driver_worker,
86
            **speculative_args,
87
        )
88
        # Uninitialized cache engine. Will be initialized by
89
        # initialize_cache.
90
        self.cache_engine: List[CacheEngine]
91
        # Initialize gpu_cache as embedding models don't initialize kv_caches
92
        self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
93
        self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}
94

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        # Torch profiler. Enabled and configured through env vars:
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
        if envs.VLLM_TORCH_PROFILER_DIR:
            torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
            logger.info("Profiling enabled. Traces will be saved to: %s",
                        torch_profiler_trace_dir)
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                with_stack=True,
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
                    torch_profiler_trace_dir, use_gzip=True))
        else:
            self.profiler = None

    def start_profile(self):
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.start()

    def stop_profile(self):
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.stop()

122
    def _is_encoder_decoder_model(self):
123
        return self.model_config.is_encoder_decoder_model
124

125
    def init_device(self) -> None:
126
127
128
129
130
131
132
133
        if self.device_config.device.type == "cuda":
            # torch.distributed.all_reduce does not free the input tensor until
            # the synchronization point. This causes the memory usage to grow
            # as the number of all_reduce calls increases. This env var disables
            # this behavior.
            # Related issue:
            # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
            os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
134

135
136
137
138
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
            self.device = torch.device(f"cuda:{self.local_rank}")
            torch.cuda.set_device(self.device)
139

140
            _check_if_gpu_supports_dtype(self.model_config.dtype)
141
            gc.collect()
142
143
            torch.cuda.empty_cache()
            self.init_gpu_memory = torch.cuda.mem_get_info()[0]
144
145
146
        else:
            raise RuntimeError(
                f"Not support device type: {self.device_config.device}")
147
        # Initialize the distributed environment.
148
149
150
        init_worker_distributed_environment(self.parallel_config, self.rank,
                                            self.distributed_init_method,
                                            self.local_rank)
151
        # Set random seed.
152
        set_random_seed(self.model_config.seed)
153
154

    def load_model(self):
155
        self.model_runner.load_model()
156

157
158
159
160
161
162
163
164
165
166
167
168
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        self.model_runner.save_sharded_state(
            path,
            pattern=pattern,
            max_size=max_size,
        )

169
170
171
172
173
174
175
    def save_tensorized_model(
        self,
        tensorizer_config: TensorizerConfig,
    ) -> None:
        self.model_runner.save_tensorized_model(
            tensorizer_config=tensorizer_config, )

176
    @torch.inference_mode()
177
178
179
180
181
182
183
184
185
186
187
    def determine_num_available_blocks(self) -> Tuple[int, int]:
        """Profiles the peak memory usage of the model to determine how many
        KV blocks may be allocated without OOMs.

        The engine will first conduct a profiling of the existing memory usage.
        Then, it calculate the maximum possible number of GPU and CPU blocks
        that can be allocated with the remaining free memory.

        .. tip::
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
188
        """
189
190
191
        # Profile the memory usage of the model and get the maximum number of
        # cache blocks that can be allocated with the remaining free memory.
        torch.cuda.empty_cache()
192
193
194
        torch.cuda.reset_peak_memory_stats()

        free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
195

196
197
198
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
        self.model_runner.profile_run()
199
200
201
202
203
204
205
206
207
208
209
        torch.cuda.synchronize()

        self._assert_memory_footprint_increased_during_profiling()

        # Get the peak memory allocation recorded by torch
        peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"]

        # Check for any memory left around that may have been allocated on the
        # gpu outside of `torch`. NCCL operations, for example, can use a few
        # GB during a forward pass
        torch.cuda.empty_cache()
210
211
212
213
214
        torch_allocated_bytes = torch.cuda.memory_stats(
        )["allocated_bytes.all.current"]
        total_allocated_bytes = torch.cuda.mem_get_info(
        )[1] - torch.cuda.mem_get_info()[0]
        non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
215
216
217
218
219
220
        if non_torch_allocations > 0:
            peak_memory += non_torch_allocations

        available_kv_cache_memory = (
            total_gpu_memory * self.cache_config.gpu_memory_utilization -
            peak_memory)
221
222
223

        # Calculate the number of blocks that can be allocated with the
        # profiled peak memory.
224
        cache_block_size = self.get_cache_block_size_bytes()
225
226
227
228
        if cache_block_size == 0:
            num_gpu_blocks = 0
            num_cpu_blocks = 0
        else:
229
            num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
230
231
            num_cpu_blocks = int(self.cache_config.swap_space_bytes //
                                 cache_block_size)
232
233
        num_gpu_blocks = max(num_gpu_blocks, 0)
        num_cpu_blocks = max(num_cpu_blocks, 0)
234
235
236
237

        logger.info(
            "Memory profiling results: total_gpu_memory=%.2fGiB"
            " initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB"
238
            " memory_usage_post_profile=%.2fGib"
239
240
241
242
            " non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB"
            " gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3),
            (total_gpu_memory - free_memory_pre_profile) / (1024**3),
            (peak_memory - non_torch_allocations) / (1024**3),
243
            total_allocated_bytes / (1024**3),
244
245
246
247
248
            non_torch_allocations / (1024**3),
            available_kv_cache_memory / (1024**3),
            self.cache_config.gpu_memory_utilization)

        # Final cleanup
249
250
251
        if self.model_runner.lora_manager:
            self.model_runner.remove_all_loras()
        gc.collect()
252

253
254
        return num_gpu_blocks, num_cpu_blocks

255
256
257
258
259
260
261
262
263
264
    def _assert_memory_footprint_increased_during_profiling(self):
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
        free_gpu_memory, _ = torch.cuda.mem_get_info()
        assert self.init_gpu_memory - free_gpu_memory > 0, (
            "Error in memory profiling. "
            f"Initial free memory {self.init_gpu_memory}, current free memory"
            f" {free_gpu_memory}. This happens when the GPU memory was "
            "not properly cleaned up before initializing the vLLM instance.")

265
266
267
268
269
270
271
272
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Allocate GPU and CPU KV cache with the specified number of blocks.

        This also warms up the model, which may record CUDA graphs.
        """
        raise_if_cache_size_invalid(num_gpu_blocks,
                                    self.cache_config.block_size,
273
                                    self.cache_config.is_attention_free,
274
275
276
277
278
279
280
281
282
283
                                    self.model_config.max_model_len)

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self._init_cache_engine()
        self._warm_up_model()

    def _init_cache_engine(self):
        assert self.cache_config.num_gpu_blocks is not None
284
285
286
287
288
289
290
291
292
        self.cache_engine = [
            CacheEngine(self.cache_config, self.model_config,
                        self.parallel_config, self.device_config)
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]
        self.gpu_cache = [
            self.cache_engine[ve].gpu_cache
            for ve in range(self.parallel_config.pipeline_parallel_size)
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
293

294
    def _warm_up_model(self) -> None:
295
296
297
298
299
300
        if not self.model_config.enforce_eager:
            self.model_runner.capture_model(self.gpu_cache)
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

301
302
303
304
305
    @property
    def do_metadata_broadcast(self) -> bool:
        return self.parallel_config.tensor_parallel_size > 1

    @property
306
    def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
307
        return self.gpu_cache
308
309

    @torch.inference_mode()
310
311
    def prepare_worker_input(
            self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
312
        virtual_engine = execute_model_req.virtual_engine
313
        num_steps = execute_model_req.num_steps
314
        num_seq_groups = len(execute_model_req.seq_group_metadata_list)
315
316
317
318
319
320
321
        # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
        # they contain parameters to launch cudamemcpyasync.
        blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
                                         device="cpu",
                                         dtype=torch.int64).view(-1, 2)
        blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
                                          device="cpu",
322
                                          dtype=torch.int64).view(-1, 2)
323
324
325
326
327
328
        # `blocks_to_copy` is a gpu tensor. The src and tgt of
        # blocks to copy are in the same device, and `blocks_to_copy`
        # can be used directly within cuda kernels.
        blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
                                      device=self.device,
                                      dtype=torch.int64).view(-1, 2)
Woosuk Kwon's avatar
Woosuk Kwon committed
329

330
331
332
333
334
        return WorkerInput(
            num_seq_groups=num_seq_groups,
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy,
335
            virtual_engine=virtual_engine,
336
            num_steps=num_steps,
337
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
338

339
    @torch.inference_mode()
340
    def execute_worker(self, worker_input: WorkerInput) -> None:
341
        virtual_engine = worker_input.virtual_engine
342
343
344
        # Issue cache operations.
        if (worker_input.blocks_to_swap_in is not None
                and worker_input.blocks_to_swap_in.numel() > 0):
345
346
            self.cache_engine[virtual_engine].swap_in(
                worker_input.blocks_to_swap_in)
347
348
        if (worker_input.blocks_to_swap_out is not None
                and worker_input.blocks_to_swap_out.numel() > 0):
349
350
            self.cache_engine[virtual_engine].swap_out(
                worker_input.blocks_to_swap_out)
351
352
        if (worker_input.blocks_to_copy is not None
                and worker_input.blocks_to_copy.numel() > 0):
353
            self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
354

355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
    def _get_cached_seq_group_metadata(
            self,
            seq_group_metadata_list: List[Union[SequenceGroupMetadata,
                                                SequenceGroupMetadataDelta]],
            finished_request_ids: List[str]) -> List[SequenceGroupMetadata]:
        """Return a list of cached Sequence Group Metadata after updating its
        state.

        It is used because scheduler only sends delta to workers to reduce
        the data payload size. The function also cleans up cache based on
        a given `finished_request_ids`.
        """
        new_seq_group_metadata_list = []
        for metadata_or_delta in seq_group_metadata_list:
            request_id = metadata_or_delta.request_id
            if request_id not in self._seq_group_metadata_cache:
                # The first prefill.
                assert isinstance(metadata_or_delta, SequenceGroupMetadata)
                self._seq_group_metadata_cache[request_id] = metadata_or_delta
            else:
                # The first prefill is already cached.
                if isinstance(metadata_or_delta, SequenceGroupMetadataDelta):
                    self._seq_group_metadata_cache[request_id].apply_delta(
                        metadata_or_delta)
                else:
                    # If metadata snapshot is sent again, it is
                    # preempted. Reset the cache because we need to start
                    # from scratch.
                    assert isinstance(metadata_or_delta, SequenceGroupMetadata)
                    self._seq_group_metadata_cache[
                        request_id] = metadata_or_delta

            new_seq_group_metadata_list.append(
                self._seq_group_metadata_cache[request_id])

        # Clean up finished ids
        for finished_id in finished_request_ids:
            del self._seq_group_metadata_cache[finished_id]

        return new_seq_group_metadata_list

    def _execute_model_spmd(
        self,
        execute_model_req: ExecuteModelRequest,
        intermediate_tensors: Optional[IntermediateTensors] = None,
    ) -> Optional[List[SamplerOutput]]:
        if execute_model_req is not None:
            new_seq_group_metadata_list = self._get_cached_seq_group_metadata(
                execute_model_req.seq_group_metadata_list,
                execute_model_req.finished_requests_ids)

            execute_model_req.seq_group_metadata_list = (
                new_seq_group_metadata_list)
        output = super()._execute_model_spmd(execute_model_req,
                                             intermediate_tensors)
        return output

412
413
414
415
416
417
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

418
419
420
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

421
422
423
    def list_loras(self) -> Set[int]:
        return self.model_runner.list_loras()

424
425
426
427
428
429
430
431
432
433
434
435
436
    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
        return self.model_runner.add_prompt_adapter(prompt_adapter_request)

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        return self.model_runner.remove_lora(prompt_adapter_id)

    def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        return self.model_runner.pin_prompt_adapter(prompt_adapter_id)

    def list_prompt_adapters(self) -> Set[int]:
        return self.model_runner.list_prompt_adapters()

437
438
439
440
441
442
443
444
    @property
    def max_model_len(self) -> int:
        return self.model_config.max_model_len

    @property
    def vocab_size(self) -> int:
        return self.model_runner.vocab_size

445
    def get_cache_block_size_bytes(self) -> int:
446
447
        """Get the size of the KV cache block size in bytes.
        """
448
        return CacheEngine.get_cache_block_size(self.cache_config,
449
450
451
                                                self.model_config,
                                                self.parallel_config)

Woosuk Kwon's avatar
Woosuk Kwon committed
452

453
def init_worker_distributed_environment(
454
455
    parallel_config: ParallelConfig,
    rank: int,
456
    distributed_init_method: Optional[str] = None,
457
    local_rank: int = -1,
458
459
) -> None:
    """Initialize the distributed environment."""
460
461
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

462
463
    init_distributed_environment(parallel_config.world_size, rank,
                                 distributed_init_method, local_rank)
464

465
466
467
    ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
                                      parallel_config.pipeline_parallel_size)

468

469
470
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
    # Check if the GPU supports the dtype.
471
472
473
    if torch_dtype == torch.bfloat16:  # noqa: SIM102
        if not current_platform.has_device_capability(80):
            capability = current_platform.get_device_capability()
474
            gpu_name = current_platform.get_device_name()
475
476
477
478
479
480
481

            if capability is None:
                compute_str = "does not have a compute capability"
            else:
                version_str = capability.as_version_str()
                compute_str = f"has compute capability {version_str}"

482
483
            raise ValueError(
                "Bfloat16 is only supported on GPUs with compute capability "
484
                f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
Woosuk Kwon's avatar
Woosuk Kwon committed
485
486
                "You can use float16 instead by explicitly setting the"
                "`dtype` flag in CLI, for example: --dtype=half.")
487
488


489
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free,
490
                                max_model_len) -> None:
491
492
493
494
495
    if is_attention_free and num_gpu_blocks != 0:
        raise ValueError("No memory should be allocated for the cache blocks "
                         f"for an attention-free model, but {num_gpu_blocks}"
                         "blocks are allocated.")
    if not is_attention_free and num_gpu_blocks <= 0:
496
497
498
499
        raise ValueError("No available memory for the cache blocks. "
                         "Try increasing `gpu_memory_utilization` when "
                         "initializing the engine.")
    max_seq_len = block_size * num_gpu_blocks
500
    if not is_attention_free and max_model_len > max_seq_len:
501
502
503
504
505
506
        raise ValueError(
            f"The model's max seq len ({max_model_len}) "
            "is larger than the maximum number of tokens that can be "
            f"stored in KV cache ({max_seq_len}). Try increasing "
            "`gpu_memory_utilization` or decreasing `max_model_len` when "
            "initializing the engine.")