worker.py 20.8 KB
Newer Older
1
"""A GPU worker class."""
2
import gc
3
import os
4
from typing import Dict, List, Optional, Set, Tuple, Type, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6

import torch
7
import torch.distributed
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
import vllm.envs as envs
10
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
11
12
                         ModelConfig, ObservabilityConfig, ParallelConfig,
                         PromptAdapterConfig, SchedulerConfig,
13
                         SpeculativeConfig)
14
from vllm.distributed import (ensure_model_parallel_initialized,
15
16
                              init_distributed_environment,
                              set_custom_all_reduce)
17
from vllm.logger import init_logger
18
from vllm.lora.request import LoRARequest
19
from vllm.model_executor import set_random_seed
20
from vllm.model_executor.layers.sampler import SamplerOutput
21
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
22
from vllm.platforms import current_platform
23
from vllm.prompt_adapter.request import PromptAdapterRequest
24
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
25
                           SequenceGroupMetadata, SequenceGroupMetadataDelta)
Woosuk Kwon's avatar
Woosuk Kwon committed
26
from vllm.worker.cache_engine import CacheEngine
27
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
28
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
29
30
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput
Woosuk Kwon's avatar
Woosuk Kwon committed
31

32
33
logger = init_logger(__name__)

34

35
class Worker(LocalOrDistributedWorkerBase):
36
37
38
39
40
41
    """A worker class that executes (a partition of) the model on a GPU.

    Each worker is associated with a single GPU. The worker is responsible for
    maintaining the KV cache and executing the model on the GPU. In case of
    distributed inference, each worker is assigned a partition of the model.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
42
43
44

    def __init__(
        self,
45
46
47
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
48
        device_config: DeviceConfig,
49
        cache_config: CacheConfig,
50
        load_config: LoadConfig,
51
52
53
        local_rank: int,
        rank: int,
        distributed_init_method: str,
54
        lora_config: Optional[LoRAConfig] = None,
55
        speculative_config: Optional[SpeculativeConfig] = None,
56
        prompt_adapter_config: Optional[PromptAdapterConfig] = None,
57
        is_driver_worker: bool = False,
58
        model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
59
        observability_config: Optional[ObservabilityConfig] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
60
    ) -> None:
61
62
        self.model_config = model_config
        self.parallel_config = parallel_config
63
        self.parallel_config.rank = rank
64
        self.scheduler_config = scheduler_config
65
        self.device_config = device_config
66
        self.cache_config = cache_config
67
        self.local_rank = local_rank
68
69
        self.rank = rank
        self.distributed_init_method = distributed_init_method
70
        self.lora_config = lora_config
71
        self.load_config = load_config
72
        self.prompt_adapter_config = prompt_adapter_config
73
        self.is_driver_worker = is_driver_worker
74
75
76
        if parallel_config and is_driver_worker:
            assert rank % parallel_config.tensor_parallel_size == 0, \
                   "Driver worker should be rank 0 of tensor parallel group."
77
78
79
80
        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
            init_cached_hf_modules()
81
        self.observability_config = observability_config
82

83
84
85
86
87
        # Return hidden states from target model if the draft model is an
        # mlp_speculator
        speculative_args = {} if speculative_config is None \
            or (speculative_config.draft_model_config.model ==
                model_config.model) \
88
            or (speculative_config.draft_model_config.hf_config.model_type
89
                not in ["medusa", "mlp_speculator", "eagle"]) \
90
                    else {"return_hidden_states": True}
91

92
        ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
93
94
        if model_runner_cls is not None:
            ModelRunnerClass = model_runner_cls
95
        elif self._is_embedding_model():
96
            ModelRunnerClass = EmbeddingModelRunner
97
98
        elif self._is_encoder_decoder_model():
            ModelRunnerClass = EncoderDecoderModelRunner
99
        self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
100
101
102
103
            model_config,
            parallel_config,
            scheduler_config,
            device_config,
104
            cache_config,
105
            load_config=load_config,
106
            lora_config=self.lora_config,
107
            kv_cache_dtype=self.cache_config.cache_dtype,
108
            is_driver_worker=is_driver_worker,
109
            prompt_adapter_config=prompt_adapter_config,
110
            observability_config=observability_config,
111
            **speculative_args,
112
        )
113
        # Uninitialized cache engine. Will be initialized by
114
        # initialize_cache.
115
        self.cache_engine: List[CacheEngine]
116
        # Initialize gpu_cache as embedding models don't initialize kv_caches
117
        self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
118
        self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}
119

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        # Torch profiler. Enabled and configured through env vars:
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
        if envs.VLLM_TORCH_PROFILER_DIR:
            torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
            logger.info("Profiling enabled. Traces will be saved to: %s",
                        torch_profiler_trace_dir)
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                with_stack=True,
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
                    torch_profiler_trace_dir, use_gzip=True))
        else:
            self.profiler = None

    def start_profile(self):
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.start()

    def stop_profile(self):
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.stop()

147
    def _is_encoder_decoder_model(self):
148
        return self.model_config.is_encoder_decoder_model
149
150

    def _is_embedding_model(self):
151
        return self.model_config.is_embedding_model
152

153
    def init_device(self) -> None:
154
155
156
157
158
159
160
161
        if self.device_config.device.type == "cuda":
            # torch.distributed.all_reduce does not free the input tensor until
            # the synchronization point. This causes the memory usage to grow
            # as the number of all_reduce calls increases. This env var disables
            # this behavior.
            # Related issue:
            # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
            os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
162

163
164
165
166
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
            self.device = torch.device(f"cuda:{self.local_rank}")
            torch.cuda.set_device(self.device)
167

168
            _check_if_gpu_supports_dtype(self.model_config.dtype)
169
            gc.collect()
170
171
            torch.cuda.empty_cache()
            self.init_gpu_memory = torch.cuda.mem_get_info()[0]
172
173
174
        else:
            raise RuntimeError(
                f"Not support device type: {self.device_config.device}")
175
        # Initialize the distributed environment.
176
177
178
        init_worker_distributed_environment(self.parallel_config, self.rank,
                                            self.distributed_init_method,
                                            self.local_rank)
179
        # Set random seed.
180
        set_random_seed(self.model_config.seed)
181
182

    def load_model(self):
183
        self.model_runner.load_model()
184

185
186
187
188
189
190
191
192
193
194
195
196
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        self.model_runner.save_sharded_state(
            path,
            pattern=pattern,
            max_size=max_size,
        )

197
198
199
200
201
202
203
    def save_tensorized_model(
        self,
        tensorizer_config: TensorizerConfig,
    ) -> None:
        self.model_runner.save_tensorized_model(
            tensorizer_config=tensorizer_config, )

204
    @torch.inference_mode()
205
206
207
208
209
210
211
212
213
214
215
    def determine_num_available_blocks(self) -> Tuple[int, int]:
        """Profiles the peak memory usage of the model to determine how many
        KV blocks may be allocated without OOMs.

        The engine will first conduct a profiling of the existing memory usage.
        Then, it calculate the maximum possible number of GPU and CPU blocks
        that can be allocated with the remaining free memory.

        .. tip::
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
216
        """
217
218
219
220
        # Profile the memory usage of the model and get the maximum number of
        # cache blocks that can be allocated with the remaining free memory.
        torch.cuda.empty_cache()

221
222
223
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
        self.model_runner.profile_run()
224
225
226
227

        # Calculate the number of blocks that can be allocated with the
        # profiled peak memory.
        torch.cuda.synchronize()
228
        free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
229
230
231
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
        peak_memory = self.init_gpu_memory - free_gpu_memory
232
        assert peak_memory > 0, (
233
234
235
            "Error in memory profiling. "
            f"Initial free memory {self.init_gpu_memory}, current free memory"
            f" {free_gpu_memory}. This happens when the GPU memory was "
236
            "not properly cleaned up before initializing the vLLM instance.")
237

238
        cache_block_size = self.get_cache_block_size_bytes()
239
        num_gpu_blocks = int(
240
241
242
243
            (total_gpu_memory * self.cache_config.gpu_memory_utilization -
             peak_memory) // cache_block_size)
        num_cpu_blocks = int(self.cache_config.swap_space_bytes //
                             cache_block_size)
244
245
        num_gpu_blocks = max(num_gpu_blocks, 0)
        num_cpu_blocks = max(num_cpu_blocks, 0)
246
247
248
        if self.model_runner.lora_manager:
            self.model_runner.remove_all_loras()
        gc.collect()
249
250
251
        torch.cuda.empty_cache()
        return num_gpu_blocks, num_cpu_blocks

252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Allocate GPU and CPU KV cache with the specified number of blocks.

        This also warms up the model, which may record CUDA graphs.
        """
        raise_if_cache_size_invalid(num_gpu_blocks,
                                    self.cache_config.block_size,
                                    self.model_config.max_model_len)

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self._init_cache_engine()
        self._warm_up_model()

    def _init_cache_engine(self):
        assert self.cache_config.num_gpu_blocks is not None
270
271
272
273
274
275
276
277
278
        self.cache_engine = [
            CacheEngine(self.cache_config, self.model_config,
                        self.parallel_config, self.device_config)
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]
        self.gpu_cache = [
            self.cache_engine[ve].gpu_cache
            for ve in range(self.parallel_config.pipeline_parallel_size)
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
279

280
    def _warm_up_model(self) -> None:
281
282
283
284
285
286
        if not self.model_config.enforce_eager:
            self.model_runner.capture_model(self.gpu_cache)
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

287
288
289
290
291
    @property
    def do_metadata_broadcast(self) -> bool:
        return self.parallel_config.tensor_parallel_size > 1

    @property
292
    def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
293
        return self.gpu_cache
294
295

    @torch.inference_mode()
296
297
    def prepare_worker_input(
            self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
298
        virtual_engine = execute_model_req.virtual_engine
299
        num_steps = execute_model_req.num_steps
300
        num_seq_groups = len(execute_model_req.seq_group_metadata_list)
301
302
303
304
305
306
307
        # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
        # they contain parameters to launch cudamemcpyasync.
        blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
                                         device="cpu",
                                         dtype=torch.int64).view(-1, 2)
        blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
                                          device="cpu",
308
                                          dtype=torch.int64).view(-1, 2)
309
310
311
312
313
314
        # `blocks_to_copy` is a gpu tensor. The src and tgt of
        # blocks to copy are in the same device, and `blocks_to_copy`
        # can be used directly within cuda kernels.
        blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
                                      device=self.device,
                                      dtype=torch.int64).view(-1, 2)
Woosuk Kwon's avatar
Woosuk Kwon committed
315

316
317
318
319
320
        return WorkerInput(
            num_seq_groups=num_seq_groups,
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy,
321
            virtual_engine=virtual_engine,
322
            num_steps=num_steps,
323
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
324

325
    @torch.inference_mode()
326
    def execute_worker(self, worker_input: WorkerInput) -> None:
327
        virtual_engine = worker_input.virtual_engine
328
329
330
        # Issue cache operations.
        if (worker_input.blocks_to_swap_in is not None
                and worker_input.blocks_to_swap_in.numel() > 0):
331
332
            self.cache_engine[virtual_engine].swap_in(
                worker_input.blocks_to_swap_in)
333
334
        if (worker_input.blocks_to_swap_out is not None
                and worker_input.blocks_to_swap_out.numel() > 0):
335
336
            self.cache_engine[virtual_engine].swap_out(
                worker_input.blocks_to_swap_out)
337
338
        if (worker_input.blocks_to_copy is not None
                and worker_input.blocks_to_copy.numel() > 0):
339
            self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
340

341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
    def _get_cached_seq_group_metadata(
            self,
            seq_group_metadata_list: List[Union[SequenceGroupMetadata,
                                                SequenceGroupMetadataDelta]],
            finished_request_ids: List[str]) -> List[SequenceGroupMetadata]:
        """Return a list of cached Sequence Group Metadata after updating its
        state.

        It is used because scheduler only sends delta to workers to reduce
        the data payload size. The function also cleans up cache based on
        a given `finished_request_ids`.
        """
        new_seq_group_metadata_list = []
        for metadata_or_delta in seq_group_metadata_list:
            request_id = metadata_or_delta.request_id
            if request_id not in self._seq_group_metadata_cache:
                # The first prefill.
                assert isinstance(metadata_or_delta, SequenceGroupMetadata)
                self._seq_group_metadata_cache[request_id] = metadata_or_delta
            else:
                # The first prefill is already cached.
                if isinstance(metadata_or_delta, SequenceGroupMetadataDelta):
                    self._seq_group_metadata_cache[request_id].apply_delta(
                        metadata_or_delta)
                else:
                    # If metadata snapshot is sent again, it is
                    # preempted. Reset the cache because we need to start
                    # from scratch.
                    assert isinstance(metadata_or_delta, SequenceGroupMetadata)
                    self._seq_group_metadata_cache[
                        request_id] = metadata_or_delta

            new_seq_group_metadata_list.append(
                self._seq_group_metadata_cache[request_id])

        # Clean up finished ids
        for finished_id in finished_request_ids:
            del self._seq_group_metadata_cache[finished_id]

        return new_seq_group_metadata_list

    def _execute_model_spmd(
        self,
        execute_model_req: ExecuteModelRequest,
        intermediate_tensors: Optional[IntermediateTensors] = None,
    ) -> Optional[List[SamplerOutput]]:
        if execute_model_req is not None:
            new_seq_group_metadata_list = self._get_cached_seq_group_metadata(
                execute_model_req.seq_group_metadata_list,
                execute_model_req.finished_requests_ids)

            execute_model_req.seq_group_metadata_list = (
                new_seq_group_metadata_list)
        output = super()._execute_model_spmd(execute_model_req,
                                             intermediate_tensors)
        return output

398
399
400
401
402
403
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

404
405
406
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

407
408
409
    def list_loras(self) -> Set[int]:
        return self.model_runner.list_loras()

410
411
412
413
414
415
416
417
418
419
420
421
422
    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
        return self.model_runner.add_prompt_adapter(prompt_adapter_request)

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        return self.model_runner.remove_lora(prompt_adapter_id)

    def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        return self.model_runner.pin_prompt_adapter(prompt_adapter_id)

    def list_prompt_adapters(self) -> Set[int]:
        return self.model_runner.list_prompt_adapters()

423
424
425
426
427
428
429
430
    @property
    def max_model_len(self) -> int:
        return self.model_config.max_model_len

    @property
    def vocab_size(self) -> int:
        return self.model_runner.vocab_size

431
    def get_cache_block_size_bytes(self) -> int:
432
433
        """Get the size of the KV cache block size in bytes.
        """
434
        return CacheEngine.get_cache_block_size(self.cache_config,
435
436
437
                                                self.model_config,
                                                self.parallel_config)

Woosuk Kwon's avatar
Woosuk Kwon committed
438

439
def init_worker_distributed_environment(
440
441
    parallel_config: ParallelConfig,
    rank: int,
442
    distributed_init_method: Optional[str] = None,
443
    local_rank: int = -1,
444
445
) -> None:
    """Initialize the distributed environment."""
446
447
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

448
449
    init_distributed_environment(parallel_config.world_size, rank,
                                 distributed_init_method, local_rank)
450

451
452
453
    ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
                                      parallel_config.pipeline_parallel_size)

