worker.py 16.5 KB
Newer Older
1
"""A GPU worker class."""
2
import gc
3
import os
4
from typing import Any, Dict, List, Optional, Set, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6

import torch
7
import torch.distributed
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
10
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
                         ModelConfig, ParallelConfig, SchedulerConfig,
11
                         SpeculativeConfig, VisionLanguageConfig)
12
13
from vllm.distributed import (broadcast_tensor_dict,
                              ensure_model_parallel_initialized,
14
15
                              init_distributed_environment,
                              set_custom_all_reduce)
16
from vllm.lora.request import LoRARequest
17
from vllm.model_executor import set_random_seed
18
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
19
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
20
from vllm.worker.cache_engine import CacheEngine
21
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
22
from vllm.worker.model_runner import ModelRunner
23
from vllm.worker.worker_base import WorkerBase
Woosuk Kwon's avatar
Woosuk Kwon committed
24

25

26
class Worker(WorkerBase):
27
28
29
30
31
32
    """A worker class that executes (a partition of) the model on a GPU.

    Each worker is associated with a single GPU. The worker is responsible for
    maintaining the KV cache and executing the model on the GPU. In case of
    distributed inference, each worker is assigned a partition of the model.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
33
34
35

    def __init__(
        self,
36
37
38
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
39
        device_config: DeviceConfig,
40
        cache_config: CacheConfig,
41
        load_config: LoadConfig,
42
43
44
        local_rank: int,
        rank: int,
        distributed_init_method: str,
45
        lora_config: Optional[LoRAConfig] = None,
46
        vision_language_config: Optional[VisionLanguageConfig] = None,
47
        speculative_config: Optional[SpeculativeConfig] = None,
48
        is_driver_worker: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
49
    ) -> None:
50
51
52
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
53
        self.device_config = device_config
54
        self.cache_config = cache_config
55
        self.local_rank = local_rank
56
57
        self.rank = rank
        self.distributed_init_method = distributed_init_method
58
        self.lora_config = lora_config
59
        self.load_config = load_config
60
61
62
        self.is_driver_worker = is_driver_worker
        if self.is_driver_worker:
            assert self.rank == 0, "The driver worker must have rank 0."
63

64
65
66
67
        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
            init_cached_hf_modules()
68
69
70
71
72
        self.vision_language_config = vision_language_config
        if self.vision_language_config:
            assert not self.lora_config, (
                "To be tested: vision language model with LoRA settings.")

73
74
75
76
77
78
79
80
        # Return hidden states from target model if the draft model is an
        # mlp_speculator
        speculative_args = {} if speculative_config is None \
            or (speculative_config.draft_model_config.model ==
                model_config.model) \
              or (speculative_config.draft_model_config.hf_config.model_type !=
                  "mlp_speculator") else {"return_hidden_states": True}

81
82
83
        ModelRunnerClass = (EmbeddingModelRunner if
                            self.model_config.embedding_mode else ModelRunner)
        self.model_runner = ModelRunnerClass(
84
85
86
87
            model_config,
            parallel_config,
            scheduler_config,
            device_config,
88
            cache_config,
89
            load_config=load_config,
90
            lora_config=self.lora_config,
91
            kv_cache_dtype=self.cache_config.cache_dtype,
92
            is_driver_worker=is_driver_worker,
93
            vision_language_config=vision_language_config,
94
            **speculative_args,
95
        )
96
        # Uninitialized cache engine. Will be initialized by
97
        # initialize_cache.
98
        self.cache_engine: CacheEngine
99
100
        # Initialize gpu_cache as embedding models don't initialize kv_caches
        self.gpu_cache: Optional[List[torch.tensor]] = None
101

102
    def init_device(self) -> None:
103
104
105
106
107
108
109
110
        if self.device_config.device.type == "cuda":
            # torch.distributed.all_reduce does not free the input tensor until
            # the synchronization point. This causes the memory usage to grow
            # as the number of all_reduce calls increases. This env var disables
            # this behavior.
            # Related issue:
            # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
            os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
111

112
113
114
115
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
            self.device = torch.device(f"cuda:{self.local_rank}")
            torch.cuda.set_device(self.device)
116

117
            _check_if_gpu_supports_dtype(self.model_config.dtype)
118
119
            torch.cuda.empty_cache()
            self.init_gpu_memory = torch.cuda.mem_get_info()[0]
120
121
122
        else:
            raise RuntimeError(
                f"Not support device type: {self.device_config.device}")
123
        # Initialize the distributed environment.
124
125
126
        init_worker_distributed_environment(self.parallel_config, self.rank,
                                            self.distributed_init_method,
                                            self.local_rank)
127
        # Set random seed.
128
        set_random_seed(self.model_config.seed)
129
130

    def load_model(self):
131
        self.model_runner.load_model()
132

133
134
135
136
137
138
139
140
141
142
143
144
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        self.model_runner.save_sharded_state(
            path,
            pattern=pattern,
            max_size=max_size,
        )

145
146
147
148
149
150
151
    def save_tensorized_model(
        self,
        tensorizer_config: TensorizerConfig,
    ) -> None:
        self.model_runner.save_tensorized_model(
            tensorizer_config=tensorizer_config, )

152
    @torch.inference_mode()
153
154
155
156
157
158
159
160
161
162
163
    def determine_num_available_blocks(self) -> Tuple[int, int]:
        """Profiles the peak memory usage of the model to determine how many
        KV blocks may be allocated without OOMs.

        The engine will first conduct a profiling of the existing memory usage.
        Then, it calculate the maximum possible number of GPU and CPU blocks
        that can be allocated with the remaining free memory.

        .. tip::
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
164
        """
165
166
167
168
        # Profile the memory usage of the model and get the maximum number of
        # cache blocks that can be allocated with the remaining free memory.
        torch.cuda.empty_cache()

169
170
171
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
        self.model_runner.profile_run()
172
173
174
175

        # Calculate the number of blocks that can be allocated with the
        # profiled peak memory.
        torch.cuda.synchronize()
176
        free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
177
178
179
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
        peak_memory = self.init_gpu_memory - free_gpu_memory
180
181
182
        assert peak_memory > 0, (
            "Error in memory profiling. This happens when the GPU memory was "
            "not properly cleaned up before initializing the vLLM instance.")
183

184
        cache_block_size = self.get_cache_block_size_bytes()
185
        num_gpu_blocks = int(
186
187
188
189
            (total_gpu_memory * self.cache_config.gpu_memory_utilization -
             peak_memory) // cache_block_size)
        num_cpu_blocks = int(self.cache_config.swap_space_bytes //
                             cache_block_size)
190
191
        num_gpu_blocks = max(num_gpu_blocks, 0)
        num_cpu_blocks = max(num_cpu_blocks, 0)
192
193
194
        if self.model_runner.lora_manager:
            self.model_runner.remove_all_loras()
        gc.collect()
195
196
197
        torch.cuda.empty_cache()
        return num_gpu_blocks, num_cpu_blocks

198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Allocate GPU and CPU KV cache with the specified number of blocks.

        This also warms up the model, which may record CUDA graphs.
        """
        raise_if_cache_size_invalid(num_gpu_blocks,
                                    self.cache_config.block_size,
                                    self.model_config.max_model_len)

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self._init_cache_engine()
        self._warm_up_model()

    def _init_cache_engine(self):
        assert self.cache_config.num_gpu_blocks is not None
216
        self.cache_engine = CacheEngine(self.cache_config, self.model_config,
217
218
                                        self.parallel_config,
                                        self.device_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
219
220
        self.gpu_cache = self.cache_engine.gpu_cache

221
    def _warm_up_model(self) -> None:
222
223
224
225
226
227
        if not self.model_config.enforce_eager:
            self.model_runner.capture_model(self.gpu_cache)
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

228
    def cache_swap(
Woosuk Kwon's avatar
Woosuk Kwon committed
229
        self,
230
231
        blocks_to_swap_in: torch.Tensor,
        blocks_to_swap_out: torch.Tensor,
232
        blocks_to_copy: torch.Tensor,
233
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
234
        # Issue cache operations.
235
        if blocks_to_swap_in.numel() > 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
236
            self.cache_engine.swap_in(blocks_to_swap_in)
237
        if blocks_to_swap_out.numel() > 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
238
            self.cache_engine.swap_out(blocks_to_swap_out)
239
        if blocks_to_copy.numel() > 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
240
            self.cache_engine.copy(blocks_to_copy)
241
242
243
244

    @torch.inference_mode()
    def execute_model(
        self,
245
        execute_model_req: Optional[ExecuteModelRequest] = None
246
    ) -> List[Union[SamplerOutput, PoolerOutput]]:
247
248
249
        if not self.is_driver_worker:
            self._execute_model_non_driver()
            return []
250

251
        if execute_model_req is None:
252
253
254
255
256
257
258
            # This signals that there's no more requests to process for now.
            # All workers are running infinite loop with broadcast_tensor_dict,
            # and it stops the loop when the driver broadcasts an empty input.
            # Send an empty input to notify all other workers to stop their
            # execution loop.
            broadcast_tensor_dict({}, src=0)
            return []
259

260
261
262
263
264
265
266
267
268
        seq_group_metadata_list = execute_model_req.seq_group_metadata_list
        num_seq_groups = len(seq_group_metadata_list)
        # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
        # they contain parameters to launch cudamemcpyasync.
        blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
                                         device="cpu",
                                         dtype=torch.int64).view(-1, 2)
        blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
                                          device="cpu",
269
                                          dtype=torch.int64).view(-1, 2)
270
271
272
273
274
275
276
277
278
279
280
281
282
        # `blocks_to_copy` is a gpu tensor. The src and tgt of
        # blocks to copy are in the same device, and `blocks_to_copy`
        # can be used directly within cuda kernels.
        blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
                                      device=self.device,
                                      dtype=torch.int64).view(-1, 2)
        data: Dict[str, Any] = {
            "num_seq_groups": num_seq_groups,
            "blocks_to_swap_in": blocks_to_swap_in,
            "blocks_to_swap_out": blocks_to_swap_out,
            "blocks_to_copy": blocks_to_copy,
        }
        broadcast_tensor_dict(data, src=0)
283
284

        self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
285

Woosuk Kwon's avatar
Woosuk Kwon committed
286
        # If there is no input, we don't need to execute the model.
287
        if num_seq_groups == 0:
288
            return []
Woosuk Kwon's avatar
Woosuk Kwon committed
289

290
        output = self.model_runner.execute_model(seq_group_metadata_list,
291
                                                 self.gpu_cache)
292
293
294
295

        # Worker only supports single-step execution. Wrap the output in a list
        # to conform to interface.
        return [output]
Woosuk Kwon's avatar
Woosuk Kwon committed
296

297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
    @torch.inference_mode()
    def start_worker_execution_loop(self) -> None:
        """Execute model loop in parallel worker.

        You can stop the loop by executing a driver worker with an empty output.
        See `stop_remote_worker_execution_loop` for more details.
        """
        while self._execute_model_non_driver():
            pass

    def _execute_model_non_driver(self) -> bool:
        """Execute model in parallel worker.

        Returns True iff there are remaining sequences to process.
        """
        assert not self.is_driver_worker
        data = broadcast_tensor_dict(src=0)
        if not data:
            return False

        num_seq_groups = data.get("num_seq_groups", 0)
        blocks_to_swap_in = data.get("blocks_to_swap_in")
        blocks_to_swap_out = data.get("blocks_to_swap_out")
        blocks_to_copy = data.get("blocks_to_copy")
        self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)

        # If there is no input, we don't need to execute the model.
        if num_seq_groups == 0:
            return False

        self.model_runner.execute_model(None, self.gpu_cache)
        return True

330
331
332
333
334
335
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

336
337
338
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

339
340
341
    def list_loras(self) -> Set[int]:
        return self.model_runner.list_loras()

342
343
344
345
346
347
348
349
    @property
    def max_model_len(self) -> int:
        return self.model_config.max_model_len

    @property
    def vocab_size(self) -> int:
        return self.model_runner.vocab_size

350
    def get_cache_block_size_bytes(self) -> int:
351
352
        """Get the size of the KV cache block size in bytes.
        """
353
        return CacheEngine.get_cache_block_size(self.cache_config,
354
355
356
                                                self.model_config,
                                                self.parallel_config)

Woosuk Kwon's avatar
Woosuk Kwon committed
357

358
def init_worker_distributed_environment(
359
360
    parallel_config: ParallelConfig,
    rank: int,
361
    distributed_init_method: Optional[str] = None,
362
    local_rank: int = -1,
363
364
) -> None:
    """Initialize the distributed environment."""
365
366
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

367
368
    init_distributed_environment(parallel_config.world_size, rank,
                                 distributed_init_method, local_rank)
369

370
371
372
    ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
                                      parallel_config.pipeline_parallel_size)

373

374
375
376
377
378
379
380
381
382
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
    # Check if the GPU supports the dtype.
    if torch_dtype == torch.bfloat16:
        compute_capability = torch.cuda.get_device_capability()
        if compute_capability[0] < 8:
            gpu_name = torch.cuda.get_device_name()
            raise ValueError(
                "Bfloat16 is only supported on GPUs with compute capability "
                f"of at least 8.0. Your {gpu_name} GPU has compute capability "
Woosuk Kwon's avatar
Woosuk Kwon committed
383
384
385
                f"{compute_capability[0]}.{compute_capability[1]}. "
                "You can use float16 instead by explicitly setting the"
                "`dtype` flag in CLI, for example: --dtype=half.")
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401


def raise_if_cache_size_invalid(num_gpu_blocks, block_size,
                                max_model_len) -> None:
    if num_gpu_blocks <= 0:
        raise ValueError("No available memory for the cache blocks. "
                         "Try increasing `gpu_memory_utilization` when "
                         "initializing the engine.")
    max_seq_len = block_size * num_gpu_blocks
    if max_model_len > max_seq_len:
        raise ValueError(
            f"The model's max seq len ({max_model_len}) "
            "is larger than the maximum number of tokens that can be "
            f"stored in KV cache ({max_seq_len}). Try increasing "
            "`gpu_memory_utilization` or decreasing `max_model_len` when "
            "initializing the engine.")