worker.py 14.3 KB
Newer Older
1
"""A GPU worker class."""
2
import gc
3
import os
4
from typing import Any, Dict, List, Optional, Set, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6

import torch
7
import torch.distributed
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
10
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
                         ModelConfig, ParallelConfig, SchedulerConfig,
11
                         VisionLanguageConfig)
12
13
from vllm.distributed import (broadcast_tensor_dict,
                              ensure_model_parallel_initialized,
14
                              get_tensor_model_parallel_cpu_group,
15
16
17
18
                              init_distributed_environment)
from vllm.distributed.device_communicators import pynccl_utils
from vllm.distributed.device_communicators.custom_all_reduce import (
    init_custom_ar)
19
from vllm.lora.request import LoRARequest
20
from vllm.model_executor import set_random_seed
21
from vllm.sequence import ExecuteModelRequest, SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
22
from vllm.worker.cache_engine import CacheEngine
23
from vllm.worker.model_runner import ModelRunner
24
from vllm.worker.worker_base import WorkerBase
Woosuk Kwon's avatar
Woosuk Kwon committed
25

26

27
class Worker(WorkerBase):
28
29
30
31
32
33
    """A worker class that executes (a partition of) the model on a GPU.

    Each worker is associated with a single GPU. The worker is responsible for
    maintaining the KV cache and executing the model on the GPU. In case of
    distributed inference, each worker is assigned a partition of the model.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
34
35
36

    def __init__(
        self,
37
38
39
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
40
        device_config: DeviceConfig,
41
        cache_config: CacheConfig,
42
        load_config: LoadConfig,
43
44
45
        local_rank: int,
        rank: int,
        distributed_init_method: str,
46
        lora_config: Optional[LoRAConfig] = None,
47
        vision_language_config: Optional[VisionLanguageConfig] = None,
48
        is_driver_worker: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
49
    ) -> None:
50
51
52
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
53
        self.device_config = device_config
54
        self.cache_config = cache_config
55
        self.local_rank = local_rank
56
57
        self.rank = rank
        self.distributed_init_method = distributed_init_method
58
        self.lora_config = lora_config
59
        self.load_config = load_config
60
61
62
        self.is_driver_worker = is_driver_worker
        if self.is_driver_worker:
            assert self.rank == 0, "The driver worker must have rank 0."
63

64
65
66
67
        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
            init_cached_hf_modules()
68
69
70
71
72
73
74
75
76
77
        self.vision_language_config = vision_language_config
        if self.vision_language_config:
            assert not self.lora_config, (
                "To be tested: vision language model with LoRA settings.")

        self.model_runner = ModelRunner(
            model_config,
            parallel_config,
            scheduler_config,
            device_config,
78
            load_config=load_config,
79
            lora_config=self.lora_config,
80
            kv_cache_dtype=self.cache_config.cache_dtype,
81
            is_driver_worker=is_driver_worker,
82
83
            vision_language_config=vision_language_config,
        )
84
        # Uninitialized cache engine. Will be initialized by
85
        # initialize_cache.
86
87
        self.cache_engine: CacheEngine
        self.gpu_cache: List[torch.Tensor]
88

89
    def init_device(self) -> None:
90
91
92
93
94
95
96
97
        if self.device_config.device.type == "cuda":
            # torch.distributed.all_reduce does not free the input tensor until
            # the synchronization point. This causes the memory usage to grow
            # as the number of all_reduce calls increases. This env var disables
            # this behavior.
            # Related issue:
            # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
            os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
98

99
100
101
102
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
            self.device = torch.device(f"cuda:{self.local_rank}")
            torch.cuda.set_device(self.device)
103

104
            _check_if_gpu_supports_dtype(self.model_config.dtype)
105
106
            torch.cuda.empty_cache()
            self.init_gpu_memory = torch.cuda.mem_get_info()[0]
107
108
109
        else:
            raise RuntimeError(
                f"Not support device type: {self.device_config.device}")
110
        # Initialize the distributed environment.
111
112
113
        init_worker_distributed_environment(self.parallel_config, self.rank,
                                            self.distributed_init_method,
                                            self.local_rank)
114
        # Set random seed.
115
        set_random_seed(self.model_config.seed)
116
117

    def load_model(self):
118
        self.model_runner.load_model()
119

120
    @torch.inference_mode()
121
122
123
124
125
126
127
128
129
130
131
    def determine_num_available_blocks(self) -> Tuple[int, int]:
        """Profiles the peak memory usage of the model to determine how many
        KV blocks may be allocated without OOMs.

        The engine will first conduct a profiling of the existing memory usage.
        Then, it calculate the maximum possible number of GPU and CPU blocks
        that can be allocated with the remaining free memory.

        .. tip::
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
132
        """
133
134
135
136
        # Profile the memory usage of the model and get the maximum number of
        # cache blocks that can be allocated with the remaining free memory.
        torch.cuda.empty_cache()

137
138
139
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
        self.model_runner.profile_run()
140
141
142
143

        # Calculate the number of blocks that can be allocated with the
        # profiled peak memory.
        torch.cuda.synchronize()
144
        free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
145
146
147
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
        peak_memory = self.init_gpu_memory - free_gpu_memory
148
149
150
        assert peak_memory > 0, (
            "Error in memory profiling. This happens when the GPU memory was "
            "not properly cleaned up before initializing the vLLM instance.")
151

152
        cache_block_size = self.get_cache_block_size_bytes()
153
        num_gpu_blocks = int(
154
155
156
157
            (total_gpu_memory * self.cache_config.gpu_memory_utilization -
             peak_memory) // cache_block_size)
        num_cpu_blocks = int(self.cache_config.swap_space_bytes //
                             cache_block_size)
158
159
        num_gpu_blocks = max(num_gpu_blocks, 0)
        num_cpu_blocks = max(num_cpu_blocks, 0)
160
161
162
        if self.model_runner.lora_manager:
            self.model_runner.remove_all_loras()
        gc.collect()
163
164
165
        torch.cuda.empty_cache()
        return num_gpu_blocks, num_cpu_blocks

166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Allocate GPU and CPU KV cache with the specified number of blocks.

        This also warms up the model, which may record CUDA graphs.
        """
        raise_if_cache_size_invalid(num_gpu_blocks,
                                    self.cache_config.block_size,
                                    self.model_config.max_model_len)

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self._init_cache_engine()
        self._warm_up_model()

    def _init_cache_engine(self):
        assert self.cache_config.num_gpu_blocks is not None
184
185
        self.cache_engine = CacheEngine(self.cache_config, self.model_config,
                                        self.parallel_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
186
        self.gpu_cache = self.cache_engine.gpu_cache
187
        self.model_runner.set_block_size(self.cache_engine.block_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
188

189
    def _warm_up_model(self) -> None:
190
191
192
193
194
195
        if not self.model_config.enforce_eager:
            self.model_runner.capture_model(self.gpu_cache)
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

196
    def cache_swap(
Woosuk Kwon's avatar
Woosuk Kwon committed
197
198
199
        self,
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
200
        blocks_to_copy: Dict[int, List[int]],
201
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
202
        # Issue cache operations.
203
        # TODO(woosuk): Profile swapping overhead and optimize if needed.
Woosuk Kwon's avatar
Woosuk Kwon committed
204
205
206
207
208
209
        if blocks_to_swap_in:
            self.cache_engine.swap_in(blocks_to_swap_in)
        if blocks_to_swap_out:
            self.cache_engine.swap_out(blocks_to_swap_out)
        if blocks_to_copy:
            self.cache_engine.copy(blocks_to_copy)
210
211
212
213

    @torch.inference_mode()
    def execute_model(
        self,
214
        execute_model_req: Optional[ExecuteModelRequest] = None
215
216
    ) -> List[SamplerOutput]:

217
218
219
220
221
        if execute_model_req is None:
            seq_group_metadata_list = None
        else:
            seq_group_metadata_list = execute_model_req.seq_group_metadata_list

222
223
        if self.is_driver_worker:
            assert seq_group_metadata_list is not None
224
            assert execute_model_req is not None
225
            num_seq_groups = len(seq_group_metadata_list)
226
227
228
            blocks_to_swap_in = execute_model_req.blocks_to_swap_in
            blocks_to_swap_out = execute_model_req.blocks_to_swap_out
            blocks_to_copy = execute_model_req.blocks_to_copy
229
            data: Dict[str, Any] = {
230
231
232
233
234
235
                "num_seq_groups": num_seq_groups,
                "blocks_to_swap_in": blocks_to_swap_in,
                "blocks_to_swap_out": blocks_to_swap_out,
                "blocks_to_copy": blocks_to_copy,
            }
            broadcast_tensor_dict(data, src=0)
236
        else:
237
238
239
240
241
242
243
            data = broadcast_tensor_dict(src=0)
            num_seq_groups = data["num_seq_groups"]
            blocks_to_swap_in = data["blocks_to_swap_in"]
            blocks_to_swap_out = data["blocks_to_swap_out"]
            blocks_to_copy = data["blocks_to_copy"]

        self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
244

Woosuk Kwon's avatar
Woosuk Kwon committed
245
        # If there is no input, we don't need to execute the model.
246
        if num_seq_groups == 0:
247
            return []
Woosuk Kwon's avatar
Woosuk Kwon committed
248

249
        output = self.model_runner.execute_model(seq_group_metadata_list,
250
                                                 self.gpu_cache)
251
252
253
254

        # Worker only supports single-step execution. Wrap the output in a list
        # to conform to interface.
        return [output]
Woosuk Kwon's avatar
Woosuk Kwon committed
255

256
257
258
259
260
261
262
263
264
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

    def list_loras(self) -> Set[int]:
        return self.model_runner.list_loras()

265
266
267
268
269
270
271
272
    @property
    def max_model_len(self) -> int:
        return self.model_config.max_model_len

    @property
    def vocab_size(self) -> int:
        return self.model_runner.vocab_size

273
    def get_cache_block_size_bytes(self) -> int:
274
275
        """Get the size of the KV cache block size in bytes.
        """
276
        return CacheEngine.get_cache_block_size(self.cache_config,
277
278
279
                                                self.model_config,
                                                self.parallel_config)

Woosuk Kwon's avatar
Woosuk Kwon committed
280

281
def init_worker_distributed_environment(
282
283
    parallel_config: ParallelConfig,
    rank: int,
284
    distributed_init_method: Optional[str] = None,
285
    local_rank: int = -1,
286
287
) -> None:
    """Initialize the distributed environment."""
288
289
    init_distributed_environment(parallel_config.world_size, rank,
                                 distributed_init_method, local_rank)
290

291
292
293
    ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
                                      parallel_config.pipeline_parallel_size)

294
295
296
    if pynccl_utils.is_initialized():
        pynccl_world_size = pynccl_utils.get_world_size()
        if pynccl_world_size != parallel_config.world_size:
Woosuk Kwon's avatar
Woosuk Kwon committed
297
            raise RuntimeError(
298
                "pynccl is already initialized but the pynccl world "
Woosuk Kwon's avatar
Woosuk Kwon committed
299
                "size does not match parallel_config.world_size "
300
301
302
                f"({pynccl_world_size} vs. {parallel_config.world_size}).")
    elif parallel_config.world_size > 1:
        # NOTE(woosuk): We don't initialize pynccl process group when world size
Woosuk Kwon's avatar
Woosuk Kwon committed
303
        # is 1.
304
305
306
        # NOTE(kaichao): By default, pynccl is initialized for tp group.
        pynccl_utils.init_process_group(
            group=get_tensor_model_parallel_cpu_group())
307

308
309
310
311
    # Initialize a custom fast all-reduce implementation.
    if not parallel_config.disable_custom_all_reduce:
        init_custom_ar()

312
313
314
315
316
    # A small all_reduce for warmup.
    torch.distributed.all_reduce(torch.zeros(1).cuda())
    if pynccl_utils.is_initialized():
        pynccl_utils.all_reduce(torch.zeros(1).cuda())

317

318
319
320
321
322
323
324
325
326
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
    # Check if the GPU supports the dtype.
    if torch_dtype == torch.bfloat16:
        compute_capability = torch.cuda.get_device_capability()
        if compute_capability[0] < 8:
            gpu_name = torch.cuda.get_device_name()
            raise ValueError(
                "Bfloat16 is only supported on GPUs with compute capability "
                f"of at least 8.0. Your {gpu_name} GPU has compute capability "
Woosuk Kwon's avatar
Woosuk Kwon committed
327
328
329
                f"{compute_capability[0]}.{compute_capability[1]}. "
                "You can use float16 instead by explicitly setting the"
                "`dtype` flag in CLI, for example: --dtype=half.")
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345


def raise_if_cache_size_invalid(num_gpu_blocks, block_size,
                                max_model_len) -> None:
    if num_gpu_blocks <= 0:
        raise ValueError("No available memory for the cache blocks. "
                         "Try increasing `gpu_memory_utilization` when "
                         "initializing the engine.")
    max_seq_len = block_size * num_gpu_blocks
    if max_model_len > max_seq_len:
        raise ValueError(
            f"The model's max seq len ({max_model_len}) "
            "is larger than the maximum number of tokens that can be "
            f"stored in KV cache ({max_seq_len}). Try increasing "
            "`gpu_memory_utilization` or decreasing `max_model_len` when "
            "initializing the engine.")