worker.py 9.79 KB
Newer Older
1
"""A GPU worker class."""
2
import os
3
from typing import Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5

import torch
6
import torch.distributed
Woosuk Kwon's avatar
Woosuk Kwon committed
7

Woosuk Kwon's avatar
Woosuk Kwon committed
8
9
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
                         SchedulerConfig)
10
from vllm.model_executor import set_random_seed
11
from vllm.model_executor.parallel_utils.communication_op import (
12
    broadcast_tensor_dict)
Woosuk Kwon's avatar
Woosuk Kwon committed
13
from vllm.model_executor.parallel_utils.parallel_state import (
Zhuohan Li's avatar
Zhuohan Li committed
14
    initialize_model_parallel)
15
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
16
from vllm.worker.cache_engine import CacheEngine
17
from vllm.worker.model_runner import ModelRunner
Woosuk Kwon's avatar
Woosuk Kwon committed
18

19

Woosuk Kwon's avatar
Woosuk Kwon committed
20
class Worker:
21
22
23
24
25
26
    """A worker class that executes (a partition of) the model on a GPU.

    Each worker is associated with a single GPU. The worker is responsible for
    maintaining the KV cache and executing the model on the GPU. In case of
    distributed inference, each worker is assigned a partition of the model.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
27
28
29

    def __init__(
        self,
30
31
32
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
33
34
35
36
        local_rank: int,
        rank: int,
        distributed_init_method: str,
        is_driver_worker: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
37
    ) -> None:
38
39
40
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
41
        self.local_rank = local_rank
42
43
        self.rank = rank
        self.distributed_init_method = distributed_init_method
44
45
46
        self.is_driver_worker = is_driver_worker
        if self.is_driver_worker:
            assert self.rank == 0, "The driver worker must have rank 0."
47

48
        self.model_runner = ModelRunner(model_config, parallel_config,
49
                                        scheduler_config, is_driver_worker)
50
51
52
53
54
55
56
        # Uninitialized cache engine. Will be initialized by
        # self.init_cache_engine().
        self.cache_config = None
        self.cache_engine = None
        self.cache_events = None
        self.gpu_cache = None

57
    def init_model(self) -> None:
58
59
60
61
62
63
64
65
        # torch.distributed.all_reduce does not free the input tensor until
        # the synchronization point. This causes the memory usage to grow
        # as the number of all_reduce calls increases. This env var disables
        # this behavior.
        # Related issue:
        # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
        os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"

66
67
        # This env var set by Ray causes exceptions with graph building.
        os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
68
        self.device = torch.device(f"cuda:{self.local_rank}")
69
70
        torch.cuda.set_device(self.device)

71
72
        _check_if_gpu_supports_dtype(self.model_config.dtype)

73
        # Initialize the distributed environment.
74
        _init_distributed_environment(self.parallel_config, self.rank,
75
                                      self.distributed_init_method)
Woosuk Kwon's avatar
Woosuk Kwon committed
76
77

        # Initialize the model.
78
        set_random_seed(self.model_config.seed)
79
80

    def load_model(self):
81
        self.model_runner.load_model()
82

83
    @torch.inference_mode()
84
85
86
87
88
89
    def profile_num_available_blocks(
        self,
        block_size: int,
        gpu_memory_utilization: float,
        cpu_swap_space: int,
    ) -> Tuple[int, int]:
90
91
92
93
94
95
96
97
        """Profiles the peak memory usage of the model and returns the maximum
        number of GPU and CPU cache blocks that can be allocated.

        Args:
            block_size: The size of the cache block.
            gpu_memory_utilization: The fraction of the total GPU memory to use.
            cpu_swap_space: The size of the CPU swap space in bytes.
        """
98
99
100
101
        # Profile the memory usage of the model and get the maximum number of
        # cache blocks that can be allocated with the remaining free memory.
        torch.cuda.empty_cache()

102
103
104
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
        self.model_runner.profile_run()
105
106
107
108

        # Calculate the number of blocks that can be allocated with the
        # profiled peak memory.
        torch.cuda.synchronize()
109
110
111
        free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
        peak_memory = total_gpu_memory - free_gpu_memory

112
113
        cache_block_size = CacheEngine.get_cache_block_size(
            block_size, self.model_config, self.parallel_config)
114
115
116
        num_gpu_blocks = int(
            (total_gpu_memory * gpu_memory_utilization - peak_memory) //
            cache_block_size)
117
        num_cpu_blocks = int(cpu_swap_space // cache_block_size)
118
119
        num_gpu_blocks = max(num_gpu_blocks, 0)
        num_cpu_blocks = max(num_cpu_blocks, 0)
120
121
122
        torch.cuda.empty_cache()
        return num_gpu_blocks, num_cpu_blocks

123
124
    def init_cache_engine(self, cache_config: CacheConfig) -> None:
        self.cache_config = cache_config
125
126
        self.cache_engine = CacheEngine(self.cache_config, self.model_config,
                                        self.parallel_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
127
128
        self.cache_events = self.cache_engine.events
        self.gpu_cache = self.cache_engine.gpu_cache
129
        self.model_runner.set_block_size(self.cache_engine.block_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
130

131
132
133
134
135
136
137
    def warm_up_model(self) -> None:
        if not self.model_config.enforce_eager:
            self.model_runner.capture_model(self.gpu_cache)
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

138
    def cache_swap(
Woosuk Kwon's avatar
Woosuk Kwon committed
139
140
141
        self,
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
142
        blocks_to_copy: Dict[int, List[int]],
143
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
144
        # Issue cache operations.
145
        issued_cache_op = False
Woosuk Kwon's avatar
Woosuk Kwon committed
146
147
        if blocks_to_swap_in:
            self.cache_engine.swap_in(blocks_to_swap_in)
148
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
149
150
        if blocks_to_swap_out:
            self.cache_engine.swap_out(blocks_to_swap_out)
151
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
152
153
        if blocks_to_copy:
            self.cache_engine.copy(blocks_to_copy)
154
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
155

156
        cache_events = self.cache_events if issued_cache_op else None
Woosuk Kwon's avatar
Woosuk Kwon committed
157

158
159
160
161
162
        # Wait for cache operations to finish.
        # TODO(woosuk): Profile swapping overhead and optimize if needed.
        if cache_events is not None:
            for event in cache_events:
                event.wait()
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

    @torch.inference_mode()
    def execute_model(
        self,
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
        blocks_to_swap_in: Optional[Dict[int, int]] = None,
        blocks_to_swap_out: Optional[Dict[int, int]] = None,
        blocks_to_copy: Optional[Dict[int, List[int]]] = None,
    ) -> Optional[SamplerOutput]:
        if self.is_driver_worker:
            assert seq_group_metadata_list is not None
            num_seq_groups = len(seq_group_metadata_list)
            assert blocks_to_swap_in is not None
            assert blocks_to_swap_out is not None
            assert blocks_to_copy is not None
178
179
180
181
182
183
184
            data = {
                "num_seq_groups": num_seq_groups,
                "blocks_to_swap_in": blocks_to_swap_in,
                "blocks_to_swap_out": blocks_to_swap_out,
                "blocks_to_copy": blocks_to_copy,
            }
            broadcast_tensor_dict(data, src=0)
185
        else:
186
187
188
189
190
191
192
            data = broadcast_tensor_dict(src=0)
            num_seq_groups = data["num_seq_groups"]
            blocks_to_swap_in = data["blocks_to_swap_in"]
            blocks_to_swap_out = data["blocks_to_swap_out"]
            blocks_to_copy = data["blocks_to_copy"]

        self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
193

Woosuk Kwon's avatar
Woosuk Kwon committed
194
        # If there is no input, we don't need to execute the model.
195
        if num_seq_groups == 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
196
197
            return {}

198
        output = self.model_runner.execute_model(seq_group_metadata_list,
199
                                                 self.gpu_cache)
Woosuk Kwon's avatar
Woosuk Kwon committed
200
201
202
        return output


203
204
205
def _init_distributed_environment(
    parallel_config: ParallelConfig,
    rank: int,
206
    distributed_init_method: Optional[str] = None,
207
208
) -> None:
    """Initialize the distributed environment."""
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    if torch.distributed.is_initialized():
        torch_world_size = torch.distributed.get_world_size()
        if torch_world_size != parallel_config.world_size:
            raise RuntimeError(
                "torch.distributed is already initialized but the torch world "
                "size does not match parallel_config.world_size "
                f"({torch_world_size} vs. {parallel_config.world_size}).")
    elif not distributed_init_method:
        raise ValueError(
            "distributed_init_method must be set if torch.distributed "
            "is not already initialized")
    else:
        torch.distributed.init_process_group(
            backend="nccl",
            world_size=parallel_config.world_size,
            rank=rank,
            init_method=distributed_init_method,
        )

228
229
    # A small all_reduce for warmup.
    torch.distributed.all_reduce(torch.zeros(1).cuda())
230
231
232
233
    initialize_model_parallel(parallel_config.tensor_parallel_size,
                              parallel_config.pipeline_parallel_size)


234
235
236
237
238
239
240
241
242
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
    # Check if the GPU supports the dtype.
    if torch_dtype == torch.bfloat16:
        compute_capability = torch.cuda.get_device_capability()
        if compute_capability[0] < 8:
            gpu_name = torch.cuda.get_device_name()
            raise ValueError(
                "Bfloat16 is only supported on GPUs with compute capability "
                f"of at least 8.0. Your {gpu_name} GPU has compute capability "
Woosuk Kwon's avatar
Woosuk Kwon committed
243
244
245
                f"{compute_capability[0]}.{compute_capability[1]}. "
                "You can use float16 instead by explicitly setting the"
                "`dtype` flag in CLI, for example: --dtype=half.")