worker.py 10.7 KB
Newer Older
1
"""A GPU worker class."""
2
import gc
3
import os
4
from typing import Dict, List, Tuple, Set, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6

import torch
7
import torch.distributed
Woosuk Kwon's avatar
Woosuk Kwon committed
8

Woosuk Kwon's avatar
Woosuk Kwon committed
9
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
10
                         SchedulerConfig, LoRAConfig)
11
from vllm.model_executor import set_random_seed
12
from vllm.model_executor.parallel_utils.communication_op import (
13
    broadcast_tensor_dict)
14
from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar
Woosuk Kwon's avatar
Woosuk Kwon committed
15
from vllm.model_executor.parallel_utils.parallel_state import (
16
    ensure_model_parallel_initialized)
17
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
18
from vllm.worker.cache_engine import CacheEngine
19
from vllm.worker.model_runner import ModelRunner
20
from vllm.lora.request import LoRARequest
Woosuk Kwon's avatar
Woosuk Kwon committed
21

22

Woosuk Kwon's avatar
Woosuk Kwon committed
23
class Worker:
24
25
26
27
28
29
    """A worker class that executes (a partition of) the model on a GPU.

    Each worker is associated with a single GPU. The worker is responsible for
    maintaining the KV cache and executing the model on the GPU. In case of
    distributed inference, each worker is assigned a partition of the model.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
30
31
32

    def __init__(
        self,
33
34
35
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
36
37
38
        local_rank: int,
        rank: int,
        distributed_init_method: str,
39
        lora_config: Optional[LoRAConfig] = None,
40
        is_driver_worker: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
41
    ) -> None:
42
43
44
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
45
        self.local_rank = local_rank
46
47
        self.rank = rank
        self.distributed_init_method = distributed_init_method
48
        self.lora_config = lora_config
49
50
51
        self.is_driver_worker = is_driver_worker
        if self.is_driver_worker:
            assert self.rank == 0, "The driver worker must have rank 0."
52

53
54
55
56
57
        self.model_runner = ModelRunner(model_config,
                                        parallel_config,
                                        scheduler_config,
                                        lora_config=self.lora_config,
                                        is_driver_worker=is_driver_worker)
58
59
60
61
62
63
64
        # Uninitialized cache engine. Will be initialized by
        # self.init_cache_engine().
        self.cache_config = None
        self.cache_engine = None
        self.cache_events = None
        self.gpu_cache = None

65
    def init_model(self) -> None:
66
67
68
69
70
71
72
73
        # torch.distributed.all_reduce does not free the input tensor until
        # the synchronization point. This causes the memory usage to grow
        # as the number of all_reduce calls increases. This env var disables
        # this behavior.
        # Related issue:
        # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
        os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"

74
75
        # This env var set by Ray causes exceptions with graph building.
        os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
76
        self.device = torch.device(f"cuda:{self.local_rank}")
77
78
        torch.cuda.set_device(self.device)

79
80
        _check_if_gpu_supports_dtype(self.model_config.dtype)

81
        # Initialize the distributed environment.
82
83
84
85
        init_distributed_environment(self.parallel_config, self.rank,
                                     self.distributed_init_method)
        if not self.parallel_config.disable_custom_all_reduce:
            init_custom_ar()
Woosuk Kwon's avatar
Woosuk Kwon committed
86
        # Initialize the model.
87
        set_random_seed(self.model_config.seed)
88
89

    def load_model(self):
90
        self.model_runner.load_model()
91

92
    @torch.inference_mode()
93
94
95
96
97
98
    def profile_num_available_blocks(
        self,
        block_size: int,
        gpu_memory_utilization: float,
        cpu_swap_space: int,
    ) -> Tuple[int, int]:
99
100
101
102
103
104
105
106
        """Profiles the peak memory usage of the model and returns the maximum
        number of GPU and CPU cache blocks that can be allocated.

        Args:
            block_size: The size of the cache block.
            gpu_memory_utilization: The fraction of the total GPU memory to use.
            cpu_swap_space: The size of the CPU swap space in bytes.
        """
107
108
109
110
        # Profile the memory usage of the model and get the maximum number of
        # cache blocks that can be allocated with the remaining free memory.
        torch.cuda.empty_cache()

111
112
113
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
        self.model_runner.profile_run()
114
115
116
117

        # Calculate the number of blocks that can be allocated with the
        # profiled peak memory.
        torch.cuda.synchronize()
118
119
120
        free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
        peak_memory = total_gpu_memory - free_gpu_memory

121
122
        cache_block_size = CacheEngine.get_cache_block_size(
            block_size, self.model_config, self.parallel_config)
123
124
125
        num_gpu_blocks = int(
            (total_gpu_memory * gpu_memory_utilization - peak_memory) //
            cache_block_size)
126
        num_cpu_blocks = int(cpu_swap_space // cache_block_size)
127
128
        num_gpu_blocks = max(num_gpu_blocks, 0)
        num_cpu_blocks = max(num_cpu_blocks, 0)
129
130
131
        if self.model_runner.lora_manager:
            self.model_runner.remove_all_loras()
        gc.collect()
132
133
134
        torch.cuda.empty_cache()
        return num_gpu_blocks, num_cpu_blocks

135
136
    def init_cache_engine(self, cache_config: CacheConfig) -> None:
        self.cache_config = cache_config
137
138
        self.cache_engine = CacheEngine(self.cache_config, self.model_config,
                                        self.parallel_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
139
140
        self.cache_events = self.cache_engine.events
        self.gpu_cache = self.cache_engine.gpu_cache
141
        self.model_runner.set_block_size(self.cache_engine.block_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
142

143
144
145
146
147
148
149
    def warm_up_model(self) -> None:
        if not self.model_config.enforce_eager:
            self.model_runner.capture_model(self.gpu_cache)
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

150
    def cache_swap(
Woosuk Kwon's avatar
Woosuk Kwon committed
151
152
153
        self,
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
154
        blocks_to_copy: Dict[int, List[int]],
155
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
156
        # Issue cache operations.
157
        issued_cache_op = False
Woosuk Kwon's avatar
Woosuk Kwon committed
158
159
        if blocks_to_swap_in:
            self.cache_engine.swap_in(blocks_to_swap_in)
160
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
161
162
        if blocks_to_swap_out:
            self.cache_engine.swap_out(blocks_to_swap_out)
163
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
164
165
        if blocks_to_copy:
            self.cache_engine.copy(blocks_to_copy)
166
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
167

168
        cache_events = self.cache_events if issued_cache_op else None
Woosuk Kwon's avatar
Woosuk Kwon committed
169

170
171
172
173
174
        # Wait for cache operations to finish.
        # TODO(woosuk): Profile swapping overhead and optimize if needed.
        if cache_events is not None:
            for event in cache_events:
                event.wait()
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189

    @torch.inference_mode()
    def execute_model(
        self,
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
        blocks_to_swap_in: Optional[Dict[int, int]] = None,
        blocks_to_swap_out: Optional[Dict[int, int]] = None,
        blocks_to_copy: Optional[Dict[int, List[int]]] = None,
    ) -> Optional[SamplerOutput]:
        if self.is_driver_worker:
            assert seq_group_metadata_list is not None
            num_seq_groups = len(seq_group_metadata_list)
            assert blocks_to_swap_in is not None
            assert blocks_to_swap_out is not None
            assert blocks_to_copy is not None
190
191
192
193
194
195
196
            data = {
                "num_seq_groups": num_seq_groups,
                "blocks_to_swap_in": blocks_to_swap_in,
                "blocks_to_swap_out": blocks_to_swap_out,
                "blocks_to_copy": blocks_to_copy,
            }
            broadcast_tensor_dict(data, src=0)
197
        else:
198
199
200
201
202
203
204
            data = broadcast_tensor_dict(src=0)
            num_seq_groups = data["num_seq_groups"]
            blocks_to_swap_in = data["blocks_to_swap_in"]
            blocks_to_swap_out = data["blocks_to_swap_out"]
            blocks_to_copy = data["blocks_to_copy"]

        self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
205

Woosuk Kwon's avatar
Woosuk Kwon committed
206
        # If there is no input, we don't need to execute the model.
207
        if num_seq_groups == 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
208
209
            return {}

210
        output = self.model_runner.execute_model(seq_group_metadata_list,
211
                                                 self.gpu_cache)
Woosuk Kwon's avatar
Woosuk Kwon committed
212
213
        return output

214
215
216
217
218
219
220
221
222
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

    def list_loras(self) -> Set[int]:
        return self.model_runner.list_loras()

Woosuk Kwon's avatar
Woosuk Kwon committed
223

224
def init_distributed_environment(
225
226
    parallel_config: ParallelConfig,
    rank: int,
227
    distributed_init_method: Optional[str] = None,
228
229
) -> None:
    """Initialize the distributed environment."""
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
    if torch.distributed.is_initialized():
        torch_world_size = torch.distributed.get_world_size()
        if torch_world_size != parallel_config.world_size:
            raise RuntimeError(
                "torch.distributed is already initialized but the torch world "
                "size does not match parallel_config.world_size "
                f"({torch_world_size} vs. {parallel_config.world_size}).")
    elif not distributed_init_method:
        raise ValueError(
            "distributed_init_method must be set if torch.distributed "
            "is not already initialized")
    else:
        torch.distributed.init_process_group(
            backend="nccl",
            world_size=parallel_config.world_size,
            rank=rank,
            init_method=distributed_init_method,
        )

249
250
    # A small all_reduce for warmup.
    torch.distributed.all_reduce(torch.zeros(1).cuda())
251
252
    ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
                                      parallel_config.pipeline_parallel_size)
253
254


255
256
257
258
259
260
261
262
263
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
    # Check if the GPU supports the dtype.
    if torch_dtype == torch.bfloat16:
        compute_capability = torch.cuda.get_device_capability()
        if compute_capability[0] < 8:
            gpu_name = torch.cuda.get_device_name()
            raise ValueError(
                "Bfloat16 is only supported on GPUs with compute capability "
                f"of at least 8.0. Your {gpu_name} GPU has compute capability "
Woosuk Kwon's avatar
Woosuk Kwon committed
264
265
266
                f"{compute_capability[0]}.{compute_capability[1]}. "
                "You can use float16 instead by explicitly setting the"
                "`dtype` flag in CLI, for example: --dtype=half.")