Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
11d3976b
Unverified
Commit
11d3976b
authored
Feb 19, 2026
by
zhrrr
Committed by
GitHub
Feb 18, 2026
Browse files
[Model Runner V2] support piecewise & mixed cudagraph (#32771)
Signed-off-by:
zhuhaoran
<
zhuhaoran.zhr@alibaba-inc.com
>
parent
40da9625
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
343 additions
and
108 deletions
+343
-108
vllm/v1/worker/gpu/cudagraph_utils.py
vllm/v1/worker/gpu/cudagraph_utils.py
+179
-58
vllm/v1/worker/gpu/dp_utils.py
vllm/v1/worker/gpu/dp_utils.py
+36
-19
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+26
-14
vllm/v1/worker/gpu/spec_decode/eagle.py
vllm/v1/worker/gpu/spec_decode/eagle.py
+27
-9
vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py
vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py
+75
-8
No files found.
vllm/v1/worker/gpu/cudagraph_utils.py
View file @
11d3976b
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
,
Iterable
from
collections.abc
import
Callable
from
typing
import
Any
from
typing
import
Any
import
numpy
as
np
import
numpy
as
np
...
@@ -11,7 +11,8 @@ from tqdm import tqdm
...
@@ -11,7 +11,8 @@ from tqdm import tqdm
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.distributed.parallel_state
import
graph_capture
,
is_global_first_rank
from
vllm.distributed.parallel_state
import
graph_capture
,
is_global_first_rank
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
BatchDescriptor
,
set_forward_context
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.attention.backend
import
AttentionMetadataBuilder
from
vllm.v1.attention.backend
import
AttentionMetadataBuilder
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.worker.gpu.attn_utils
import
(
from
vllm.v1.worker.gpu.attn_utils
import
(
...
@@ -34,14 +35,27 @@ class CudaGraphManager:
...
@@ -34,14 +35,27 @@ class CudaGraphManager:
self
.
max_num_reqs
=
self
.
scheduler_config
.
max_num_seqs
self
.
max_num_reqs
=
self
.
scheduler_config
.
max_num_seqs
self
.
max_num_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
self
.
max_num_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
self
.
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
self
.
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
self
.
uniform_decode_query_len
=
1
spec_config
=
vllm_config
.
speculative_config
if
spec_config
is
not
None
:
self
.
uniform_decode_query_len
+=
spec_config
.
num_speculative_tokens
self
.
compilation_config
=
vllm_config
.
compilation_config
self
.
compilation_config
=
vllm_config
.
compilation_config
assert
self
.
compilation_config
is
not
None
assert
self
.
compilation_config
is
not
None
self
.
cudagraph_mode
=
self
.
compilation_config
.
cudagraph_mode
self
.
cudagraph_mode
=
self
.
compilation_config
.
cudagraph_mode
self
.
cudagraph_sizes
=
get_cudagraph_sizes
(
use_uniform_decode_cudagraph
=
(
self
.
cudagraph_mode
.
decode_mode
()
==
CUDAGraphMode
.
FULL
and
self
.
cudagraph_mode
.
separate_routine
()
)
self
.
cudagraph_sizes
,
self
.
uniform_decode_cudagraph_sizes
=
get_cudagraph_sizes
(
self
.
compilation_config
.
cudagraph_capture_sizes
,
self
.
compilation_config
.
cudagraph_capture_sizes
,
self
.
max_num_reqs
,
self
.
max_num_reqs
,
self
.
max_num_tokens
,
self
.
max_num_tokens
,
self
.
cudagraph_mode
,
self
.
cudagraph_mode
,
self
.
uniform_decode_query_len
,
use_uniform_decode_cudagraph
,
)
)
self
.
graphs
:
dict
[
int
,
torch
.
cuda
.
CUDAGraph
]
=
{}
self
.
graphs
:
dict
[
int
,
torch
.
cuda
.
CUDAGraph
]
=
{}
...
@@ -54,20 +68,16 @@ class CudaGraphManager:
...
@@ -54,20 +68,16 @@ class CudaGraphManager:
return
len
(
self
.
cudagraph_sizes
)
>
0
return
len
(
self
.
cudagraph_sizes
)
>
0
def
get_cudagraph_size
(
def
get_cudagraph_size
(
self
,
self
,
num_tokens
:
int
,
uniform_decode
:
bool
=
False
num_tokens_after_padding
:
int
,
num_tokens_per_request
:
Iterable
[
int
],
)
->
int
|
None
:
)
->
int
|
None
:
return
get_cudagraph_size
(
if
uniform_decode
and
self
.
uniform_decode_cudagraph_sizes
:
num_tokens_after_padding
,
return
self
.
uniform_decode_cudagraph_sizes
.
get
(
num_tokens
)
num_tokens_per_request
,
return
self
.
cudagraph_sizes
.
get
(
num_tokens
)
self
.
cudagraph_sizes
,
self
.
cudagraph_mode
,
)
def
capture_graph
(
def
capture_graph
(
self
,
self
,
num_tokens
:
int
,
num_tokens
:
int
,
capture_cg_mode
:
CUDAGraphMode
,
model
:
nn
.
Module
,
model
:
nn
.
Module
,
input_buffers
:
InputBuffers
,
input_buffers
:
InputBuffers
,
mrope_positions
:
torch
.
Tensor
|
None
,
mrope_positions
:
torch
.
Tensor
|
None
,
...
@@ -75,8 +85,25 @@ class CudaGraphManager:
...
@@ -75,8 +85,25 @@ class CudaGraphManager:
block_tables
:
BlockTables
,
block_tables
:
BlockTables
,
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
],
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
],
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
has_lora
:
bool
=
False
,
uniform_decode
:
bool
=
False
,
)
->
None
:
)
->
None
:
num_reqs
=
min
(
num_tokens
,
self
.
max_num_reqs
)
# select and check capture function
assert
capture_cg_mode
in
[
CUDAGraphMode
.
PIECEWISE
,
CUDAGraphMode
.
FULL
],
(
f
"Invalid capture_cudagraph_mode for capture:
{
capture_cg_mode
}
"
)
if
capture_cg_mode
==
CUDAGraphMode
.
PIECEWISE
:
capture_fn
=
self
.
_capture_piecewise_graph
else
:
capture_fn
=
self
.
_capture_full_graph
# prepare inputs
if
uniform_decode
:
num_reqs
=
min
(
cdiv
(
num_tokens
,
self
.
uniform_decode_query_len
),
self
.
max_num_reqs
,
)
else
:
num_reqs
=
min
(
num_tokens
,
self
.
max_num_reqs
)
input_ids
=
input_buffers
.
input_ids
[:
num_tokens
]
input_ids
=
input_buffers
.
input_ids
[:
num_tokens
]
positions
=
input_buffers
.
positions
[:
num_tokens
]
positions
=
input_buffers
.
positions
[:
num_tokens
]
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
...
@@ -92,6 +119,9 @@ class CudaGraphManager:
...
@@ -92,6 +119,9 @@ class CudaGraphManager:
attn_metadata_builders
,
attn_metadata_builders
,
self
.
max_model_len
,
self
.
max_model_len
,
kv_cache_config
,
kv_cache_config
,
uniform_decode_query_len
=
(
self
.
uniform_decode_query_len
if
uniform_decode
else
0
),
)
)
num_tokens_across_dp
=
make_num_tokens_across_dp
(
self
.
dp_size
,
num_tokens
)
num_tokens_across_dp
=
make_num_tokens_across_dp
(
self
.
dp_size
,
num_tokens
)
...
@@ -112,13 +142,40 @@ class CudaGraphManager:
...
@@ -112,13 +142,40 @@ class CudaGraphManager:
if
self
.
hidden_states
is
None
:
if
self
.
hidden_states
is
None
:
self
.
hidden_states
=
torch
.
empty_like
(
hidden_states
)
self
.
hidden_states
=
torch
.
empty_like
(
hidden_states
)
capture_fn
(
num_tokens
=
num_tokens
,
num_reqs
=
num_reqs
,
model
=
model
,
input_ids
=
input_ids
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
num_tokens_across_dp
=
num_tokens_across_dp
,
attn_metadata
=
attn_metadata
,
slot_mappings
=
slot_mappings
,
has_lora
=
has_lora
,
)
def
_capture_full_graph
(
self
,
num_tokens
:
int
,
num_reqs
:
int
,
model
:
nn
.
Module
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
|
None
,
num_tokens_across_dp
:
torch
.
Tensor
,
attn_metadata
:
dict
[
str
,
Any
]
|
None
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
|
None
,
has_lora
:
bool
=
False
,
)
->
None
:
assert
attn_metadata
is
not
None
# Capture the graph.
# Capture the graph.
assert
num_tokens
not
in
self
.
graphs
assert
num_tokens
not
in
self
.
graphs
graph
=
torch
.
cuda
.
CUDAGraph
()
graph
=
torch
.
cuda
.
CUDAGraph
()
with
(
with
(
set_forward_context
(
set_forward_context
(
attn_metadata
,
attn_metadata
=
attn_metadata
,
self
.
vllm_config
,
vllm_config
=
self
.
vllm_config
,
num_tokens
=
num_tokens
,
num_tokens
=
num_tokens
,
cudagraph_runtime_mode
=
CUDAGraphMode
.
NONE
,
cudagraph_runtime_mode
=
CUDAGraphMode
.
NONE
,
num_tokens_across_dp
=
num_tokens_across_dp
,
num_tokens_across_dp
=
num_tokens_across_dp
,
...
@@ -131,9 +188,44 @@ class CudaGraphManager:
...
@@ -131,9 +188,44 @@ class CudaGraphManager:
positions
=
positions
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
)
)
assert
self
.
hidden_states
is
not
None
self
.
hidden_states
[:
num_tokens
]
=
hidden_states
self
.
hidden_states
[:
num_tokens
]
=
hidden_states
self
.
graphs
[
num_tokens
]
=
graph
self
.
graphs
[
num_tokens
]
=
graph
def
_capture_piecewise_graph
(
self
,
num_tokens
:
int
,
num_reqs
:
int
,
model
:
nn
.
Module
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
|
None
,
num_tokens_across_dp
:
torch
.
Tensor
,
attn_metadata
:
dict
[
str
,
Any
]
|
None
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
|
None
,
has_lora
:
bool
=
False
,
)
->
None
:
# create batch descriptor for piecewise cudagraph dispatch key
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
num_tokens
,
has_lora
=
has_lora
)
# Capture run - CUDAGraphWrapper inside torch.compile will auto capture.
with
set_forward_context
(
attn_metadata
=
None
,
# piecewise no need attn_metadata
vllm_config
=
self
.
vllm_config
,
num_tokens
=
num_tokens
,
cudagraph_runtime_mode
=
CUDAGraphMode
.
PIECEWISE
,
num_tokens_across_dp
=
num_tokens_across_dp
,
batch_descriptor
=
batch_descriptor
,
slot_mapping
=
slot_mappings
,
):
hidden_states
=
model
(
input_ids
=
input_ids
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
)
assert
self
.
hidden_states
is
not
None
self
.
hidden_states
[:
num_tokens
]
=
hidden_states
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
capture
(
def
capture
(
self
,
self
,
...
@@ -144,11 +236,11 @@ class CudaGraphManager:
...
@@ -144,11 +236,11 @@ class CudaGraphManager:
block_tables
:
BlockTables
,
block_tables
:
BlockTables
,
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
],
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
],
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
has_lora
:
bool
=
False
,
)
->
None
:
)
->
None
:
capture_graphs
(
common_kwargs
=
dict
(
self
.
cudagraph_sizes
,
device
=
self
.
device
,
self
.
device
,
capture_fn
=
self
.
capture_graph
,
self
.
capture_graph
,
model
=
model
,
model
=
model
,
input_buffers
=
input_buffers
,
input_buffers
=
input_buffers
,
mrope_positions
=
mrope_positions
,
mrope_positions
=
mrope_positions
,
...
@@ -156,10 +248,50 @@ class CudaGraphManager:
...
@@ -156,10 +248,50 @@ class CudaGraphManager:
block_tables
=
block_tables
,
block_tables
=
block_tables
,
attn_metadata_builders
=
attn_metadata_builders
,
attn_metadata_builders
=
attn_metadata_builders
,
kv_cache_config
=
kv_cache_config
,
kv_cache_config
=
kv_cache_config
,
has_lora
=
has_lora
,
)
)
def
run
(
self
,
num_tokens
:
int
)
->
torch
.
Tensor
:
# Phase 1: Capture for mixed prefill-decode batches if needed.
assert
num_tokens
in
self
.
graphs
mixed_mode
=
self
.
cudagraph_mode
.
mixed_mode
()
if
mixed_mode
!=
CUDAGraphMode
.
NONE
:
capture_graphs
(
cudagraph_sizes
=
self
.
cudagraph_sizes
,
capture_cudagraph_mode
=
mixed_mode
,
desc
=
f
"Capturing CUDA graphs (mixed,
{
mixed_mode
.
name
}
)"
,
uniform_decode
=
False
,
**
common_kwargs
,
)
# Phase 2: Capture FULL graphs for uniform decode batches if needed.
# This is only needed if we use a separate routine for decode batches
# and the decode_mode is FULL.
if
self
.
uniform_decode_cudagraph_sizes
:
capture_graphs
(
cudagraph_sizes
=
self
.
uniform_decode_cudagraph_sizes
,
capture_cudagraph_mode
=
CUDAGraphMode
.
FULL
,
desc
=
"Capturing CUDA graphs (decode, FULL)"
,
uniform_decode
=
True
,
**
common_kwargs
,
)
def
get_cudagraph_runtime_mode
(
self
,
num_reqs
:
int
,
num_tokens
:
int
,
max_query_len
:
int
)
->
tuple
[
CUDAGraphMode
,
int
|
None
]:
is_uniform_decode
=
(
max_query_len
==
self
.
uniform_decode_query_len
)
and
(
num_tokens
==
max_query_len
*
num_reqs
)
cudagraph_size
=
self
.
get_cudagraph_size
(
num_tokens
,
is_uniform_decode
)
if
cudagraph_size
is
None
:
cudagraph_mode
=
CUDAGraphMode
.
NONE
elif
is_uniform_decode
:
cudagraph_mode
=
self
.
cudagraph_mode
.
decode_mode
()
else
:
cudagraph_mode
=
self
.
cudagraph_mode
.
mixed_mode
()
return
cudagraph_mode
,
cudagraph_size
def
run_fullgraph
(
self
,
num_tokens
:
int
)
->
torch
.
Tensor
:
assert
num_tokens
in
self
.
graphs
,
f
"No cudagraph for
{
num_tokens
}
tokens"
self
.
graphs
[
num_tokens
].
replay
()
self
.
graphs
[
num_tokens
].
replay
()
assert
self
.
hidden_states
is
not
None
assert
self
.
hidden_states
is
not
None
return
self
.
hidden_states
[:
num_tokens
]
return
self
.
hidden_states
[:
num_tokens
]
...
@@ -170,22 +302,18 @@ def get_cudagraph_sizes(
...
@@ -170,22 +302,18 @@ def get_cudagraph_sizes(
max_num_reqs
:
int
,
max_num_reqs
:
int
,
max_num_tokens
:
int
,
max_num_tokens
:
int
,
cudagraph_mode
:
CUDAGraphMode
,
cudagraph_mode
:
CUDAGraphMode
,
)
->
dict
[
int
,
int
]:
uniform_decode_query_len
:
int
=
1
,
if
not
cudagraph_mode
.
has_full_cudagraphs
():
uniform_decode_cudagraph
:
bool
=
False
,
return
{}
)
->
tuple
[
dict
[
int
,
int
],
dict
[
int
,
int
]]:
# Support both FULL and PIECEWISE cudagraph modes
if
cudagraph_mode
==
CUDAGraphMode
.
NONE
:
return
{},
{}
if
not
capture_sizes
:
if
not
capture_sizes
:
return
{}
return
{}
,
{}
capture_sizes
=
sorted
(
capture_sizes
)
capture_sizes
=
sorted
(
capture_sizes
)
# Limit the capture sizes to the max number of requests or tokens.
upper_bound
=
(
max_num_reqs
if
cudagraph_mode
==
CUDAGraphMode
.
FULL_DECODE_ONLY
else
max_num_tokens
)
capture_sizes
=
[
x
for
x
in
capture_sizes
if
x
<=
upper_bound
]
if
not
capture_sizes
:
if
not
capture_sizes
:
return
{}
return
{}
,
{}
cudagraph_sizes
:
dict
[
int
,
int
]
=
{}
cudagraph_sizes
:
dict
[
int
,
int
]
=
{}
for
i
in
range
(
1
,
capture_sizes
[
-
1
]
+
1
):
for
i
in
range
(
1
,
capture_sizes
[
-
1
]
+
1
):
...
@@ -193,45 +321,34 @@ def get_cudagraph_sizes(
...
@@ -193,45 +321,34 @@ def get_cudagraph_sizes(
if
i
<=
x
:
if
i
<=
x
:
cudagraph_sizes
[
i
]
=
x
cudagraph_sizes
[
i
]
=
x
break
break
return
cudagraph_sizes
def
get_cudagraph_size
(
num_tokens_after_dp_padding
:
int
,
num_tokens_per_request
:
Iterable
[
int
],
cudagraph_sizes
:
dict
[
int
,
int
],
cudagraph_mode
:
CUDAGraphMode
,
)
->
int
|
None
:
if
not
cudagraph_mode
.
has_full_cudagraphs
():
# No full CUDA graph is used.
return
None
size
=
cudagraph_sizes
.
get
(
num_tokens_after_dp_padding
)
if
size
is
None
:
# No CUDA graph for this size.
return
None
is_mixed
=
any
(
x
>
1
for
x
in
num_tokens_per_request
)
uniform_decode_cudagraph_sizes
:
dict
[
int
,
int
]
=
{}
if
is_mixed
and
cudagraph_mode
.
mixed_mode
()
!=
CUDAGraphMode
.
FULL
:
if
uniform_decode_cudagraph
:
# Prefill is included, and this mode doesn't use CUDA graph for it.
max_num_tokens
=
max_num_reqs
*
uniform_decode_query_len
return
None
uniform_decode_cudagraph_sizes
=
{
return
size
k
:
v
for
k
,
v
in
cudagraph_sizes
.
items
()
if
v
<=
max_num_tokens
and
v
>=
uniform_decode_query_len
}
return
cudagraph_sizes
,
uniform_decode_cudagraph_sizes
def
capture_graphs
(
def
capture_graphs
(
cudagraph_sizes
:
dict
[
int
,
int
],
cudagraph_sizes
:
dict
[
int
,
int
],
device
:
torch
.
device
,
device
:
torch
.
device
,
capture_fn
:
Callable
,
capture_fn
:
Callable
,
capture_cudagraph_mode
:
CUDAGraphMode
,
desc
:
str
=
"Capturing CUDA graphs"
,
**
capture_kwargs
,
**
capture_kwargs
,
)
->
None
:
)
->
None
:
# Capture larger graphs first.
# Capture larger graphs first.
sizes_to_capture
=
sorted
(
set
(
cudagraph_sizes
.
values
()),
reverse
=
True
)
sizes_to_capture
=
sorted
(
set
(
cudagraph_sizes
.
values
()),
reverse
=
True
)
if
is_global_first_rank
():
if
is_global_first_rank
():
sizes_to_capture
=
tqdm
(
sizes_to_capture
,
desc
=
"Capturing CUDA graphs"
)
sizes_to_capture
=
tqdm
(
sizes_to_capture
,
desc
=
desc
)
with
graph_capture
(
device
=
device
):
with
graph_capture
(
device
=
device
):
for
size
in
sizes_to_capture
:
for
size
in
sizes_to_capture
:
capture_fn
(
size
,
**
capture_kwargs
)
capture_fn
(
size
,
capture_cudagraph_mode
,
**
capture_kwargs
)
def
prepare_inputs_to_capture
(
def
prepare_inputs_to_capture
(
...
@@ -242,8 +359,12 @@ def prepare_inputs_to_capture(
...
@@ -242,8 +359,12 @@ def prepare_inputs_to_capture(
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
],
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
],
max_model_len
:
int
,
max_model_len
:
int
,
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
uniform_decode_query_len
:
int
=
0
,
)
->
tuple
[
dict
[
str
,
Any
],
dict
[
str
,
torch
.
Tensor
]]:
)
->
tuple
[
dict
[
str
,
Any
],
dict
[
str
,
torch
.
Tensor
]]:
num_tokens_per_req
=
num_tokens
//
num_reqs
if
uniform_decode_query_len
>
0
:
num_tokens_per_req
=
uniform_decode_query_len
else
:
num_tokens_per_req
=
num_tokens
//
num_reqs
query_start_loc_np
=
np
.
arange
(
num_reqs
+
1
,
dtype
=
np
.
int32
)
*
num_tokens_per_req
query_start_loc_np
=
np
.
arange
(
num_reqs
+
1
,
dtype
=
np
.
int32
)
*
num_tokens_per_req
query_start_loc_np
[
-
1
]
=
num_tokens
query_start_loc_np
[
-
1
]
=
num_tokens
...
...
vllm/v1/worker/gpu/dp_utils.py
View file @
11d3976b
...
@@ -13,48 +13,65 @@ def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | N
...
@@ -13,48 +13,65 @@ def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | N
def
get_batch_metadata_across_dp
(
def
get_batch_metadata_across_dp
(
num_tokens
:
int
,
cudagraph_size
:
int
,
dp_size
:
int
,
dp_rank
:
int
num_tokens
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
cudagraph_size
:
int
,
cudagraph_runtime_mode
:
int
,
dp_size
:
int
,
dp_rank
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
dp_size
>
1
assert
dp_size
>
1
# Use CPU group to avoid CPU-GPU synchronization.
# Use CPU group to avoid CPU-GPU synchronization.
group
=
get_dp_group
().
cpu_group
group
=
get_dp_group
().
cpu_group
tensor
=
torch
.
zeros
(
2
,
dp_size
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
tensor
=
torch
.
zeros
(
3
,
dp_size
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
tensor
[
0
][
dp_rank
]
=
num_tokens
tensor
[
0
][
dp_rank
]
=
num_tokens
tensor
[
1
][
dp_rank
]
=
cudagraph_size
tensor
[
1
][
dp_rank
]
=
cudagraph_size
tensor
[
2
][
dp_rank
]
=
cudagraph_runtime_mode
dist
.
all_reduce
(
tensor
,
group
=
group
)
dist
.
all_reduce
(
tensor
,
group
=
group
)
return
tensor
[
0
],
tensor
[
1
]
return
tensor
[
0
],
tensor
[
1
]
,
tensor
[
2
]
def
get_cudagraph_and_dp_padding
(
def
get_cudagraph_and_dp_padding
(
num_tokens
:
int
,
cudagraph_size
:
int
|
None
,
dp_size
:
int
,
dp_rank
:
int
num_tokens
:
int
,
)
->
tuple
[
bool
,
int
,
torch
.
Tensor
|
None
]:
cudagraph_size
:
int
|
None
,
cudagraph_runtime_mode
:
int
,
dp_size
:
int
,
dp_rank
:
int
,
)
->
tuple
[
int
,
torch
.
Tensor
|
None
,
int
]:
if
dp_size
==
1
:
if
dp_size
==
1
:
if
cudagraph_size
is
not
None
:
if
cudagraph_size
is
not
None
:
return
True
,
cudagraph_size
,
None
return
cudagraph_size
,
None
,
cudagraph_runtime_mode
else
:
else
:
return
False
,
num_tokens
,
None
return
num_tokens
,
None
,
cudagraph_runtime_mode
# Convert None to -1 for sync (indicates no cudagraph available)
if
num_tokens
==
0
:
if
num_tokens
==
0
:
cudagraph_size
=
0
cudagraph_size
=
0
elif
cudagraph_size
is
None
:
elif
cudagraph_size
is
None
:
cudagraph_size
=
-
1
cudagraph_size
=
-
1
num_tokens_across_dp
,
cudagraph_size_across_dp
=
get_batch_metadata_across_dp
(
num_tokens
,
cudagraph_size
,
dp_size
,
dp_rank
num_tokens_across_dp
,
cudagraph_size_across_dp
,
cudagraph_mode_across_dp
=
(
get_batch_metadata_across_dp
(
num_tokens
,
cudagraph_size
,
cudagraph_runtime_mode
,
dp_size
,
dp_rank
)
)
)
if
torch
.
all
(
num_tokens_across_dp
==
0
).
item
():
if
torch
.
all
(
num_tokens_across_dp
==
0
).
item
():
# All ranks have zero tokens to run.
# All ranks have zero tokens to run.
return
False
,
0
,
None
return
0
,
None
,
0
# Synchronize cudagraph_runtime_mode across ranks by taking the minimum.
synced_cudagraph_mode
=
int
(
cudagraph_mode_across_dp
.
min
().
item
())
# Check if all ranks have valid cudagraph_size.
all_have_cudagraph
=
torch
.
all
(
cudagraph_size_across_dp
!=
-
1
).
item
()
if
torch
.
all
(
cudagraph_size_across_dp
!=
-
1
).
item
():
if
synced_cudagraph_mode
!=
0
and
all_have_cudagraph
:
# All ranks use CUDA graph or have zero tokens.
# All ranks use cudagraph. Pad to max cudagraph_size.
# Use CUDA graph for all ranks.
# Pad all ranks to the maximum CUDA graph size.
max_cudagraph_size
=
int
(
cudagraph_size_across_dp
.
max
().
item
())
max_cudagraph_size
=
int
(
cudagraph_size_across_dp
.
max
().
item
())
num_tokens_across_dp
[:]
=
max_cudagraph_size
num_tokens_across_dp
[:]
=
max_cudagraph_size
return
True
,
max_cudagraph_size
,
num_tokens_across_dp
return
max_cudagraph_size
,
num_tokens_across_dp
,
synced_cudagraph_mode
else
:
else
:
# Some ranks do not use CUDA graph. Use eager mode for all ranks.
# Fall back to eager mode (no cudagraph).
# No padding is needed except for ranks that have no tokens to run.
# Either some rank doesn't have cudagraph size or mode is NONE.
synced_cudagraph_mode
=
0
num_tokens_across_dp
=
torch
.
clamp
(
num_tokens_across_dp
,
min
=
1
)
num_tokens_across_dp
=
torch
.
clamp
(
num_tokens_across_dp
,
min
=
1
)
num_tokens_after_padding
=
int
(
num_tokens_across_dp
[
dp_rank
].
item
())
num_tokens_after_padding
=
int
(
num_tokens_across_dp
[
dp_rank
].
item
())
return
False
,
num_tokens_after_padding
,
num_tokens_across_dp
return
num_tokens_after_padding
,
num_tokens_across_dp
,
synced_cudagraph_mode
vllm/v1/worker/gpu/model_runner.py
View file @
11d3976b
...
@@ -15,7 +15,7 @@ from vllm.distributed.parallel_state import (
...
@@ -15,7 +15,7 @@ from vllm.distributed.parallel_state import (
get_pp_group
,
get_pp_group
,
prepare_communication_buffer_for_model
,
prepare_communication_buffer_for_model
,
)
)
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
BatchDescriptor
,
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
...
@@ -140,7 +140,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -140,7 +140,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
do_spec_decode
=
False
self
.
do_spec_decode
=
False
self
.
num_speculative_steps
=
0
self
.
num_speculative_steps
=
0
self
.
speculator
=
None
self
.
speculator
=
None
self
.
req_states
=
RequestState
(
self
.
req_states
=
RequestState
(
max_num_reqs
=
self
.
max_num_reqs
,
max_num_reqs
=
self
.
max_num_reqs
,
max_model_len
=
self
.
max_model_len
,
max_model_len
=
self
.
max_model_len
,
...
@@ -458,6 +457,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -458,6 +457,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_tables
=
self
.
block_tables
,
block_tables
=
self
.
block_tables
,
attn_metadata_builders
=
self
.
attn_metadata_builders
,
attn_metadata_builders
=
self
.
attn_metadata_builders
,
kv_cache_config
=
self
.
kv_cache_config
,
kv_cache_config
=
self
.
kv_cache_config
,
has_lora
=
self
.
lora_config
is
not
None
,
)
)
if
self
.
do_spec_decode
:
if
self
.
do_spec_decode
:
self
.
speculator
.
capture_model
()
self
.
speculator
.
capture_model
()
...
@@ -884,19 +884,26 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -884,19 +884,26 @@ class GPUModelRunner(LoRAModelRunnerMixin):
empty_output
=
self
.
kv_connector
.
no_forward
(
scheduler_output
)
empty_output
=
self
.
kv_connector
.
no_forward
(
scheduler_output
)
return
empty_output
return
empty_output
# Get the CUDA graph size. None means no CUDA graph is used.
# Get local cudagraph mode and size.
cudagraph_size
=
self
.
cudagraph_manager
.
get_cudagraph_size
(
local_cudagraph_mode
,
local_cudagraph_size
=
(
scheduler_output
.
total_num_scheduled_tokens
,
self
.
cudagraph_manager
.
get_cudagraph_runtime_mode
(
scheduler_output
.
num_scheduled_tokens
.
values
(),
num_reqs
=
len
(
scheduler_output
.
num_scheduled_tokens
),
num_tokens
=
scheduler_output
.
total_num_scheduled_tokens
,
max_query_len
=
max
(
scheduler_output
.
num_scheduled_tokens
.
values
()),
)
)
)
use_cudagraph
,
num_tokens_after_padding
,
num_tokens_across_dp
=
(
# DP sync: num_tokens + cudagraph_size + cudagraph_mode
num_tokens_after_padding
,
num_tokens_across_dp
,
synced_cudagraph_mode
=
(
get_cudagraph_and_dp_padding
(
get_cudagraph_and_dp_padding
(
scheduler_output
.
total_num_scheduled_tokens
,
scheduler_output
.
total_num_scheduled_tokens
,
cudagraph_size
,
local_cudagraph_size
,
local_cudagraph_mode
.
value
,
self
.
parallel_config
.
data_parallel_size
,
self
.
parallel_config
.
data_parallel_size
,
self
.
parallel_config
.
data_parallel_rank
,
self
.
parallel_config
.
data_parallel_rank
,
)
)
)
)
cudagraph_runtime_mode
=
CUDAGraphMode
(
synced_cudagraph_mode
)
if
num_tokens_after_padding
==
0
:
if
num_tokens_after_padding
==
0
:
# All DP ranks have zero tokens to run.
# All DP ranks have zero tokens to run.
empty_output
=
self
.
kv_connector
.
no_forward
(
scheduler_output
)
empty_output
=
self
.
kv_connector
.
no_forward
(
scheduler_output
)
...
@@ -946,16 +953,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -946,16 +953,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# FIXME(woosuk): Fix warmup for LoRA.
# FIXME(woosuk): Fix warmup for LoRA.
# Run model.
# Run model.
if
use_
cudagraph
:
if
cudagraph
_runtime_mode
==
CUDAGraphMode
.
FULL
:
#
Run CUDA graph
.
#
Use explicit cudagraph replay for FULL mode
.
# NOTE(woosuk): Here, we don't need to pass the input tensors,
# NOTE(woosuk): Here, we don't need to pass the input tensors,
# because they are already copied to the CUDA graph input buffers.
# because they are already copied to the CUDA graph input buffers.
self
.
kv_connector
.
pre_forward
(
scheduler_output
)
self
.
kv_connector
.
pre_forward
(
scheduler_output
)
hidden_states
=
self
.
cudagraph_manager
.
run
(
hidden_states
=
self
.
cudagraph_manager
.
run
_fullgraph
(
input_batch
.
num_tokens_after_padding
input_batch
.
num_tokens_after_padding
)
)
else
:
else
:
#
Run PyTorch model in eager mode
.
#
For piecewise and eager mode, just call model()
.
positions
=
input_batch
.
positions
positions
=
input_batch
.
positions
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
assert
input_batch
.
mrope_positions
is
not
None
assert
input_batch
.
mrope_positions
is
not
None
...
@@ -970,13 +977,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -970,13 +977,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
inputs_embeds
=
None
inputs_embeds
=
None
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
input_batch
.
num_tokens_after_padding
,
has_lora
=
self
.
lora_config
is
not
None
,
)
with
set_forward_context
(
with
set_forward_context
(
input_batch
.
attn_metadata
,
input_batch
.
attn_metadata
,
self
.
vllm_config
,
self
.
vllm_config
,
num_tokens
=
input_batch
.
num_tokens_after_padding
,
num_tokens
=
input_batch
.
num_tokens_after_padding
,
# TODO(woosuk): Support piecewise CUDA graph.
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
cudagraph_runtime_mode
=
CUDAGraphMode
.
NONE
,
num_tokens_across_dp
=
num_tokens_across_dp
,
num_tokens_across_dp
=
num_tokens_across_dp
,
batch_descriptor
=
batch_descriptor
,
slot_mapping
=
input_batch
.
slot_mappings
,
slot_mapping
=
input_batch
.
slot_mappings
,
):
):
self
.
kv_connector
.
pre_forward
(
scheduler_output
)
self
.
kv_connector
.
pre_forward
(
scheduler_output
)
...
...
vllm/v1/worker/gpu/spec_decode/eagle.py
View file @
11d3976b
...
@@ -7,7 +7,7 @@ import torch.nn as nn
...
@@ -7,7 +7,7 @@ import torch.nn as nn
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
BatchDescriptor
,
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
...
@@ -103,14 +103,17 @@ class EagleSpeculator:
...
@@ -103,14 +103,17 @@ class EagleSpeculator:
attn_metadata
:
dict
[
str
,
Any
]
|
None
,
attn_metadata
:
dict
[
str
,
Any
]
|
None
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
|
None
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
|
None
,
num_tokens_across_dp
:
torch
.
Tensor
|
None
,
num_tokens_across_dp
:
torch
.
Tensor
|
None
,
cudagraph_runtime_mode
:
CUDAGraphMode
=
CUDAGraphMode
.
NONE
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
num_tokens
)
with
set_forward_context
(
with
set_forward_context
(
attn_metadata
,
attn_metadata
,
self
.
vllm_config
,
self
.
vllm_config
,
num_tokens
=
num_tokens
,
num_tokens
=
num_tokens
,
cudagraph_runtime_mode
=
CUDAGraphMode
.
NONE
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
num_tokens_across_dp
=
num_tokens_across_dp
,
num_tokens_across_dp
=
num_tokens_across_dp
,
slot_mapping
=
slot_mappings
,
slot_mapping
=
slot_mappings
,
batch_descriptor
=
batch_descriptor
,
):
):
ret_hidden_states
=
self
.
model
(
ret_hidden_states
=
self
.
model
(
input_ids
=
self
.
input_buffers
.
input_ids
[:
num_tokens
],
input_ids
=
self
.
input_buffers
.
input_ids
[:
num_tokens
],
...
@@ -127,9 +130,11 @@ class EagleSpeculator:
...
@@ -127,9 +130,11 @@ class EagleSpeculator:
def
generate_draft
(
def
generate_draft
(
self
,
self
,
num_reqs
:
int
,
num_reqs
:
int
,
num_tokens_padded
:
int
,
attn_metadata
:
dict
[
str
,
Any
],
attn_metadata
:
dict
[
str
,
Any
],
slot_mappings
:
dict
[
str
,
torch
.
Tensor
],
slot_mappings
:
dict
[
str
,
torch
.
Tensor
],
num_tokens_across_dp
:
torch
.
Tensor
|
None
,
num_tokens_across_dp
:
torch
.
Tensor
|
None
,
cudagraph_runtime_mode
:
CUDAGraphMode
=
CUDAGraphMode
.
NONE
,
)
->
None
:
)
->
None
:
pos
=
self
.
input_buffers
.
positions
[:
num_reqs
]
pos
=
self
.
input_buffers
.
positions
[:
num_reqs
]
query_start_loc
=
self
.
input_buffers
.
query_start_loc
[:
num_reqs
+
1
]
query_start_loc
=
self
.
input_buffers
.
query_start_loc
[:
num_reqs
+
1
]
...
@@ -137,8 +142,14 @@ class EagleSpeculator:
...
@@ -137,8 +142,14 @@ class EagleSpeculator:
for
step
in
range
(
1
,
self
.
num_speculative_steps
):
for
step
in
range
(
1
,
self
.
num_speculative_steps
):
# Run the eagle model.
# Run the eagle model.
last_hidden_states
,
hidden_states
=
self
.
run_model
(
last_hidden_states
,
hidden_states
=
self
.
run_model
(
num_reqs
,
attn_metadata
,
slot_mappings
,
num_tokens_across_dp
num_tokens_padded
,
attn_metadata
,
slot_mappings
,
num_tokens_across_dp
,
cudagraph_runtime_mode
,
)
)
last_hidden_states
=
last_hidden_states
[:
num_reqs
]
hidden_states
=
hidden_states
[:
num_reqs
]
logits
=
self
.
model
.
compute_logits
(
last_hidden_states
)
logits
=
self
.
model
.
compute_logits
(
last_hidden_states
)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
...
@@ -283,12 +294,14 @@ class EagleSpeculator:
...
@@ -283,12 +294,14 @@ class EagleSpeculator:
)
)
cudagraph_size
=
self
.
cudagraph_manager
.
get_cudagraph_size
(
num_reqs
)
cudagraph_size
=
self
.
cudagraph_manager
.
get_cudagraph_size
(
num_reqs
)
if
cudagraph_size
is
not
None
:
cudagraph_mode
=
self
.
cudagraph_manager
.
cudagraph_mode
# Run CUDA graph.
if
cudagraph_size
is
not
None
and
cudagraph_mode
==
CUDAGraphMode
.
FULL
:
self
.
cudagraph_manager
.
run
(
cudagraph_size
)
# Run full CUDA graph.
self
.
cudagraph_manager
.
run_fullgraph
(
cudagraph_size
)
return
self
.
draft_tokens
[:
num_reqs
]
return
self
.
draft_tokens
[:
num_reqs
]
# Run eager mode.
# Run eager or piecewise CUDA graph.
num_tokens_padded
=
cudagraph_size
if
cudagraph_size
is
not
None
else
num_reqs
query_start_loc_cpu
=
torch
.
arange
(
query_start_loc_cpu
=
torch
.
arange
(
num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
"cpu"
num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
)
...
@@ -312,8 +325,13 @@ class EagleSpeculator:
...
@@ -312,8 +325,13 @@ class EagleSpeculator:
slot_mappings
,
self
.
kv_cache_config
slot_mappings
,
self
.
kv_cache_config
)
)
self
.
generate_draft
(
self
.
generate_draft
(
num_reqs
,
attn_metadata
,
slot_mappings_by_layer
,
num_tokens_across_dp
=
None
num_reqs
,
)
# FIXME
num_tokens_padded
,
attn_metadata
,
slot_mappings_by_layer
,
num_tokens_across_dp
=
None
,
# FIXME
cudagraph_runtime_mode
=
cudagraph_mode
,
)
return
self
.
draft_tokens
[:
num_reqs
]
return
self
.
draft_tokens
[:
num_reqs
]
...
...
vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py
View file @
11d3976b
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
typing
import
Any
import
torch
import
torch
...
@@ -31,16 +32,17 @@ class EagleCudaGraphManager:
...
@@ -31,16 +32,17 @@ class EagleCudaGraphManager:
self
.
compilation_config
=
vllm_config
.
compilation_config
self
.
compilation_config
=
vllm_config
.
compilation_config
assert
self
.
compilation_config
is
not
None
assert
self
.
compilation_config
is
not
None
self
.
cudagraph_mode
=
self
.
compilation_config
.
cudagraph_mode
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
if
self
.
cudagraph_mode
==
CUDAGraphMode
.
FULL
:
self
.
cudagraph_mode
=
self
.
compilation_config
.
cudagraph_mode
.
decode_mode
()
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
self
.
cudagraph_mode
=
CUDAGraphMode
.
FULL_DECODE_ONLY
self
.
cudagraph_sizes
=
get_cudagraph_sizes
(
# only need to capture uniform decode cudagraph sizes (the 2nd return value)
_
,
self
.
cudagraph_sizes
=
get_cudagraph_sizes
(
self
.
compilation_config
.
cudagraph_capture_sizes
,
self
.
compilation_config
.
cudagraph_capture_sizes
,
self
.
max_num_reqs
,
self
.
max_num_reqs
,
self
.
max_num_tokens
,
self
.
max_num_tokens
,
self
.
cudagraph_mode
,
self
.
cudagraph_mode
,
uniform_decode_query_len
=
1
,
uniform_decode_cudagraph
=
True
,
)
)
self
.
graphs
:
dict
[
int
,
torch
.
cuda
.
CUDAGraph
]
=
{}
self
.
graphs
:
dict
[
int
,
torch
.
cuda
.
CUDAGraph
]
=
{}
...
@@ -54,12 +56,21 @@ class EagleCudaGraphManager:
...
@@ -54,12 +56,21 @@ class EagleCudaGraphManager:
def
capture_graph
(
def
capture_graph
(
self
,
self
,
num_tokens
:
int
,
num_tokens
:
int
,
capture_cg_mode
:
CUDAGraphMode
,
generate_fn
:
Callable
,
generate_fn
:
Callable
,
input_buffers
:
InputBuffers
,
input_buffers
:
InputBuffers
,
block_tables
:
BlockTables
,
block_tables
:
BlockTables
,
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
],
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
],
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
)
->
None
:
)
->
None
:
assert
capture_cg_mode
in
[
CUDAGraphMode
.
PIECEWISE
,
CUDAGraphMode
.
FULL
],
(
f
"Invalid capture_cudagraph_mode for capture:
{
capture_cg_mode
}
"
)
if
capture_cg_mode
==
CUDAGraphMode
.
PIECEWISE
:
capture_fn
=
self
.
_capture_piecewise_graph
else
:
capture_fn
=
self
.
_capture_full_graph
num_reqs
=
min
(
num_tokens
,
self
.
max_num_reqs
)
num_reqs
=
min
(
num_tokens
,
self
.
max_num_reqs
)
attn_metadata
,
slot_mappings
=
prepare_inputs_to_capture
(
attn_metadata
,
slot_mappings
=
prepare_inputs_to_capture
(
num_reqs
,
num_reqs
,
...
@@ -69,19 +80,70 @@ class EagleCudaGraphManager:
...
@@ -69,19 +80,70 @@ class EagleCudaGraphManager:
attn_metadata_builders
,
attn_metadata_builders
,
self
.
max_model_len
,
self
.
max_model_len
,
kv_cache_config
,
kv_cache_config
,
uniform_decode_query_len
=
1
,
)
)
num_tokens_across_dp
=
make_num_tokens_across_dp
(
self
.
dp_size
,
num_tokens
)
num_tokens_across_dp
=
make_num_tokens_across_dp
(
self
.
dp_size
,
num_tokens
)
# Warm up.
# Warm up.
generate_fn
(
num_tokens
,
attn_metadata
,
slot_mappings
,
num_tokens_across_dp
)
generate_fn
(
num_reqs
,
num_tokens
,
attn_metadata
,
slot_mappings
,
num_tokens_across_dp
,
CUDAGraphMode
.
NONE
,
)
# Capture the graph.
# Capture the graph.
capture_fn
(
num_reqs
=
num_reqs
,
num_tokens
=
num_tokens
,
generate_fn
=
generate_fn
,
attn_metadata
=
attn_metadata
,
slot_mappings
=
slot_mappings
,
num_tokens_across_dp
=
num_tokens_across_dp
,
)
def
_capture_full_graph
(
self
,
num_reqs
:
int
,
num_tokens
:
int
,
generate_fn
:
Callable
,
attn_metadata
:
dict
[
str
,
Any
],
slot_mappings
:
dict
[
str
,
torch
.
Tensor
],
num_tokens_across_dp
:
torch
.
Tensor
,
)
->
None
:
assert
num_tokens
not
in
self
.
graphs
assert
num_tokens
not
in
self
.
graphs
graph
=
torch
.
cuda
.
CUDAGraph
()
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
self
.
pool
):
with
torch
.
cuda
.
graph
(
graph
,
self
.
pool
):
generate_fn
(
num_tokens
,
attn_metadata
,
slot_mappings
,
num_tokens_across_dp
)
generate_fn
(
num_reqs
,
num_tokens
,
attn_metadata
,
slot_mappings
,
num_tokens_across_dp
,
CUDAGraphMode
.
NONE
,
)
self
.
graphs
[
num_tokens
]
=
graph
self
.
graphs
[
num_tokens
]
=
graph
def
_capture_piecewise_graph
(
self
,
num_reqs
:
int
,
num_tokens
:
int
,
generate_fn
:
Callable
,
attn_metadata
:
dict
[
str
,
Any
],
slot_mappings
:
dict
[
str
,
torch
.
Tensor
],
num_tokens_across_dp
:
torch
.
Tensor
,
)
->
None
:
generate_fn
(
num_reqs
,
num_tokens
,
attn_metadata
,
slot_mappings
,
num_tokens_across_dp
,
CUDAGraphMode
.
PIECEWISE
,
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
capture
(
def
capture
(
self
,
self
,
...
@@ -91,10 +153,15 @@ class EagleCudaGraphManager:
...
@@ -91,10 +153,15 @@ class EagleCudaGraphManager:
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
],
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
],
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
)
->
None
:
)
->
None
:
if
self
.
cudagraph_mode
==
CUDAGraphMode
.
NONE
:
return
capture_graphs
(
capture_graphs
(
self
.
cudagraph_sizes
,
self
.
cudagraph_sizes
,
self
.
device
,
self
.
device
,
self
.
capture_graph
,
self
.
capture_graph
,
capture_cudagraph_mode
=
self
.
cudagraph_mode
,
desc
=
f
"Capturing eagle CUDA graphs (
{
self
.
cudagraph_mode
.
name
}
)"
,
generate_fn
=
generate_fn
,
generate_fn
=
generate_fn
,
input_buffers
=
input_buffers
,
input_buffers
=
input_buffers
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
...
@@ -102,6 +169,6 @@ class EagleCudaGraphManager:
...
@@ -102,6 +169,6 @@ class EagleCudaGraphManager:
kv_cache_config
=
kv_cache_config
,
kv_cache_config
=
kv_cache_config
,
)
)
def
run
(
self
,
num_tokens
:
int
)
->
None
:
def
run
_fullgraph
(
self
,
num_tokens
:
int
)
->
None
:
assert
num_tokens
in
self
.
graphs
assert
num_tokens
in
self
.
graphs
self
.
graphs
[
num_tokens
].
replay
()
self
.
graphs
[
num_tokens
].
replay
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment