worker.py 14.8 KB
Newer Older
1
"""A GPU worker class."""
2
3
import os
from typing import Dict, List, Tuple, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5

import torch
6
import torch.distributed
Woosuk Kwon's avatar
Woosuk Kwon committed
7

Woosuk Kwon's avatar
Woosuk Kwon committed
8
9
10
11
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
                         SchedulerConfig)
from vllm.model_executor import get_model, InputMetadata, set_random_seed
from vllm.model_executor.parallel_utils.parallel_state import (
Zhuohan Li's avatar
Zhuohan Li committed
12
    initialize_model_parallel)
Woosuk Kwon's avatar
Woosuk Kwon committed
13
from vllm.sampling_params import SamplingParams
14
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
15
from vllm.worker.cache_engine import CacheEngine
16
from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes
Woosuk Kwon's avatar
Woosuk Kwon committed
17

18

Woosuk Kwon's avatar
Woosuk Kwon committed
19
class Worker:
20
21
22
23
24
25
    """A worker class that executes (a partition of) the model on a GPU.

    Each worker is associated with a single GPU. The worker is responsible for
    maintaining the KV cache and executing the model on the GPU. In case of
    distributed inference, each worker is assigned a partition of the model.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28

    def __init__(
        self,
29
30
31
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
32
33
        rank: Optional[int] = None,
        distributed_init_method: Optional[str] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
34
    ) -> None:
35
36
37
38
39
40
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
        self.rank = rank
        self.distributed_init_method = distributed_init_method

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        # Uninitialized cache engine. Will be initialized by
        # self.init_cache_engine().
        self.cache_config = None
        self.block_size = None
        self.cache_engine = None
        self.cache_events = None
        self.gpu_cache = None

    def init_model(self):
        # This env var set by Ray causes exceptions with graph building.
        os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
        # Env vars will be set by Ray.
        self.rank = self.rank if self.rank is not None else int(
            os.getenv("RANK", "-1"))
        local_rank = int(os.getenv("LOCAL_RANK", "0"))
        self.device = torch.device(f"cuda:{local_rank}")
        if self.rank < 0:
            raise ValueError("Invalid or unspecified rank.")
        torch.cuda.set_device(self.device)

61
        # Initialize the distributed environment.
62
63
        _init_distributed_environment(self.parallel_config, self.rank,
                                      self.distributed_init_method)
Woosuk Kwon's avatar
Woosuk Kwon committed
64
65

        # Initialize the model.
66
        set_random_seed(self.model_config.seed)
67
        self.model = get_model(self.model_config)
68

69
    @torch.inference_mode()
70
71
72
73
74
75
    def profile_num_available_blocks(
        self,
        block_size: int,
        gpu_memory_utilization: float,
        cpu_swap_space: int,
    ) -> Tuple[int, int]:
76
77
78
79
80
81
82
83
84
        # Profile the memory usage of the model and get the maximum number of
        # cache blocks that can be allocated with the remaining free memory.
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.

        # Enable top-k sampling to reflect the accurate memory usage.
85
86
        vocab_size = self.model.config.vocab_size
        sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
87
88
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs
89
        seqs = []
90
91
92
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
93
94
            seq_data = SequenceData([0] * seq_len)
            seq = SequenceGroupMetadata(
95
                request_id=str(group_id),
96
97
98
99
100
101
102
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
            )
            seqs.append(seq)

103
104
        input_tokens, input_positions, input_metadata = self._prepare_inputs(
            seqs)
105
106

        # Execute the model.
107
        num_layers = self.model_config.get_num_layers(self.parallel_config)
108
109
110
        self.model(
            input_ids=input_tokens,
            positions=input_positions,
111
            kv_caches=[(None, None)] * num_layers,
112
113
114
115
116
117
118
119
120
            input_metadata=input_metadata,
            cache_events=None,
        )

        # Calculate the number of blocks that can be allocated with the
        # profiled peak memory.
        torch.cuda.synchronize()
        peak_memory = torch.cuda.max_memory_allocated()
        total_gpu_memory = get_gpu_memory()
121
122
        cache_block_size = CacheEngine.get_cache_block_size(
            block_size, self.model_config, self.parallel_config)
123
124
125
        num_gpu_blocks = int(
            (total_gpu_memory * gpu_memory_utilization - peak_memory) //
            cache_block_size)
126
        num_cpu_blocks = int(cpu_swap_space // cache_block_size)
127
128
        num_gpu_blocks = max(num_gpu_blocks, 0)
        num_cpu_blocks = max(num_cpu_blocks, 0)
129
        torch.cuda.empty_cache()
130
131
132
133

        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)
134
135
        return num_gpu_blocks, num_cpu_blocks

136
137
138
    def init_cache_engine(self, cache_config: CacheConfig) -> None:
        self.cache_config = cache_config
        self.block_size = cache_config.block_size
139
140
141
142

        _check_if_can_support_max_seq_len(self.scheduler_config.max_model_len,
                                          self.block_size)

143
144
        self.cache_engine = CacheEngine(self.cache_config, self.model_config,
                                        self.parallel_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
145
146
147
        self.cache_events = self.cache_engine.events
        self.gpu_cache = self.cache_engine.gpu_cache

148
    def _prepare_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
149
        self,
150
        seq_group_metadata_list: List[SequenceGroupMetadata],
151
    ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
152
        seq_groups: List[Tuple[List[int], SamplingParams]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
153
154
155
156
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []

157
158
        # Add prompt tokens.
        prompt_lens: List[int] = []
159
160
        for seq_group_metadata in seq_group_metadata_list:
            if not seq_group_metadata.is_prompt:
161
162
                continue

163
            seq_ids = list(seq_group_metadata.seq_data.keys())
164
            sampling_params = seq_group_metadata.sampling_params
165
166
167
168
169
            seq_groups.append((seq_ids, sampling_params))

            # Use any sequence in the group.
            seq_id = seq_ids[0]

170
171
            seq_data = seq_group_metadata.seq_data[seq_id]
            prompt_tokens = seq_data.get_token_ids()
172
            prompt_len = len(prompt_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
173
174
            prompt_lens.append(prompt_len)

175
176
177
178
            input_tokens.extend(prompt_tokens)
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
            input_positions.extend(range(len(prompt_tokens)))
Woosuk Kwon's avatar
Woosuk Kwon committed
179

180
181
182
183
184
185
            if seq_group_metadata.block_tables is None:
                # During memory profiling, the block tables are not initialized
                # yet. In this case, we just use a dummy slot mapping.
                slot_mapping.extend([0] * prompt_len)
                continue

186
            # Compute the slot mapping.
187
            block_table = seq_group_metadata.block_tables[seq_id]
Woosuk Kwon's avatar
Woosuk Kwon committed
188
189
190
191
192
193
            for i in range(prompt_len):
                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping.append(slot)

194
        # Add generation tokens.
Woosuk Kwon's avatar
Woosuk Kwon committed
195
196
        max_context_len = 0
        max_num_blocks_per_seq = 0
197
        context_lens: List[int] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
198
        generation_block_tables: List[List[int]] = []
199
200
        for seq_group_metadata in seq_group_metadata_list:
            if seq_group_metadata.is_prompt:
201
202
                continue

203
            seq_ids = list(seq_group_metadata.seq_data.keys())
204
            sampling_params = seq_group_metadata.sampling_params
205
206
207
            seq_groups.append((seq_ids, sampling_params))

            for seq_id in seq_ids:
208
209
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
210
211
                input_tokens.append(generation_token)

212
213
                context_len = seq_data.get_len()
                position = context_len - 1
214
215
                input_positions.append(position)

216
                block_table = seq_group_metadata.block_tables[seq_id]
217
218
                generation_block_tables.append(block_table)

219
                max_context_len = max(max_context_len, context_len)
220
221
                max_num_blocks_per_seq = max(max_num_blocks_per_seq,
                                             len(block_table))
222
                context_lens.append(context_len)
223
224
225
226
227

                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping.append(slot)
Woosuk Kwon's avatar
Woosuk Kwon committed
228
229
230
231
232
233
234

        # Optimization: Pad the input length to be a multiple of 8.
        # This is required for utilizing the Tensor Cores in NVIDIA GPUs.
        input_tokens = _pad_to_alignment(input_tokens, multiple_of=8)
        input_positions = _pad_to_alignment(input_positions, multiple_of=8)

        # Convert to tensors.
Woosuk Kwon's avatar
Woosuk Kwon committed
235
236
237
238
239
240
241
242
243
244
245
246
        tokens_tensor = torch.tensor(input_tokens,
                                     dtype=torch.long,
                                     device="cuda")
        positions_tensor = torch.tensor(input_positions,
                                        dtype=torch.long,
                                        device="cuda")
        slot_mapping_tensor = torch.tensor(slot_mapping,
                                           dtype=torch.int,
                                           device="cuda")
        context_lens_tensor = torch.tensor(context_lens,
                                           dtype=torch.int,
                                           device="cuda")
Woosuk Kwon's avatar
Woosuk Kwon committed
247
248
        padded_block_tables = [
            _pad_to_max(block_table, max_num_blocks_per_seq)
249
250
            for block_table in generation_block_tables
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
251
252
253
        block_tables_tensor = torch.tensor(padded_block_tables,
                                           dtype=torch.int,
                                           device="cuda")
Woosuk Kwon's avatar
Woosuk Kwon committed
254

255
256
257
258
        seq_data: Dict[int, SequenceData] = {}
        for seq_group_metadata in seq_group_metadata_list:
            seq_data.update(seq_group_metadata.seq_data)

Woosuk Kwon's avatar
Woosuk Kwon committed
259
        input_metadata = InputMetadata(
260
            seq_groups=seq_groups,
261
            seq_data=seq_data,
Woosuk Kwon's avatar
Woosuk Kwon committed
262
263
264
265
266
267
268
269
270
            prompt_lens=prompt_lens,
            slot_mapping=slot_mapping_tensor,
            context_lens=context_lens_tensor,
            max_context_len=max_context_len,
            block_tables=block_tables_tensor,
        )
        return tokens_tensor, positions_tensor, input_metadata

    @torch.inference_mode()
271
    def execute_model(
Woosuk Kwon's avatar
Woosuk Kwon committed
272
        self,
273
        seq_group_metadata_list: List[SequenceGroupMetadata],
Woosuk Kwon's avatar
Woosuk Kwon committed
274
275
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
276
        blocks_to_copy: Dict[int, List[int]],
277
    ) -> SamplerOutput:
Woosuk Kwon's avatar
Woosuk Kwon committed
278
        # Issue cache operations.
279
        issued_cache_op = False
Woosuk Kwon's avatar
Woosuk Kwon committed
280
281
        if blocks_to_swap_in:
            self.cache_engine.swap_in(blocks_to_swap_in)
282
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
283
284
        if blocks_to_swap_out:
            self.cache_engine.swap_out(blocks_to_swap_out)
285
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
286
287
        if blocks_to_copy:
            self.cache_engine.copy(blocks_to_copy)
288
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
289

290
        if issued_cache_op:
Woosuk Kwon's avatar
Woosuk Kwon committed
291
292
293
294
            cache_events = self.cache_events
        else:
            cache_events = None

Woosuk Kwon's avatar
Woosuk Kwon committed
295
        # If there is no input, we don't need to execute the model.
296
        if not seq_group_metadata_list:
Woosuk Kwon's avatar
Woosuk Kwon committed
297
298
299
300
301
            if cache_events is not None:
                for event in cache_events:
                    event.wait()
            return {}

Woosuk Kwon's avatar
Woosuk Kwon committed
302
        # Prepare input tensors.
303
        input_tokens, input_positions, input_metadata = self._prepare_inputs(
304
            seq_group_metadata_list)
Woosuk Kwon's avatar
Woosuk Kwon committed
305
306
307
308
309

        # Execute the model.
        output = self.model(
            input_ids=input_tokens,
            positions=input_positions,
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
310
            kv_caches=self.gpu_cache,
Woosuk Kwon's avatar
Woosuk Kwon committed
311
312
313
314
315
316
            input_metadata=input_metadata,
            cache_events=cache_events,
        )
        return output


317
318
319
def _init_distributed_environment(
    parallel_config: ParallelConfig,
    rank: int,
320
    distributed_init_method: Optional[str] = None,
321
322
) -> None:
    """Initialize the distributed environment."""
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
    if torch.distributed.is_initialized():
        torch_world_size = torch.distributed.get_world_size()
        if torch_world_size != parallel_config.world_size:
            raise RuntimeError(
                "torch.distributed is already initialized but the torch world "
                "size does not match parallel_config.world_size "
                f"({torch_world_size} vs. {parallel_config.world_size}).")
    elif not distributed_init_method:
        raise ValueError(
            "distributed_init_method must be set if torch.distributed "
            "is not already initialized")
    else:
        torch.distributed.init_process_group(
            backend="nccl",
            world_size=parallel_config.world_size,
            rank=rank,
            init_method=distributed_init_method,
        )

342
343
344
345
346
347
    # A small all_reduce for warmup.
    torch.distributed.all_reduce(torch.zeros(1).cuda())
    initialize_model_parallel(parallel_config.tensor_parallel_size,
                              parallel_config.pipeline_parallel_size)


Woosuk Kwon's avatar
Woosuk Kwon committed
348
349
350
351
352
353
def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
    return x + [0] * ((-len(x)) % multiple_of)


def _pad_to_max(x: List[int], max_len: int) -> List[int]:
    return x + [0] * (max_len - len(x))
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373


def _check_if_can_support_max_seq_len(max_seq_len: int,
                                      block_size: int) -> None:
    # Follows the logic in
    # attention_kernels.cu::single_query_cached_kv_attention_launcher
    max_shared_mem = get_max_shared_memory_bytes()
    float32_bytes = torch.finfo(torch.float).bits // 8
    padded_max_seq_len = (
        (max_seq_len + block_size - 1) / block_size) * block_size
    # padded_max_seq_len + extra buffer
    required_shared_mem = (padded_max_seq_len + 512) * float32_bytes
    if padded_max_seq_len * float32_bytes > max_shared_mem:
        raise RuntimeError(
            f"vLLM cannot currently support max_model_len={max_seq_len} "
            f"with block_size={block_size} on GPU with compute "
            f"capability {torch.cuda.get_device_capability()} "
            f"(required shared memory {required_shared_mem} > "
            f"available shared memory {max_shared_mem}). "
            "This will be fixed in a future release.")