Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c907d221
Unverified
Commit
c907d221
authored
Jan 08, 2026
by
Ning Xie
Committed by
GitHub
Jan 07, 2026
Browse files
[refactor] refactor memory constants usage (#31865)
Signed-off-by:
Andy Xie
<
andy.xning@gmail.com
>
parent
f347ac6c
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
33 additions
and
30 deletions
+33
-30
tests/basic_correctness/test_cumem.py
tests/basic_correctness/test_cumem.py
+0
-1
vllm/config/cache.py
vllm/config/cache.py
+5
-4
vllm/multimodal/cache.py
vllm/multimodal/cache.py
+3
-2
vllm/platforms/cpu.py
vllm/platforms/cpu.py
+3
-3
vllm/utils/mem_utils.py
vllm/utils/mem_utils.py
+7
-3
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+5
-5
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+3
-4
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+3
-4
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+4
-4
No files found.
tests/basic_correctness/test_cumem.py
View file @
c907d221
...
@@ -247,7 +247,6 @@ def test_deep_sleep_async():
...
@@ -247,7 +247,6 @@ def test_deep_sleep_async():
@
requires_fp8
@
requires_fp8
def
test_deep_sleep_fp8_kvcache
():
def
test_deep_sleep_fp8_kvcache
():
GiB_bytes
=
1
<<
30
model
=
"Qwen/Qwen2-0.5B"
model
=
"Qwen/Qwen2-0.5B"
used_bytes_baseline
=
current_platform
.
get_current_memory_usage
()
used_bytes_baseline
=
current_platform
.
get_current_memory_usage
()
...
...
vllm/config/cache.py
View file @
c907d221
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
from
dataclasses
import
field
from
dataclasses
import
field
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
...
@@ -10,7 +11,7 @@ from pydantic.dataclasses import dataclass
...
@@ -10,7 +11,7 @@ from pydantic.dataclasses import dataclass
from
vllm.config.utils
import
config
from
vllm.config.utils
import
config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils.mem_constants
import
GiB_bytes
from
vllm.utils.mem_constants
import
GiB_bytes
from
vllm.utils.mem_utils
import
get_cpu_memory
from
vllm.utils.mem_utils
import
format_gib
,
get_cpu_memory
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.config.parallel
import
ParallelConfig
from
vllm.config.parallel
import
ParallelConfig
...
@@ -214,7 +215,7 @@ class CacheConfig:
...
@@ -214,7 +215,7 @@ class CacheConfig:
self
,
self
,
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
)
->
None
:
)
->
None
:
swap_space_bytes
=
self
.
swap_space
*
GiB_bytes
swap_space_bytes
=
math
.
ceil
(
self
.
swap_space
*
GiB_bytes
)
total_cpu_memory
=
get_cpu_memory
()
total_cpu_memory
=
get_cpu_memory
()
# FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
# FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
# group are in the same node. However, the GPUs may span multiple nodes.
# group are in the same node. However, the GPUs may span multiple nodes.
...
@@ -222,8 +223,8 @@ class CacheConfig:
...
@@ -222,8 +223,8 @@ class CacheConfig:
cpu_memory_usage
=
swap_space_bytes
*
num_gpus_per_node
cpu_memory_usage
=
swap_space_bytes
*
num_gpus_per_node
msg
=
(
msg
=
(
f
"
{
cpu_memory_usage
/
GiB_bytes
:.
2
f
}
GiB out of the "
f
"
{
format_gib
(
cpu_memory_usage
)
}
GiB out of the "
f
"
{
total_cpu_memory
/
GiB_bytes
:.
2
f
}
GiB total CPU memory "
f
"
{
format_gib
(
total_cpu_memory
)
}
GiB total CPU memory "
"is allocated for the swap space."
"is allocated for the swap space."
)
)
if
cpu_memory_usage
>
0.7
*
total_cpu_memory
:
if
cpu_memory_usage
>
0.7
*
total_cpu_memory
:
...
...
vllm/multimodal/cache.py
View file @
c907d221
...
@@ -20,6 +20,7 @@ from vllm.logger import init_logger
...
@@ -20,6 +20,7 @@ from vllm.logger import init_logger
from
vllm.utils.cache
import
CacheInfo
,
LRUCache
from
vllm.utils.cache
import
CacheInfo
,
LRUCache
from
vllm.utils.jsontree
import
json_count_leaves
,
json_map_leaves
,
json_reduce_leaves
from
vllm.utils.jsontree
import
json_count_leaves
,
json_map_leaves
,
json_reduce_leaves
from
vllm.utils.mem_constants
import
GiB_bytes
,
MiB_bytes
from
vllm.utils.mem_constants
import
GiB_bytes
,
MiB_bytes
from
vllm.utils.mem_utils
import
format_gib
from
.inputs
import
(
from
.inputs
import
(
MultiModalBatchedField
,
MultiModalBatchedField
,
...
@@ -130,9 +131,9 @@ class MultiModalCache:
...
@@ -130,9 +131,9 @@ class MultiModalCache:
if
debug
:
if
debug
:
leaf_count
=
json_count_leaves
(
value
)
leaf_count
=
json_count_leaves
(
value
)
logger
.
debug
(
logger
.
debug
(
"Calculated size of %s to be %
.2f
GiB (%d leaves)"
,
"Calculated size of %s to be %
s
GiB (%d leaves)"
,
type
(
value
),
type
(
value
),
size
/
GiB_bytes
,
format_gib
(
size
)
,
leaf_count
,
leaf_count
,
)
)
...
...
vllm/platforms/cpu.py
View file @
c907d221
...
@@ -140,6 +140,7 @@ class CpuPlatform(Platform):
...
@@ -140,6 +140,7 @@ class CpuPlatform(Platform):
@
classmethod
@
classmethod
def
get_device_total_memory
(
cls
,
device_id
:
int
=
0
)
->
int
:
def
get_device_total_memory
(
cls
,
device_id
:
int
=
0
)
->
int
:
from
vllm.utils.mem_constants
import
GiB_bytes
from
vllm.utils.mem_constants
import
GiB_bytes
from
vllm.utils.mem_utils
import
format_gib
kv_cache_space
=
envs
.
VLLM_CPU_KVCACHE_SPACE
kv_cache_space
=
envs
.
VLLM_CPU_KVCACHE_SPACE
node_dir
=
"/sys/devices/system/node"
node_dir
=
"/sys/devices/system/node"
...
@@ -153,10 +154,9 @@ class CpuPlatform(Platform):
...
@@ -153,10 +154,9 @@ class CpuPlatform(Platform):
free_cpu_memory
=
psutil
.
virtual_memory
().
total
//
num_numa_nodes
free_cpu_memory
=
psutil
.
virtual_memory
().
total
//
num_numa_nodes
DEFAULT_CPU_MEM_UTILIZATION
=
0.5
DEFAULT_CPU_MEM_UTILIZATION
=
0.5
kv_cache_space
=
int
(
free_cpu_memory
*
DEFAULT_CPU_MEM_UTILIZATION
)
kv_cache_space
=
int
(
free_cpu_memory
*
DEFAULT_CPU_MEM_UTILIZATION
)
kv_cache_space_gib
=
kv_cache_space
/
GiB_bytes
logger
.
warning_once
(
logger
.
warning_once
(
"VLLM_CPU_KVCACHE_SPACE not set. Using
"
"VLLM_CPU_KVCACHE_SPACE not set. Using
%s GiB for KV cache."
,
f
"
{
kv_cache_space
_gib
:.
2
f
}
GiB for KV cache."
f
ormat_gib
(
kv_cache_space
),
)
)
else
:
else
:
kv_cache_space
*=
GiB_bytes
kv_cache_space
*=
GiB_bytes
...
...
vllm/utils/mem_utils.py
View file @
c907d221
...
@@ -11,11 +11,15 @@ import psutil
...
@@ -11,11 +11,15 @@ import psutil
import
torch
import
torch
import
torch.types
import
torch.types
from
.mem_constants
import
GiB_bytes
from
.mem_constants
import
GiB_bytes
,
MiB_bytes
def
format_gib
(
b
:
int
)
->
float
:
def
format_mib
(
b
:
int
)
->
str
:
return
round
(
b
/
GiB_bytes
,
2
)
return
f
"
{
round
(
b
/
MiB_bytes
,
2
)
}
"
def
format_gib
(
b
:
int
)
->
str
:
return
f
"
{
round
(
b
/
GiB_bytes
,
2
)
}
"
@
cache
@
cache
...
...
vllm/v1/core/kv_cache_utils.py
View file @
c907d221
...
@@ -14,7 +14,7 @@ from vllm.config import VllmConfig
...
@@ -14,7 +14,7 @@ from vllm.config import VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils.hashing
import
sha256_cbor
,
xxhash_cbor
from
vllm.utils.hashing
import
sha256_cbor
,
xxhash_cbor
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.mem_
constant
s
import
GiB_bytes
from
vllm.utils.mem_
util
s
import
format_gib
from
vllm.v1.kv_cache_interface
import
(
from
vllm.v1.kv_cache_interface
import
(
ChunkedLocalAttentionSpec
,
ChunkedLocalAttentionSpec
,
FullAttentionSpec
,
FullAttentionSpec
,
...
@@ -633,9 +633,9 @@ def _check_enough_kv_cache_memory(
...
@@ -633,9 +633,9 @@ def _check_enough_kv_cache_memory(
raise
ValueError
(
raise
ValueError
(
f
"To serve at least one request with the models's max seq len "
f
"To serve at least one request with the models's max seq len "
f
"(
{
max_model_len
}
), (
{
needed_memory
/
GiB_bytes
:.
2
f
}
GiB KV "
f
"(
{
max_model_len
}
), (
{
format_gib
(
needed_memory
)
}
GiB KV "
f
"cache is needed, which is larger than the available KV cache "
f
"cache is needed, which is larger than the available KV cache "
f
"memory (
{
available_memory
/
GiB_bytes
:.
2
f
}
GiB).
{
estimated_msg
}
"
f
"memory (
{
format_gib
(
available_memory
)
}
GiB).
{
estimated_msg
}
"
f
"Try increasing `gpu_memory_utilization` or decreasing `max_model_len` "
f
"Try increasing `gpu_memory_utilization` or decreasing `max_model_len` "
f
"when initializing the engine. "
f
"when initializing the engine. "
f
"See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ "
f
"See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ "
...
@@ -1441,10 +1441,10 @@ def _auto_fit_max_model_len(
...
@@ -1441,10 +1441,10 @@ def _auto_fit_max_model_len(
vllm_config
.
model_config
.
max_model_len
=
auto_fit_max
vllm_config
.
model_config
.
max_model_len
=
auto_fit_max
logger
.
info_once
(
logger
.
info_once
(
"Auto-fit max_model_len: reduced from %d to %d to fit in "
"Auto-fit max_model_len: reduced from %d to %d to fit in "
"available GPU memory (%
.2f
GiB available for KV cache)"
,
"available GPU memory (%
s
GiB available for KV cache)"
,
original_max
,
original_max
,
auto_fit_max
,
auto_fit_max
,
min_available_memory
/
GiB_bytes
,
format_gib
(
min_available_memory
)
,
scope
=
"local"
,
scope
=
"local"
,
)
)
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
c907d221
...
@@ -14,8 +14,7 @@ from vllm.config.compilation import CUDAGraphMode
...
@@ -14,8 +14,7 @@ from vllm.config.compilation import CUDAGraphMode
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.utils.mem_constants
import
GiB_bytes
from
vllm.utils.mem_utils
import
DeviceMemoryProfiler
,
format_gib
from
vllm.utils.mem_utils
import
DeviceMemoryProfiler
from
vllm.utils.platform_utils
import
is_pin_memory_available
from
vllm.utils.platform_utils
import
is_pin_memory_available
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
...
@@ -165,8 +164,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -165,8 +164,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
model_memory_usage
=
m
.
consumed_memory
self
.
model_memory_usage
=
m
.
consumed_memory
logger
.
info
(
logger
.
info
(
"Model loading took %
.4f
GiB and %.6f seconds"
,
"Model loading took %
s
GiB and %.6f seconds"
,
m
.
consumed_memory
/
GiB_bytes
,
format_gib
(
m
.
consumed_memory
)
,
time_after_load
-
time_before_load
,
time_after_load
-
time_before_load
,
)
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
c907d221
...
@@ -93,8 +93,7 @@ from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
...
@@ -93,8 +93,7 @@ from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from
vllm.utils
import
length_from_prompt_token_ids_or_embeds
from
vllm.utils
import
length_from_prompt_token_ids_or_embeds
from
vllm.utils.jsontree
import
json_map_leaves
from
vllm.utils.jsontree
import
json_map_leaves
from
vllm.utils.math_utils
import
cdiv
,
round_up
from
vllm.utils.math_utils
import
cdiv
,
round_up
from
vllm.utils.mem_constants
import
GiB_bytes
from
vllm.utils.mem_utils
import
DeviceMemoryProfiler
,
format_gib
from
vllm.utils.mem_utils
import
DeviceMemoryProfiler
from
vllm.utils.nvtx_pytorch_hooks
import
PytHooks
from
vllm.utils.nvtx_pytorch_hooks
import
PytHooks
from
vllm.utils.platform_utils
import
is_pin_memory_available
from
vllm.utils.platform_utils
import
is_pin_memory_available
from
vllm.utils.torch_utils
import
(
from
vllm.utils.torch_utils
import
(
...
@@ -3899,8 +3898,8 @@ class GPUModelRunner(
...
@@ -3899,8 +3898,8 @@ class GPUModelRunner(
logger
.
error
(
combined_msg
)
logger
.
error
(
combined_msg
)
raise
e
raise
e
logger
.
info_once
(
logger
.
info_once
(
"Model loading took %
.4f
GiB memory and %.6f seconds"
,
"Model loading took %
s
GiB memory and %.6f seconds"
,
self
.
model_memory_usage
/
GiB_bytes
,
format_gib
(
self
.
model_memory_usage
)
,
time_after_load
-
time_before_load
,
time_after_load
-
time_before_load
,
scope
=
"local"
,
scope
=
"local"
,
)
)
...
...
vllm/v1/worker/gpu_worker.py
View file @
c907d221
...
@@ -125,7 +125,7 @@ class Worker(WorkerBase):
...
@@ -125,7 +125,7 @@ class Worker(WorkerBase):
used_bytes
=
total
-
free_bytes_after_sleep
used_bytes
=
total
-
free_bytes_after_sleep
assert
freed_bytes
>=
0
,
"Memory usage increased after sleeping."
assert
freed_bytes
>=
0
,
"Memory usage increased after sleeping."
logger
.
info
(
logger
.
info
(
"Sleep mode freed %
f
GiB memory, %
f
GiB memory is still in use."
,
"Sleep mode freed %
s
GiB memory, %
s
GiB memory is still in use."
,
format_gib
(
freed_bytes
),
format_gib
(
freed_bytes
),
format_gib
(
used_bytes
),
format_gib
(
used_bytes
),
)
)
...
@@ -342,19 +342,19 @@ class Worker(WorkerBase):
...
@@ -342,19 +342,19 @@ class Worker(WorkerBase):
unrequested_memory
=
self
.
init_snapshot
.
free_memory
-
self
.
requested_memory
unrequested_memory
=
self
.
init_snapshot
.
free_memory
-
self
.
requested_memory
logger
.
debug
(
logger
.
debug
(
"Initial free memory: %
f
GiB; Requested memory: %f (util), %
f
GiB"
,
"Initial free memory: %
s
GiB; Requested memory: %f (util), %
s
GiB"
,
format_gib
(
self
.
init_snapshot
.
free_memory
),
format_gib
(
self
.
init_snapshot
.
free_memory
),
self
.
cache_config
.
gpu_memory_utilization
,
self
.
cache_config
.
gpu_memory_utilization
,
format_gib
(
self
.
requested_memory
),
format_gib
(
self
.
requested_memory
),
)
)
logger
.
debug
(
logger
.
debug
(
"Free memory after profiling: %
f
GiB (total), %
f
GiB (within requested)"
,
"Free memory after profiling: %
s
GiB (total), %
s
GiB (within requested)"
,
format_gib
(
free_gpu_memory
),
format_gib
(
free_gpu_memory
),
format_gib
(
free_gpu_memory
-
unrequested_memory
),
format_gib
(
free_gpu_memory
-
unrequested_memory
),
)
)
logger
.
debug
(
profile_result
)
logger
.
debug
(
profile_result
)
logger
.
info_once
(
logger
.
info_once
(
"Available KV cache memory: %
f
GiB"
,
"Available KV cache memory: %
s
GiB"
,
format_gib
(
self
.
available_kv_cache_memory_bytes
),
format_gib
(
self
.
available_kv_cache_memory_bytes
),
scope
=
"local"
,
scope
=
"local"
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment