Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
551603fe
Unverified
Commit
551603fe
authored
Dec 16, 2024
by
youkaichao
Committed by
GitHub
Dec 16, 2024
Browse files
[core] overhaul memory profiling and fix backward compatibility (#10511)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
efbce85f
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
236 additions
and
60 deletions
+236
-60
tests/entrypoints/llm/test_gpu_utilization.py
tests/entrypoints/llm/test_gpu_utilization.py
+25
-0
tests/entrypoints/llm/test_lazy_outlines.py
tests/entrypoints/llm/test_lazy_outlines.py
+1
-1
tests/test_utils.py
tests/test_utils.py
+42
-2
tests/worker/test_profile.py
tests/worker/test_profile.py
+9
-9
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+6
-5
vllm/utils.py
vllm/utils.py
+123
-2
vllm/worker/multi_step_model_runner.py
vllm/worker/multi_step_model_runner.py
+2
-1
vllm/worker/worker.py
vllm/worker/worker.py
+28
-40
No files found.
tests/entrypoints/llm/test_gpu_utilization.py
0 → 100644
View file @
551603fe
from
vllm
import
LLM
,
SamplingParams
def
test_gpu_memory_utilization
():
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
)
# makes sure gpu_memory_utilization is per-instance limit,
# not a global limit
llms
=
[
LLM
(
model
=
"facebook/opt-125m"
,
gpu_memory_utilization
=
0.3
,
enforce_eager
=
True
)
for
i
in
range
(
3
)
]
for
llm
in
llms
:
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
tests/entrypoints/llm/test_lazy_outlines.py
View file @
551603fe
...
...
@@ -36,7 +36,7 @@ def run_lmfe(sample_regex):
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
enforce_eager
=
True
,
guided_decoding_backend
=
"lm-format-enforcer"
,
gpu_memory_utilization
=
0.
6
)
gpu_memory_utilization
=
0.
3
)
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
)
outputs
=
llm
.
generate
(
prompts
=
[
...
...
tests/test_utils.py
View file @
551603fe
...
...
@@ -5,11 +5,13 @@ from functools import partial
from
typing
import
AsyncIterator
,
Tuple
import
pytest
import
torch
from
vllm.utils
import
(
FlexibleArgumentParser
,
StoreBoolean
,
deprecate_kwargs
,
get_open_port
,
merge_async_iterators
,
supports_kw
)
get_open_port
,
memory_profiling
,
merge_async_iterators
,
supports_kw
)
from
.utils
import
error_on_warning
from
.utils
import
error_on_warning
,
fork_new_process_for_each_test
@
pytest
.
mark
.
asyncio
...
...
@@ -270,3 +272,41 @@ def test_supports_kw(callable,kw_name,requires_kw_only,
requires_kw_only
=
requires_kw_only
,
allow_var_kwargs
=
allow_var_kwargs
)
==
is_supported
@
fork_new_process_for_each_test
def
test_memory_profiling
():
# Fake out some model loading + inference memory usage to test profiling
# Memory used by other processes will show up as cuda usage outside of torch
from
vllm.distributed.device_communicators.cuda_wrapper
import
(
CudaRTLibrary
)
lib
=
CudaRTLibrary
()
# 512 MiB allocation outside of this instance
handle1
=
lib
.
cudaMalloc
(
512
*
1024
*
1024
)
baseline_memory_in_bytes
=
\
torch
.
cuda
.
mem_get_info
()[
1
]
-
torch
.
cuda
.
mem_get_info
()[
0
]
# load weights
weights
=
torch
.
randn
(
128
,
1024
,
1024
,
device
=
'cuda'
,
dtype
=
torch
.
float32
)
weights_memory_in_bytes
=
128
*
1024
*
1024
*
4
# 512 MiB
with
memory_profiling
(
baseline_memory_in_bytes
=
baseline_memory_in_bytes
,
weights_memory_in_bytes
=
weights_memory_in_bytes
)
as
result
:
# make a memory spike, 1 GiB
spike
=
torch
.
randn
(
256
,
1024
,
1024
,
device
=
'cuda'
,
dtype
=
torch
.
float32
)
del
spike
# Add some extra non-torch memory 256 MiB (simulate NCCL)
handle2
=
lib
.
cudaMalloc
(
256
*
1024
*
1024
)
# Check that the memory usage is within 5% of the expected values
non_torch_ratio
=
result
.
non_torch_increase_in_bytes
/
(
256
*
1024
*
1024
)
# noqa
torch_peak_ratio
=
result
.
torch_peak_increase_in_bytes
/
(
1024
*
1024
*
1024
)
# noqa
assert
abs
(
non_torch_ratio
-
1
)
<=
0.05
assert
abs
(
torch_peak_ratio
-
1
)
<=
0.05
del
weights
lib
.
cudaFree
(
handle1
)
lib
.
cudaFree
(
handle2
)
tests/worker/test_profile.py
View file @
551603fe
...
...
@@ -31,10 +31,6 @@ def test_gpu_memory_profiling():
is_driver_worker
=
True
,
)
# Load the model so we can profile it
worker
.
init_device
()
worker
.
load_model
()
# Set 10GiB as the total gpu ram to be device-agnostic
def
mock_mem_info
():
current_usage
=
torch
.
cuda
.
memory_stats
(
...
...
@@ -46,20 +42,24 @@ def test_gpu_memory_profiling():
from
unittest.mock
import
patch
with
patch
(
"torch.cuda.mem_get_info"
,
side_effect
=
mock_mem_info
):
# Load the model so we can profile it
worker
.
init_device
()
worker
.
load_model
()
gpu_blocks
,
_
=
worker
.
determine_num_available_blocks
()
# Peak vram usage by torch should be 0.7077 GiB
# Peak vram usage by torch should be 0.47 GiB
# Model weights take 0.25 GiB
# No memory should be allocated outside of torch
# 9.0 GiB should be the utilization target
# 8.2
923
GiB should be available for the KV cache
# 8.2
8
GiB should be available for the KV cache
block_size
=
CacheEngine
.
get_cache_block_size
(
engine_config
.
cache_config
,
engine_config
.
model_config
,
engine_config
.
parallel_config
)
expected_blocks
=
(
8.2
923
*
1024
**
3
)
//
block_size
expected_blocks
=
(
8.2
8
*
1024
**
3
)
//
block_size
# Check within a small tolerance for portability
# Hardware, kernel, or dependency changes could all affect memory
# utilization.
# A 10 block tolerance here should be about 6MB of wiggle room.
assert
abs
(
gpu_blocks
-
expected_blocks
)
<
10
# A 10
0
block tolerance here should be about 6
0
MB of wiggle room.
assert
abs
(
gpu_blocks
-
expected_blocks
)
<
10
0
vllm/engine/arg_utils.py
View file @
551603fe
...
...
@@ -487,11 +487,12 @@ class EngineArgs:
help
=
'The fraction of GPU memory to be used for the model '
'executor, which can range from 0 to 1. For example, a value of '
'0.5 would imply 50%% GPU memory utilization. If unspecified, '
'will use the default value of 0.9. This is a global gpu memory '
'utilization limit, for example if 50%% of the gpu memory is '
'already used before vLLM starts and --gpu-memory-utilization is '
'set to 0.9, then only 40%% of the gpu memory will be allocated '
'to the model executor.'
)
'will use the default value of 0.9. This is a per-instance '
'limit, and only applies to the current vLLM instance.'
'It does not matter if you have another vLLM instance running '
'on the same GPU. For example, if you have two vLLM instances '
'running on the same GPU, you can set the GPU memory utilization '
'to 0.5 for each instance.'
)
parser
.
add_argument
(
'--num-gpu-blocks-override'
,
type
=
int
,
...
...
vllm/utils.py
View file @
551603fe
...
...
@@ -23,10 +23,12 @@ import weakref
from
asyncio
import
FIRST_COMPLETED
,
AbstractEventLoop
,
Future
,
Task
from
collections
import
UserDict
,
defaultdict
from
collections.abc
import
Iterable
,
Mapping
from
dataclasses
import
dataclass
,
field
from
functools
import
lru_cache
,
partial
,
wraps
from
typing
import
(
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
Awaitable
,
Callable
,
Dict
,
Generic
,
Hashable
,
List
,
Literal
,
Optional
,
OrderedDict
,
Set
,
Tuple
,
Type
,
TypeVar
,
Union
,
overload
)
Dict
,
Generator
,
Generic
,
Hashable
,
List
,
Literal
,
Optional
,
OrderedDict
,
Set
,
Tuple
,
Type
,
TypeVar
,
Union
,
overload
)
from
uuid
import
uuid4
import
numpy
as
np
...
...
@@ -1664,3 +1666,122 @@ def kill_process_tree(pid: int):
# Finally kill the parent
with
contextlib
.
suppress
(
ProcessLookupError
):
os
.
kill
(
pid
,
signal
.
SIGKILL
)
@
dataclass
class
MemorySnapshot
:
"""Memory snapshot."""
torch_peak_in_bytes
:
int
=
0
torch_memory_in_bytes
:
int
=
0
timestamp
:
float
=
0.0
def
measure
(
self
):
self
.
torch_peak_in_bytes
=
torch
.
cuda
.
memory_stats
(
)[
"allocated_bytes.all.peak"
]
self
.
torch_memory_in_bytes
=
torch
.
cuda
.
memory_stats
(
)[
"allocated_bytes.all.current"
]
self
.
timestamp
=
time
.
time
()
def
__sub__
(
self
,
other
:
"MemorySnapshot"
)
->
"MemorySnapshot"
:
"""support a - b"""
return
MemorySnapshot
(
torch_peak_in_bytes
=
self
.
torch_peak_in_bytes
-
other
.
torch_peak_in_bytes
,
torch_memory_in_bytes
=
self
.
torch_memory_in_bytes
-
other
.
torch_memory_in_bytes
,
timestamp
=
self
.
timestamp
-
other
.
timestamp
)
@
dataclass
class
MemoryProfilingResult
:
"""Memory profiling result.
"""
# noqa
baseline_memory_in_bytes
:
int
=
0
non_kv_cache_memory_in_bytes
:
int
=
0
torch_peak_increase_in_bytes
:
int
=
0
non_torch_increase_in_bytes
:
int
=
0
weights_memory_in_bytes
:
float
=
0
before_profile
:
MemorySnapshot
=
field
(
default_factory
=
MemorySnapshot
)
after_profile
:
MemorySnapshot
=
field
(
default_factory
=
MemorySnapshot
)
profile_time
:
float
=
0.0
@
contextlib
.
contextmanager
def
memory_profiling
(
baseline_memory_in_bytes
:
int
,
weights_memory_in_bytes
:
int
)
->
Generator
[
MemoryProfilingResult
,
None
,
None
]:
"""Memory profiling context manager.
baseline_memory_in_bytes: memory used by all the components other than
the current vLLM instance. It contains: memory used by other processes, memory
used by another vLLM instance in the same process, etc. It is usually measured
before the current vLLM instance initialize the device. And we assume it is
constant during the profiling of the current vLLM instance.
weights_memory_in_bytes: memory used by PyTorch when loading the model weights.
Note that, before loading the model weights, we also initialize the device
and distributed environment, which may consume some memory. This part is not
included in the weights_memory_in_bytes because PyTorch does not control it.
The memory in one GPU can be classified into 3 categories:
1. memory used by anything other than the current vLLM instance.
2. memory used by torch in the current vLLM instance.
3. memory used in the current vLLM instance, but not by torch.
A quantitive example:
Before creating the current vLLM instance:
category 1: 1 GiB
category 2: 0 GiB
category 3: 0 GiB
After creating the current vLLM instance and loading the model,
(i.e. before profiling):
category 1: 1 GiB
category 2: 2 GiB (model weights take 2 GiB)
category 3: 0.5 GiB (memory used by NCCL)
During profiling (peak):
category 1: 1 GiB
category 2: 4 GiB (peak activation tensors take 2 GiB)
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
After profiling:
category 1: 1 GiB
category 2: 3 GiB (after garbage-collecting activation tensors)
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
In this case, non-kv cache takes 5 GiB in total, including:
a. 2 GiB used by the model weights (category 2)
b. 2 GiB reserved for the peak activation tensors (category 2)
c. 1 GiB used by non-torch components (category 3)
The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes`.
The increase of ``torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.).
(c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`),
subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_stats()["allocated_bytes.all.current"]`.
"""
# noqa
torch
.
cuda
.
reset_peak_memory_stats
()
result
=
MemoryProfilingResult
()
result
.
baseline_memory_in_bytes
=
baseline_memory_in_bytes
# the part of memory used for holding the model weights
result
.
weights_memory_in_bytes
=
weights_memory_in_bytes
result
.
before_profile
.
measure
()
yield
result
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
result
.
after_profile
.
measure
()
diff
=
result
.
after_profile
-
result
.
before_profile
result
.
torch_peak_increase_in_bytes
=
diff
.
torch_peak_in_bytes
current_cuda_memory_bytes
=
torch
.
cuda
.
mem_get_info
(
)[
1
]
-
torch
.
cuda
.
mem_get_info
()[
0
]
result
.
non_torch_increase_in_bytes
=
current_cuda_memory_bytes
-
baseline_memory_in_bytes
-
weights_memory_in_bytes
-
diff
.
torch_memory_in_bytes
# noqa
result
.
profile_time
=
diff
.
timestamp
result
.
non_kv_cache_memory_in_bytes
=
result
.
non_torch_increase_in_bytes
+
result
.
torch_peak_increase_in_bytes
+
result
.
weights_memory_in_bytes
# noqa
vllm/worker/multi_step_model_runner.py
View file @
551603fe
...
...
@@ -645,7 +645,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
return
model_input
def
load_model
(
self
)
->
None
:
return
self
.
_base_model_runner
.
load_model
()
self
.
_base_model_runner
.
load_model
()
self
.
model_memory_usage
=
self
.
_base_model_runner
.
model_memory_usage
def
save_sharded_state
(
self
,
...
...
vllm/worker/worker.py
View file @
551603fe
"""A GPU worker class."""
import
gc
import
os
import
time
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
import
torch
...
...
@@ -22,6 +21,7 @@ from vllm.platforms import current_platform
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
(
ExecuteModelRequest
,
IntermediateTensors
,
SequenceGroupMetadata
,
SequenceGroupMetadataDelta
)
from
vllm.utils
import
GiB_bytes
,
memory_profiling
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.enc_dec_model_runner
import
EncoderDecoderModelRunner
from
vllm.worker.model_runner
import
GPUModelRunnerBase
,
ModelRunner
...
...
@@ -192,33 +192,22 @@ class Worker(LocalOrDistributedWorkerBase):
torch
.
cuda
.
reset_peak_memory_stats
()
free_memory_pre_profile
,
total_gpu_memory
=
torch
.
cuda
.
mem_get_info
()
start_time
=
time
.
time
()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
with
memory_profiling
(
baseline_memory_in_bytes
=
total_gpu_memory
-
self
.
init_gpu_memory
,
weights_memory_in_bytes
=
self
.
model_runner
.
model_memory_usage
)
as
result
:
self
.
model_runner
.
profile_run
()
torch
.
cuda
.
synchronize
()
self
.
_assert_memory_footprint_increased_during_profiling
()
# Get the peak memory allocation recorded by torch
peak_memory
=
torch
.
cuda
.
memory_stats
()[
"allocated_bytes.all.peak"
]
# Check for any memory left around that may have been allocated on the
# gpu outside of `torch`. NCCL operations, for example, can use a few
# GB during a forward pass
torch
.
cuda
.
empty_cache
()
torch_allocated_bytes
=
torch
.
cuda
.
memory_stats
(
)[
"allocated_bytes.all.current"
]
total_allocated_bytes
=
torch
.
cuda
.
mem_get_info
(
)[
1
]
-
torch
.
cuda
.
mem_get_info
()[
0
]
non_torch_allocations
=
total_allocated_bytes
-
torch_allocated_bytes
if
non_torch_allocations
>
0
:
peak_memory
+=
non_torch_allocations
available_kv_cache_memory
=
(
total_gpu_memory
*
self
.
cache_config
.
gpu_memory_utilization
-
peak_memory
)
memory_for_current_instance
=
total_gpu_memory
*
\
self
.
cache_config
.
gpu_memory_utilization
available_kv_cache_memory
=
(
memory_for_current_instance
-
result
.
non_kv_cache_memory_in_bytes
)
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
...
...
@@ -233,24 +222,23 @@ class Worker(LocalOrDistributedWorkerBase):
num_gpu_blocks
=
max
(
num_gpu_blocks
,
0
)
num_cpu_blocks
=
max
(
num_cpu_blocks
,
0
)
end_time
=
time
.
time
()
logger
.
info
(
"Memory profiling results: "
"duration=%.2f seconds, "
"total_gpu_memory=%.2fGiB, "
"initial_memory_usage=%.2fGiB, "
"peak_torch_memory=%.2fGiB, "
"memory_usage_post_profile=%.2fGiB, "
"non_torch_memory=%.2fGiB, "
"kv_cache_size=%.2fGiB, "
"gpu_memory_utilization=%.2f."
,
end_time
-
start_time
,
total_gpu_memory
/
(
1024
**
3
),
(
total_gpu_memory
-
free_memory_pre_profile
)
/
(
1024
**
3
),
(
peak_memory
-
non_torch_allocations
)
/
(
1024
**
3
),
total_allocated_bytes
/
(
1024
**
3
),
non_torch_allocations
/
(
1024
**
3
),
available_kv_cache_memory
/
(
1024
**
3
),
self
.
cache_config
.
gpu_memory_utilization
)
msg
=
(
f
"Memory profiling takes
{
result
.
profile_time
:.
2
f
}
seconds
\n
"
"the current vLLM instance can use "
"total_gpu_memory "
f
"(
{
(
total_gpu_memory
/
GiB_bytes
):.
2
f
}
GiB)"
" x gpu_memory_utilization "
f
"(
{
self
.
cache_config
.
gpu_memory_utilization
:.
2
f
}
)"
f
" =
{
(
memory_for_current_instance
/
GiB_bytes
):.
2
f
}
GiB
\n
"
"model weights take "
f
"
{
(
result
.
weights_memory_in_bytes
/
GiB_bytes
):.
2
f
}
GiB;"
" non_torch_memory takes "
f
"
{
(
result
.
non_torch_increase_in_bytes
/
GiB_bytes
):.
2
f
}
GiB;"
" PyTorch activation peak memory takes "
f
"
{
(
result
.
torch_peak_increase_in_bytes
/
GiB_bytes
):.
2
f
}
GiB;"
" the rest of the memory reserved for KV Cache is "
f
"
{
(
available_kv_cache_memory
/
GiB_bytes
):.
2
f
}
GiB."
)
logger
.
info
(
msg
)
# Final cleanup
if
self
.
model_runner
.
lora_manager
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment