Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9ddac563
Unverified
Commit
9ddac563
authored
Jan 15, 2025
by
Shanshan Shen
Committed by
GitHub
Jan 15, 2025
Browse files
[Platform] move current_memory_usage() into platform (#11369)
Signed-off-by:
Shanshan Shen
<
467638484@qq.com
>
parent
1a51b9f8
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
31 additions
and
7 deletions
+31
-7
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+7
-0
vllm/platforms/interface.py
vllm/platforms/interface.py
+9
-0
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+7
-0
vllm/platforms/xpu.py
vllm/platforms/xpu.py
+7
-0
vllm/utils.py
vllm/utils.py
+1
-7
No files found.
vllm/platforms/cuda.py
View file @
9ddac563
...
...
@@ -143,6 +143,13 @@ class CudaPlatformBase(Platform):
if
cache_config
and
cache_config
.
block_size
is
None
:
cache_config
.
block_size
=
16
@
classmethod
def
get_current_memory_usage
(
cls
,
device
:
Optional
[
torch
.
types
.
Device
]
=
None
)
->
float
:
torch
.
cuda
.
reset_peak_memory_stats
(
device
)
return
torch
.
cuda
.
max_memory_allocated
(
device
)
@
classmethod
def
get_attn_backend_cls
(
cls
,
selected_backend
,
head_size
,
dtype
,
kv_cache_dtype
,
block_size
,
use_v1
)
->
str
:
...
...
vllm/platforms/interface.py
View file @
9ddac563
...
...
@@ -277,6 +277,15 @@ class Platform:
return
False
return
True
@
classmethod
def
get_current_memory_usage
(
cls
,
device
:
Optional
[
torch
.
types
.
Device
]
=
None
)
->
float
:
"""
Return the memory usage in bytes.
"""
raise
NotImplementedError
@
classmethod
def
get_punica_wrapper
(
cls
)
->
str
:
"""
...
...
vllm/platforms/rocm.py
View file @
9ddac563
...
...
@@ -157,3 +157,10 @@ class RocmPlatform(Platform):
@
classmethod
def
get_punica_wrapper
(
cls
)
->
str
:
return
"vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
@
classmethod
def
get_current_memory_usage
(
cls
,
device
:
Optional
[
torch
.
types
.
Device
]
=
None
)
->
float
:
torch
.
cuda
.
reset_peak_memory_stats
(
device
)
return
torch
.
cuda
.
max_memory_allocated
(
device
)
vllm/platforms/xpu.py
View file @
9ddac563
...
...
@@ -94,3 +94,10 @@ class XPUPlatform(Platform):
def
is_pin_memory_available
(
cls
):
logger
.
warning
(
"Pin memory is not supported on XPU."
)
return
False
@
classmethod
def
get_current_memory_usage
(
cls
,
device
:
Optional
[
torch
.
types
.
Device
]
=
None
)
->
float
:
torch
.
xpu
.
reset_peak_memory_stats
(
device
)
return
torch
.
xpu
.
max_memory_allocated
(
device
)
vllm/utils.py
View file @
9ddac563
...
...
@@ -710,13 +710,7 @@ class DeviceMemoryProfiler:
def
current_memory_usage
(
self
)
->
float
:
# Return the memory usage in bytes.
from
vllm.platforms
import
current_platform
if
current_platform
.
is_cuda_alike
():
torch
.
cuda
.
reset_peak_memory_stats
(
self
.
device
)
mem
=
torch
.
cuda
.
max_memory_allocated
(
self
.
device
)
elif
current_platform
.
is_xpu
():
torch
.
xpu
.
reset_peak_memory_stats
(
self
.
device
)
# type: ignore
mem
=
torch
.
xpu
.
max_memory_allocated
(
self
.
device
)
# type: ignore
return
mem
return
current_platform
.
get_current_memory_usage
(
self
.
device
)
def
__enter__
(
self
):
self
.
initial_memory
=
self
.
current_memory_usage
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment