worker.py 14.8 KB
Newer Older
1
"""A GPU worker class."""
2
import gc
3
import os
4
from typing import List, Optional, Set, Tuple, Type
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6

import torch
7
import torch.distributed
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
10
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
                         ModelConfig, ParallelConfig, SchedulerConfig,
11
                         SpeculativeConfig, VisionLanguageConfig)
12
from vllm.distributed import (ensure_model_parallel_initialized,
13
14
                              init_distributed_environment,
                              set_custom_all_reduce)
15
from vllm.lora.request import LoRARequest
16
from vllm.model_executor import set_random_seed
17
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
18
from vllm.sequence import ExecuteModelRequest
19
from vllm.utils import get_device_capability_stateless
Woosuk Kwon's avatar
Woosuk Kwon committed
20
from vllm.worker.cache_engine import CacheEngine
21
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
22
23
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput
Woosuk Kwon's avatar
Woosuk Kwon committed
24

25

26
class Worker(LocalOrDistributedWorkerBase):
27
28
29
30
31
32
    """A worker class that executes (a partition of) the model on a GPU.

    Each worker is associated with a single GPU. The worker is responsible for
    maintaining the KV cache and executing the model on the GPU. In case of
    distributed inference, each worker is assigned a partition of the model.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
33
34
35

    def __init__(
        self,
36
37
38
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
39
        device_config: DeviceConfig,
40
        cache_config: CacheConfig,
41
        load_config: LoadConfig,
42
43
44
        local_rank: int,
        rank: int,
        distributed_init_method: str,
45
        lora_config: Optional[LoRAConfig] = None,
46
        vision_language_config: Optional[VisionLanguageConfig] = None,
47
        speculative_config: Optional[SpeculativeConfig] = None,
48
        is_driver_worker: bool = False,
49
        model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
50
    ) -> None:
51
52
53
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
54
        self.device_config = device_config
55
        self.cache_config = cache_config
56
        self.local_rank = local_rank
57
58
        self.rank = rank
        self.distributed_init_method = distributed_init_method
59
        self.lora_config = lora_config
60
        self.load_config = load_config
61
62
63
        self.is_driver_worker = is_driver_worker
        if self.is_driver_worker:
            assert self.rank == 0, "The driver worker must have rank 0."
64

65
66
67
68
        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
            init_cached_hf_modules()
69
70
71
72
73
        self.vision_language_config = vision_language_config
        if self.vision_language_config:
            assert not self.lora_config, (
                "To be tested: vision language model with LoRA settings.")

74
75
76
77
78
79
80
81
        # Return hidden states from target model if the draft model is an
        # mlp_speculator
        speculative_args = {} if speculative_config is None \
            or (speculative_config.draft_model_config.model ==
                model_config.model) \
              or (speculative_config.draft_model_config.hf_config.model_type !=
                  "mlp_speculator") else {"return_hidden_states": True}

82
        ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
83
84
85
        if model_runner_cls is not None:
            ModelRunnerClass = model_runner_cls
        elif self.model_config.embedding_mode:
86
87
            ModelRunnerClass = EmbeddingModelRunner
        self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
88
89
90
91
            model_config,
            parallel_config,
            scheduler_config,
            device_config,
92
            cache_config,
93
            load_config=load_config,
94
            lora_config=self.lora_config,
95
            kv_cache_dtype=self.cache_config.cache_dtype,
96
            is_driver_worker=is_driver_worker,
97
            vision_language_config=vision_language_config,
98
            **speculative_args,
99
        )
100
        # Uninitialized cache engine. Will be initialized by
101
        # initialize_cache.
102
        self.cache_engine: CacheEngine
103
104
        # Initialize gpu_cache as embedding models don't initialize kv_caches
        self.gpu_cache: Optional[List[torch.tensor]] = None
105

106
    def init_device(self) -> None:
107
108
109
110
111
112
113
114
        if self.device_config.device.type == "cuda":
            # torch.distributed.all_reduce does not free the input tensor until
            # the synchronization point. This causes the memory usage to grow
            # as the number of all_reduce calls increases. This env var disables
            # this behavior.
            # Related issue:
            # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
            os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
115

116
117
118
119
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
            self.device = torch.device(f"cuda:{self.local_rank}")
            torch.cuda.set_device(self.device)
120

121
            _check_if_gpu_supports_dtype(self.model_config.dtype)
122
123
            torch.cuda.empty_cache()
            self.init_gpu_memory = torch.cuda.mem_get_info()[0]
124
125
126
        else:
            raise RuntimeError(
                f"Not support device type: {self.device_config.device}")
127
        # Initialize the distributed environment.
128
129
130
        init_worker_distributed_environment(self.parallel_config, self.rank,
                                            self.distributed_init_method,
                                            self.local_rank)
131
        # Set random seed.
132
        set_random_seed(self.model_config.seed)
133
134

    def load_model(self):
135
        self.model_runner.load_model()
136

137
138
139
140
141
142
143
144
145
146
147
148
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        self.model_runner.save_sharded_state(
            path,
            pattern=pattern,
            max_size=max_size,
        )

149
150
151
152
153
154
155
    def save_tensorized_model(
        self,
        tensorizer_config: TensorizerConfig,
    ) -> None:
        self.model_runner.save_tensorized_model(
            tensorizer_config=tensorizer_config, )

156
    @torch.inference_mode()
157
158
159
160
161
162
163
164
165
166
167
    def determine_num_available_blocks(self) -> Tuple[int, int]:
        """Profiles the peak memory usage of the model to determine how many
        KV blocks may be allocated without OOMs.

        The engine will first conduct a profiling of the existing memory usage.
        Then, it calculate the maximum possible number of GPU and CPU blocks
        that can be allocated with the remaining free memory.

        .. tip::
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
168
        """
169
170
171
172
        # Profile the memory usage of the model and get the maximum number of
        # cache blocks that can be allocated with the remaining free memory.
        torch.cuda.empty_cache()

173
174
175
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
        self.model_runner.profile_run()
176
177
178
179

        # Calculate the number of blocks that can be allocated with the
        # profiled peak memory.
        torch.cuda.synchronize()
180
        free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
181
182
183
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
        peak_memory = self.init_gpu_memory - free_gpu_memory
184
185
186
        assert peak_memory > 0, (
            "Error in memory profiling. This happens when the GPU memory was "
            "not properly cleaned up before initializing the vLLM instance.")
187

188
        cache_block_size = self.get_cache_block_size_bytes()
189
        num_gpu_blocks = int(
190
191
192
193
            (total_gpu_memory * self.cache_config.gpu_memory_utilization -
             peak_memory) // cache_block_size)
        num_cpu_blocks = int(self.cache_config.swap_space_bytes //
                             cache_block_size)
194
195
        num_gpu_blocks = max(num_gpu_blocks, 0)
        num_cpu_blocks = max(num_cpu_blocks, 0)
196
197
198
        if self.model_runner.lora_manager:
            self.model_runner.remove_all_loras()
        gc.collect()
199
200
201
        torch.cuda.empty_cache()
        return num_gpu_blocks, num_cpu_blocks

202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Allocate GPU and CPU KV cache with the specified number of blocks.

        This also warms up the model, which may record CUDA graphs.
        """
        raise_if_cache_size_invalid(num_gpu_blocks,
                                    self.cache_config.block_size,
                                    self.model_config.max_model_len)

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self._init_cache_engine()
        self._warm_up_model()

    def _init_cache_engine(self):
        assert self.cache_config.num_gpu_blocks is not None
220
        self.cache_engine = CacheEngine(self.cache_config, self.model_config,
221
222
                                        self.parallel_config,
                                        self.device_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
223
224
        self.gpu_cache = self.cache_engine.gpu_cache

225
    def _warm_up_model(self) -> None:
226
227
228
229
230
231
        if not self.model_config.enforce_eager:
            self.model_runner.capture_model(self.gpu_cache)
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

232
233
234
235
236
237
238
    @property
    def do_metadata_broadcast(self) -> bool:
        return self.parallel_config.tensor_parallel_size > 1

    @property
    def kv_cache(self) -> Optional[List[torch.Tensor]]:
        return self.gpu_cache
239
240

    @torch.inference_mode()
241
242
243
    def prepare_worker_input(
            self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
        num_seq_groups = len(execute_model_req.seq_group_metadata_list)
244
245
246
247
248
249
250
        # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
        # they contain parameters to launch cudamemcpyasync.
        blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
                                         device="cpu",
                                         dtype=torch.int64).view(-1, 2)
        blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
                                          device="cpu",
251
                                          dtype=torch.int64).view(-1, 2)
252
253
254
255
256
257
        # `blocks_to_copy` is a gpu tensor. The src and tgt of
        # blocks to copy are in the same device, and `blocks_to_copy`
        # can be used directly within cuda kernels.
        blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
                                      device=self.device,
                                      dtype=torch.int64).view(-1, 2)
Woosuk Kwon's avatar
Woosuk Kwon committed
258

259
260
261
262
263
264
        return WorkerInput(
            num_seq_groups=num_seq_groups,
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
265

266
    @torch.inference_mode()
267
268
269
270
271
272
273
274
275
276
277
    def execute_worker(self, worker_input: WorkerInput) -> None:
        # Issue cache operations.
        if (worker_input.blocks_to_swap_in is not None
                and worker_input.blocks_to_swap_in.numel() > 0):
            self.cache_engine.swap_in(worker_input.blocks_to_swap_in)
        if (worker_input.blocks_to_swap_out is not None
                and worker_input.blocks_to_swap_out.numel() > 0):
            self.cache_engine.swap_out(worker_input.blocks_to_swap_out)
        if (worker_input.blocks_to_copy is not None
                and worker_input.blocks_to_copy.numel() > 0):
            self.cache_engine.copy(worker_input.blocks_to_copy)
278

279
280
281
282
283
284
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

285
286
287
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

288
289
290
    def list_loras(self) -> Set[int]:
        return self.model_runner.list_loras()

291
292
293
294
295
296
297
298
    @property
    def max_model_len(self) -> int:
        return self.model_config.max_model_len

    @property
    def vocab_size(self) -> int:
        return self.model_runner.vocab_size

299
    def get_cache_block_size_bytes(self) -> int:
300
301
        """Get the size of the KV cache block size in bytes.
        """
302
        return CacheEngine.get_cache_block_size(self.cache_config,
303
304
305
                                                self.model_config,
                                                self.parallel_config)

Woosuk Kwon's avatar
Woosuk Kwon committed
306

307
def init_worker_distributed_environment(
308
309
    parallel_config: ParallelConfig,
    rank: int,
310
    distributed_init_method: Optional[str] = None,
311
    local_rank: int = -1,
312
313
) -> None:
    """Initialize the distributed environment."""
314
315
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

316
317
    init_distributed_environment(parallel_config.world_size, rank,
                                 distributed_init_method, local_rank)
318

319
320
321
    ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
                                      parallel_config.pipeline_parallel_size)

322

323
324
325
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
    # Check if the GPU supports the dtype.
    if torch_dtype == torch.bfloat16:
326
        compute_capability = get_device_capability_stateless()
327
328
329
330
331
        if compute_capability[0] < 8:
            gpu_name = torch.cuda.get_device_name()
            raise ValueError(
                "Bfloat16 is only supported on GPUs with compute capability "
                f"of at least 8.0. Your {gpu_name} GPU has compute capability "
Woosuk Kwon's avatar
Woosuk Kwon committed
332
333
334
                f"{compute_capability[0]}.{compute_capability[1]}. "
                "You can use float16 instead by explicitly setting the"
                "`dtype` flag in CLI, for example: --dtype=half.")
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350


def raise_if_cache_size_invalid(num_gpu_blocks, block_size,
                                max_model_len) -> None:
    if num_gpu_blocks <= 0:
        raise ValueError("No available memory for the cache blocks. "
                         "Try increasing `gpu_memory_utilization` when "
                         "initializing the engine.")
    max_seq_len = block_size * num_gpu_blocks
    if max_model_len > max_seq_len:
        raise ValueError(
            f"The model's max seq len ({max_model_len}) "
            "is larger than the maximum number of tokens that can be "
            f"stored in KV cache ({max_seq_len}). Try increasing "
            "`gpu_memory_utilization` or decreasing `max_model_len` when "
            "initializing the engine.")