Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f9310cbd
Unverified
Commit
f9310cbd
authored
Nov 21, 2024
by
Woosuk Kwon
Committed by
GitHub
Nov 21, 2024
Browse files
[V1] Fix Compilation config & Enable CUDA graph by default (#10528)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
7560ae5c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
62 additions
and
42 deletions
+62
-42
vllm/config.py
vllm/config.py
+2
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+34
-28
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+26
-13
No files found.
vllm/config.py
View file @
f9310cbd
...
@@ -2370,7 +2370,7 @@ class VllmConfig:
...
@@ -2370,7 +2370,7 @@ class VllmConfig:
if
self
.
compilation_config
is
None
:
if
self
.
compilation_config
is
None
:
self
.
compilation_config
=
CompilationConfig
()
self
.
compilation_config
=
CompilationConfig
()
if
envs
.
VLLM_USE_V1
:
if
envs
.
VLLM_USE_V1
and
not
self
.
model_config
.
enforce_eager
:
# NOTE(woosuk): Currently, we use inductor because the piecewise
# NOTE(woosuk): Currently, we use inductor because the piecewise
# CUDA graphs do not work properly with the custom CUDA kernels.
# CUDA graphs do not work properly with the custom CUDA kernels.
# FIXME(woosuk): Disable inductor to reduce the compilation time
# FIXME(woosuk): Disable inductor to reduce the compilation time
...
@@ -2380,6 +2380,7 @@ class VllmConfig:
...
@@ -2380,6 +2380,7 @@ class VllmConfig:
self
.
compilation_config
.
use_inductor
=
True
self
.
compilation_config
.
use_inductor
=
True
self
.
compilation_config
.
pass_config
.
enable_fusion
=
False
self
.
compilation_config
.
pass_config
.
enable_fusion
=
False
self
.
compilation_config
.
pass_config
.
enable_reshape
=
False
self
.
compilation_config
.
pass_config
.
enable_reshape
=
False
self
.
compilation_config
.
level
=
CompilationLevel
.
PIECEWISE
current_platform
.
check_and_update_config
(
self
)
current_platform
.
check_and_update_config
(
self
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
f9310cbd
import
gc
import
time
import
time
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Set
,
Tuple
...
@@ -515,7 +516,25 @@ class GPUModelRunner:
...
@@ -515,7 +516,25 @@ class GPUModelRunner:
logger
.
info
(
"Loading model weights took %.4f GB"
,
logger
.
info
(
"Loading model weights took %.4f GB"
,
self
.
model_memory_usage
/
float
(
2
**
30
))
self
.
model_memory_usage
/
float
(
2
**
30
))
def
_dummy_run
(
self
,
model
:
nn
.
Module
,
num_tokens
:
int
)
->
None
:
@
torch
.
inference_mode
()
def
_dummy_run
(
self
,
model
:
nn
.
Module
,
num_tokens
:
int
,
kv_caches
:
List
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
with
set_forward_context
(
None
):
hidden_states
=
model
(
input_ids
=
None
,
positions
=
self
.
positions
[:
num_tokens
],
kv_caches
=
kv_caches
,
attn_metadata
=
None
,
inputs_embeds
=
self
.
inputs_embeds
[:
num_tokens
])
return
hidden_states
def
profile_run
(
self
)
->
None
:
# TODO(woosuk): Profile the max memory usage of the encoder and
# the encoder cache.
# use an empty tensor instead of `None`` to force Dynamo to pass
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value `None`.
# it by reference, rather by specializing on the value `None`.
# the `dtype` argument does not matter, and we use `float32` as
# the `dtype` argument does not matter, and we use `float32` as
...
@@ -527,23 +546,17 @@ class GPUModelRunner:
...
@@ -527,23 +546,17 @@ class GPUModelRunner:
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
for
_
in
range
(
self
.
num_attn_layers
)
for
_
in
range
(
self
.
num_attn_layers
)
]
]
with
set_forward_context
(
None
):
# noqa: SIM117
with
set_compile_context
(
self
.
cudagraph_batch_sizes
):
with
set_compile_context
(
self
.
cudagraph_batch_sizes
):
# Trigger compilation for general shape.
# Trigger compilation for general shape.
hidden_states
=
self
.
_dummy_run
(
self
.
model
,
self
.
max_num_tokens
,
model
(
input_ids
=
None
,
dummy_kv_caches
)
positions
=
self
.
positions
,
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
None
)
kv_caches
=
dummy_kv_caches
,
logits
=
logits
[:
self
.
max_num_tokens
]
attn_metadata
=
None
,
# TODO(woosuk): Consider the memory usage of the sampler.
inputs_embeds
=
self
.
inputs_embeds
)
@
torch
.
inference_mode
()
def
profile_run
(
self
)
->
None
:
# TODO(woosuk): Profile the max memory usage of the encoder and
# the encoder cache.
self
.
_dummy_run
(
self
.
model
,
self
.
max_num_tokens
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
del
hidden_states
,
logits
gc
.
collect
()
@
torch
.
inference_mode
()
def
capture_model
(
self
)
->
None
:
def
capture_model
(
self
)
->
None
:
if
not
self
.
use_cuda_graph
:
if
not
self
.
use_cuda_graph
:
logger
.
warning
(
logger
.
warning
(
...
@@ -554,18 +567,11 @@ class GPUModelRunner:
...
@@ -554,18 +567,11 @@ class GPUModelRunner:
start_time
=
time
.
perf_counter
()
start_time
=
time
.
perf_counter
()
start_free_gpu_memory
=
torch
.
cuda
.
mem_get_info
()[
0
]
start_free_gpu_memory
=
torch
.
cuda
.
mem_get_info
()[
0
]
with
set_forward_context
(
None
):
# Trigger CUDA graph capture for specific shapes.
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
# can reuse the memory pool allocated for the large shapes.
for
num_tokens
in
reversed
(
self
.
cudagraph_batch_sizes
):
for
num_tokens
in
reversed
(
self
.
cudagraph_batch_sizes
):
self
.
_dummy_run
(
self
.
model
,
num_tokens
,
self
.
kv_caches
)
self
.
model
(
input_ids
=
None
,
positions
=
self
.
positions
[:
num_tokens
],
kv_caches
=
self
.
kv_caches
,
attn_metadata
=
None
,
inputs_embeds
=
self
.
inputs_embeds
[:
num_tokens
],
)
end_time
=
time
.
perf_counter
()
end_time
=
time
.
perf_counter
()
end_free_gpu_memory
=
torch
.
cuda
.
mem_get_info
()[
0
]
end_free_gpu_memory
=
torch
.
cuda
.
mem_get_info
()[
0
]
...
...
vllm/v1/worker/gpu_worker.py
View file @
f9310cbd
...
@@ -105,35 +105,48 @@ class Worker:
...
@@ -105,35 +105,48 @@ class Worker:
# Profile the memory usage of the model and get the maximum number of
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
# cache blocks that can be allocated with the remaining free memory.
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
reset_peak_memory_stats
()
_
,
total_gpu_memory
=
torch
.
cuda
.
mem_get_info
()
# Execute a forward pass with dummy inputs to profile the memory usage
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
# of the model.
self
.
model_runner
.
profile_run
()
self
.
model_runner
.
profile_run
()
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
free_gpu_memory
,
total_gpu_memory
=
torch
.
cuda
.
mem_get_info
()
free_gpu_memory
,
_
=
torch
.
cuda
.
mem_get_info
()
# NOTE(woosuk): Here we assume that the other processes using the same
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
# GPU did not change their memory usage during the profiling.
peak_memory
=
self
.
init_gpu_memory
-
free_gpu_memory
assert
self
.
init_gpu_memory
>
free_gpu_memory
,
(
assert
peak_memory
>
0
,
(
"Error in memory profiling. "
"Error in memory profiling. "
f
"Initial free memory
{
self
.
init_gpu_memory
}
, current free memory"
f
"Initial free memory
{
self
.
init_gpu_memory
}
, current free memory"
f
"
{
free_gpu_memory
}
. This happens when the GPU memory was "
f
"
{
free_gpu_memory
}
. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance."
)
"not properly cleaned up before initializing the vLLM instance."
)
# Get the peak memory allocation recorded by torch
peak_memory
=
torch
.
cuda
.
memory_stats
()[
"allocated_bytes.all.peak"
]
# Check for any memory left around that may have been allocated on the
# gpu outside of `torch`. NCCL operations, for example, can use a few
# GB during a forward pass
torch
.
cuda
.
empty_cache
()
torch_allocated_bytes
=
torch
.
cuda
.
memory_stats
(
)[
"allocated_bytes.all.current"
]
total_allocated_bytes
=
torch
.
cuda
.
mem_get_info
(
)[
1
]
-
torch
.
cuda
.
mem_get_info
()[
0
]
non_torch_allocations
=
total_allocated_bytes
-
torch_allocated_bytes
if
non_torch_allocations
>
0
:
peak_memory
+=
non_torch_allocations
available_kv_cache_memory
=
(
total_gpu_memory
*
self
.
cache_config
.
gpu_memory_utilization
-
peak_memory
)
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
cache_block_size
=
_get_cache_block_size
(
self
.
cache_config
,
cache_block_size
=
_get_cache_block_size
(
self
.
cache_config
,
self
.
model_config
,
self
.
model_config
,
self
.
parallel_config
)
self
.
parallel_config
)
num_gpu_blocks
=
int
(
num_gpu_blocks
=
int
(
available_kv_cache_memory
//
cache_block_size
)
(
total_gpu_memory
*
self
.
cache_config
.
gpu_memory_utilization
-
peak_memory
)
//
cache_block_size
)
num_gpu_blocks
=
max
(
num_gpu_blocks
,
0
)
num_gpu_blocks
=
max
(
num_gpu_blocks
,
0
)
# if self.model_runner.lora_manager:
# self.model_runner.remove_all_loras()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
return
num_gpu_blocks
,
0
return
num_gpu_blocks
,
0
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
)
->
None
:
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
)
->
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment