Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
40899855
Unverified
Commit
40899855
authored
Nov 05, 2024
by
Woosuk Kwon
Committed by
GitHub
Nov 05, 2024
Browse files
[V1] Integrate Piecewise CUDA graphs (#10058)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
9d59b755
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
133 additions
and
36 deletions
+133
-36
vllm/compilation/backends.py
vllm/compilation/backends.py
+5
-2
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+21
-14
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+107
-20
No files found.
vllm/compilation/backends.py
View file @
40899855
...
...
@@ -496,7 +496,10 @@ class PiecewiseBackend:
return
entry
.
runnable
(
*
args
)
if
self
.
is_first_graph
:
logger
.
info
(
"Capturing a cudagraph for shape %s"
,
# Since we capture cudagraph for many different shapes and
# capturing is fast, we don't need to log it for every shape.
# We only log it in the debug mode.
logger
.
debug
(
"Capturing a cudagraph for shape %s"
,
runtime_shape
)
input_addresses
=
[
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
40899855
...
...
@@ -51,6 +51,7 @@ class FlashAttentionMetadata:
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens
:
int
# Number of tokens excluding padding.
max_query_len
:
int
query_start_loc
:
torch
.
Tensor
max_seq_len
:
int
...
...
@@ -134,7 +135,9 @@ class FlashAttentionImpl(AttentionImpl):
assert
k_scale
==
1.0
and
v_scale
==
1.0
,
(
"key/v_scale is not supported in FlashAttention."
)
output
=
torch
.
ops
.
vllm
.
unified_flash_attention
(
output
=
torch
.
empty_like
(
query
)
torch
.
ops
.
vllm
.
unified_flash_attention
(
output
,
query
,
key
,
value
,
...
...
@@ -154,6 +157,7 @@ class FlashAttentionImpl(AttentionImpl):
def
unified_flash_attention
(
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
...
...
@@ -168,17 +172,17 @@ def unified_flash_attention(
window_size
:
Optional
[
List
[
int
]]
=
None
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
)
->
None
:
current_metadata
=
get_forward_context
()
if
current_metadata
is
None
:
# Profiling run.
return
torch
.
empty_like
(
query
)
return
assert
current_metadata
is
not
None
assert
isinstance
(
current_metadata
,
FlashAttentionMetadata
)
attn_metadata
:
FlashAttentionMetadata
=
current_metadata
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
key
=
key
.
view
(
-
1
,
num_kv_heads
,
head_size
)
...
...
@@ -188,18 +192,18 @@ def unified_flash_attention(
key_cache
=
kv_cache
[
0
]
value_cache
=
kv_cache
[
1
]
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
k
v
_cache
[
0
]
,
k
v_cache
[
1
]
,
key
[:
num_actual_tokens
]
,
value
[:
num_actual_tokens
]
,
k
ey
_cache
,
v
alue
_cache
,
attn_metadata
.
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
output
=
flash_attn_varlen_func
(
q
=
query
,
attn_
output
=
flash_attn_varlen_func
(
q
=
query
[:
num_actual_tokens
]
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
attn_metadata
.
query_start_loc
,
...
...
@@ -213,10 +217,13 @@ def unified_flash_attention(
block_table
=
attn_metadata
.
block_table
,
softcap
=
logits_soft_cap
,
)
return
output
.
view
(
num_tokens
,
hidden_size
)
attn_output
=
attn_output
.
view
(
num_actual_tokens
,
-
1
)
# TODO(woosuk): Optimize this.
output
[:
num_actual_tokens
].
copy_
(
attn_output
)
def
unified_flash_attention_fake
(
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
...
...
@@ -231,13 +238,13 @@ def unified_flash_attention_fake(
window_size
:
Optional
[
List
[
int
]]
=
None
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
query
)
)
->
None
:
return
direct_register_custom_op
(
op_name
=
"unified_flash_attention"
,
op_func
=
unified_flash_attention
,
mutates_args
=
[
"kv_cache"
],
mutates_args
=
[
"kv_cache"
,
"output"
],
fake_impl
=
unified_flash_attention_fake
,
)
vllm/v1/worker/gpu_model_runner.py
View file @
40899855
import
os
import
time
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Set
from
unittest.mock
import
patch
...
...
@@ -7,11 +9,16 @@ import torch
import
torch.distributed
import
torch.nn
as
nn
from
vllm
import
envs
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.config
import
CompilationConfig
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.plugins
import
set_compilation_config
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
cdiv
,
is_pin_memory_available
)
...
...
@@ -86,6 +93,18 @@ class GPUModelRunner:
pin_memory
=
self
.
pin_memory
,
)
self
.
use_cuda_graph
=
(
envs
.
VLLM_TORCH_COMPILE_LEVEL
==
CompilationLevel
.
PIECEWISE
and
not
self
.
model_config
.
enforce_eager
)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
self
.
cudagraph_batch_sizes
=
[
1
,
2
,
4
]
+
[
i
for
i
in
range
(
8
,
513
,
8
)]
self
.
input_ids
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
positions
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
def
_update_states
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
# Remove stopped requests from the cached states.
# Keep the states of the pre-empted requests.
...
...
@@ -268,12 +287,16 @@ class GPUModelRunner:
seq_start_loc_np
[
0
]
=
0
np
.
cumsum
(
seq_lens
,
out
=
seq_start_loc_np
[
1
:])
input_ids
=
input_ids
.
to
(
self
.
device
,
non_blocking
=
True
)
positions
=
positions
.
to
(
self
.
device
,
non_blocking
=
True
).
long
()
self
.
input_ids
[:
total_num_scheduled_tokens
].
copy_
(
input_ids
,
non_blocking
=
True
)
self
.
positions
[:
total_num_scheduled_tokens
].
copy_
(
positions
,
non_blocking
=
True
)
query_start_loc
=
query_start_loc
.
to
(
self
.
device
,
non_blocking
=
True
)
seq_start_loc
=
seq_start_loc
.
to
(
self
.
device
,
non_blocking
=
True
)
slot_mapping
=
slot_mapping
.
to
(
self
.
device
,
non_blocking
=
True
).
long
()
attn_metadata
=
FlashAttentionMetadata
(
num_actual_tokens
=
total_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
,
query_start_loc
=
query_start_loc
,
max_seq_len
=
max_seq_len
,
...
...
@@ -287,7 +310,7 @@ class GPUModelRunner:
# token from the partial request.
# TODO: Support prompt logprobs.
logits_indices
=
query_start_loc
[
1
:]
-
1
return
input_ids
,
positions
,
attn_metadata
,
logits_indices
return
attn_metadata
,
logits_indices
def
_prepare_sampling
(
self
,
...
...
@@ -310,16 +333,26 @@ class GPUModelRunner:
scheduler_output
:
"SchedulerOutput"
,
)
->
ModelRunnerOutput
:
self
.
_update_states
(
scheduler_output
)
inputs
=
self
.
_prepare_inputs
(
scheduler_output
)
input_ids
,
positions
,
attn_metadata
,
logits_indices
=
inputs
attn_metadata
,
logits_indices
=
self
.
_prepare_inputs
(
scheduler_output
)
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
if
(
self
.
use_cuda_graph
and
num_scheduled_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens
=
self
.
_get_padded_batch_size
(
num_scheduled_tokens
)
else
:
# Eager mode.
num_input_tokens
=
num_scheduled_tokens
with
set_forward_context
(
attn_metadata
):
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
input_ids
=
self
.
input_ids
[:
num_input_tokens
]
,
positions
=
self
.
positions
[:
num_input_tokens
]
,
kv_caches
=
self
.
kv_caches
,
attn_metadata
=
attn_metadata
,
attn_metadata
=
None
,
)
hidden_states
=
hidden_states
[:
num_scheduled_tokens
]
hidden_states
=
hidden_states
[
logits_indices
]
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
None
)
...
...
@@ -371,6 +404,18 @@ class GPUModelRunner:
return
model_runner_output
def
load_model
(
self
)
->
None
:
if
self
.
use_cuda_graph
:
# FIXME(woosuk): Currently, the custom ops are not supported
# in the piecewise compilation mode. We rely on TorchInductor
# to optimize the model.
os
.
environ
[
"VLLM_CUSTOM_OPS"
]
=
"none"
set_compilation_config
(
CompilationConfig
(
use_cudagraph
=
True
,
non_cudagraph_ops
=
[
"vllm.unified_flash_attention"
],
use_inductor
=
True
,
))
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
with
DeviceMemoryProfiler
()
as
m
:
# noqa: SIM117
with
patch
(
"vllm.model_executor.layers.sampler.Sampler"
,
Sampler
):
...
...
@@ -381,27 +426,62 @@ class GPUModelRunner:
self
.
model_memory_usage
/
float
(
2
**
30
))
def
_dummy_run
(
self
,
model
:
nn
.
Module
,
num_tokens
:
int
)
->
None
:
input_ids
=
torch
.
zeros
(
num_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
positions
=
torch
.
zeros
(
num_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
kv_caches
=
[
None
for
_
in
range
(
self
.
num_attn_layers
)]
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
=
None
)
return
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value `None`.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
# it is important to create tensors inside the loop, rather than
# multiplying the list, to avoid Dynamo from treating them as
# tensor aliasing.
dummy_kv_caches
=
[
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
for
_
in
range
(
self
.
num_attn_layers
)
]
with
set_forward_context
(
None
):
# noqa: SIM117
with
set_compile_context
(
self
.
cudagraph_batch_sizes
):
# Trigger compilation for general shape.
model
(
self
.
input_ids
,
self
.
positions
,
dummy_kv_caches
,
attn_metadata
=
None
)
@
torch
.
inference_mode
()
def
profile_run
(
self
)
->
None
:
self
.
_dummy_run
(
self
.
model
,
self
.
max_num_tokens
)
torch
.
cuda
.
synchronize
()
return
@
torch
.
inference_mode
()
def
capture_model
(
self
)
->
None
:
# TODO: Implement CUDA graph support.
if
not
self
.
use_cuda_graph
:
logger
.
warning
(
"Skipping CUDA graph capture. Please set "
"VLLM_TORCH_COMPILE_LEVEL=%d to use CUDA graphs."
,
CompilationLevel
.
PIECEWISE
)
return
start_time
=
time
.
perf_counter
()
start_free_gpu_memory
=
torch
.
cuda
.
mem_get_info
()[
0
]
with
set_forward_context
(
None
):
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
for
num_tokens
in
reversed
(
self
.
cudagraph_batch_sizes
):
self
.
model
(
self
.
input_ids
[:
num_tokens
],
self
.
positions
[:
num_tokens
],
kv_caches
=
self
.
kv_caches
,
attn_metadata
=
None
,
)
end_time
=
time
.
perf_counter
()
end_free_gpu_memory
=
torch
.
cuda
.
mem_get_info
()[
0
]
elapsed_time
=
end_time
-
start_time
cuda_graph_size
=
start_free_gpu_memory
-
end_free_gpu_memory
# This usually takes 5~20 seconds.
logger
.
info
(
"Graph capturing finished in %.0f secs, took %.2f GiB"
,
elapsed_time
,
cuda_graph_size
/
(
1
<<
30
))
def
initialize_kv_cache
(
self
,
num_blocks
:
int
)
->
None
:
assert
len
(
self
.
kv_caches
)
==
0
kv_cache_shape
=
FlashAttentionBackend
.
get_kv_cache_shape
(
...
...
@@ -412,6 +492,13 @@ class GPUModelRunner:
dtype
=
self
.
kv_cache_dtype
,
device
=
self
.
device
))
def
_get_padded_batch_size
(
self
,
batch_size
:
int
)
->
Optional
[
int
]:
# TODO: Optimize this?
for
size
in
self
.
cudagraph_batch_sizes
:
if
batch_size
<=
size
:
return
size
return
None
@
dataclass
class
CachedRequestState
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment