worker.py 16 KB
Newer Older
1
"""A GPU worker class."""
2
import gc
3
import os
4
from typing import List, Optional, Set, Tuple, Type
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6

import torch
7
import torch.distributed
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
10
                         ModelConfig, MultiModalConfig, ParallelConfig,
11
12
                         PromptAdapterConfig, SchedulerConfig,
                         SpeculativeConfig)
13
from vllm.distributed import (ensure_model_parallel_initialized,
14
15
                              init_distributed_environment,
                              set_custom_all_reduce)
16
from vllm.lora.request import LoRARequest
17
from vllm.model_executor import set_random_seed
18
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
19
from vllm.platforms import current_platform
20
from vllm.prompt_adapter.request import PromptAdapterRequest
21
from vllm.sequence import ExecuteModelRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
22
from vllm.worker.cache_engine import CacheEngine
23
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
24
25
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput
Woosuk Kwon's avatar
Woosuk Kwon committed
26

27

28
class Worker(LocalOrDistributedWorkerBase):
29
30
31
32
33
34
    """A worker class that executes (a partition of) the model on a GPU.

    Each worker is associated with a single GPU. The worker is responsible for
    maintaining the KV cache and executing the model on the GPU. In case of
    distributed inference, each worker is assigned a partition of the model.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
35
36
37

    def __init__(
        self,
38
39
40
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
41
        device_config: DeviceConfig,
42
        cache_config: CacheConfig,
43
        load_config: LoadConfig,
44
45
46
        local_rank: int,
        rank: int,
        distributed_init_method: str,
47
        lora_config: Optional[LoRAConfig] = None,
48
        multimodal_config: Optional[MultiModalConfig] = None,
49
        speculative_config: Optional[SpeculativeConfig] = None,
50
        prompt_adapter_config: Optional[PromptAdapterConfig] = None,
51
        is_driver_worker: bool = False,
52
        model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
53
    ) -> None:
54
55
        self.model_config = model_config
        self.parallel_config = parallel_config
56
        self.parallel_config.rank = rank
57
        self.scheduler_config = scheduler_config
58
        self.device_config = device_config
59
        self.cache_config = cache_config
60
        self.local_rank = local_rank
61
62
        self.rank = rank
        self.distributed_init_method = distributed_init_method
63
        self.lora_config = lora_config
64
        self.load_config = load_config
65
        self.prompt_adapter_config = prompt_adapter_config
66
        self.is_driver_worker = is_driver_worker
67
68
69
        if parallel_config and is_driver_worker:
            assert rank % parallel_config.tensor_parallel_size == 0, \
                   "Driver worker should be rank 0 of tensor parallel group."
70
71
72
73
        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
            init_cached_hf_modules()
74
        self.multimodal_config = multimodal_config
75

76
77
78
79
80
        # Return hidden states from target model if the draft model is an
        # mlp_speculator
        speculative_args = {} if speculative_config is None \
            or (speculative_config.draft_model_config.model ==
                model_config.model) \
81
82
83
            or (speculative_config.draft_model_config.hf_config.model_type
                not in ["medusa", "mlp_speculator"]) \
                    else {"return_hidden_states": True}
84

85
        ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
86
87
88
        if model_runner_cls is not None:
            ModelRunnerClass = model_runner_cls
        elif self.model_config.embedding_mode:
89
90
            ModelRunnerClass = EmbeddingModelRunner
        self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
91
92
93
94
            model_config,
            parallel_config,
            scheduler_config,
            device_config,
95
            cache_config,
96
            load_config=load_config,
97
            lora_config=self.lora_config,
98
            kv_cache_dtype=self.cache_config.cache_dtype,
99
            is_driver_worker=is_driver_worker,
100
            prompt_adapter_config=prompt_adapter_config,
101
            multimodal_config=multimodal_config,
102
            **speculative_args,
103
        )
104
        # Uninitialized cache engine. Will be initialized by
105
        # initialize_cache.
106
        self.cache_engine: List[CacheEngine]
107
        # Initialize gpu_cache as embedding models don't initialize kv_caches
108
        self.gpu_cache: Optional[List[List[torch.tensor]]] = None
109

110
    def init_device(self) -> None:
111
112
113
114
115
116
117
118
        if self.device_config.device.type == "cuda":
            # torch.distributed.all_reduce does not free the input tensor until
            # the synchronization point. This causes the memory usage to grow
            # as the number of all_reduce calls increases. This env var disables
            # this behavior.
            # Related issue:
            # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
            os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
119

120
121
122
123
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
            self.device = torch.device(f"cuda:{self.local_rank}")
            torch.cuda.set_device(self.device)
124

125
            _check_if_gpu_supports_dtype(self.model_config.dtype)
126
127
            torch.cuda.empty_cache()
            self.init_gpu_memory = torch.cuda.mem_get_info()[0]
128
129
130
        else:
            raise RuntimeError(
                f"Not support device type: {self.device_config.device}")
131
        # Initialize the distributed environment.
132
133
134
        init_worker_distributed_environment(self.parallel_config, self.rank,
                                            self.distributed_init_method,
                                            self.local_rank)
135
        # Set random seed.
136
        set_random_seed(self.model_config.seed)
137
138

    def load_model(self):
139
        self.model_runner.load_model()
140

141
142
143
144
145
146
147
148
149
150
151
152
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        self.model_runner.save_sharded_state(
            path,
            pattern=pattern,
            max_size=max_size,
        )

153
154
155
156
157
158
159
    def save_tensorized_model(
        self,
        tensorizer_config: TensorizerConfig,
    ) -> None:
        self.model_runner.save_tensorized_model(
            tensorizer_config=tensorizer_config, )

160
    @torch.inference_mode()
161
162
163
164
165
166
167
168
169
170
171
    def determine_num_available_blocks(self) -> Tuple[int, int]:
        """Profiles the peak memory usage of the model to determine how many
        KV blocks may be allocated without OOMs.

        The engine will first conduct a profiling of the existing memory usage.
        Then, it calculate the maximum possible number of GPU and CPU blocks
        that can be allocated with the remaining free memory.

        .. tip::
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
172
        """
173
174
175
176
        # Profile the memory usage of the model and get the maximum number of
        # cache blocks that can be allocated with the remaining free memory.
        torch.cuda.empty_cache()

177
178
179
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
        self.model_runner.profile_run()
180
181
182
183

        # Calculate the number of blocks that can be allocated with the
        # profiled peak memory.
        torch.cuda.synchronize()
184
        free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
185
186
187
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
        peak_memory = self.init_gpu_memory - free_gpu_memory
188
189
190
        assert peak_memory > 0, (
            "Error in memory profiling. This happens when the GPU memory was "
            "not properly cleaned up before initializing the vLLM instance.")
191

192
        cache_block_size = self.get_cache_block_size_bytes()
193
        num_gpu_blocks = int(
194
195
196
197
            (total_gpu_memory * self.cache_config.gpu_memory_utilization -
             peak_memory) // cache_block_size)
        num_cpu_blocks = int(self.cache_config.swap_space_bytes //
                             cache_block_size)
198
199
        num_gpu_blocks = max(num_gpu_blocks, 0)
        num_cpu_blocks = max(num_cpu_blocks, 0)
200
201
202
        if self.model_runner.lora_manager:
            self.model_runner.remove_all_loras()
        gc.collect()
203
204
205
        torch.cuda.empty_cache()
        return num_gpu_blocks, num_cpu_blocks

206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Allocate GPU and CPU KV cache with the specified number of blocks.

        This also warms up the model, which may record CUDA graphs.
        """
        raise_if_cache_size_invalid(num_gpu_blocks,
                                    self.cache_config.block_size,
                                    self.model_config.max_model_len)

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self._init_cache_engine()
        self._warm_up_model()

    def _init_cache_engine(self):
        assert self.cache_config.num_gpu_blocks is not None
224
225
226
227
228
229
230
231
232
        self.cache_engine = [
            CacheEngine(self.cache_config, self.model_config,
                        self.parallel_config, self.device_config)
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]
        self.gpu_cache = [
            self.cache_engine[ve].gpu_cache
            for ve in range(self.parallel_config.pipeline_parallel_size)
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
233

234
    def _warm_up_model(self) -> None:
235
236
237
238
239
240
        if not self.model_config.enforce_eager:
            self.model_runner.capture_model(self.gpu_cache)
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

241
242
243
244
245
    @property
    def do_metadata_broadcast(self) -> bool:
        return self.parallel_config.tensor_parallel_size > 1

    @property
246
    def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
247
        return self.gpu_cache
248
249

    @torch.inference_mode()
250
251
    def prepare_worker_input(
            self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
252
        virtual_engine = execute_model_req.virtual_engine
253
        num_seq_groups = len(execute_model_req.seq_group_metadata_list)
254
255
256
257
258
259
260
        # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
        # they contain parameters to launch cudamemcpyasync.
        blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
                                         device="cpu",
                                         dtype=torch.int64).view(-1, 2)
        blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
                                          device="cpu",
261
                                          dtype=torch.int64).view(-1, 2)
262
263
264
265
266
267
        # `blocks_to_copy` is a gpu tensor. The src and tgt of
        # blocks to copy are in the same device, and `blocks_to_copy`
        # can be used directly within cuda kernels.
        blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
                                      device=self.device,
                                      dtype=torch.int64).view(-1, 2)
Woosuk Kwon's avatar
Woosuk Kwon committed
268

269
270
271
272
273
        return WorkerInput(
            num_seq_groups=num_seq_groups,
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy,
274
            virtual_engine=virtual_engine,
275
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
276

277
    @torch.inference_mode()
278
    def execute_worker(self, worker_input: WorkerInput) -> None:
279
        virtual_engine = worker_input.virtual_engine
280
281
282
        # Issue cache operations.
        if (worker_input.blocks_to_swap_in is not None
                and worker_input.blocks_to_swap_in.numel() > 0):
283
284
            self.cache_engine[virtual_engine].swap_in(
                worker_input.blocks_to_swap_in)
285
286
        if (worker_input.blocks_to_swap_out is not None
                and worker_input.blocks_to_swap_out.numel() > 0):
287
288
            self.cache_engine[virtual_engine].swap_out(
                worker_input.blocks_to_swap_out)
289
290
        if (worker_input.blocks_to_copy is not None
                and worker_input.blocks_to_copy.numel() > 0):
291
            self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
292

293
294
295
296
297
298
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

299
300
301
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

302
303
304
    def list_loras(self) -> Set[int]:
        return self.model_runner.list_loras()

305
306
307
308
309
310
311
312
313
314
315
316
317
    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
        return self.model_runner.add_prompt_adapter(prompt_adapter_request)

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        return self.model_runner.remove_lora(prompt_adapter_id)

    def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        return self.model_runner.pin_prompt_adapter(prompt_adapter_id)

    def list_prompt_adapters(self) -> Set[int]:
        return self.model_runner.list_prompt_adapters()

318
319
320
321
322
323
324
325
    @property
    def max_model_len(self) -> int:
        return self.model_config.max_model_len

    @property
    def vocab_size(self) -> int:
        return self.model_runner.vocab_size

326
    def get_cache_block_size_bytes(self) -> int:
327
328
        """Get the size of the KV cache block size in bytes.
        """
329
        return CacheEngine.get_cache_block_size(self.cache_config,
330
331
332
                                                self.model_config,
                                                self.parallel_config)

Woosuk Kwon's avatar
Woosuk Kwon committed
333

334
def init_worker_distributed_environment(
335
336
    parallel_config: ParallelConfig,
    rank: int,
337
    distributed_init_method: Optional[str] = None,
338
    local_rank: int = -1,
339
340
) -> None:
    """Initialize the distributed environment."""
341
342
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

343
344
    init_distributed_environment(parallel_config.world_size, rank,
                                 distributed_init_method, local_rank)
345

346
347
348
    ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
                                      parallel_config.pipeline_parallel_size)

349

350
351
352
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
    # Check if the GPU supports the dtype.
    if torch_dtype == torch.bfloat16:
353
        compute_capability = current_platform.get_device_capability()
354
355
356
357
358
        if compute_capability[0] < 8:
            gpu_name = torch.cuda.get_device_name()
            raise ValueError(
                "Bfloat16 is only supported on GPUs with compute capability "
                f"of at least 8.0. Your {gpu_name} GPU has compute capability "
Woosuk Kwon's avatar
Woosuk Kwon committed
359
360
361
                f"{compute_capability[0]}.{compute_capability[1]}. "
                "You can use float16 instead by explicitly setting the"
                "`dtype` flag in CLI, for example: --dtype=half.")
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377


def raise_if_cache_size_invalid(num_gpu_blocks, block_size,
                                max_model_len) -> None:
    if num_gpu_blocks <= 0:
        raise ValueError("No available memory for the cache blocks. "
                         "Try increasing `gpu_memory_utilization` when "
                         "initializing the engine.")
    max_seq_len = block_size * num_gpu_blocks
    if max_model_len > max_seq_len:
        raise ValueError(
            f"The model's max seq len ({max_model_len}) "
            "is larger than the maximum number of tokens that can be "
            f"stored in KV cache ({max_seq_len}). Try increasing "
            "`gpu_memory_utilization` or decreasing `max_model_len` when "
            "initializing the engine.")