openvino_worker.py 24.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
"""An OpenVINO worker class."""
from typing import Any, Dict, List, Optional, Tuple

import openvino as ov
import torch
import torch.distributed
8
import torch.nn as nn
9

10
import vllm.envs as envs
11
from vllm.attention import get_attn_backend
12
13
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
                         ParallelConfig, VllmConfig)
14
15
16
from vllm.distributed import (broadcast_tensor_dict,
                              ensure_model_parallel_initialized,
                              init_distributed_environment)
17
from vllm.inputs import INPUT_REGISTRY
18
19
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
20
from vllm.model_executor.layers.sampler import SamplerOutput
21
from vllm.multimodal import MULTIMODAL_REGISTRY
22
from vllm.platforms import current_platform
23
24
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
25
from vllm.utils import bind_kv_cache
26
from vllm.worker.openvino_model_runner import OpenVINOModelRunner
27
from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

logger = init_logger(__name__)


class OpenVINOCacheEngine:
    """Manages the KV cache for OpenVINO backend.

    This class is responsible for initializing and managing CPU KV
    caches. It also provides methods for performing KV cache operations, such
    as copying.
    """

    def __init__(
        self,
        cache_config: CacheConfig,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        device_config: DeviceConfig,
46
47
        ov_core: ov.Core,
        ov_device: str,
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    ) -> None:
        assert device_config.device_type == "openvino"
        self.cache_config = cache_config
        self.model_config = model_config
        self.parallel_config = parallel_config

        self.head_size = model_config.get_head_size()
        if device_config.device.type == "cpu" and \
            cache_config.cache_dtype == ov.Type.u8:
            # Scale, zero point and quantized data will be stored together.
            # The layout for per token per head:
            # |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501
            # so, we have to extend head_size by 8, which is sizeof(float)
            # for scale and sizeof(float) for zeropoint
            self.head_size += 8
        self.num_layers = model_config.get_num_layers(parallel_config)
        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)

        self.block_size = cache_config.block_size
        # Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
68
69
70
71
        # for OpenVINO backend with a CPU target device, because we want
        # to reuse KV cache management in the scheduler.
        self.num_device_blocks = cache_config.num_gpu_blocks
        self.num_swap_blocks = cache_config.num_cpu_blocks
72
73
74
75
76
77
78

        # Get attention backend.
        self.attn_backend = get_attn_backend(
            self.head_size,
            self.model_config.dtype,
            self.cache_config.cache_dtype,
            self.block_size,
79
            self.model_config.is_attention_free,
80
81
82
83
84
        )

        # Initialize the cache.
        self.kv_cache: List[Tuple[ov.Tensor,
                                  ov.Tensor]] = self._allocate_kv_cache(
85
86
87
88
89
90
91
                                      self.num_device_blocks, ov_core,
                                      ov_device)

        # Initialize the swap.
        self.swap_cache: List[Tuple[ov.Tensor,
                                    ov.Tensor]] = self._allocate_swap_cache(
                                        self.num_swap_blocks, ov_device)
92
93
94
95

    def _allocate_kv_cache(
        self,
        num_blocks: int,
96
97
        ov_core: ov.Core,
        ov_device: str,
98
99
100
101
102
    ) -> List[Tuple[ov.Tensor, ov.Tensor]]:
        """Allocates KV cache."""
        k_block_shape = v_block_shape = self.attn_backend.get_kv_cache_shape(
            num_blocks, self.block_size, self.num_kv_heads, self.head_size)[1:]
        kv_cache: List[Tuple[ov.Tensor, ov.Tensor]] = []
103

104
        if current_platform.is_openvino_cpu():
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
            for _ in range(self.num_layers):
                key_blocks = ov.Tensor(self.cache_config.cache_dtype,
                                       k_block_shape)
                value_blocks = ov.Tensor(self.cache_config.cache_dtype,
                                         v_block_shape)
                kv_cache.append((key_blocks, value_blocks))
        else:
            # Update key_cache shape:
            k_block_shape = (v_block_shape[0], v_block_shape[1],
                             v_block_shape[3], v_block_shape[2])

            remote_context = ov_core.get_default_context(ov_device)

            for _ in range(self.num_layers):
                key_blocks = \
                    remote_context.create_tensor(self.cache_config.cache_dtype,
                                                 ov.Shape(k_block_shape),
                                                 {})

                value_blocks = \
                    remote_context.create_tensor(self.cache_config.cache_dtype,
                                                 ov.Shape(v_block_shape),
                                                 {})

                kv_cache.append((key_blocks, value_blocks))

        return kv_cache

    def _allocate_swap_cache(
        self,
        num_blocks: int,
        ov_device: str,
    ) -> List[Tuple[ov.Tensor, ov.Tensor]]:
        """Allocates swap cache."""
        k_block_shape = v_block_shape = self.attn_backend.get_kv_cache_shape(
            num_blocks, self.block_size, self.num_kv_heads, self.head_size)[1:]
        swap_cache: List[Tuple[ov.Tensor, ov.Tensor]] = []

        if num_blocks == 0:
            return swap_cache

146
        assert not current_platform.is_openvino_cpu(), \
147
148
149
150
151
152
            "CPU device isn't supposed to have swap cache"

        # Update key_cache shape:
        k_block_shape = (v_block_shape[0], v_block_shape[1], v_block_shape[3],
                         v_block_shape[2])

153
154
155
156
157
        for _ in range(self.num_layers):
            key_blocks = ov.Tensor(self.cache_config.cache_dtype,
                                   k_block_shape)
            value_blocks = ov.Tensor(self.cache_config.cache_dtype,
                                     v_block_shape)
158
159
160
            swap_cache.append((key_blocks, value_blocks))

        return swap_cache
161

162
163
164
165
166
167
    def swap_in(self, src_to_dst: List[Tuple[int, int]]) -> None:
        for i in range(self.num_layers):
            for swap_tensor, kv_tensor in zip(self.swap_cache[i],
                                              self.kv_cache[i]):
                self.attn_backend.swap_blocks(swap_tensor, kv_tensor,
                                              src_to_dst)
168

169
170
171
172
173
174
    def swap_out(self, src_to_dst: List[Tuple[int, int]]) -> None:
        for i in range(self.num_layers):
            for swap_tensor, kv_tensor in zip(self.swap_cache[i],
                                              self.kv_cache[i]):
                self.attn_backend.swap_blocks(kv_tensor, swap_tensor,
                                              src_to_dst)
175

176
177
178
    def copy(self, src_to_dsts: List[Tuple[int, int]]) -> None:
        if (len(src_to_dsts) > 0):
            self.attn_backend.copy_blocks(self.kv_cache, src_to_dsts)
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205

    @staticmethod
    def get_cache_block_size(
        block_size: int,
        cache_dtype: ov.Type,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
    ) -> int:
        head_size = model_config.get_head_size()
        num_kv_heads = model_config.get_num_kv_heads(parallel_config)
        num_layers = model_config.get_num_layers(parallel_config)

        if cache_dtype == ov.Type.u8:
            # Scale, zero point and quantized data will be stored together.
            # The layout for per token per head:
            # |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501
            # so, we have to extend head_size by 8, which is sizeof(float)
            # for scale and sizeof(float) for zeropoint
            head_size += 8

        key_cache_block = block_size * num_kv_heads * head_size
        value_cache_block = key_cache_block
        total = num_layers * (key_cache_block + value_cache_block)
        dtype_size = cache_dtype.size
        return dtype_size * total


206
class OpenVINOWorker(LoRANotSupportedWorkerBase):
207
208
209
210
211
212
213
214
215
    """A worker class that executes the model on OpenVINO backend.

    Each worker is associated with a single OpenVINO device. The worker is
    responsible for maintaining the KV cache and executing the model on the
    OpenVINO backend.
    """

    def __init__(
        self,
216
        vllm_config: VllmConfig,
217
218
219
220
221
        local_rank: int,
        rank: int,
        distributed_init_method: str,
        is_driver_worker: bool = False,
    ) -> None:
222
        WorkerBase.__init__(self, vllm_config)
223
        self.ov_core = ov.Core()
224
        self.parallel_config.rank = rank
225
226
227
228
229
230
231
232
233
234
235
236
237
        self.local_rank = local_rank
        self.rank = rank
        self.distributed_init_method = distributed_init_method
        self.is_driver_worker = is_driver_worker
        if self.is_driver_worker:
            assert self.rank == 0, "The driver worker must have rank 0."

        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules

            init_cached_hf_modules()
        self.model_runner = OpenVINOModelRunner(
238
            self.ov_core,
239
            vllm_config=self.vllm_config,
240
            kv_cache_dtype=self.vllm_config.cache_config.cache_dtype,
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
            is_driver_worker=is_driver_worker,
        )
        # Uninitialized cache engine. Will be initialized by
        # initialize_cache.
        self.cache_engine: OpenVINOCacheEngine
        self.kv_cache: List[Tuple[ov.Tensor, ov.Tensor]]

    def init_device(self) -> None:
        self.init_distributed_environment()
        # Set random seed.
        set_random_seed(self.model_config.seed)

    def load_model(self):
        self.model_runner.load_model()

    def determine_num_available_blocks(self) -> Tuple[int, int]:
        """Determine the number of blocks available for the KV cache.

        This determines how many KV blocks can fit into the configured
        KV cache space.
        """
262
263
        # For OpenVINO backend, in case of CPU device, the block number will be
        # calculated based on the openvino_kvcache_space_bytes.
264
        cache_block_size = self.get_cache_block_size_bytes()
265
        kvcache_space_bytes = self.cache_config.openvino_kvcache_space_bytes
266

267
        if current_platform.is_openvino_cpu():
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
            num_device_blocks = int(kvcache_space_bytes // cache_block_size)
            num_swap_blocks = 0
        else:
            if kvcache_space_bytes > 0:
                logger.info("KV_CACHE size was explicitly configured via "
                            "VLLM_OPENVINO_KVCACHE_SPACE environment "
                            "variable, ignoring profiling run.")
                kv_cache_size = kvcache_space_bytes
            else:
                try:
                    kv_cache_size = self.profile_run()
                except Exception as err:
                    raise RuntimeError(
                        "The error occurred during profile run. This might be "
                        "due to insufficient GPU memory. Consider decreasing "
                        "`max_model_len` to limit the maximum simultaneously "
                        "processed tokens.") from err

            num_device_blocks = int(kv_cache_size // cache_block_size)
            num_swap_blocks = int(self.cache_config.swap_space_bytes //
                                  cache_block_size)

        return num_device_blocks, num_swap_blocks
291
292
293

    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
294
295
        """Initialize the KV cache. Swappable CPU memory is only
        supported on GPU.
296

297
        For CPU, we use the num_gpu_blocks to
298
299
300
        determine how many non-swappable CPU blocks to allocate.
        """

301
302
303
        num_device_blocks = num_gpu_blocks
        num_swap_blocks = num_cpu_blocks

304
        if current_platform.is_openvino_cpu():
305
306
            assert (num_swap_blocks == 0
                    ), f"{type(self)} does not support swappable cache for CPU"
307

308
309
310
        self._validate_num_blocks(num_device_blocks)
        self.cache_config.num_gpu_blocks = num_device_blocks
        self.cache_config.num_cpu_blocks = num_swap_blocks
311
312
313
314

        # Initialize the cache.
        self._init_cache_engine()

315
316
317
    def _validate_num_blocks(self, num_blocks: int) -> None:
        """Raise errors if the num_blocks is invalid."""
        if num_blocks <= 0:
318
319
320
321
322
            raise ValueError(
                "No available memory for the cache blocks. "
                "Try increasing `VLLM_OPENVINO_KVCACHE_SPACE` when "
                "initializing the engine.")

323
        max_seq_len = self.cache_config.block_size * num_blocks
324
325
326
327
328
329
330
331
332
        if self.model_config.max_model_len > max_seq_len:
            raise ValueError(
                f"The model's max seq len ({self.model_config.max_model_len}) "
                "is larger than the maximum number of tokens that can be "
                f"stored in KV cache ({max_seq_len}). Try increasing "
                "`VLLM_OPENVINO_KVCACHE_SPACE` or decreasing `max_model_len` "
                "when initializing the engine.")

    def _init_cache_engine(self) -> None:
333
        ov_device = envs.VLLM_OPENVINO_DEVICE
334
335
336
337
338
        self.cache_engine = OpenVINOCacheEngine(
            self.cache_config,
            self.model_config,
            self.parallel_config,
            self.device_config,
339
340
            self.ov_core,
            ov_device,
341
342
        )
        self.kv_cache = self.cache_engine.kv_cache
343
344
        bind_kv_cache(self.compilation_config.static_forward_context,
                      [self.kv_cache])
345
346
347
348
349
        self.model_runner.block_size = self.cache_engine.block_size

        assert self.kv_cache is not None

        # Populate the cache to warmup the memory
350
        if current_platform.is_openvino_cpu():
351
352
353
354
355
356
357
358
359
            for key_cache, value_cache in self.kv_cache:
                key_cache.data[:] = 0
                value_cache.data[:] = 0

    def cache_swap_in(self, src_to_dst: List[Tuple[int, int]]) -> None:
        self.cache_engine.swap_in(src_to_dst)

    def cache_swap_out(self, src_to_dst: List[Tuple[int, int]]) -> None:
        self.cache_engine.swap_out(src_to_dst)
360
361
362
363
364
365
366

    def cache_copy(
        self,
        blocks_to_copy: List[Tuple[int, int]],
    ) -> None:
        self.cache_engine.copy(blocks_to_copy)  # type: ignore

367
368
369
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
    @torch.inference_mode()
    def execute_model(
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None,
    ) -> List[SamplerOutput]:
        if execute_model_req is None:
            seq_group_metadata_list = None
        else:
            seq_group_metadata_list = execute_model_req.seq_group_metadata_list

        if self.is_driver_worker:
            assert seq_group_metadata_list is not None
            num_seq_groups: int = len(seq_group_metadata_list)
            assert execute_model_req is not None
            blocks_to_copy = execute_model_req.blocks_to_copy
385
386
            blocks_to_swap_in = execute_model_req.blocks_to_swap_in
            blocks_to_swap_out = execute_model_req.blocks_to_swap_out
387
388
389
            data: Dict[str, Any] = {
                "num_seq_groups": num_seq_groups,
                "blocks_to_copy": execute_model_req.blocks_to_copy,
390
391
                "blocks_to_swap_in": execute_model_req.blocks_to_swap_in,
                "blocks_to_swap_out": execute_model_req.blocks_to_swap_out,
392
393
394
395
396
397
            }
            broadcast_tensor_dict(data, src=0)
        else:
            data = broadcast_tensor_dict(src=0)
            num_seq_groups = data["num_seq_groups"]
            blocks_to_copy = data["blocks_to_copy"]
398
399
400
            blocks_to_swap_in = data["blocks_to_swap_in"]
            blocks_to_swap_out = data["blocks_to_swap_out"]

401
        if current_platform.is_openvino_cpu():
402
403
404
405
406
            assert len(execute_model_req.blocks_to_swap_in) == 0
            assert len(execute_model_req.blocks_to_swap_out) == 0
        else:
            self.cache_swap_in(blocks_to_swap_in)
            self.cache_swap_out(blocks_to_swap_out)
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448

        self.cache_copy(blocks_to_copy)

        # If there is no input, we don't need to execute the model.
        if num_seq_groups == 0:
            return []

        output = self.model_runner.execute_model(seq_group_metadata_list,
                                                 self.kv_cache)

        # OpenVINO worker only supports single-step execution.
        return [output]

    def init_distributed_environment(self) -> None:
        """Initialize the distributed environment."""

        parallel_config = self.parallel_config
        rank = self.rank
        distributed_init_method = self.distributed_init_method
        init_distributed_environment(
            world_size=parallel_config.world_size,
            rank=rank,
            distributed_init_method=distributed_init_method,
            backend="gloo",
        )

        # A small all_reduce for warmup.
        torch.distributed.all_reduce(torch.zeros(1).cpu())

        ensure_model_parallel_initialized(
            parallel_config.tensor_parallel_size,
            parallel_config.pipeline_parallel_size,
        )

    def get_cache_block_size_bytes(self) -> int:
        """Return the size in bytes of a single KV cache block."""
        return OpenVINOCacheEngine.get_cache_block_size(
            self.cache_config.block_size,
            self.cache_config.cache_dtype,
            self.model_config,
            self.parallel_config,
        )
449
450
451
452

    def profile_run(self) -> int:
        ov_device = envs.VLLM_OPENVINO_DEVICE

453
        assert not current_platform.is_openvino_cpu(), \
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
            "CPU device isn't supposed to use profile run."

        import openvino.properties.device as device
        import openvino.properties.intel_gpu as intel_gpu

        ov_core = self.ov_core
        cache_config = self.cache_config
        model_config = self.model_config
        parallel_config = self.parallel_config
        device_config = self.device_config
        input_registry = INPUT_REGISTRY
        mm_registry = MULTIMODAL_REGISTRY
        mm_registry.init_mm_limits_per_prompt(model_config)

        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
        def model_profile_run():
            top_k = model_config.get_vocab_size() - 1
            sampling_params = SamplingParams(top_p=0.99, top_k=top_k)

            max_num_batched_tokens = \
                self.scheduler_config.max_num_batched_tokens
            max_num_seqs = self.scheduler_config.max_num_seqs
            tmp_cache_config = CacheConfig(cache_config.block_size,
                                           cache_config.gpu_memory_utilization,
                                           cache_config.swap_space_bytes,
                                           "auto")
            tmp_cache_config.num_gpu_blocks = 1
            tmp_cache_config.num_cpu_blocks = 0
            tmp_cache_config.cache_dtype = cache_config.cache_dtype

            profiling_cache_engine = OpenVINOCacheEngine(
                tmp_cache_config, model_config, parallel_config, device_config,
                ov_core, ov_device)

            # Profile memory usage with max_num_sequences sequences and the
            # total # number of tokens equal to max_num_batched_tokens.
            seqs: List[SequenceGroupMetadata] = []
            for group_id in range(max_num_seqs):
                seq_len = (max_num_batched_tokens // max_num_seqs +
                           (group_id < max_num_batched_tokens % max_num_seqs))
                block_size = cache_config.block_size
                seq_num_blocks = (seq_len + block_size - 1) // block_size

Jani Monoses's avatar
Jani Monoses committed
498
                dummy_data = input_registry \
499
500
501
502
503
504
505
506
                    .dummy_data_for_profiling(model_config,
                                              seq_len,
                                              mm_registry)

                block_tables = [[0] * seq_num_blocks] * max_num_seqs
                seq = SequenceGroupMetadata(
                    request_id=str(group_id),
                    is_prompt=True,
Jani Monoses's avatar
Jani Monoses committed
507
                    seq_data={group_id: dummy_data.seq_data},
508
509
510
                    sampling_params=sampling_params,
                    block_tables=block_tables,
                    lora_request=None,
Jani Monoses's avatar
Jani Monoses committed
511
                    multi_modal_data=dummy_data.multi_modal_data)
512
513
514
515
                seqs.append(seq)

            self.model_runner.block_size = tmp_cache_config.block_size

516
517
            bind_kv_cache(self.compilation_config.static_forward_context,
                          profiling_cache_engine.kv_cache)
518
519
520
521
            # Run the model with the dummy inputs.
            self.model_runner.execute_model(seqs,
                                            profiling_cache_engine.kv_cache)

522
523
524
525
526
527
            # Explicitly revert bind_kv_cache and delete temporary KV cache
            # manager to free KV cache when real inputs will be passed to OV
            bind_kv_cache(self.compilation_config.static_forward_context, [[
                torch.tensor([])
                for _ in range(len(profiling_cache_engine.kv_cache))
            ]])
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
            del profiling_cache_engine

            logger.info(
                "Start profiling run with dummy inputs to evaluate "
                "memory usage for %s. It might take a while.", ov_device)

        model_profile_run()

        gpu_device_type = ov_core.get_property(ov_device, device.type)
        memory_statistics = \
            ov_core.get_property(ov_device, intel_gpu.memory_statistics)
        memory_utilization = cache_config.gpu_memory_utilization

        if gpu_device_type == device.Type.INTEGRATED and \
            memory_utilization >= 0.9:
            logger.warning(
                "iGPU is used with high gpu_memory_utilization=%f "
                "value. This may cause low performance due to "
                "occupying the majority of available system "
                "memory. Please consider decreasing "
548
                "gpu_memory_utilization or explicitly setting "
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
                "`VLLM_OPENVINO_KVCACHE_SPACE` (GB) environment "
                "variable.", memory_utilization)

        # sum up all used device memory
        device_memory_types = ["cl_mem", "usm_device"]
        used_device_mem = \
            sum(memory_statistics.get(key, 0) for key in device_memory_types)

        if gpu_device_type == device.Type.INTEGRATED:
            used_device_mem += memory_statistics.get("usm_host", 0)

        # there could be unaccounted extra memory reserved by kernels, kept
        # in memory pools, etc
        # therefore, add a threshold to account for this
        used_memory_threshold = 1.1
        used_device_mem *= used_memory_threshold

        total_device_memory = \
            ov_core.get_property(ov_device, intel_gpu.device_total_mem_size)

        def format_memory_size(size) -> str:
            units = ["B", "KB", "MB", "GB"]
            unit_index = 0

            while size > 1024 and unit_index < len(units) - 1:
                size /= 1024
                unit_index += 1

            return f"{size:.2f} {units[unit_index]}"

        total_device_memory_str = \
            format(format_memory_size(total_device_memory))
        used_device_memory_str = \
            format(format_memory_size(used_device_mem))

        logger.info(
            "Total %s memory: %s. "
            "Amount of memory required to run the model with "
            "max_num_batched_tokens=%d: %s.", ov_device,
            total_device_memory_str,
            self.scheduler_config.max_num_batched_tokens,
            used_device_memory_str)

        if used_device_mem >= total_device_memory:
            raise RuntimeError(
                f"The required memory size {used_device_memory_str} for model "
                "is higher than the total available device "
                "memory {total_device_memory_str}. Please consider to "
                "decrease `max_num_batched_tokens` or increase "
                "`gpu_memory_utilization`")

        return total_device_memory * memory_utilization - used_device_mem