454

455
456
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
    # Check if the GPU supports the dtype.
457
458
459
    if torch_dtype == torch.bfloat16:  # noqa: SIM102
        if not current_platform.has_device_capability(80):
            capability = current_platform.get_device_capability()
460
            gpu_name = current_platform.get_device_name()
461
462
463
464
465
466
467

            if capability is None:
                compute_str = "does not have a compute capability"
            else:
                version_str = capability.as_version_str()
                compute_str = f"has compute capability {version_str}"

468
469
            raise ValueError(
                "Bfloat16 is only supported on GPUs with compute capability "
470
                f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
Woosuk Kwon's avatar
Woosuk Kwon committed
471
472
                "You can use float16 instead by explicitly setting the"
                "`dtype` flag in CLI, for example: --dtype=half.")
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488


def raise_if_cache_size_invalid(num_gpu_blocks, block_size,
                                max_model_len) -> None:
    if num_gpu_blocks <= 0:
        raise ValueError("No available memory for the cache blocks. "
                         "Try increasing `gpu_memory_utilization` when "
                         "initializing the engine.")
    max_seq_len = block_size * num_gpu_blocks
    if max_model_len > max_seq_len:
        raise ValueError(
            f"The model's max seq len ({max_model_len}) "
            "is larger than the maximum number of tokens that can be "
            f"stored in KV cache ({max_seq_len}). Try increasing "
            "`gpu_memory_utilization` or decreasing `max_model_len` when "
            "initializing the engine.")