Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
da02cb4b
Unverified
Commit
da02cb4b
authored
Jan 18, 2025
by
youkaichao
Committed by
GitHub
Jan 18, 2025
Browse files
[core] further polish memory profiling (#12126)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
c09503dd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
85 additions
and
67 deletions
+85
-67
tests/test_utils.py
tests/test_utils.py
+12
-14
vllm/utils.py
vllm/utils.py
+56
-39
vllm/worker/worker.py
vllm/worker/worker.py
+17
-14
No files found.
tests/test_utils.py
View file @
da02cb4b
...
@@ -9,10 +9,10 @@ import torch
...
@@ -9,10 +9,10 @@ import torch
from
vllm_test_utils
import
monitor
from
vllm_test_utils
import
monitor
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.utils
import
(
FlexibleArgumentParser
,
PlaceholderModule
,
from
vllm.utils
import
(
FlexibleArgumentParser
,
MemorySnapshot
,
StoreBoolean
,
bind_kv_cache
,
deprecate_kwargs
,
PlaceholderModule
,
StoreBoolean
,
bind_kv_cache
,
get_open_port
,
memory_profiling
,
merge_async_iterators
,
deprecate_kwargs
,
get_open_port
,
memory_profiling
,
supports_kw
)
merge_async_iterators
,
supports_kw
)
from
.utils
import
error_on_warning
,
fork_new_process_for_each_test
from
.utils
import
error_on_warning
,
fork_new_process_for_each_test
...
@@ -284,14 +284,13 @@ def test_memory_profiling():
...
@@ -284,14 +284,13 @@ def test_memory_profiling():
# 512 MiB allocation outside of this instance
# 512 MiB allocation outside of this instance
handle1
=
lib
.
cudaMalloc
(
512
*
1024
*
1024
)
handle1
=
lib
.
cudaMalloc
(
512
*
1024
*
1024
)
baseline_memory_in_bytes
=
\
baseline_snapshot
=
MemorySnapshot
()
torch
.
cuda
.
mem_get_info
()[
1
]
-
torch
.
cuda
.
mem_get_info
()[
0
]
# load weights
# load weights
weights
=
torch
.
randn
(
128
,
1024
,
1024
,
device
=
'cuda'
,
dtype
=
torch
.
float32
)
weights
=
torch
.
randn
(
128
,
1024
,
1024
,
device
=
'cuda'
,
dtype
=
torch
.
float32
)
weights_memory
_in_bytes
=
128
*
1024
*
1024
*
4
# 512 MiB
weights_memory
=
128
*
1024
*
1024
*
4
# 512 MiB
def
measure_current_non_torch
():
def
measure_current_non_torch
():
free
,
total
=
torch
.
cuda
.
mem_get_info
()
free
,
total
=
torch
.
cuda
.
mem_get_info
()
...
@@ -300,8 +299,8 @@ def test_memory_profiling():
...
@@ -300,8 +299,8 @@ def test_memory_profiling():
current_non_torch
=
current_used
-
current_torch
current_non_torch
=
current_used
-
current_torch
return
current_non_torch
return
current_non_torch
with
memory_profiling
(
baseline_
memory_in_bytes
=
baseline_memory_in_bytes
,
with
memory_profiling
(
baseline_
snapshot
=
baseline_snapshot
,
weights_memory
_in_bytes
=
weights_memory
_in_bytes
)
as
result
,
\
weights_memory
=
weights_memory
)
as
result
,
\
monitor
(
measure_current_non_torch
)
as
monitored_values
:
monitor
(
measure_current_non_torch
)
as
monitored_values
:
# make a memory spike, 1 GiB
# make a memory spike, 1 GiB
spike
=
torch
.
randn
(
256
,
1024
,
1024
,
device
=
'cuda'
,
dtype
=
torch
.
float32
)
spike
=
torch
.
randn
(
256
,
1024
,
1024
,
device
=
'cuda'
,
dtype
=
torch
.
float32
)
...
@@ -316,13 +315,12 @@ def test_memory_profiling():
...
@@ -316,13 +315,12 @@ def test_memory_profiling():
assert
measured_diff
==
256
*
1024
*
1024
assert
measured_diff
==
256
*
1024
*
1024
# Check that the memory usage is within 5% of the expected values
# Check that the memory usage is within 5% of the expected values
# 5% tolerance is caused by
PyTorch caching allocator,
# 5% tolerance is caused by
cuda runtime.
# we cannot control
PyTorch's behavior of its internal buffer
s,
# we cannot control
cuda runtime in the granularity of byte
s,
# which causes a small error (<10 MiB in practice)
# which causes a small error (<10 MiB in practice)
non_torch_ratio
=
result
.
non_torch_increase_in_bytes
/
(
256
*
1024
*
1024
)
# noqa
non_torch_ratio
=
result
.
non_torch_increase
/
(
256
*
1024
*
1024
)
# noqa
torch_peak_ratio
=
result
.
torch_peak_increase_in_bytes
/
(
1024
*
1024
*
1024
)
# noqa
assert
abs
(
non_torch_ratio
-
1
)
<=
0.05
assert
abs
(
non_torch_ratio
-
1
)
<=
0.05
assert
abs
(
torch_peak_
ratio
-
1
)
<=
0.05
assert
result
.
torch_peak_
increase
==
1024
*
1024
*
1024
del
weights
del
weights
lib
.
cudaFree
(
handle1
)
lib
.
cudaFree
(
handle1
)
lib
.
cudaFree
(
handle2
)
lib
.
cudaFree
(
handle2
)
...
...
vllm/utils.py
View file @
da02cb4b
...
@@ -1923,36 +1923,57 @@ def kill_process_tree(pid: int):
...
@@ -1923,36 +1923,57 @@ def kill_process_tree(pid: int):
@
dataclass
@
dataclass
class
MemorySnapshot
:
class
MemorySnapshot
:
"""Memory snapshot."""
"""Memory snapshot."""
torch_peak_in_bytes
:
int
=
0
torch_peak
:
int
=
0
torch_memory_in_bytes
:
int
=
0
cuda_memory
:
int
=
0
torch_memory
:
int
=
0
non_torch_memory
:
int
=
0
timestamp
:
float
=
0.0
timestamp
:
float
=
0.0
auto_measure
:
bool
=
True
def
__post_init__
(
self
):
if
self
.
auto_measure
:
self
.
measure
()
def
measure
(
self
):
def
measure
(
self
):
self
.
torch_peak_in_bytes
=
torch
.
cuda
.
max_memory_reserved
()
# we measure the torch peak memory usage via allocated_bytes,
# rather than `torch.cuda.memory_reserved()` .
# After `torch.cuda.reset_peak_memory_stats()`,
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
# when we call `torch.cuda.empty_cache()` or OOM happens.
self
.
torch_peak
=
torch
.
cuda
.
memory_stats
().
get
(
"allocated_bytes.all.peak"
,
0
)
self
.
cuda_memory
=
torch
.
cuda
.
mem_get_info
(
)[
1
]
-
torch
.
cuda
.
mem_get_info
()[
0
]
# torch.cuda.memory_reserved() is how many bytes
# torch.cuda.memory_reserved() is how many bytes
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
self
.
torch_memory_in_bytes
=
torch
.
cuda
.
memory_reserved
()
# this is used to measure the non-torch memory usage
self
.
torch_memory
=
torch
.
cuda
.
memory_reserved
()
self
.
non_torch_memory
=
self
.
cuda_memory
-
self
.
torch_memory
self
.
timestamp
=
time
.
time
()
self
.
timestamp
=
time
.
time
()
def
__sub__
(
self
,
other
:
"MemorySnapshot"
)
->
"MemorySnapshot"
:
def
__sub__
(
self
,
other
:
"MemorySnapshot"
)
->
"MemorySnapshot"
:
"""support a - b"""
return
MemorySnapshot
(
return
MemorySnapshot
(
torch_peak_in_bytes
=
self
.
torch_peak_in_bytes
-
torch_peak
=
self
.
torch_peak
-
other
.
torch_peak
,
other
.
torch_peak_in_bytes
,
cuda_memory
=
self
.
cuda_memory
-
other
.
cuda_memory
,
torch_memory_in_bytes
=
self
.
torch_memory_in_bytes
-
torch_memory
=
self
.
torch_memory
-
other
.
torch_memory
,
other
.
torch_memory_in_bytes
,
non_torch_memory
=
self
.
non_torch_memory
-
other
.
non_torch_memory
,
timestamp
=
self
.
timestamp
-
other
.
timestamp
)
timestamp
=
self
.
timestamp
-
other
.
timestamp
,
auto_measure
=
False
,
)
@
dataclass
@
dataclass
class
MemoryProfilingResult
:
class
MemoryProfilingResult
:
"""Memory profiling result.
"""Memory profiling result.
All numbers are in bytes.
"""
# noqa
"""
baseline_memory_in_bytes
:
int
=
0
non_kv_cache_memory
:
int
=
0
non_kv_cache_memory_in_bytes
:
int
=
0
torch_peak_increase
:
int
=
0
torch_
peak_
increase
_in_bytes
:
int
=
0
non_
torch_increase
:
int
=
0
non_torch_increase_in_bytes
:
in
t
=
0
weights_memory
:
floa
t
=
0
weights_memory_in_bytes
:
float
=
0
before_create
:
MemorySnapshot
=
field
(
default_factory
=
MemorySnapshot
)
before_profile
:
MemorySnapshot
=
field
(
default_factory
=
MemorySnapshot
)
before_profile
:
MemorySnapshot
=
field
(
default_factory
=
MemorySnapshot
)
after_profile
:
MemorySnapshot
=
field
(
default_factory
=
MemorySnapshot
)
after_profile
:
MemorySnapshot
=
field
(
default_factory
=
MemorySnapshot
)
profile_time
:
float
=
0.0
profile_time
:
float
=
0.0
...
@@ -1960,18 +1981,14 @@ class MemoryProfilingResult:
...
@@ -1960,18 +1981,14 @@ class MemoryProfilingResult:
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
memory_profiling
(
def
memory_profiling
(
baseline_
memory_in_bytes
:
int
,
weights_memory_in_bytes
:
int
baseline_
snapshot
:
MemorySnapshot
,
)
->
Generator
[
MemoryProfilingResult
,
None
,
None
]:
weights_memory
:
int
)
->
Generator
[
MemoryProfilingResult
,
None
,
None
]:
"""Memory profiling context manager.
"""Memory profiling context manager.
baseline_memory_in_bytes: memory used by all the components other than
baseline_snapshot: the memory snapshot before the current vLLM instance.
the current vLLM instance. It contains: memory used by other processes, memory
weights_memory: memory used by PyTorch when loading the model weights.
used by another vLLM instance in the same process, etc. It is usually measured
before the current vLLM instance initialize the device. And we assume it is
constant during the profiling of the current vLLM instance.
weights_memory_in_bytes: memory used by PyTorch when loading the model weights.
Note that, before loading the model weights, we also initialize the device
Note that, before loading the model weights, we also initialize the device
and distributed environment, which may consume some memory. This part is not
and distributed environment, which may consume some memory. This part is not
included in the weights_memory
_in_bytes
because PyTorch does not control it.
included in the weights_memory because PyTorch does not control it.
The memory in one GPU can be classified into 3 categories:
The memory in one GPU can be classified into 3 categories:
1. memory used by anything other than the current vLLM instance.
1. memory used by anything other than the current vLLM instance.
...
@@ -2006,20 +2023,21 @@ def memory_profiling(
...
@@ -2006,20 +2023,21 @@ def memory_profiling(
b. 2 GiB reserved for the peak activation tensors (category 2)
b. 2 GiB reserved for the peak activation tensors (category 2)
c. 1 GiB used by non-torch components (category 3)
c. 1 GiB used by non-torch components (category 3)
The memory used for loading weights (a.) is directly given from the argument `weights_memory
_in_bytes
`.
The memory used for loading weights (a.) is directly given from the argument `weights_memory`.
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]`
after
profiling gives (b.).
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]`
during
profiling gives (b.).
(c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`),
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_reserved()`.
"""
# noqa
"""
# noqa
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
reset_peak_memory_stats
()
torch
.
cuda
.
reset_peak_memory_stats
()
result
=
MemoryProfilingResult
()
result
=
MemoryProfilingResult
()
result
.
b
aseline_memory_in_by
te
s
=
baseline_
memory_in_bytes
result
.
b
efore_crea
te
=
baseline_
snapshot
# the part of memory used for holding the model weights
# the part of memory used for holding the model weights
result
.
weights_memory
_in_bytes
=
weights_memory
_in_bytes
result
.
weights_memory
=
weights_memory
result
.
before_profile
.
measure
()
result
.
before_profile
.
measure
()
...
@@ -2030,13 +2048,12 @@ def memory_profiling(
...
@@ -2030,13 +2048,12 @@ def memory_profiling(
result
.
after_profile
.
measure
()
result
.
after_profile
.
measure
()
diff
=
result
.
after_profile
-
result
.
before_profile
diff_profile
=
result
.
after_profile
-
result
.
before_profile
result
.
torch_peak_increase_in_bytes
=
diff
.
torch_peak_in_bytes
diff_from_create
=
result
.
after_profile
-
result
.
before_create
current_cuda_memory_bytes
=
torch
.
cuda
.
mem_get_info
(
result
.
torch_peak_increase
=
diff_profile
.
torch_peak
)[
1
]
-
torch
.
cuda
.
mem_get_info
()[
0
]
result
.
non_torch_increase
=
diff_from_create
.
non_torch_memory
result
.
non_torch_increase_in_bytes
=
current_cuda_memory_bytes
-
baseline_memory_in_bytes
-
weights_memory_in_bytes
-
diff
.
torch_memory_in_bytes
# noqa
result
.
profile_time
=
diff_profile
.
timestamp
result
.
profile_time
=
diff
.
timestamp
result
.
non_kv_cache_memory
=
result
.
non_torch_increase
+
result
.
torch_peak_increase
+
result
.
weights_memory
# noqa
result
.
non_kv_cache_memory_in_bytes
=
result
.
non_torch_increase_in_bytes
+
result
.
torch_peak_increase_in_bytes
+
result
.
weights_memory_in_bytes
# noqa
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501
...
...
vllm/worker/worker.py
View file @
da02cb4b
...
@@ -21,7 +21,8 @@ from vllm.platforms import current_platform
...
@@ -21,7 +21,8 @@ from vllm.platforms import current_platform
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
(
ExecuteModelRequest
,
IntermediateTensors
,
from
vllm.sequence
import
(
ExecuteModelRequest
,
IntermediateTensors
,
SequenceGroupMetadata
,
SequenceGroupMetadataDelta
)
SequenceGroupMetadata
,
SequenceGroupMetadataDelta
)
from
vllm.utils
import
GiB_bytes
,
bind_kv_cache
,
memory_profiling
from
vllm.utils
import
(
GiB_bytes
,
MemorySnapshot
,
bind_kv_cache
,
memory_profiling
)
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.enc_dec_model_runner
import
EncoderDecoderModelRunner
from
vllm.worker.enc_dec_model_runner
import
EncoderDecoderModelRunner
from
vllm.worker.model_runner
import
GPUModelRunnerBase
,
ModelRunner
from
vllm.worker.model_runner
import
GPUModelRunnerBase
,
ModelRunner
...
@@ -137,7 +138,8 @@ class Worker(LocalOrDistributedWorkerBase):
...
@@ -137,7 +138,8 @@ class Worker(LocalOrDistributedWorkerBase):
_check_if_gpu_supports_dtype
(
self
.
model_config
.
dtype
)
_check_if_gpu_supports_dtype
(
self
.
model_config
.
dtype
)
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
self
.
init_gpu_memory
=
torch
.
cuda
.
mem_get_info
()[
0
]
torch
.
cuda
.
reset_peak_memory_stats
()
self
.
baseline_snapshot
=
MemorySnapshot
()
else
:
else
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Not support device type:
{
self
.
device_config
.
device
}
"
)
f
"Not support device type:
{
self
.
device_config
.
device
}
"
)
...
@@ -192,10 +194,9 @@ class Worker(LocalOrDistributedWorkerBase):
...
@@ -192,10 +194,9 @@ class Worker(LocalOrDistributedWorkerBase):
# Execute a forward pass with dummy inputs to profile the memory usage
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
# of the model.
with
memory_profiling
(
baseline_memory_in_bytes
=
total_gpu_memory
-
with
memory_profiling
(
self
.
init_gpu_memory
,
self
.
baseline_snapshot
,
weights_memory_in_bytes
=
self
.
model_runner
.
weights_memory
=
self
.
model_runner
.
model_memory_usage
)
as
result
:
model_memory_usage
)
as
result
:
self
.
model_runner
.
profile_run
()
self
.
model_runner
.
profile_run
()
self
.
_assert_memory_footprint_increased_during_profiling
()
self
.
_assert_memory_footprint_increased_during_profiling
()
...
@@ -203,7 +204,7 @@ class Worker(LocalOrDistributedWorkerBase):
...
@@ -203,7 +204,7 @@ class Worker(LocalOrDistributedWorkerBase):
memory_for_current_instance
=
total_gpu_memory
*
\
memory_for_current_instance
=
total_gpu_memory
*
\
self
.
cache_config
.
gpu_memory_utilization
self
.
cache_config
.
gpu_memory_utilization
available_kv_cache_memory
=
(
memory_for_current_instance
-
available_kv_cache_memory
=
(
memory_for_current_instance
-
result
.
non_kv_cache_memory
_in_bytes
)
result
.
non_kv_cache_memory
)
# Calculate the number of blocks that can be allocated with the
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
# profiled peak memory.
...
@@ -226,11 +227,11 @@ class Worker(LocalOrDistributedWorkerBase):
...
@@ -226,11 +227,11 @@ class Worker(LocalOrDistributedWorkerBase):
f
"(
{
self
.
cache_config
.
gpu_memory_utilization
:.
2
f
}
)"
f
"(
{
self
.
cache_config
.
gpu_memory_utilization
:.
2
f
}
)"
f
" =
{
(
memory_for_current_instance
/
GiB_bytes
):.
2
f
}
GiB
\n
"
f
" =
{
(
memory_for_current_instance
/
GiB_bytes
):.
2
f
}
GiB
\n
"
"model weights take "
"model weights take "
f
"
{
(
result
.
weights_memory
_in_bytes
/
GiB_bytes
):.
2
f
}
GiB;"
f
"
{
(
result
.
weights_memory
/
GiB_bytes
):.
2
f
}
GiB;"
" non_torch_memory takes "
" non_torch_memory takes "
f
"
{
(
result
.
non_torch_increase
_in_bytes
/
GiB_bytes
):.
2
f
}
GiB;"
f
"
{
(
result
.
non_torch_increase
/
GiB_bytes
):.
2
f
}
GiB;"
" PyTorch activation peak memory takes "
" PyTorch activation peak memory takes "
f
"
{
(
result
.
torch_peak_increase
_in_bytes
/
GiB_bytes
):.
2
f
}
GiB;"
f
"
{
(
result
.
torch_peak_increase
/
GiB_bytes
):.
2
f
}
GiB;"
" the rest of the memory reserved for KV Cache is "
" the rest of the memory reserved for KV Cache is "
f
"
{
(
available_kv_cache_memory
/
GiB_bytes
):.
2
f
}
GiB."
)
f
"
{
(
available_kv_cache_memory
/
GiB_bytes
):.
2
f
}
GiB."
)
...
@@ -246,11 +247,13 @@ class Worker(LocalOrDistributedWorkerBase):
...
@@ -246,11 +247,13 @@ class Worker(LocalOrDistributedWorkerBase):
def
_assert_memory_footprint_increased_during_profiling
(
self
):
def
_assert_memory_footprint_increased_during_profiling
(
self
):
# NOTE(woosuk): Here we assume that the other processes using the same
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
# GPU did not change their memory usage during the profiling.
free_gpu_memory
,
_
=
torch
.
cuda
.
mem_get_info
()
free_gpu_memory
,
total
=
torch
.
cuda
.
mem_get_info
()
assert
self
.
init_gpu_memory
-
free_gpu_memory
>
0
,
(
cuda_memory
=
total
-
free_gpu_memory
assert
self
.
baseline_snapshot
.
cuda_memory
<
cuda_memory
,
(
"Error in memory profiling. "
"Error in memory profiling. "
f
"Initial free memory
{
self
.
init_gpu_memory
}
, current free memory"
f
"Initial used memory
{
self
.
baseline_snapshot
.
cuda_memory
}
, "
f
"
{
free_gpu_memory
}
. This happens when the GPU memory was "
f
"currently used memory
{
cuda_memory
}
. "
f
"This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance."
)
"not properly cleaned up before initializing the vLLM instance."
)
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment