Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b71fbd06
Unverified
Commit
b71fbd06
authored
Feb 21, 2026
by
Woosuk Kwon
Committed by
GitHub
Feb 21, 2026
Browse files
[Model Runner V2] Support attention group (#35036)
Signed-off-by:
Woosuk Kwon
<
woosuk@inferact.ai
>
parent
74d90b1c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
87 additions
and
55 deletions
+87
-55
vllm/v1/worker/gpu/attn_utils.py
vllm/v1/worker/gpu/attn_utils.py
+55
-30
vllm/v1/worker/gpu/cudagraph_utils.py
vllm/v1/worker/gpu/cudagraph_utils.py
+17
-7
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+5
-8
vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py
vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py
+5
-5
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
+5
-5
No files found.
vllm/v1/worker/gpu/attn_utils.py
View file @
b71fbd06
...
@@ -7,17 +7,14 @@ import torch
...
@@ -7,17 +7,14 @@ import torch
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.v1.attention.backend
import
(
from
vllm.v1.attention.backend
import
AttentionBackend
,
CommonAttentionMetadata
AttentionBackend
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
)
from
vllm.v1.kv_cache_interface
import
(
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
AttentionSpec
,
KVCacheConfig
,
KVCacheConfig
,
KVCacheSpec
,
KVCacheSpec
,
UniformTypeKVCacheSpecs
,
)
)
from
vllm.v1.worker.utils
import
bind_kv_cache
from
vllm.v1.worker.utils
import
AttentionGroup
,
bind_kv_cache
def
get_kv_cache_spec
(
vllm_config
:
VllmConfig
)
->
dict
[
str
,
KVCacheSpec
]:
def
get_kv_cache_spec
(
vllm_config
:
VllmConfig
)
->
dict
[
str
,
KVCacheSpec
]:
...
@@ -35,29 +32,56 @@ def init_attn_backend(
...
@@ -35,29 +32,56 @@ def init_attn_backend(
kv_cache_config
:
KVCacheConfig
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
kv_cache_config
:
KVCacheConfig
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
):
):
attn_backends
:
dict
[
str
,
type
[
AttentionBackend
]]
=
{}
attn_backends
:
dict
[
str
,
type
[
AttentionBackend
]]
=
{}
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
]
=
[]
attn_groups
:
list
[
list
[
AttentionGroup
]]
=
[]
flashinfer_workspace
:
torch
.
Tensor
|
None
=
None
attn_backend_workspace
:
torch
.
Tensor
|
None
=
None
for
kv_cache_group_spec
in
kv_cache_config
.
kv_cache_groups
:
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
kv_cache_config
.
kv_cache_groups
):
layer_names
=
kv_cache_group_spec
.
layer_names
layer_names
=
kv_cache_group_spec
.
layer_names
any_layer_name
=
next
(
iter
(
layer_names
))
layer_type
=
cast
(
type
[
Any
],
AttentionLayerBase
)
layer_type
=
cast
(
type
[
Any
],
AttentionLayerBase
)
attn_layers
=
get_layers_from_vllm_config
(
vllm_config
,
layer_type
,
layer_names
)
attn_layers
=
get_layers_from_vllm_config
(
vllm_config
,
layer_type
,
layer_names
)
attn_backend
=
attn_layers
[
any_layer_name
].
get_attn_backend
()
group_map
:
dict
[
tuple
[
tuple
[
str
,
str
],
KVCacheSpec
],
AttentionGroup
]
=
{}
group_order
:
list
[
tuple
[
tuple
[
str
,
str
],
KVCacheSpec
]]
=
[]
for
layer_name
in
layer_names
:
for
layer_name
in
layer_names
:
attn_backend
=
attn_layers
[
layer_name
].
get_attn_backend
()
attn_backends
[
layer_name
]
=
attn_backend
attn_backends
[
layer_name
]
=
attn_backend
attn_metadata_builder
=
attn_backend
.
get_builder_cls
()(
layer_kv_cache_spec
:
KVCacheSpec
=
kv_cache_group_spec
.
kv_cache_spec
kv_cache_group_spec
.
kv_cache_spec
,
layer_names
,
vllm_config
,
device
if
isinstance
(
layer_kv_cache_spec
,
UniformTypeKVCacheSpecs
):
)
layer_kv_cache_spec
=
layer_kv_cache_spec
.
kv_cache_specs
[
layer_name
]
attn_metadata_builders
.
append
(
attn_metadata_builder
)
# type: ignore
key
=
(
attn_backend
.
full_cls_name
(),
layer_kv_cache_spec
)
if
attn_backend
.
get_name
()
==
"FLASHINFER"
:
if
key
not
in
group_map
:
if
flashinfer_workspace
is
None
:
group_map
[
key
]
=
AttentionGroup
(
flashinfer_workspace
=
attn_metadata_builder
.
_get_workspace_buffer
()
attn_backend
,
[
layer_name
],
layer_kv_cache_spec
,
kv_cache_group_id
,
)
group_order
.
append
(
key
)
else
:
else
:
attn_metadata_builder
.
set_workspace_buffer
(
flashinfer_workspace
)
group_map
[
key
].
layer_names
.
append
(
layer_name
)
return
attn_backends
,
attn_metadata_builders
groups
=
[
group_map
[
key
]
for
key
in
group_order
]
for
group
in
groups
:
group
.
create_metadata_builders
(
vllm_config
=
vllm_config
,
device
=
device
,
kernel_block_size
=
None
,
num_metadata_builders
=
1
,
)
builder
=
group
.
get_metadata_builder
(
0
)
if
attn_backend_workspace
is
None
:
if
hasattr
(
builder
,
"_get_workspace_buffer"
):
attn_backend_workspace
=
builder
.
_get_workspace_buffer
()
else
:
if
hasattr
(
builder
,
"set_workspace_buffer"
):
builder
.
set_workspace_buffer
(
attn_backend_workspace
)
attn_groups
.
append
(
groups
)
return
attn_backends
,
attn_groups
def
_allocate_kv_cache
(
kv_cache_config
:
KVCacheConfig
,
device
:
torch
.
device
):
def
_allocate_kv_cache
(
kv_cache_config
:
KVCacheConfig
,
device
:
torch
.
device
):
...
@@ -144,7 +168,7 @@ def build_slot_mappings_by_layer(
...
@@ -144,7 +168,7 @@ def build_slot_mappings_by_layer(
def
build_attn_metadata
(
def
build_attn_metadata
(
attn_
metadata_builder
s
:
list
[
Attention
MetadataBuilder
],
attn_
group
s
:
list
[
list
[
Attention
Group
]
],
num_reqs
:
int
,
num_reqs
:
int
,
num_tokens
:
int
,
num_tokens
:
int
,
query_start_loc_gpu
:
torch
.
Tensor
,
query_start_loc_gpu
:
torch
.
Tensor
,
...
@@ -162,8 +186,8 @@ def build_attn_metadata(
...
@@ -162,8 +186,8 @@ def build_attn_metadata(
dcp_local_seq_lens
=
dcp_local_seq_lens
[:
num_reqs
]
dcp_local_seq_lens
=
dcp_local_seq_lens
[:
num_reqs
]
attn_metadata
:
dict
[
str
,
Any
]
=
{}
attn_metadata
:
dict
[
str
,
Any
]
=
{}
kv_cache_groups
=
kv_cache_config
.
kv_cache_groups
num_
kv_cache_groups
=
len
(
kv_cache_config
.
kv_cache_groups
)
for
i
,
kv_cache_spec
in
enumerate
(
kv_cache_groups
):
for
i
in
range
(
num_
kv_cache_groups
):
block_table
=
block_tables
[
i
]
block_table
=
block_tables
[
i
]
slot_mapping
=
slot_mappings
[
i
]
slot_mapping
=
slot_mappings
[
i
]
...
@@ -181,10 +205,11 @@ def build_attn_metadata(
...
@@ -181,10 +205,11 @@ def build_attn_metadata(
dcp_local_seq_lens
=
dcp_local_seq_lens
,
dcp_local_seq_lens
=
dcp_local_seq_lens
,
)
)
attn_metadata_builder
=
attn_metadata_builders
[
i
]
for
attn_group
in
attn_groups
[
i
]:
metadata
=
attn_metadata_builder
.
build
(
attn_metadata_builder
=
attn_group
.
get_metadata_builder
(
0
)
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
metadata
=
attn_metadata_builder
.
build
(
)
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
for
layer_name
in
kv_cache_spec
.
layer_names
:
)
attn_metadata
[
layer_name
]
=
metadata
for
layer_name
in
attn_group
.
layer_names
:
attn_metadata
[
layer_name
]
=
metadata
return
attn_metadata
return
attn_metadata
vllm/v1/worker/gpu/cudagraph_utils.py
View file @
b71fbd06
...
@@ -13,7 +13,6 @@ from vllm.config.compilation import CUDAGraphMode
...
@@ -13,7 +13,6 @@ from vllm.config.compilation import CUDAGraphMode
from
vllm.distributed.parallel_state
import
graph_capture
,
is_global_first_rank
from
vllm.distributed.parallel_state
import
graph_capture
,
is_global_first_rank
from
vllm.forward_context
import
BatchDescriptor
,
set_forward_context
from
vllm.forward_context
import
BatchDescriptor
,
set_forward_context
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.attention.backend
import
AttentionMetadataBuilder
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.worker.gpu.attn_utils
import
(
from
vllm.v1.worker.gpu.attn_utils
import
(
build_attn_metadata
,
build_attn_metadata
,
...
@@ -22,6 +21,7 @@ from vllm.v1.worker.gpu.attn_utils import (
...
@@ -22,6 +21,7 @@ from vllm.v1.worker.gpu.attn_utils import (
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.dp_utils
import
make_num_tokens_across_dp
from
vllm.v1.worker.gpu.dp_utils
import
make_num_tokens_across_dp
from
vllm.v1.worker.gpu.input_batch
import
InputBuffers
from
vllm.v1.worker.gpu.input_batch
import
InputBuffers
from
vllm.v1.worker.utils
import
AttentionGroup
class
CudaGraphManager
:
class
CudaGraphManager
:
...
@@ -83,7 +83,7 @@ class CudaGraphManager:
...
@@ -83,7 +83,7 @@ class CudaGraphManager:
mrope_positions
:
torch
.
Tensor
|
None
,
mrope_positions
:
torch
.
Tensor
|
None
,
inputs_embeds
:
torch
.
Tensor
|
None
,
inputs_embeds
:
torch
.
Tensor
|
None
,
block_tables
:
BlockTables
,
block_tables
:
BlockTables
,
attn_
metadata_builder
s
:
list
[
Attention
MetadataBuilder
],
attn_
group
s
:
list
[
list
[
Attention
Group
]
],
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
has_lora
:
bool
=
False
,
has_lora
:
bool
=
False
,
uniform_decode
:
bool
=
False
,
uniform_decode
:
bool
=
False
,
...
@@ -116,7 +116,7 @@ class CudaGraphManager:
...
@@ -116,7 +116,7 @@ class CudaGraphManager:
num_tokens
,
num_tokens
,
input_buffers
,
input_buffers
,
block_tables
,
block_tables
,
attn_
metadata_builder
s
,
attn_
group
s
,
self
.
max_model_len
,
self
.
max_model_len
,
kv_cache_config
,
kv_cache_config
,
uniform_decode_query_len
=
(
uniform_decode_query_len
=
(
...
@@ -232,7 +232,7 @@ class CudaGraphManager:
...
@@ -232,7 +232,7 @@ class CudaGraphManager:
mrope_positions
:
torch
.
Tensor
|
None
,
mrope_positions
:
torch
.
Tensor
|
None
,
inputs_embeds
:
torch
.
Tensor
|
None
,
inputs_embeds
:
torch
.
Tensor
|
None
,
block_tables
:
BlockTables
,
block_tables
:
BlockTables
,
attn_
metadata_builder
s
:
list
[
Attention
MetadataBuilder
],
attn_
group
s
:
list
[
list
[
Attention
Group
]
],
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
has_lora
:
bool
=
False
,
has_lora
:
bool
=
False
,
)
->
None
:
)
->
None
:
...
@@ -244,7 +244,7 @@ class CudaGraphManager:
...
@@ -244,7 +244,7 @@ class CudaGraphManager:
mrope_positions
=
mrope_positions
,
mrope_positions
=
mrope_positions
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
attn_
metadata_builders
=
attn_metadata_builder
s
,
attn_
groups
=
attn_group
s
,
kv_cache_config
=
kv_cache_config
,
kv_cache_config
=
kv_cache_config
,
has_lora
=
has_lora
,
has_lora
=
has_lora
,
)
)
...
@@ -286,6 +286,16 @@ class CudaGraphManager:
...
@@ -286,6 +286,16 @@ class CudaGraphManager:
cudagraph_mode
=
self
.
cudagraph_mode
.
decode_mode
()
cudagraph_mode
=
self
.
cudagraph_mode
.
decode_mode
()
else
:
else
:
cudagraph_mode
=
self
.
cudagraph_mode
.
mixed_mode
()
cudagraph_mode
=
self
.
cudagraph_mode
.
mixed_mode
()
if
(
cudagraph_mode
==
CUDAGraphMode
.
FULL
and
cudagraph_size
is
not
None
and
cudagraph_size
not
in
self
.
graphs
):
# If graph wasn't captured yet, fall back to eager.
# This might happen when the dummy run is called before capture.
cudagraph_mode
=
CUDAGraphMode
.
NONE
cudagraph_size
=
None
return
cudagraph_mode
,
cudagraph_size
return
cudagraph_mode
,
cudagraph_size
def
run_fullgraph
(
self
,
num_tokens
:
int
)
->
torch
.
Tensor
:
def
run_fullgraph
(
self
,
num_tokens
:
int
)
->
torch
.
Tensor
:
...
@@ -354,7 +364,7 @@ def prepare_inputs_to_capture(
...
@@ -354,7 +364,7 @@ def prepare_inputs_to_capture(
num_tokens
:
int
,
num_tokens
:
int
,
input_buffers
:
InputBuffers
,
input_buffers
:
InputBuffers
,
block_tables
:
BlockTables
,
block_tables
:
BlockTables
,
attn_
metadata_builder
s
:
list
[
Attention
MetadataBuilder
],
attn_
group
s
:
list
[
list
[
Attention
Group
]
],
max_model_len
:
int
,
max_model_len
:
int
,
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
uniform_decode_query_len
:
int
=
0
,
uniform_decode_query_len
:
int
=
0
,
...
@@ -386,7 +396,7 @@ def prepare_inputs_to_capture(
...
@@ -386,7 +396,7 @@ def prepare_inputs_to_capture(
)
)
attn_metadata
=
build_attn_metadata
(
attn_metadata
=
build_attn_metadata
(
attn_
metadata_builders
=
attn_metadata_builder
s
,
attn_
groups
=
attn_group
s
,
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
num_tokens
=
num_tokens
,
num_tokens
=
num_tokens
,
query_start_loc_gpu
=
query_start_loc
,
query_start_loc_gpu
=
query_start_loc
,
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
b71fbd06
...
@@ -283,7 +283,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -283,7 +283,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
cp_interleave
=
self
.
cp_interleave
,
cp_interleave
=
self
.
cp_interleave
,
)
)
self
.
attn_backends
,
self
.
attn_
metadata_builder
s
=
init_attn_backend
(
self
.
attn_backends
,
self
.
attn_
group
s
=
init_attn_backend
(
self
.
kv_cache_config
,
self
.
vllm_config
,
self
.
device
self
.
kv_cache_config
,
self
.
vllm_config
,
self
.
device
)
)
check_attention_cp_compatibility
(
self
.
vllm_config
)
check_attention_cp_compatibility
(
self
.
vllm_config
)
...
@@ -291,7 +291,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -291,7 +291,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# HACK(woosuk)
# HACK(woosuk)
self
.
speculator
.
set_attn
(
self
.
speculator
.
set_attn
(
self
.
kv_cache_config
,
self
.
kv_cache_config
,
self
.
attn_
metadata_builder
s
,
self
.
attn_
group
s
,
self
.
block_tables
,
self
.
block_tables
,
)
)
...
@@ -305,9 +305,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -305,9 +305,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
)
self
.
kv_connector
=
get_kv_connector
(
self
.
vllm_config
,
kv_caches_dict
)
self
.
kv_connector
=
get_kv_connector
(
self
.
vllm_config
,
kv_caches_dict
)
# Attention groups are not supported.
self
.
attn_groups
=
[]
# type: ignore
def
prepare_dummy_attn_metadata
(
self
,
input_batch
:
InputBatch
)
->
None
:
def
prepare_dummy_attn_metadata
(
self
,
input_batch
:
InputBatch
)
->
None
:
block_tables
=
self
.
block_tables
.
get_dummy_block_tables
(
input_batch
.
num_reqs
)
block_tables
=
self
.
block_tables
.
get_dummy_block_tables
(
input_batch
.
num_reqs
)
slot_mappings
=
self
.
block_tables
.
get_dummy_slot_mappings
(
slot_mappings
=
self
.
block_tables
.
get_dummy_slot_mappings
(
...
@@ -317,7 +314,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -317,7 +314,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
slot_mappings
,
self
.
kv_cache_config
slot_mappings
,
self
.
kv_cache_config
)
)
attn_metadata
=
build_attn_metadata
(
attn_metadata
=
build_attn_metadata
(
attn_
metadata_builders
=
self
.
attn_metadata_builder
s
,
attn_
groups
=
self
.
attn_group
s
,
num_reqs
=
input_batch
.
num_reqs
,
num_reqs
=
input_batch
.
num_reqs
,
num_tokens
=
input_batch
.
num_tokens
,
num_tokens
=
input_batch
.
num_tokens
,
query_start_loc_gpu
=
input_batch
.
query_start_loc
,
query_start_loc_gpu
=
input_batch
.
query_start_loc
,
...
@@ -477,7 +474,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -477,7 +474,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
mrope_positions
=
mrope_positions
,
mrope_positions
=
mrope_positions
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
block_tables
=
self
.
block_tables
,
block_tables
=
self
.
block_tables
,
attn_
metadata_builders
=
self
.
attn_metadata_builder
s
,
attn_
groups
=
self
.
attn_group
s
,
kv_cache_config
=
self
.
kv_cache_config
,
kv_cache_config
=
self
.
kv_cache_config
,
has_lora
=
self
.
lora_config
is
not
None
,
has_lora
=
self
.
lora_config
is
not
None
,
)
)
...
@@ -712,7 +709,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -712,7 +709,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Layer name -> attention metadata.
# Layer name -> attention metadata.
attn_metadata
=
build_attn_metadata
(
attn_metadata
=
build_attn_metadata
(
attn_
metadata_builders
=
self
.
attn_metadata_builder
s
,
attn_
groups
=
self
.
attn_group
s
,
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
num_tokens
=
num_tokens
,
num_tokens
=
num_tokens
,
query_start_loc_gpu
=
query_start_loc
,
query_start_loc_gpu
=
query_start_loc
,
...
...
vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py
View file @
b71fbd06
...
@@ -7,7 +7,6 @@ import torch
...
@@ -7,7 +7,6 @@ import torch
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.v1.attention.backend
import
AttentionMetadataBuilder
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.cudagraph_utils
import
(
from
vllm.v1.worker.gpu.cudagraph_utils
import
(
...
@@ -17,6 +16,7 @@ from vllm.v1.worker.gpu.cudagraph_utils import (
...
@@ -17,6 +16,7 @@ from vllm.v1.worker.gpu.cudagraph_utils import (
)
)
from
vllm.v1.worker.gpu.dp_utils
import
make_num_tokens_across_dp
from
vllm.v1.worker.gpu.dp_utils
import
make_num_tokens_across_dp
from
vllm.v1.worker.gpu.input_batch
import
InputBuffers
from
vllm.v1.worker.gpu.input_batch
import
InputBuffers
from
vllm.v1.worker.utils
import
AttentionGroup
class
EagleCudaGraphManager
:
class
EagleCudaGraphManager
:
...
@@ -60,7 +60,7 @@ class EagleCudaGraphManager:
...
@@ -60,7 +60,7 @@ class EagleCudaGraphManager:
generate_fn
:
Callable
,
generate_fn
:
Callable
,
input_buffers
:
InputBuffers
,
input_buffers
:
InputBuffers
,
block_tables
:
BlockTables
,
block_tables
:
BlockTables
,
attn_
metadata_builder
s
:
list
[
Attention
MetadataBuilder
],
attn_
group
s
:
list
[
list
[
Attention
Group
]
],
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
)
->
None
:
)
->
None
:
assert
capture_cg_mode
in
[
CUDAGraphMode
.
PIECEWISE
,
CUDAGraphMode
.
FULL
],
(
assert
capture_cg_mode
in
[
CUDAGraphMode
.
PIECEWISE
,
CUDAGraphMode
.
FULL
],
(
...
@@ -77,7 +77,7 @@ class EagleCudaGraphManager:
...
@@ -77,7 +77,7 @@ class EagleCudaGraphManager:
num_tokens
,
num_tokens
,
input_buffers
,
input_buffers
,
block_tables
,
block_tables
,
attn_
metadata_builder
s
,
attn_
group
s
,
self
.
max_model_len
,
self
.
max_model_len
,
kv_cache_config
,
kv_cache_config
,
uniform_decode_query_len
=
1
,
uniform_decode_query_len
=
1
,
...
@@ -150,7 +150,7 @@ class EagleCudaGraphManager:
...
@@ -150,7 +150,7 @@ class EagleCudaGraphManager:
generate_fn
:
Callable
,
generate_fn
:
Callable
,
input_buffers
:
InputBuffers
,
input_buffers
:
InputBuffers
,
block_tables
:
BlockTables
,
block_tables
:
BlockTables
,
attn_
metadata_builder
s
:
list
[
Attention
MetadataBuilder
],
attn_
group
s
:
list
[
list
[
Attention
Group
]
],
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
)
->
None
:
)
->
None
:
if
self
.
cudagraph_mode
==
CUDAGraphMode
.
NONE
:
if
self
.
cudagraph_mode
==
CUDAGraphMode
.
NONE
:
...
@@ -165,7 +165,7 @@ class EagleCudaGraphManager:
...
@@ -165,7 +165,7 @@ class EagleCudaGraphManager:
generate_fn
=
generate_fn
,
generate_fn
=
generate_fn
,
input_buffers
=
input_buffers
,
input_buffers
=
input_buffers
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
attn_
metadata_builders
=
attn_metadata_builder
s
,
attn_
groups
=
attn_group
s
,
kv_cache_config
=
kv_cache_config
,
kv_cache_config
=
kv_cache_config
,
)
)
...
...
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
View file @
b71fbd06
...
@@ -10,7 +10,6 @@ from vllm.config.compilation import CUDAGraphMode
...
@@ -10,7 +10,6 @@ from vllm.config.compilation import CUDAGraphMode
from
vllm.forward_context
import
BatchDescriptor
,
set_forward_context
from
vllm.forward_context
import
BatchDescriptor
,
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.attention.backend
import
AttentionMetadataBuilder
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.worker.gpu.attn_utils
import
(
from
vllm.v1.worker.gpu.attn_utils
import
(
build_attn_metadata
,
build_attn_metadata
,
...
@@ -21,6 +20,7 @@ from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
...
@@ -21,6 +20,7 @@ from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from
vllm.v1.worker.gpu.sample.gumbel
import
gumbel_sample
from
vllm.v1.worker.gpu.sample.gumbel
import
gumbel_sample
from
vllm.v1.worker.gpu.spec_decode.eagle.cudagraph
import
EagleCudaGraphManager
from
vllm.v1.worker.gpu.spec_decode.eagle.cudagraph
import
EagleCudaGraphManager
from
vllm.v1.worker.gpu.spec_decode.eagle.utils
import
load_eagle_model
from
vllm.v1.worker.gpu.spec_decode.eagle.utils
import
load_eagle_model
from
vllm.v1.worker.utils
import
AttentionGroup
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -78,11 +78,11 @@ class EagleSpeculator:
...
@@ -78,11 +78,11 @@ class EagleSpeculator:
def
set_attn
(
def
set_attn
(
self
,
self
,
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
attn_
metadata_builder
s
:
list
[
Attention
MetadataBuilder
],
attn_
group
s
:
list
[
list
[
Attention
Group
]
],
block_tables
:
BlockTables
,
block_tables
:
BlockTables
,
)
->
None
:
)
->
None
:
self
.
kv_cache_config
=
kv_cache_config
self
.
kv_cache_config
=
kv_cache_config
self
.
attn_
metadata_builders
=
attn_metadata_builder
s
self
.
attn_
groups
=
attn_group
s
self
.
block_tables
=
block_tables
self
.
block_tables
=
block_tables
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
@@ -174,7 +174,7 @@ class EagleSpeculator:
...
@@ -174,7 +174,7 @@ class EagleSpeculator:
self
.
generate_draft
,
self
.
generate_draft
,
self
.
input_buffers
,
self
.
input_buffers
,
self
.
block_tables
,
self
.
block_tables
,
self
.
attn_
metadata_builder
s
,
self
.
attn_
group
s
,
self
.
kv_cache_config
,
self
.
kv_cache_config
,
)
)
...
@@ -298,7 +298,7 @@ class EagleSpeculator:
...
@@ -298,7 +298,7 @@ class EagleSpeculator:
# FIXME(woosuk): This is UNSAFE!!
# FIXME(woosuk): This is UNSAFE!!
attn_metadata
=
build_attn_metadata
(
attn_metadata
=
build_attn_metadata
(
attn_
metadata_builders
=
self
.
attn_metadata_builder
s
,
attn_
groups
=
self
.
attn_group
s
,
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
num_tokens
=
num_reqs
,
num_tokens
=
num_reqs
,
query_start_loc_gpu
=
query_start_loc
,
query_start_loc_gpu
=
query_start_loc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment