worker.py 7.37 KB
Newer Older
1
"""A GPU worker class."""
2
import os
3
from typing import Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5

import torch
6
import torch.distributed
Woosuk Kwon's avatar
Woosuk Kwon committed
7

Woosuk Kwon's avatar
Woosuk Kwon committed
8
9
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
                         SchedulerConfig)
10
from vllm.model_executor import set_random_seed
Woosuk Kwon's avatar
Woosuk Kwon committed
11
from vllm.model_executor.parallel_utils.parallel_state import (
Zhuohan Li's avatar
Zhuohan Li committed
12
    initialize_model_parallel)
13
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
14
from vllm.worker.cache_engine import CacheEngine
15
from vllm.worker.model_runner import ModelRunner
Woosuk Kwon's avatar
Woosuk Kwon committed
16

17

Woosuk Kwon's avatar
Woosuk Kwon committed
18
class Worker:
19
20
21
22
23
24
    """A worker class that executes (a partition of) the model on a GPU.

    Each worker is associated with a single GPU. The worker is responsible for
    maintaining the KV cache and executing the model on the GPU. In case of
    distributed inference, each worker is assigned a partition of the model.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27

    def __init__(
        self,
28
29
30
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
31
32
        rank: Optional[int] = None,
        distributed_init_method: Optional[str] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
33
    ) -> None:
34
35
36
37
38
39
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
        self.rank = rank
        self.distributed_init_method = distributed_init_method

40
41
        self.model_runner = ModelRunner(model_config, parallel_config,
                                        scheduler_config)
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        # Uninitialized cache engine. Will be initialized by
        # self.init_cache_engine().
        self.cache_config = None
        self.cache_engine = None
        self.cache_events = None
        self.gpu_cache = None

    def init_model(self):
        # This env var set by Ray causes exceptions with graph building.
        os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
        # Env vars will be set by Ray.
        self.rank = self.rank if self.rank is not None else int(
            os.getenv("RANK", "-1"))
        local_rank = int(os.getenv("LOCAL_RANK", "0"))
        self.device = torch.device(f"cuda:{local_rank}")
        if self.rank < 0:
            raise ValueError("Invalid or unspecified rank.")
        torch.cuda.set_device(self.device)

61
62
        _check_if_gpu_supports_dtype(self.model_config.dtype)

63
        # Initialize the distributed environment.
64
65
        _init_distributed_environment(self.parallel_config, self.rank,
                                      self.distributed_init_method)
Woosuk Kwon's avatar
Woosuk Kwon committed
66
67

        # Initialize the model.
68
        set_random_seed(self.model_config.seed)
69
70

    def load_model(self):
71
        self.model_runner.load_model()
72

73
    @torch.inference_mode()
74
75
76
77
78
79
    def profile_num_available_blocks(
        self,
        block_size: int,
        gpu_memory_utilization: float,
        cpu_swap_space: int,
    ) -> Tuple[int, int]:
80
81
82
83
        # Profile the memory usage of the model and get the maximum number of
        # cache blocks that can be allocated with the remaining free memory.
        torch.cuda.empty_cache()

84
85
86
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
        self.model_runner.profile_run()
87
88
89
90

        # Calculate the number of blocks that can be allocated with the
        # profiled peak memory.
        torch.cuda.synchronize()
91
92
93
        free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
        peak_memory = total_gpu_memory - free_gpu_memory

94
95
        cache_block_size = CacheEngine.get_cache_block_size(
            block_size, self.model_config, self.parallel_config)
96
97
98
        num_gpu_blocks = int(
            (total_gpu_memory * gpu_memory_utilization - peak_memory) //
            cache_block_size)
99
        num_cpu_blocks = int(cpu_swap_space // cache_block_size)
100
101
        num_gpu_blocks = max(num_gpu_blocks, 0)
        num_cpu_blocks = max(num_cpu_blocks, 0)
102
        torch.cuda.empty_cache()
103
104
105
106

        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)
107
108
        return num_gpu_blocks, num_cpu_blocks

109
110
    def init_cache_engine(self, cache_config: CacheConfig) -> None:
        self.cache_config = cache_config
111
112
        self.cache_engine = CacheEngine(self.cache_config, self.model_config,
                                        self.parallel_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
113
114
        self.cache_events = self.cache_engine.events
        self.gpu_cache = self.cache_engine.gpu_cache
115
        self.model_runner.set_block_size(self.cache_engine.block_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
116
117

    @torch.inference_mode()
118
    def execute_model(
Woosuk Kwon's avatar
Woosuk Kwon committed
119
        self,
120
        seq_group_metadata_list: List[SequenceGroupMetadata],
Woosuk Kwon's avatar
Woosuk Kwon committed
121
122
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
123
        blocks_to_copy: Dict[int, List[int]],
124
    ) -> SamplerOutput:
Woosuk Kwon's avatar
Woosuk Kwon committed
125
        # Issue cache operations.
126
        issued_cache_op = False
Woosuk Kwon's avatar
Woosuk Kwon committed
127
128
        if blocks_to_swap_in:
            self.cache_engine.swap_in(blocks_to_swap_in)
129
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
130
131
        if blocks_to_swap_out:
            self.cache_engine.swap_out(blocks_to_swap_out)
132
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
133
134
        if blocks_to_copy:
            self.cache_engine.copy(blocks_to_copy)
135
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
136

137
        cache_events = self.cache_events if issued_cache_op else None
Woosuk Kwon's avatar
Woosuk Kwon committed
138

Woosuk Kwon's avatar
Woosuk Kwon committed
139
        # If there is no input, we don't need to execute the model.
140
        if not seq_group_metadata_list:
Woosuk Kwon's avatar
Woosuk Kwon committed
141
142
143
144
145
            if cache_events is not None:
                for event in cache_events:
                    event.wait()
            return {}

146
147
        output = self.model_runner.execute_model(seq_group_metadata_list,
                                                 self.gpu_cache, cache_events)
Woosuk Kwon's avatar
Woosuk Kwon committed
148
149
150
        return output


151
152
153
def _init_distributed_environment(
    parallel_config: ParallelConfig,
    rank: int,
154
    distributed_init_method: Optional[str] = None,
155
156
) -> None:
    """Initialize the distributed environment."""
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    if torch.distributed.is_initialized():
        torch_world_size = torch.distributed.get_world_size()
        if torch_world_size != parallel_config.world_size:
            raise RuntimeError(
                "torch.distributed is already initialized but the torch world "
                "size does not match parallel_config.world_size "
                f"({torch_world_size} vs. {parallel_config.world_size}).")
    elif not distributed_init_method:
        raise ValueError(
            "distributed_init_method must be set if torch.distributed "
            "is not already initialized")
    else:
        torch.distributed.init_process_group(
            backend="nccl",
            world_size=parallel_config.world_size,
            rank=rank,
            init_method=distributed_init_method,
        )

176
177
178
179
180
181
    # A small all_reduce for warmup.
    torch.distributed.all_reduce(torch.zeros(1).cuda())
    initialize_model_parallel(parallel_config.tensor_parallel_size,
                              parallel_config.pipeline_parallel_size)


182
183
184
185
186
187
188
189
190
191
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
    # Check if the GPU supports the dtype.
    if torch_dtype == torch.bfloat16:
        compute_capability = torch.cuda.get_device_capability()
        if compute_capability[0] < 8:
            gpu_name = torch.cuda.get_device_name()
            raise ValueError(
                "Bfloat16 is only supported on GPUs with compute capability "
                f"of at least 8.0. Your {gpu_name} GPU has compute capability "
                f"{compute_capability[0]}.{compute_capability[1]}.")