worker.py 18.6 KB
Newer Older
1
"""A GPU worker class."""
2
3
import os
from typing import Dict, List, Tuple, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5

import torch
6
import torch.distributed
Woosuk Kwon's avatar
Woosuk Kwon committed
7

Woosuk Kwon's avatar
Woosuk Kwon committed
8
9
10
11
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
                         SchedulerConfig)
from vllm.model_executor import get_model, InputMetadata, set_random_seed
from vllm.model_executor.parallel_utils.parallel_state import (
Zhuohan Li's avatar
Zhuohan Li committed
12
    initialize_model_parallel)
13
from vllm.sampling_params import SamplingParams, SamplingType
14
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
15
from vllm.worker.cache_engine import CacheEngine
16
from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes
Woosuk Kwon's avatar
Woosuk Kwon committed
17

18

Woosuk Kwon's avatar
Woosuk Kwon committed
19
class Worker:
20
21
22
23
24
25
    """A worker class that executes (a partition of) the model on a GPU.

    Each worker is associated with a single GPU. The worker is responsible for
    maintaining the KV cache and executing the model on the GPU. In case of
    distributed inference, each worker is assigned a partition of the model.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28

    def __init__(
        self,
29
30
31
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
32
33
        rank: Optional[int] = None,
        distributed_init_method: Optional[str] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
34
    ) -> None:
35
36
37
38
39
40
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
        self.rank = rank
        self.distributed_init_method = distributed_init_method

41
42
43
44
        # Uninitialized cache engine. Will be initialized by
        # self.init_cache_engine().
        self.cache_config = None
        self.block_size = None
Woosuk Kwon's avatar
Woosuk Kwon committed
45
        self.sliding_window = None
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        self.cache_engine = None
        self.cache_events = None
        self.gpu_cache = None

    def init_model(self):
        # This env var set by Ray causes exceptions with graph building.
        os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
        # Env vars will be set by Ray.
        self.rank = self.rank if self.rank is not None else int(
            os.getenv("RANK", "-1"))
        local_rank = int(os.getenv("LOCAL_RANK", "0"))
        self.device = torch.device(f"cuda:{local_rank}")
        if self.rank < 0:
            raise ValueError("Invalid or unspecified rank.")
        torch.cuda.set_device(self.device)

62
63
        _check_if_gpu_supports_dtype(self.model_config.dtype)

64
        # Initialize the distributed environment.
65
66
        _init_distributed_environment(self.parallel_config, self.rank,
                                      self.distributed_init_method)
Woosuk Kwon's avatar
Woosuk Kwon committed
67
68

        # Initialize the model.
69
        set_random_seed(self.model_config.seed)
70
        self.model = get_model(self.model_config)
71

72
    @torch.inference_mode()
73
74
75
76
77
78
    def profile_num_available_blocks(
        self,
        block_size: int,
        gpu_memory_utilization: float,
        cpu_swap_space: int,
    ) -> Tuple[int, int]:
79
80
81
82
83
84
85
86
87
        # Profile the memory usage of the model and get the maximum number of
        # cache blocks that can be allocated with the remaining free memory.
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.

        # Enable top-k sampling to reflect the accurate memory usage.
88
89
        vocab_size = self.model.config.vocab_size
        sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
90
91
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs
92
        seqs = []
93
94
95
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
96
97
            seq_data = SequenceData([0] * seq_len)
            seq = SequenceGroupMetadata(
98
                request_id=str(group_id),
99
100
101
102
103
104
105
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
            )
            seqs.append(seq)

106
107
        input_tokens, input_positions, input_metadata = self._prepare_inputs(
            seqs)
108
109

        # Execute the model.
110
        num_layers = self.model_config.get_num_layers(self.parallel_config)
111
112
113
        self.model(
            input_ids=input_tokens,
            positions=input_positions,
114
            kv_caches=[(None, None)] * num_layers,
115
116
117
118
119
120
121
122
123
            input_metadata=input_metadata,
            cache_events=None,
        )

        # Calculate the number of blocks that can be allocated with the
        # profiled peak memory.
        torch.cuda.synchronize()
        peak_memory = torch.cuda.max_memory_allocated()
        total_gpu_memory = get_gpu_memory()
124
125
        cache_block_size = CacheEngine.get_cache_block_size(
            block_size, self.model_config, self.parallel_config)
126
127
128
        num_gpu_blocks = int(
            (total_gpu_memory * gpu_memory_utilization - peak_memory) //
            cache_block_size)
129
        num_cpu_blocks = int(cpu_swap_space // cache_block_size)
130
131
        num_gpu_blocks = max(num_gpu_blocks, 0)
        num_cpu_blocks = max(num_cpu_blocks, 0)
132
        torch.cuda.empty_cache()
133
134
135
136

        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)
137
138
        return num_gpu_blocks, num_cpu_blocks

139
140
141
    def init_cache_engine(self, cache_config: CacheConfig) -> None:
        self.cache_config = cache_config
        self.block_size = cache_config.block_size
Woosuk Kwon's avatar
Woosuk Kwon committed
142
        self.sliding_window = cache_config.sliding_window
143

Woosuk Kwon's avatar
Woosuk Kwon committed
144
145
146
147
148
        if self.sliding_window is None:
            max_seq_len = self.scheduler_config.max_model_len
        else:
            max_seq_len = min(self.scheduler_config.max_model_len,
                              self.sliding_window)
149
        _check_if_can_support_max_seq_len(max_seq_len, self.block_size)
150

151
152
        self.cache_engine = CacheEngine(self.cache_config, self.model_config,
                                        self.parallel_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
153
154
155
        self.cache_events = self.cache_engine.events
        self.gpu_cache = self.cache_engine.gpu_cache

156
    def _prepare_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
157
        self,
158
        seq_group_metadata_list: List[SequenceGroupMetadata],
159
    ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
160
        seq_groups: List[Tuple[List[int], SamplingParams]] = []
161
162
163
        input_tokens: List[List[int]] = []
        input_positions: List[List[int]] = []
        slot_mapping: List[List[int]] = []
164
165
166
167
        selected_token_indices: List[int] = []
        selected_token_start_idx = 0
        categorized_sample_indices = {t: [] for t in SamplingType}
        categorized_sample_indices_start_idx = 0
Woosuk Kwon's avatar
Woosuk Kwon committed
168

169
170
        # Add prompt tokens.
        prompt_lens: List[int] = []
171
172
        for seq_group_metadata in seq_group_metadata_list:
            if not seq_group_metadata.is_prompt:
173
174
                continue

175
            seq_ids = list(seq_group_metadata.seq_data.keys())
176
            sampling_params = seq_group_metadata.sampling_params
177
178
179
180
181
            seq_groups.append((seq_ids, sampling_params))

            # Use any sequence in the group.
            seq_id = seq_ids[0]

182
183
            seq_data = seq_group_metadata.seq_data[seq_id]
            prompt_tokens = seq_data.get_token_ids()
184
            prompt_len = len(prompt_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
185
186
            prompt_lens.append(prompt_len)

187
188
189
190
191
192
193
194
            if sampling_params.prompt_logprobs is not None:
                # NOTE: prompt token positions do not need sample, skip
                categorized_sample_indices_start_idx += prompt_len - 1

            categorized_sample_indices[sampling_params.sampling_type].append(
                categorized_sample_indices_start_idx)
            categorized_sample_indices_start_idx += 1

195
            input_tokens.append(prompt_tokens)
196
197
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
198
            input_positions.append(list(range(prompt_len)))
Woosuk Kwon's avatar
Woosuk Kwon committed
199

200
201
202
            if seq_group_metadata.block_tables is None:
                # During memory profiling, the block tables are not initialized
                # yet. In this case, we just use a dummy slot mapping.
203
                slot_mapping.append([0] * prompt_len)
204
205
                continue

206
            # Compute the slot mapping.
207
            slot_mapping.append([])
208
            block_table = seq_group_metadata.block_tables[seq_id]
Woosuk Kwon's avatar
Woosuk Kwon committed
209
210
211
212
            for i in range(prompt_len):
                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
213
                slot_mapping[-1].append(slot)
Woosuk Kwon's avatar
Woosuk Kwon committed
214

215
        # Add generation tokens.
Woosuk Kwon's avatar
Woosuk Kwon committed
216
217
        max_context_len = 0
        max_num_blocks_per_seq = 0
218
        context_lens: List[int] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
219
        generation_block_tables: List[List[int]] = []
220
        max_seq_len = max(prompt_lens) if prompt_lens else 1
221
222
        for seq_group_metadata in seq_group_metadata_list:
            if seq_group_metadata.is_prompt:
223
224
225
226
227
228
229
230
231
232
233
                # We need to do this in this loop as we need to know max_seq_len
                assert len(
                    seq_ids) == 1, "Prompt input should have only one seq."
                sampling_params = seq_group_metadata.sampling_params
                if sampling_params.prompt_logprobs is not None:
                    selected_token_indices.extend(
                        range(selected_token_start_idx,
                              selected_token_start_idx + prompt_len - 1))
                selected_token_indices.append(selected_token_start_idx +
                                              prompt_len - 1)
                selected_token_start_idx += max_seq_len
234
235
                continue

236
            seq_ids = list(seq_group_metadata.seq_data.keys())
237
            sampling_params = seq_group_metadata.sampling_params
238
239
            seq_groups.append((seq_ids, sampling_params))

240
241
242
243
244
245
246
247
248
249
250
            num_seqs = len(seq_ids)
            selected_token_indices.extend(
                range(selected_token_start_idx,
                      selected_token_start_idx + num_seqs))
            selected_token_start_idx += num_seqs

            categorized_sample_indices[sampling_params.sampling_type].extend(
                range(categorized_sample_indices_start_idx,
                      categorized_sample_indices_start_idx + num_seqs))
            categorized_sample_indices_start_idx += num_seqs

251
            for seq_id in seq_ids:
252
253
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
254
                input_tokens.append([generation_token])
255

256
257
                context_len = seq_data.get_len()
                position = context_len - 1
Woosuk Kwon's avatar
Woosuk Kwon committed
258
259
                if self.sliding_window is not None:
                    context_len = min(context_len, self.sliding_window)
260
                input_positions.append([position])
261

262
                block_table = seq_group_metadata.block_tables[seq_id]
263

264
                max_context_len = max(max_context_len, context_len)
265
266
                max_num_blocks_per_seq = max(max_num_blocks_per_seq,
                                             len(block_table))
267
                context_lens.append(context_len)
268
269
270
271

                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
272
                slot_mapping.append([slot])
Woosuk Kwon's avatar
Woosuk Kwon committed
273

Woosuk Kwon's avatar
Woosuk Kwon committed
274
275
276
                if self.sliding_window is not None:
                    sliding_window_blocks = (self.sliding_window //
                                             self.block_size)
277
278
279
                    block_table = block_table[-sliding_window_blocks:]
                generation_block_tables.append(block_table)

280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
        padded_input_tokens = [
            _pad_to_max(tokens, max_seq_len, pad=0) for tokens in input_tokens
        ]
        padded_input_positions = [
            _pad_to_max(positions, max_seq_len, pad=0)
            for positions in input_positions
        ]
        padded_slot_mapping = [
            _pad_to_max(mapping, max_seq_len, pad=-1)
            for mapping in slot_mapping
        ]
        padded_block_tables = [
            _pad_to_max(block_table, max_num_blocks_per_seq, pad=0)
            for block_table in generation_block_tables
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
295
296

        # Convert to tensors.
297
        tokens_tensor = torch.tensor(padded_input_tokens,
Woosuk Kwon's avatar
Woosuk Kwon committed
298
299
                                     dtype=torch.long,
                                     device="cuda")
300
        positions_tensor = torch.tensor(padded_input_positions,
Woosuk Kwon's avatar
Woosuk Kwon committed
301
302
                                        dtype=torch.long,
                                        device="cuda")
303
        slot_mapping_tensor = torch.tensor(padded_slot_mapping,
Woosuk Kwon's avatar
Woosuk Kwon committed
304
305
306
307
308
                                           dtype=torch.int,
                                           device="cuda")
        context_lens_tensor = torch.tensor(context_lens,
                                           dtype=torch.int,
                                           device="cuda")
309
310
311
312
313
314
315
        selected_token_indices = torch.tensor(selected_token_indices,
                                              dtype=torch.long,
                                              device="cuda")
        categorized_sample_indices = {
            t: torch.tensor(seq_ids, dtype=torch.int, device="cuda")
            for t, seq_ids in categorized_sample_indices.items()
        }
Woosuk Kwon's avatar
Woosuk Kwon committed
316
317
318
        block_tables_tensor = torch.tensor(padded_block_tables,
                                           dtype=torch.int,
                                           device="cuda")
Woosuk Kwon's avatar
Woosuk Kwon committed
319

320
321
322
323
        seq_data: Dict[int, SequenceData] = {}
        for seq_group_metadata in seq_group_metadata_list:
            seq_data.update(seq_group_metadata.seq_data)

Woosuk Kwon's avatar
Woosuk Kwon committed
324
        input_metadata = InputMetadata(
325
            seq_groups=seq_groups,
326
            seq_data=seq_data,
Woosuk Kwon's avatar
Woosuk Kwon committed
327
328
329
330
331
            prompt_lens=prompt_lens,
            slot_mapping=slot_mapping_tensor,
            context_lens=context_lens_tensor,
            max_context_len=max_context_len,
            block_tables=block_tables_tensor,
332
333
            selected_token_indices=selected_token_indices,
            categorized_sample_indices=categorized_sample_indices,
Woosuk Kwon's avatar
Woosuk Kwon committed
334
            sliding_window=self.sliding_window,
Woosuk Kwon's avatar
Woosuk Kwon committed
335
336
337
338
        )
        return tokens_tensor, positions_tensor, input_metadata

    @torch.inference_mode()
339
    def execute_model(
Woosuk Kwon's avatar
Woosuk Kwon committed
340
        self,
341
        seq_group_metadata_list: List[SequenceGroupMetadata],
Woosuk Kwon's avatar
Woosuk Kwon committed
342
343
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
344
        blocks_to_copy: Dict[int, List[int]],
345
    ) -> SamplerOutput:
Woosuk Kwon's avatar
Woosuk Kwon committed
346
        # Issue cache operations.
347
        issued_cache_op = False
Woosuk Kwon's avatar
Woosuk Kwon committed
348
349
        if blocks_to_swap_in:
            self.cache_engine.swap_in(blocks_to_swap_in)
350
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
351
352
        if blocks_to_swap_out:
            self.cache_engine.swap_out(blocks_to_swap_out)
353
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
354
355
        if blocks_to_copy:
            self.cache_engine.copy(blocks_to_copy)
356
            issued_cache_op = True
Woosuk Kwon's avatar
Woosuk Kwon committed
357

358
        if issued_cache_op:
Woosuk Kwon's avatar
Woosuk Kwon committed
359
360
361
362
            cache_events = self.cache_events
        else:
            cache_events = None

Woosuk Kwon's avatar
Woosuk Kwon committed
363
        # If there is no input, we don't need to execute the model.
364
        if not seq_group_metadata_list:
Woosuk Kwon's avatar
Woosuk Kwon committed
365
366
367
368
369
            if cache_events is not None:
                for event in cache_events:
                    event.wait()
            return {}

Woosuk Kwon's avatar
Woosuk Kwon committed
370
        # Prepare input tensors.
371
        input_tokens, input_positions, input_metadata = self._prepare_inputs(
372
            seq_group_metadata_list)
Woosuk Kwon's avatar
Woosuk Kwon committed
373
374
375
376
377

        # Execute the model.
        output = self.model(
            input_ids=input_tokens,
            positions=input_positions,
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
378
            kv_caches=self.gpu_cache,
Woosuk Kwon's avatar
Woosuk Kwon committed
379
380
381
382
383
384
            input_metadata=input_metadata,
            cache_events=cache_events,
        )
        return output


385
386
387
def _init_distributed_environment(
    parallel_config: ParallelConfig,
    rank: int,
388
    distributed_init_method: Optional[str] = None,
389
390
) -> None:
    """Initialize the distributed environment."""
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
    if torch.distributed.is_initialized():
        torch_world_size = torch.distributed.get_world_size()
        if torch_world_size != parallel_config.world_size:
            raise RuntimeError(
                "torch.distributed is already initialized but the torch world "
                "size does not match parallel_config.world_size "
                f"({torch_world_size} vs. {parallel_config.world_size}).")
    elif not distributed_init_method:
        raise ValueError(
            "distributed_init_method must be set if torch.distributed "
            "is not already initialized")
    else:
        torch.distributed.init_process_group(
            backend="nccl",
            world_size=parallel_config.world_size,
            rank=rank,
            init_method=distributed_init_method,
        )

410
411
412
413
414
415
    # A small all_reduce for warmup.
    torch.distributed.all_reduce(torch.zeros(1).cuda())
    initialize_model_parallel(parallel_config.tensor_parallel_size,
                              parallel_config.pipeline_parallel_size)


416
417
def _pad_to_alignment(x: List[int], multiple_of: int, pad: int) -> List[int]:
    return x + [pad] * ((-len(x)) % multiple_of)
Woosuk Kwon's avatar
Woosuk Kwon committed
418
419


420
421
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
    return x + [pad] * (max_len - len(x))
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441


def _check_if_can_support_max_seq_len(max_seq_len: int,
                                      block_size: int) -> None:
    # Follows the logic in
    # attention_kernels.cu::single_query_cached_kv_attention_launcher
    max_shared_mem = get_max_shared_memory_bytes()
    float32_bytes = torch.finfo(torch.float).bits // 8
    padded_max_seq_len = (
        (max_seq_len + block_size - 1) / block_size) * block_size
    # padded_max_seq_len + extra buffer
    required_shared_mem = (padded_max_seq_len + 512) * float32_bytes
    if padded_max_seq_len * float32_bytes > max_shared_mem:
        raise RuntimeError(
            f"vLLM cannot currently support max_model_len={max_seq_len} "
            f"with block_size={block_size} on GPU with compute "
            f"capability {torch.cuda.get_device_capability()} "
            f"(required shared memory {required_shared_mem} > "
            f"available shared memory {max_shared_mem}). "
            "This will be fixed in a future release.")
442
443
444
445
446
447
448
449
450
451
452
453


def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
    # Check if the GPU supports the dtype.
    if torch_dtype == torch.bfloat16:
        compute_capability = torch.cuda.get_device_capability()
        if compute_capability[0] < 8:
            gpu_name = torch.cuda.get_device_name()
            raise ValueError(
                "Bfloat16 is only supported on GPUs with compute capability "
                f"of at least 8.0. Your {gpu_name} GPU has compute capability "
                f"{compute_capability[0]}.{compute_capability[1]}.")