worker.py 25.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A GPU worker class."""
4
import gc
5
import os
6
from typing import Dict, List, Optional, Set, Tuple, Type, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8

import torch
9
import torch.distributed
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
import vllm.envs as envs
12
from vllm.config import VllmConfig
13
from vllm.device_allocator.cumem import CuMemAllocator
14
from vllm.distributed import (ensure_model_parallel_initialized,
15
16
                              init_distributed_environment,
                              set_custom_all_reduce)
17
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
18
from vllm.logger import init_logger
19
from vllm.lora.request import LoRARequest
20
from vllm.model_executor import set_random_seed
21
from vllm.model_executor.layers.sampler import SamplerOutput
22
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
23
from vllm.platforms import current_platform
24
from vllm.prompt_adapter.request import PromptAdapterRequest
25
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
26
                           SequenceGroupMetadata, SequenceGroupMetadataDelta)
27
28
from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache,
                        memory_profiling)
Woosuk Kwon's avatar
Woosuk Kwon committed
29
from vllm.worker.cache_engine import CacheEngine
30
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
31
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
32
from vllm.worker.pooling_model_runner import PoolingModelRunner
33
34
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
                                     WorkerInput)
Woosuk Kwon's avatar
Woosuk Kwon committed
35

36
37
logger = init_logger(__name__)

38

39
class Worker(LocalOrDistributedWorkerBase):
40
41
42
43
44
45
    """A worker class that executes (a partition of) the model on a GPU.

    Each worker is associated with a single GPU. The worker is responsible for
    maintaining the KV cache and executing the model on the GPU. In case of
    distributed inference, each worker is assigned a partition of the model.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
48

    def __init__(
        self,
49
        vllm_config: VllmConfig,
50
51
52
53
        local_rank: int,
        rank: int,
        distributed_init_method: str,
        is_driver_worker: bool = False,
54
        model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
55
    ) -> None:
56
        WorkerBase.__init__(self, vllm_config)
57
        self.parallel_config.rank = rank
58
        self.local_rank = local_rank
59
60
        self.rank = rank
        self.distributed_init_method = distributed_init_method
61
        self.is_driver_worker = is_driver_worker
62
63
64
65
        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
            init_cached_hf_modules()
66

67
68
        # Return hidden states from target model if the draft model is an
        # mlp_speculator
69
70
        speculative_config = self.speculative_config
        model_config = self.model_config
71
        speculative_args = {} if speculative_config is None \
72
73
            or (speculative_config.draft_model_config.hf_config.model_type ==
                model_config.hf_config.model_type) \
74
            or (speculative_config.draft_model_config.hf_config.model_type
75
76
77
78
79
                not in ("medusa",
                        "mlp_speculator",
                        "eagle",
                        "deepseek_mtp",
                         "mimo_mtp")) \
80
                    else {"return_hidden_states": True}
81

82
        ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
83
        if model_config.runner_type == "pooling":
84
            ModelRunnerClass = PoolingModelRunner
85
        elif self.model_config.is_encoder_decoder:
86
            ModelRunnerClass = EncoderDecoderModelRunner
87
        self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
88
            vllm_config=self.vllm_config,
89
            kv_cache_dtype=self.cache_config.cache_dtype,
90
            is_driver_worker=is_driver_worker,
91
            **speculative_args,
92
        )
93
94
95
        if model_runner_cls is not None:
            self.model_runner = model_runner_cls(self.model_runner)

96
        # Uninitialized cache engine. Will be initialized by
97
        # initialize_cache.
98
        self.cache_engine: List[CacheEngine]
99
        # Initialize gpu_cache as pooling models don't initialize kv_caches
100
        self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
101
        self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}
102

103
104
105
        # Buffers saved before sleep
        self._sleep_saved_buffers: Dict[str, torch.Tensor] = {}

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        # Torch profiler. Enabled and configured through env vars:
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
        if envs.VLLM_TORCH_PROFILER_DIR:
            torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
            logger.info("Profiling enabled. Traces will be saved to: %s",
                        torch_profiler_trace_dir)
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                with_stack=True,
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
                    torch_profiler_trace_dir, use_gzip=True))
        else:
            self.profiler = None

    def start_profile(self):
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.start()

    def stop_profile(self):
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.stop()
132
133
        print(
            self.profiler.key_averages().table(sort_by="self_cuda_time_total"))
134

135
136
    def sleep(self, level: int = 1) -> None:
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
137
138
139
140
141
142
143
144
145

        # Save the buffers before level 2 sleep
        if level == 2:
            model = self.model_runner.model
            self._sleep_saved_buffers = {
                name: buffer.cpu().clone()
                for name, buffer in model.named_buffers()
            }

146
147
148
149
150
151
152
153
154
155
156
        allocator = CuMemAllocator.get_instance()
        allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
        free_bytes_after_sleep, total = torch.cuda.mem_get_info()
        freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        used_bytes = total - free_bytes_after_sleep
        assert freed_bytes >= 0, "Memory usage increased after sleeping."
        logger.info(
            "Sleep mode freed %.2f GiB memory, "
            "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
            used_bytes / GiB_bytes)

157
    def wake_up(self, tags: Optional[list[str]] = None) -> None:
158
        allocator = CuMemAllocator.get_instance()
159
        allocator.wake_up(tags=tags)
160

161
162
163
164
165
166
167
168
        # Restore the buffers after level 2 sleep
        if len(self._sleep_saved_buffers):
            model = self.model_runner.model
            for name, buffer in model.named_buffers():
                if name in self._sleep_saved_buffers:
                    buffer.data.copy_(self._sleep_saved_buffers[name].data)
            self._sleep_saved_buffers = {}

169
    def init_device(self) -> None:
170
171
172
173
174
175
176
177
        if self.device_config.device.type == "cuda":
            # torch.distributed.all_reduce does not free the input tensor until
            # the synchronization point. This causes the memory usage to grow
            # as the number of all_reduce calls increases. This env var disables
            # this behavior.
            # Related issue:
            # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
            os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
178

179
180
181
182
            # This env var set by Ray causes exceptions with graph building.
            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
            self.device = torch.device(f"cuda:{self.local_rank}")
            torch.cuda.set_device(self.device)
183

184
            _check_if_gpu_supports_dtype(self.model_config.dtype)
185
            gc.collect()
186
            torch.cuda.empty_cache()
187
188
            torch.cuda.reset_peak_memory_stats()
            self.baseline_snapshot = MemorySnapshot()
189
190
191
        else:
            raise RuntimeError(
                f"Not support device type: {self.device_config.device}")
192
        # Initialize the distributed environment.
193
        init_worker_distributed_environment(self.vllm_config, self.rank,
194
195
                                            self.distributed_init_method,
                                            self.local_rank)
196
        # Set random seed.
197
        set_random_seed(self.model_config.seed)
198
199

    def load_model(self):
200
201
202
203
204
205
206
207
208
209
210
        if self.vllm_config.model_config.enable_sleep_mode:
            allocator = CuMemAllocator.get_instance()
            assert allocator.get_current_usage() == 0, (
                "Sleep mode can only be "
                "used for one instance per process.")
            context = allocator.use_memory_pool(tag="weights")
        else:
            from contextlib import nullcontext
            context = nullcontext()
        with context:
            self.model_runner.load_model()
211

212
213
214
215
216
217
218
219
220
221
222
223
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        self.model_runner.save_sharded_state(
            path,
            pattern=pattern,
            max_size=max_size,
        )

224
225
226
227
228
229
230
    def save_tensorized_model(
        self,
        tensorizer_config: TensorizerConfig,
    ) -> None:
        self.model_runner.save_tensorized_model(
            tensorizer_config=tensorizer_config, )

231
    @torch.inference_mode()
232
233
234
235
236
237
238
239
    def determine_num_available_blocks(self) -> Tuple[int, int]:
        """Profiles the peak memory usage of the model to determine how many
        KV blocks may be allocated without OOMs.

        The engine will first conduct a profiling of the existing memory usage.
        Then, it calculate the maximum possible number of GPU and CPU blocks
        that can be allocated with the remaining free memory.

240
        Tip:
241
242
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
243
        """
244
245
246
        # Profile the memory usage of the model and get the maximum number of
        # cache blocks that can be allocated with the remaining free memory.
        torch.cuda.empty_cache()
247
248
249
        torch.cuda.reset_peak_memory_stats()

        free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
250

251
252
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
253
254
255
        with memory_profiling(
                self.baseline_snapshot,
                weights_memory=self.model_runner.model_memory_usage) as result:
256
            self.model_runner.profile_run()
257
258
259

        self._assert_memory_footprint_increased_during_profiling()

260
261
262
        memory_for_current_instance = total_gpu_memory * \
            self.cache_config.gpu_memory_utilization
        available_kv_cache_memory = (memory_for_current_instance -
263
                                     result.non_kv_cache_memory)
264
265
266

        # Calculate the number of blocks that can be allocated with the
        # profiled peak memory.
267
        cache_block_size = self.get_cache_block_size_bytes()
268
269
270
271
        if cache_block_size == 0:
            num_gpu_blocks = 0
            num_cpu_blocks = 0
        else:
272
            num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
273
274
            num_cpu_blocks = int(self.cache_config.swap_space_bytes //
                                 cache_block_size)
275
276
        num_gpu_blocks = max(num_gpu_blocks, 0)
        num_cpu_blocks = max(num_cpu_blocks, 0)
277

278
279
280
281
282
283
284
285
        msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n"
               "the current vLLM instance can use "
               "total_gpu_memory "
               f"({(total_gpu_memory / GiB_bytes):.2f}GiB)"
               " x gpu_memory_utilization "
               f"({self.cache_config.gpu_memory_utilization:.2f})"
               f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n"
               "model weights take "
286
               f"{(result.weights_memory / GiB_bytes):.2f}GiB;"
287
               " non_torch_memory takes "
288
               f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;"
289
               " PyTorch activation peak memory takes "
290
               f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;"
291
292
293
294
               " the rest of the memory reserved for KV Cache is "
               f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.")

        logger.info(msg)
295
        # Final cleanup
296
        gc.collect()
297

298
299
        return num_gpu_blocks, num_cpu_blocks

300
301
302
    def _assert_memory_footprint_increased_during_profiling(self):
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
303
304
305
        free_gpu_memory, total = torch.cuda.mem_get_info()
        cuda_memory = total - free_gpu_memory
        assert self.baseline_snapshot.cuda_memory < cuda_memory, (
306
            "Error in memory profiling. "
307
308
309
            f"Initial used memory {self.baseline_snapshot.cuda_memory}, "
            f"currently used memory {cuda_memory}. "
            f"This happens when the GPU memory was "
310
311
            "not properly cleaned up before initializing the vLLM instance.")

312
313
314
315
316
317
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Allocate GPU and CPU KV cache with the specified number of blocks.

        This also warms up the model, which may record CUDA graphs.
        """
318
319
320
321
322
        raise_if_cache_size_invalid(
            num_gpu_blocks, self.cache_config.block_size,
            self.cache_config.is_attention_free,
            self.model_config.max_model_len,
            self.parallel_config.pipeline_parallel_size)
323
324
325
326

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

327
328
329
330
331
332
333
334
        if self.vllm_config.model_config.enable_sleep_mode:
            allocator = CuMemAllocator.get_instance()
            context = allocator.use_memory_pool(tag="kv_cache")
        else:
            from contextlib import nullcontext
            context = nullcontext()
        with context:
            self._init_cache_engine()
335
336
337
338
        self._warm_up_model()

    def _init_cache_engine(self):
        assert self.cache_config.num_gpu_blocks is not None
339
340
341
342
343
344
345
346
347
        self.cache_engine = [
            CacheEngine(self.cache_config, self.model_config,
                        self.parallel_config, self.device_config)
            for _ in range(self.parallel_config.pipeline_parallel_size)
        ]
        self.gpu_cache = [
            self.cache_engine[ve].gpu_cache
            for ve in range(self.parallel_config.pipeline_parallel_size)
        ]
348
349
        bind_kv_cache(self.compilation_config.static_forward_context,
                      self.gpu_cache)
Woosuk Kwon's avatar
Woosuk Kwon committed
350

351
    def _warm_up_model(self) -> None:
352
353
354
355
356
357
358
359
360
361
362
363
        # warm up sizes that are not in cudagraph capture sizes,
        # but users still want to compile for better performance,
        # e.g. for the max-num-batched token size in chunked prefill.
        warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
        if not self.model_config.enforce_eager:
            warmup_sizes = [
                x for x in warmup_sizes if x not in
                self.vllm_config.compilation_config.cudagraph_capture_sizes
            ]
        for size in sorted(warmup_sizes, reverse=True):
            logger.info("Compile and warming up model for size %d", size)
            self.model_runner._dummy_run(size)
364
365
366
367
368
369
        if not self.model_config.enforce_eager:
            self.model_runner.capture_model(self.gpu_cache)
        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

370
371
372
373
374
    @property
    def do_metadata_broadcast(self) -> bool:
        return self.parallel_config.tensor_parallel_size > 1

    @property
375
    def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
376
        return self.gpu_cache
377
378
379
380
    
    @property
    def cache_engines(self) -> Optional[List[CacheEngine]]:
        return self.cache_engine
381
382

    @torch.inference_mode()
383
384
    def prepare_worker_input(
            self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
385
        virtual_engine = execute_model_req.virtual_engine
386
        num_steps = execute_model_req.num_steps
387
        num_seq_groups = len(execute_model_req.seq_group_metadata_list)
388
389
390
391
392
393
394
        # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
        # they contain parameters to launch cudamemcpyasync.
        blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
                                         device="cpu",
                                         dtype=torch.int64).view(-1, 2)
        blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
                                          device="cpu",
395
                                          dtype=torch.int64).view(-1, 2)
396
397
398
399
400
401
        # `blocks_to_copy` is a gpu tensor. The src and tgt of
        # blocks to copy are in the same device, and `blocks_to_copy`
        # can be used directly within cuda kernels.
        blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
                                      device=self.device,
                                      dtype=torch.int64).view(-1, 2)
Woosuk Kwon's avatar
Woosuk Kwon committed
402

403
404
405
406
407
        return WorkerInput(
            num_seq_groups=num_seq_groups,
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy,
408
            virtual_engine=virtual_engine,
409
            num_steps=num_steps,
410
            kvcache_slot_to_be_moved=execute_model_req.kvcache_slot_to_be_moved
411
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
412

413
    @torch.inference_mode()
414
    def execute_worker(self, worker_input: WorkerInput) -> None:
415
        virtual_engine = worker_input.virtual_engine
416
417
418
        # Issue cache operations.
        if (worker_input.blocks_to_swap_in is not None
                and worker_input.blocks_to_swap_in.numel() > 0):
419
420
            self.cache_engine[virtual_engine].swap_in(
                worker_input.blocks_to_swap_in)
421
422
        if (worker_input.blocks_to_swap_out is not None
                and worker_input.blocks_to_swap_out.numel() > 0):
423
424
            self.cache_engine[virtual_engine].swap_out(
                worker_input.blocks_to_swap_out)
425
426
        if (worker_input.blocks_to_copy is not None
                and worker_input.blocks_to_copy.numel() > 0):
427
            self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
428

429
430
431
432
433
        # tree-style generation need to move kvcache to correct position
        if worker_input.kvcache_slot_to_be_moved is not None:
            self.cache_engine[virtual_engine].move_caches(self.kv_cache[virtual_engine], 
                                                          worker_input.kvcache_slot_to_be_moved)

434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
    def _get_cached_seq_group_metadata(
            self,
            seq_group_metadata_list: List[Union[SequenceGroupMetadata,
                                                SequenceGroupMetadataDelta]],
            finished_request_ids: List[str]) -> List[SequenceGroupMetadata]:
        """Return a list of cached Sequence Group Metadata after updating its
        state.

        It is used because scheduler only sends delta to workers to reduce
        the data payload size. The function also cleans up cache based on
        a given `finished_request_ids`.
        """
        new_seq_group_metadata_list = []
        for metadata_or_delta in seq_group_metadata_list:
            request_id = metadata_or_delta.request_id
            if request_id not in self._seq_group_metadata_cache:
                # The first prefill.
                assert isinstance(metadata_or_delta, SequenceGroupMetadata)
                self._seq_group_metadata_cache[request_id] = metadata_or_delta
            else:
                # The first prefill is already cached.
                if isinstance(metadata_or_delta, SequenceGroupMetadataDelta):
                    self._seq_group_metadata_cache[request_id].apply_delta(
                        metadata_or_delta)
                else:
                    # If metadata snapshot is sent again, it is
                    # preempted. Reset the cache because we need to start
                    # from scratch.
                    assert isinstance(metadata_or_delta, SequenceGroupMetadata)
                    self._seq_group_metadata_cache[
                        request_id] = metadata_or_delta

            new_seq_group_metadata_list.append(
                self._seq_group_metadata_cache[request_id])

        # Clean up finished ids
        for finished_id in finished_request_ids:
            del self._seq_group_metadata_cache[finished_id]

        return new_seq_group_metadata_list

    def _execute_model_spmd(
        self,
        execute_model_req: ExecuteModelRequest,
        intermediate_tensors: Optional[IntermediateTensors] = None,
    ) -> Optional[List[SamplerOutput]]:
        if execute_model_req is not None:
            new_seq_group_metadata_list = self._get_cached_seq_group_metadata(
                execute_model_req.seq_group_metadata_list,
                execute_model_req.finished_requests_ids)

            execute_model_req.seq_group_metadata_list = (
                new_seq_group_metadata_list)
        output = super()._execute_model_spmd(execute_model_req,
                                             intermediate_tensors)
        return output

491
492
493
494
495
496
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_runner.remove_lora(lora_id)

497
498
499
    def pin_lora(self, lora_id: int) -> bool:
        return self.model_runner.pin_lora(lora_id)

500
501
502
    def list_loras(self) -> Set[int]:
        return self.model_runner.list_loras()

503
504
505
506
507
508
509
510
511
512
513
514
515
    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
        return self.model_runner.add_prompt_adapter(prompt_adapter_request)

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        return self.model_runner.remove_lora(prompt_adapter_id)

    def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        return self.model_runner.pin_prompt_adapter(prompt_adapter_id)

    def list_prompt_adapters(self) -> Set[int]:
        return self.model_runner.list_prompt_adapters()

516
517
518
519
520
521
522
523
    @property
    def max_model_len(self) -> int:
        return self.model_config.max_model_len

    @property
    def vocab_size(self) -> int:
        return self.model_runner.vocab_size

524
    def get_cache_block_size_bytes(self) -> int:
525
526
        """Get the size of the KV cache block size in bytes.
        """
527
        return CacheEngine.get_cache_block_size(self.cache_config,
528
529
530
                                                self.model_config,
                                                self.parallel_config)

Woosuk Kwon's avatar
Woosuk Kwon committed
531

532
def init_worker_distributed_environment(
533
    vllm_config: VllmConfig,
534
    rank: int,
535
    distributed_init_method: Optional[str] = None,
536
    local_rank: int = -1,
537
538
) -> None:
    """Initialize the distributed environment."""
539
    parallel_config = vllm_config.parallel_config
540
541
    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

542
543
    init_distributed_environment(parallel_config.world_size, rank,
                                 distributed_init_method, local_rank)
544
545
546
    ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
                                      parallel_config.pipeline_parallel_size)

547
548
    ensure_kv_transfer_initialized(vllm_config)

549

550
551
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
    # Check if the GPU supports the dtype.
552
553
554
    if torch_dtype == torch.bfloat16:  # noqa: SIM102
        if not current_platform.has_device_capability(80):
            capability = current_platform.get_device_capability()
555
            gpu_name = current_platform.get_device_name()
556
557
558
559
560
561
562

            if capability is None:
                compute_str = "does not have a compute capability"
            else:
                version_str = capability.as_version_str()
                compute_str = f"has compute capability {version_str}"

563
564
            raise ValueError(
                "Bfloat16 is only supported on GPUs with compute capability "
565
                f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
566
                "You can use float16 instead by explicitly setting the "
Woosuk Kwon's avatar
Woosuk Kwon committed
567
                "`dtype` flag in CLI, for example: --dtype=half.")
568
569


570
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free,
571
                                max_model_len, pipeline_parallel_size) -> None:
572
573
    if is_attention_free and num_gpu_blocks != 0:
        raise ValueError("No memory should be allocated for the cache blocks "
574
                         f"for an attention-free model, but {num_gpu_blocks} "
575
576
                         "blocks are allocated.")
    if not is_attention_free and num_gpu_blocks <= 0:
577
578
579
        raise ValueError("No available memory for the cache blocks. "
                         "Try increasing `gpu_memory_utilization` when "
                         "initializing the engine.")
580
    max_seq_len = block_size * (num_gpu_blocks // pipeline_parallel_size)
581
    if not is_attention_free and max_model_len > max_seq_len:
582
583
584
585
586
587
        raise ValueError(
            f"The model's max seq len ({max_model_len}) "
            "is larger than the maximum number of tokens that can be "
            f"stored in KV cache ({max_seq_len}). Try increasing "
            "`gpu_memory_utilization` or decreasing `max_model_len` when "
            "initializing the engine.")