worker.py 13 KB
Newer Older
1
"""A GPU worker class."""
2
import gc
3
import os
4
from typing import Dict, List, Optional, Set, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6

import torch
7
import torch.distributed
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
10
                         ParallelConfig, SchedulerConfig, VisionLanguageConfig)
11
from vllm.lora.request import LoRARequest
12
from vllm.model_executor import set_random_seed
13
from vllm.model_executor.parallel_utils import pynccl_utils
14
from vllm.model_executor.parallel_utils.communication_op import (
15
    broadcast_tensor_dict)
16
from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar
Woosuk Kwon's avatar
Woosuk Kwon committed
17
from vllm.model_executor.parallel_utils.parallel_state import (
18
    ensure_model_parallel_initialized)
19
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
20
from vllm.worker.cache_engine import CacheEngine
21
from vllm.worker.model_runner import ModelRunner
Woosuk Kwon's avatar
Woosuk Kwon committed
22

23

Woosuk Kwon's avatar
Woosuk Kwon committed
24
class Worker:
25
26
27
28
29
30
    """A worker class that executes (a partition of) the model on a GPU.

    Each worker is associated with a single GPU. The worker is responsible for
    maintaining the KV cache and executing the model on the GPU. In case of
    distributed inference, each worker is assigned a partition of the model.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
33

    def __init__(
        self,
34
35
36
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
37
        device_config: DeviceConfig,
38
39
40
        local_rank: int,
        rank: int,
        distributed_init_method: str,
41
        lora_config: Optional[LoRAConfig] = None,
42
        vision_language_config: Optional[VisionLanguageConfig] = None,
43
        kv_cache_dtype: Optional[str] = "auto",
44
        is_driver_worker: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
45
    ) -> None:
46
47
48
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
49
        self.device_config = device_config
50
        self.local_rank = local_rank
51
52
        self.rank = rank
        self.distributed_init_method = distributed_init_method
53
        self.lora_config = lora_config
54
55
56
        self.is_driver_worker = is_driver_worker
        if self.is_driver_worker:
            assert self.rank == 0, "The driver worker must have rank 0."
57

58
59
60
61
62
63
64
65
66
67
68
69
70
71
        self.vision_language_config = vision_language_config
        if self.vision_language_config:
            assert not self.lora_config, (
                "To be tested: vision language model with LoRA settings.")

        self.model_runner = ModelRunner(
            model_config,
            parallel_config,
            scheduler_config,
            device_config,
            lora_config=self.lora_config,
            kv_cache_dtype=kv_cache_dtype,
            is_driver_worker=is_driver_worker,
            vision_language_config=vision_language_config)
72
73
74
75
76
77
        # Uninitialized cache engine. Will be initialized by
        # self.init_cache_engine().
        self.cache_config = None
        self.cache_engine = None
        self.gpu_cache = None

78
    def init_device(self) -> None:
79
80
81
82
83
84
85
86
        if self.device_config.device.type == "cuda":
            # torch.distributed.all_reduce does not free the input tensor until
            # the synchronization point. This causes the memory usage to grow
            # as the number of all_reduce calls increases. This env var disables
            # this behavior.
            # Related issue:
            # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
            os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
87

88
89
90
91
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
            self.device = torch.device(f"cuda:{self.local_rank}")
            torch.cuda.set_device(self.device)
92

93
            _check_if_gpu_supports_dtype(self.model_config.dtype)
94
95
            torch.cuda.empty_cache()
            self.init_gpu_memory = torch.cuda.mem_get_info()[0]
96
97
98
        else:
            raise RuntimeError(
                f"Not support device type: {self.device_config.device}")
99
        # Initialize the distributed environment.
100
101
102
        init_distributed_environment(self.parallel_config, self.rank,
                                     self.distributed_init_method,
                                     self.local_rank)
103
        # Set random seed.
104
        set_random_seed(self.model_config.seed)
105
106

    def load_model(self):
107
        self.model_runner.load_model()
108

109
    @torch.inference_mode()
110
111
112
113
114
    def profile_num_available_blocks(
        self,
        block_size: int,
        gpu_memory_utilization: float,
        cpu_swap_space: int,
115
        cache_dtype: str,
116
    ) -> Tuple[int, int]:
117
118
119
120
121
122
123
124
        """Profiles the peak memory usage of the model and returns the maximum
        number of GPU and CPU cache blocks that can be allocated.

        Args:
            block_size: The size of the cache block.
            gpu_memory_utilization: The fraction of the total GPU memory to use.
            cpu_swap_space: The size of the CPU swap space in bytes.
        """
125
126
127
128
        # Profile the memory usage of the model and get the maximum number of
        # cache blocks that can be allocated with the remaining free memory.
        torch.cuda.empty_cache()

129
130
131
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
        self.model_runner.profile_run()
132
133
134
135

        # Calculate the number of blocks that can be allocated with the
        # profiled peak memory.
        torch.cuda.synchronize()
136
        free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
137
138
139
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
        peak_memory = self.init_gpu_memory - free_gpu_memory
140
141
142
        assert peak_memory > 0, (
            "Error in memory profiling. This happens when the GPU memory was "
            "not properly cleaned up before initializing the vLLM instance.")
143

144
145
        cache_block_size = self.get_cache_block_size_bytes(
            block_size, cache_dtype)
146
147
148
        num_gpu_blocks = int(
            (total_gpu_memory * gpu_memory_utilization - peak_memory) //
            cache_block_size)
149
        num_cpu_blocks = int(cpu_swap_space // cache_block_size)
150
151
        num_gpu_blocks = max(num_gpu_blocks, 0)
        num_cpu_blocks = max(num_cpu_blocks, 0)
152
153
154
        if self.model_runner.lora_manager:
            self.model_runner.remove_all_loras()
        gc.collect()
155
156
157
        torch.cuda.empty_cache()
        return num_gpu_blocks, num_cpu_blocks

158
159
    def init_cache_engine(self, cache_config: CacheConfig) -> None:
        self.cache_config = cache_config
160
161
        self.cache_engine = CacheEngine(self.cache_config, self.model_config,
                                        self.parallel_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
162
        self.gpu_cache = self.cache_engine.gpu_cache
163
        self.model_runner.set_block_size(self.cache_engine.block_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
164

165
166
167
168
169
170
171
    def warm_up_model(self) -> None:
        if not self.model_config.enforce_eager:
            self.model_runner.capture_model(self.gpu_cache)
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

172
    def cache_swap(
Woosuk Kwon's avatar
Woosuk Kwon committed
173
174
175
        self,
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
176
        blocks_to_copy: Dict[int, List[int]],
177
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
178
        # Issue cache operations.
179
        # TODO(woosuk): Profile swapping overhead and optimize if needed.
Woosuk Kwon's avatar
Woosuk Kwon committed
180
181
182
183
184
185
        if blocks_to_swap_in:
            self.cache_engine.swap_in(blocks_to_swap_in)
        if blocks_to_swap_out:
            self.cache_engine.swap_out(blocks_to_swap_out)
        if blocks_to_copy:
            self.cache_engine.copy(blocks_to_copy)
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200

    @torch.inference_mode()
    def execute_model(
        self,
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
        blocks_to_swap_in: Optional[Dict[int, int]] = None,
        blocks_to_swap_out: Optional[Dict[int, int]] = None,
        blocks_to_copy: Optional[Dict[int, List[int]]] = None,
    ) -> Optional[SamplerOutput]:
        if self.is_driver_worker:
            assert seq_group_metadata_list is not None
            num_seq_groups = len(seq_group_metadata_list)
            assert blocks_to_swap_in is not None
            assert blocks_to_swap_out is not None
            assert blocks_to_copy is not None
201
202
203
204
205
206
207
            data = {
                "num_seq_groups": num_seq_groups,
                "blocks_to_swap_in": blocks_to_swap_in,
                "blocks_to_swap_out": blocks_to_swap_out,
                "blocks_to_copy": blocks_to_copy,
            }
            broadcast_tensor_dict(data, src=0)
208
        else:
209
210
211
212
213
214
215
            data = broadcast_tensor_dict(src=0)
            num_seq_groups = data["num_seq_groups"]
            blocks_to_swap_in = data["blocks_to_swap_in"]
            blocks_to_swap_out = data["blocks_to_swap_out"]
            blocks_to_copy = data["blocks_to_copy"]

        self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
216

Woosuk Kwon's avatar
Woosuk Kwon committed
217
        # If there is no input, we don't need to execute the model.
218
        if num_seq_groups == 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
219
220
            return {}

221
        output = self.model_runner.execute_model(seq_group_metadata_list,
222
                                                 self.gpu_cache)
Woosuk Kwon's avatar
Woosuk Kwon committed
223
224
        return output

225
226
227
228
229
230
231
232
233
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

    def list_loras(self) -> Set[int]:
        return self.model_runner.list_loras()

234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    @property
    def max_model_len(self) -> int:
        return self.model_config.max_model_len

    @property
    def vocab_size(self) -> int:
        return self.model_runner.vocab_size

    def get_cache_block_size_bytes(self, block_size: int,
                                   cache_dtype: str) -> int:
        """Get the size of the KV cache block size in bytes.
        """
        return CacheEngine.get_cache_block_size(block_size, cache_dtype,
                                                self.model_config,
                                                self.parallel_config)

Woosuk Kwon's avatar
Woosuk Kwon committed
250

251
def init_distributed_environment(
252
253
    parallel_config: ParallelConfig,
    rank: int,
254
    distributed_init_method: Optional[str] = None,
255
    local_rank: int = -1,
256
257
) -> None:
    """Initialize the distributed environment."""
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
    if torch.distributed.is_initialized():
        torch_world_size = torch.distributed.get_world_size()
        if torch_world_size != parallel_config.world_size:
            raise RuntimeError(
                "torch.distributed is already initialized but the torch world "
                "size does not match parallel_config.world_size "
                f"({torch_world_size} vs. {parallel_config.world_size}).")
    elif not distributed_init_method:
        raise ValueError(
            "distributed_init_method must be set if torch.distributed "
            "is not already initialized")
    else:
        torch.distributed.init_process_group(
            backend="nccl",
            world_size=parallel_config.world_size,
            rank=rank,
            init_method=distributed_init_method,
        )

277
278
279
    if pynccl_utils.is_initialized():
        pynccl_world_size = pynccl_utils.get_world_size()
        if pynccl_world_size != parallel_config.world_size:
Woosuk Kwon's avatar
Woosuk Kwon committed
280
            raise RuntimeError(
281
                "pynccl is already initialized but the pynccl world "
Woosuk Kwon's avatar
Woosuk Kwon committed
282
                "size does not match parallel_config.world_size "
283
284
285
                f"({pynccl_world_size} vs. {parallel_config.world_size}).")
    elif parallel_config.world_size > 1:
        # NOTE(woosuk): We don't initialize pynccl process group when world size
Woosuk Kwon's avatar
Woosuk Kwon committed
286
        # is 1.
287
        pynccl_utils.init_process_group(
Woosuk Kwon's avatar
Woosuk Kwon committed
288
            world_size=parallel_config.world_size,
289
            local_rank=local_rank,
Woosuk Kwon's avatar
Woosuk Kwon committed
290
            rank=rank,
291
            init_method=distributed_init_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
292
293
        )

294
295
    # A small all_reduce for warmup.
    torch.distributed.all_reduce(torch.zeros(1).cuda())
296
297
    if pynccl_utils.is_initialized():
        pynccl_utils.all_reduce(torch.zeros(1).cuda())
298
299
    ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
                                      parallel_config.pipeline_parallel_size)
300

301
302
303
304
    # Initialize a custom fast all-reduce implementation.
    if not parallel_config.disable_custom_all_reduce:
        init_custom_ar()

305

306
307
308
309
310
311
312
313
314
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
    # Check if the GPU supports the dtype.
    if torch_dtype == torch.bfloat16:
        compute_capability = torch.cuda.get_device_capability()
        if compute_capability[0] < 8:
            gpu_name = torch.cuda.get_device_name()
            raise ValueError(
                "Bfloat16 is only supported on GPUs with compute capability "
                f"of at least 8.0. Your {gpu_name} GPU has compute capability "
Woosuk Kwon's avatar
Woosuk Kwon committed
315
316
317
                f"{compute_capability[0]}.{compute_capability[1]}. "
                "You can use float16 instead by explicitly setting the"
                "`dtype` flag in CLI, for example: --dtype=half.")