Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
838cedad
Unverified
Commit
838cedad
authored
Apr 27, 2025
by
Chen Zhang
Committed by
GitHub
Apr 27, 2025
Browse files
[Bugfix] Get a specific type of layer from forward context (#17222)
Signed-off-by:
Chen Zhang
<
zhangch99@outlook.com
>
parent
4283a28c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
28 additions
and
23 deletions
+28
-23
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+2
-4
vllm/config.py
vllm/config.py
+15
-1
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+3
-4
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+5
-10
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+3
-4
No files found.
vllm/attention/backends/flashinfer.py
View file @
838cedad
...
...
@@ -38,7 +38,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
is_block_tables_empty
)
from
vllm.attention.layer
import
Attention
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.logger
import
init_logger
from
vllm.utils
import
(
async_tensor_h2d
,
get_kv_cache_torch_dtype
,
make_tensor_with_pad
)
...
...
@@ -140,12 +140,10 @@ def get_per_layer_parameters(
to use during `plan`.
"""
layers
=
vllm_config
.
compilation_config
.
static_forward_context
layers
=
get_layers_from_vllm_config
(
vllm_config
,
Attention
)
per_layer_params
:
Dict
[
str
,
PerLayerParameters
]
=
{}
for
key
,
layer
in
layers
.
items
():
assert
isinstance
(
layer
,
Attention
)
impl
=
layer
.
impl
assert
isinstance
(
impl
,
FlashInferImpl
)
...
...
vllm/config.py
View file @
838cedad
...
...
@@ -3445,7 +3445,8 @@ class CompilationConfig(BaseModel):
compilation_time
:
float
=
PrivateAttr
# Per-model forward context
# Map from layer name to the attention cls
# Map from layer name to layer objects that need to be accessed outside
# model code, e.g., Attention, FusedMOE when dp_size>1.
static_forward_context
:
dict
[
str
,
Any
]
=
PrivateAttr
def
compute_hash
(
self
)
->
str
:
...
...
@@ -4079,3 +4080,16 @@ def assert_hashable(text):
f
"vLLM tried to hash some configs that may have Python objects ids "
f
"in them. This is a bug, please file an issue. "
f
"Text being hashed:
{
text
}
"
)
T
=
TypeVar
(
"T"
)
def
get_layers_from_vllm_config
(
vllm_config
:
VllmConfig
,
layer_type
:
type
[
T
])
->
dict
[
str
,
T
]:
return
{
layer_name
:
layer
for
layer_name
,
layer
in
vllm_config
.
compilation_config
.
static_forward_context
.
items
()
if
isinstance
(
layer
,
layer_type
)
}
vllm/v1/attention/backends/flashinfer.py
View file @
838cedad
...
...
@@ -14,7 +14,8 @@ import vllm.envs as envs
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionType
)
from
vllm.attention.layer
import
Attention
from
vllm.config
import
VllmConfig
,
get_current_vllm_config
from
vllm.config
import
(
VllmConfig
,
get_current_vllm_config
,
get_layers_from_vllm_config
)
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.flash_attn
import
use_cascade_attention
...
...
@@ -81,12 +82,10 @@ def get_per_layer_parameters(
to use during `plan`.
"""
layers
=
vllm_config
.
compilation_config
.
static_forward_context
layers
=
get_layers_from_vllm_config
(
vllm_config
,
Attention
)
per_layer_params
:
dict
[
str
,
PerLayerParameters
]
=
{}
for
key
,
layer
in
layers
.
items
():
assert
isinstance
(
layer
,
Attention
)
impl
=
layer
.
impl
assert
isinstance
(
impl
,
FlashInferImpl
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
838cedad
...
...
@@ -12,13 +12,13 @@ import torch.nn as nn
from
vllm.attention
import
AttentionType
,
get_attn_backend
from
vllm.attention.layer
import
Attention
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.config
import
(
CompilationLevel
,
VllmConfig
,
get_layers_from_vllm_config
)
from
vllm.distributed.kv_transfer
import
(
get_kv_transfer_group
,
has_kv_transfer_group
)
from
vllm.distributed.parallel_state
import
get_pp_group
,
graph_capture
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
...
...
@@ -1733,17 +1733,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
format. Layers that do not need KV cache are not included.
"""
forward_ctx
=
self
.
vllm_config
.
compilation_config
.
static_forward_context
layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
)
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
use_mla
=
self
.
vllm_config
.
model_config
.
use_mla
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]
=
{}
for
layer_name
,
attn_module
in
forward_ctx
.
items
():
if
isinstance
(
attn_module
,
FusedMoE
):
continue
# TODO: Support other attention modules, e.g., sliding window,
# cross-attention
assert
isinstance
(
attn_module
,
Attention
)
for
layer_name
,
attn_module
in
layers
.
items
():
# TODO: Support other attention modules, e.g., cross-attention
if
attn_module
.
attn_type
==
AttentionType
.
DECODER
:
if
attn_module
.
sliding_window
is
not
None
:
kv_cache_spec
[
layer_name
]
=
SlidingWindowSpec
(
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
838cedad
...
...
@@ -17,7 +17,7 @@ import vllm.envs as envs
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.layer
import
Attention
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
...
...
@@ -429,11 +429,10 @@ class TPUModelRunner:
format. Layers that do not need KV cache are not included.
"""
forward_ctx
=
self
.
vllm_config
.
compilation_config
.
static_forward_context
layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
)
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]
=
{}
for
layer_name
,
attn_module
in
forward_ctx
.
items
():
assert
isinstance
(
attn_module
,
Attention
)
for
layer_name
,
attn_module
in
layers
.
items
():
if
attn_module
.
attn_type
==
AttentionType
.
DECODER
:
if
attn_module
.
sliding_window
is
not
None
:
kv_cache_spec
[
layer_name
]
=
SlidingWindowSpec
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment