Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ce094624
Unverified
Commit
ce094624
authored
Jan 14, 2026
by
Shanshan Shen
Committed by
GitHub
Jan 14, 2026
Browse files
[Misc] Make mem utils can be reused by other platforms (#32322)
Signed-off-by:
shen-shanshan
<
467638484@qq.com
>
parent
3f28174c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
25 deletions
+21
-25
vllm/platforms/interface.py
vllm/platforms/interface.py
+12
-8
vllm/utils/mem_utils.py
vllm/utils/mem_utils.py
+9
-13
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+0
-4
No files found.
vllm/platforms/interface.py
View file @
ce094624
...
...
@@ -589,14 +589,18 @@ class Platform:
def
__getattr__
(
self
,
key
:
str
):
device
=
getattr
(
torch
,
self
.
device_type
,
None
)
if
device
is
not
None
and
hasattr
(
device
,
key
):
return
getattr
(
device
,
key
)
else
:
logger
.
warning
(
"Current platform %s does not have '%s' attribute."
,
self
.
device_type
,
key
,
)
return
None
attr
=
getattr
(
device
,
key
)
# NOTE: `hasattr(device, key)=True` can only avoid AttributeError,
# but the value of this attr could be `None`.
if
attr
is
not
None
:
return
attr
logger
.
warning
(
"Current platform %s does not have '%s' attribute."
,
self
.
device_type
,
key
,
)
return
None
def
get_global_graph_pool
(
self
)
->
Any
:
"""
...
...
vllm/utils/mem_utils.py
View file @
ce094624
...
...
@@ -11,6 +11,8 @@ import psutil
import
torch
import
torch.types
from
vllm.platforms
import
current_platform
from
.mem_constants
import
GiB_bytes
,
MiB_bytes
...
...
@@ -45,8 +47,6 @@ class DeviceMemoryProfiler:
def
current_memory_usage
(
self
)
->
float
:
# Return the memory usage in bytes.
from
vllm.platforms
import
current_platform
gc
.
collect
()
return
current_platform
.
get_current_memory_usage
(
self
.
device
)
...
...
@@ -80,8 +80,6 @@ class MemorySnapshot:
def
__post_init__
(
self
)
->
None
:
if
self
.
device
is
None
:
from
vllm.platforms
import
current_platform
device_fn
=
current_platform
.
current_device
assert
device_fn
is
not
None
self
.
device_
=
torch
.
device
(
device_fn
())
...
...
@@ -92,8 +90,6 @@ class MemorySnapshot:
self
.
measure
()
def
measure
(
self
)
->
None
:
from
vllm.platforms
import
current_platform
device
=
self
.
device_
# we measure the torch peak memory usage via allocated_bytes,
...
...
@@ -101,11 +97,11 @@ class MemorySnapshot:
# After `torch.cuda.reset_peak_memory_stats()`,
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
# when we call `torch.cuda.empty_cache()` or OOM happens.
self
.
torch_peak
=
torch
.
cuda
.
memory_stats
(
device
).
get
(
self
.
torch_peak
=
current_platform
.
memory_stats
(
device
).
get
(
"allocated_bytes.all.peak"
,
0
)
self
.
free_memory
,
self
.
total_memory
=
torch
.
cuda
.
mem_get_info
(
device
)
self
.
free_memory
,
self
.
total_memory
=
current_platform
.
mem_get_info
(
device
)
shared_sysmem_device_mem_sms
=
((
8
,
7
),
(
11
,
0
),
(
12
,
1
))
# Orin, Thor, Spark
if
(
current_platform
.
is_cuda
()
...
...
@@ -130,7 +126,7 @@ class MemorySnapshot:
# torch.cuda.memory_reserved() is how many bytes
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
# this is used to measure the non-torch memory usage
self
.
torch_memory
=
torch
.
cuda
.
memory_reserved
(
device
)
self
.
torch_memory
=
current_platform
.
memory_reserved
(
device
)
self
.
non_torch_memory
=
self
.
cuda_memory
-
self
.
torch_memory
self
.
timestamp
=
time
.
time
()
...
...
@@ -159,7 +155,7 @@ class MemorySnapshot:
f
"torch_peak=
{
format_gib
(
self
.
torch_peak
)
}
GiB, "
f
"free_memory=
{
format_gib
(
self
.
free_memory
)
}
GiB, "
f
"total_memory=
{
format_gib
(
self
.
total_memory
)
}
GiB, "
f
"cu
da
_memory=
{
format_gib
(
self
.
cuda_memory
)
}
GiB, "
f
"
{
cu
rrent_platform
.
device_name
}
_memory=
{
format_gib
(
self
.
cuda_memory
)
}
GiB, "
f
"torch_memory=
{
format_gib
(
self
.
torch_memory
)
}
GiB, "
f
"non_torch_memory=
{
format_gib
(
self
.
non_torch_memory
)
}
GiB, "
f
"timestamp=
{
self
.
timestamp
}
, "
...
...
@@ -254,8 +250,8 @@ def memory_profiling(
until after profiling to get (c.).
"""
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
reset_peak_memory_stats
(
baseline_snapshot
.
device_
)
current_platform
.
empty_cache
()
current_platform
.
reset_peak_memory_stats
(
baseline_snapshot
.
device_
)
result
=
MemoryProfilingResult
(
before_create
=
baseline_snapshot
,
...
...
@@ -268,7 +264,7 @@ def memory_profiling(
yield
result
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
current_platform
.
empty_cache
()
result
.
after_profile
.
measure
()
...
...
vllm/v1/worker/gpu_worker.py
View file @
ce094624
...
...
@@ -312,9 +312,6 @@ class Worker(WorkerBase):
logger
.
info
(
msg
)
return
kv_cache_memory_bytes
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
reset_peak_memory_stats
()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
with
memory_profiling
(
...
...
@@ -360,7 +357,6 @@ class Worker(WorkerBase):
format_gib
(
self
.
available_kv_cache_memory_bytes
),
scope
=
"local"
,
)
gc
.
collect
()
return
int
(
self
.
available_kv_cache_memory_bytes
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment