Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
838cedad
Unverified
Commit
838cedad
authored
Apr 27, 2025
by
Chen Zhang
Committed by
GitHub
Apr 27, 2025
Browse files
[Bugfix] Get a specific type of layer from forward context (#17222)
Signed-off-by:
Chen Zhang
<
zhangch99@outlook.com
>
parent
4283a28c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
28 additions
and
23 deletions
+28
-23
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+2
-4
vllm/config.py
vllm/config.py
+15
-1
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+3
-4
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+5
-10
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+3
-4
No files found.
vllm/attention/backends/flashinfer.py
View file @
838cedad
...
@@ -38,7 +38,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
...
@@ -38,7 +38,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
is_block_tables_empty
)
is_block_tables_empty
)
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
(
async_tensor_h2d
,
get_kv_cache_torch_dtype
,
from
vllm.utils
import
(
async_tensor_h2d
,
get_kv_cache_torch_dtype
,
make_tensor_with_pad
)
make_tensor_with_pad
)
...
@@ -140,12 +140,10 @@ def get_per_layer_parameters(
...
@@ -140,12 +140,10 @@ def get_per_layer_parameters(
to use during `plan`.
to use during `plan`.
"""
"""
layers
=
vllm_config
.
compilation_config
.
static_forward_context
layers
=
get_layers_from_vllm_config
(
vllm_config
,
Attention
)
per_layer_params
:
Dict
[
str
,
PerLayerParameters
]
=
{}
per_layer_params
:
Dict
[
str
,
PerLayerParameters
]
=
{}
for
key
,
layer
in
layers
.
items
():
for
key
,
layer
in
layers
.
items
():
assert
isinstance
(
layer
,
Attention
)
impl
=
layer
.
impl
impl
=
layer
.
impl
assert
isinstance
(
impl
,
FlashInferImpl
)
assert
isinstance
(
impl
,
FlashInferImpl
)
...
...
vllm/config.py
View file @
838cedad
...
@@ -3445,7 +3445,8 @@ class CompilationConfig(BaseModel):
...
@@ -3445,7 +3445,8 @@ class CompilationConfig(BaseModel):
compilation_time
:
float
=
PrivateAttr
compilation_time
:
float
=
PrivateAttr
# Per-model forward context
# Per-model forward context
# Map from layer name to the attention cls
# Map from layer name to layer objects that need to be accessed outside
# model code, e.g., Attention, FusedMOE when dp_size>1.
static_forward_context
:
dict
[
str
,
Any
]
=
PrivateAttr
static_forward_context
:
dict
[
str
,
Any
]
=
PrivateAttr
def
compute_hash
(
self
)
->
str
:
def
compute_hash
(
self
)
->
str
:
...
@@ -4079,3 +4080,16 @@ def assert_hashable(text):
...
@@ -4079,3 +4080,16 @@ def assert_hashable(text):
f
"vLLM tried to hash some configs that may have Python objects ids "
f
"vLLM tried to hash some configs that may have Python objects ids "
f
"in them. This is a bug, please file an issue. "
f
"in them. This is a bug, please file an issue. "
f
"Text being hashed:
{
text
}
"
)
f
"Text being hashed:
{
text
}
"
)
T
=
TypeVar
(
"T"
)
def
get_layers_from_vllm_config
(
vllm_config
:
VllmConfig
,
layer_type
:
type
[
T
])
->
dict
[
str
,
T
]:
return
{
layer_name
:
layer
for
layer_name
,
layer
in
vllm_config
.
compilation_config
.
static_forward_context
.
items
()
if
isinstance
(
layer
,
layer_type
)
}
vllm/v1/attention/backends/flashinfer.py
View file @
838cedad
...
@@ -14,7 +14,8 @@ import vllm.envs as envs
...
@@ -14,7 +14,8 @@ import vllm.envs as envs
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionType
)
AttentionType
)
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.config
import
VllmConfig
,
get_current_vllm_config
from
vllm.config
import
(
VllmConfig
,
get_current_vllm_config
,
get_layers_from_vllm_config
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.flash_attn
import
use_cascade_attention
from
vllm.v1.attention.backends.flash_attn
import
use_cascade_attention
...
@@ -81,12 +82,10 @@ def get_per_layer_parameters(
...
@@ -81,12 +82,10 @@ def get_per_layer_parameters(
to use during `plan`.
to use during `plan`.
"""
"""
layers
=
vllm_config
.
compilation_config
.
static_forward_context
layers
=
get_layers_from_vllm_config
(
vllm_config
,
Attention
)
per_layer_params
:
dict
[
str
,
PerLayerParameters
]
=
{}
per_layer_params
:
dict
[
str
,
PerLayerParameters
]
=
{}
for
key
,
layer
in
layers
.
items
():
for
key
,
layer
in
layers
.
items
():
assert
isinstance
(
layer
,
Attention
)
impl
=
layer
.
impl
impl
=
layer
.
impl
assert
isinstance
(
impl
,
FlashInferImpl
)
assert
isinstance
(
impl
,
FlashInferImpl
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
838cedad
...
@@ -12,13 +12,13 @@ import torch.nn as nn
...
@@ -12,13 +12,13 @@ import torch.nn as nn
from
vllm.attention
import
AttentionType
,
get_attn_backend
from
vllm.attention
import
AttentionType
,
get_attn_backend
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.config
import
(
CompilationLevel
,
VllmConfig
,
get_layers_from_vllm_config
)
from
vllm.distributed.kv_transfer
import
(
get_kv_transfer_group
,
from
vllm.distributed.kv_transfer
import
(
get_kv_transfer_group
,
has_kv_transfer_group
)
has_kv_transfer_group
)
from
vllm.distributed.parallel_state
import
get_pp_group
,
graph_capture
from
vllm.distributed.parallel_state
import
get_pp_group
,
graph_capture
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
...
@@ -1733,17 +1733,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1733,17 +1733,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
format. Layers that do not need KV cache are not included.
format. Layers that do not need KV cache are not included.
"""
"""
forward_ctx
=
self
.
vllm_config
.
compilation_config
.
static_forward_context
layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
)
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
use_mla
=
self
.
vllm_config
.
model_config
.
use_mla
use_mla
=
self
.
vllm_config
.
model_config
.
use_mla
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]
=
{}
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]
=
{}
for
layer_name
,
attn_module
in
forward_ctx
.
items
():
for
layer_name
,
attn_module
in
layers
.
items
():
if
isinstance
(
attn_module
,
FusedMoE
):
# TODO: Support other attention modules, e.g., cross-attention
continue
# TODO: Support other attention modules, e.g., sliding window,
# cross-attention
assert
isinstance
(
attn_module
,
Attention
)
if
attn_module
.
attn_type
==
AttentionType
.
DECODER
:
if
attn_module
.
attn_type
==
AttentionType
.
DECODER
:
if
attn_module
.
sliding_window
is
not
None
:
if
attn_module
.
sliding_window
is
not
None
:
kv_cache_spec
[
layer_name
]
=
SlidingWindowSpec
(
kv_cache_spec
[
layer_name
]
=
SlidingWindowSpec
(
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
838cedad
...
@@ -17,7 +17,7 @@ import vllm.envs as envs
...
@@ -17,7 +17,7 @@ import vllm.envs as envs
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
...
@@ -429,11 +429,10 @@ class TPUModelRunner:
...
@@ -429,11 +429,10 @@ class TPUModelRunner:
format. Layers that do not need KV cache are not included.
format. Layers that do not need KV cache are not included.
"""
"""
forward_ctx
=
self
.
vllm_config
.
compilation_config
.
static_forward_context
layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
)
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]
=
{}
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]
=
{}
for
layer_name
,
attn_module
in
forward_ctx
.
items
():
for
layer_name
,
attn_module
in
layers
.
items
():
assert
isinstance
(
attn_module
,
Attention
)
if
attn_module
.
attn_type
==
AttentionType
.
DECODER
:
if
attn_module
.
attn_type
==
AttentionType
.
DECODER
:
if
attn_module
.
sliding_window
is
not
None
:
if
attn_module
.
sliding_window
is
not
None
:
kv_cache_spec
[
layer_name
]
=
SlidingWindowSpec
(
kv_cache_spec
[
layer_name
]
=
SlidingWindowSpec
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment