Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0d81a1fe
Unverified
Commit
0d81a1fe
authored
Mar 18, 2026
by
Wentao Ye
Committed by
GitHub
Mar 18, 2026
Browse files
[V0 Deprecation] Deprecate virtual engine (#37195)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
6ae4c8d6
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
19 additions
and
39 deletions
+19
-39
tests/compile/passes/test_rope_kvcache_fusion.py
tests/compile/passes/test_rope_kvcache_fusion.py
+2
-2
tests/v1/kv_connector/unit/test_decode_bench_connector.py
tests/v1/kv_connector/unit/test_decode_bench_connector.py
+1
-1
tests/v1/kv_connector/unit/test_lmcache_integration.py
tests/v1/kv_connector/unit/test_lmcache_integration.py
+0
-1
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+0
-8
tests/v1/kv_connector/unit/test_offloading_connector.py
tests/v1/kv_connector/unit/test_offloading_connector.py
+0
-1
vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py
...tributed/kv_transfer/kv_connector/v1/example_connector.py
+1
-1
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
...er/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
+1
-3
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
...ted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
+1
-1
vllm/forward_context.py
vllm/forward_context.py
+0
-7
vllm/model_executor/layers/attention/attention.py
vllm/model_executor/layers/attention/attention.py
+2
-2
vllm/model_executor/layers/attention/mla_attention.py
vllm/model_executor/layers/attention/mla_attention.py
+2
-2
vllm/model_executor/layers/attention/static_sink_attention.py
.../model_executor/layers/attention/static_sink_attention.py
+1
-2
vllm/model_executor/layers/kda.py
vllm/model_executor/layers/kda.py
+1
-1
vllm/model_executor/layers/mamba/linear_attn.py
vllm/model_executor/layers/mamba/linear_attn.py
+1
-1
vllm/model_executor/layers/mamba/mamba_mixer.py
vllm/model_executor/layers/mamba/mamba_mixer.py
+1
-1
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+1
-1
vllm/model_executor/layers/mamba/short_conv.py
vllm/model_executor/layers/mamba/short_conv.py
+1
-1
vllm/model_executor/models/bailing_moe_linear.py
vllm/model_executor/models/bailing_moe_linear.py
+1
-1
vllm/model_executor/models/extract_hidden_states.py
vllm/model_executor/models/extract_hidden_states.py
+1
-1
vllm/model_executor/models/olmo_hybrid.py
vllm/model_executor/models/olmo_hybrid.py
+1
-1
No files found.
tests/compile/passes/test_rope_kvcache_fusion.py
View file @
0d81a1fe
...
@@ -295,7 +295,7 @@ def test_rope_kvcache_fusion(
...
@@ -295,7 +295,7 @@ def test_rope_kvcache_fusion(
}
}
q_unfused
,
k_unfused
,
v_unfused
,
dummy
=
model
(
qkv_unfused
,
pos_unfused
)
q_unfused
,
k_unfused
,
v_unfused
,
dummy
=
model
(
qkv_unfused
,
pos_unfused
)
attn_layer
=
forward_context
.
no_compile_layers
[
model
.
layer_name
]
attn_layer
=
forward_context
.
no_compile_layers
[
model
.
layer_name
]
kv_cache_unfused
=
attn_layer
.
kv_cache
[
forward_context
.
virtual_engine
]
kv_cache_unfused
=
attn_layer
.
kv_cache
[
0
]
del
dummy
del
dummy
torch
.
_dynamo
.
mark_dynamic
(
qkv
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
qkv
,
0
)
...
@@ -309,7 +309,7 @@ def test_rope_kvcache_fusion(
...
@@ -309,7 +309,7 @@ def test_rope_kvcache_fusion(
}
}
q_fused
,
k_fused
,
v_fused
,
dummy
=
model_fused
(
qkv
,
pos
)
q_fused
,
k_fused
,
v_fused
,
dummy
=
model_fused
(
qkv
,
pos
)
attn_layer
=
forward_context
.
no_compile_layers
[
model
.
layer_name
]
attn_layer
=
forward_context
.
no_compile_layers
[
model
.
layer_name
]
kv_cache_fused
=
attn_layer
.
kv_cache
[
forward_context
.
virtual_engine
]
kv_cache_fused
=
attn_layer
.
kv_cache
[
0
]
del
dummy
del
dummy
assert
fusion_pass
.
matched_count
==
1
assert
fusion_pass
.
matched_count
==
1
...
...
tests/v1/kv_connector/unit/test_decode_bench_connector.py
View file @
0d81a1fe
...
@@ -86,7 +86,7 @@ class DecodeBenchTestRunner:
...
@@ -86,7 +86,7 @@ class DecodeBenchTestRunner:
self
.
_block_hasher
=
get_request_block_hasher
(
block_size
,
sha256
)
self
.
_block_hasher
=
get_request_block_hasher
(
block_size
,
sha256
)
self
.
_dummy_ctx
:
ForwardContext
=
ForwardContext
(
self
.
_dummy_ctx
:
ForwardContext
=
ForwardContext
(
no_compile_layers
=
{},
attn_metadata
=
{},
virtual_engine
=
0
,
slot_mapping
=
{}
no_compile_layers
=
{},
attn_metadata
=
{},
slot_mapping
=
{}
)
)
def
new_request
(
self
,
token_ids
:
list
[
int
])
->
Request
:
def
new_request
(
self
,
token_ids
:
list
[
int
])
->
Request
:
...
...
tests/v1/kv_connector/unit/test_lmcache_integration.py
View file @
0d81a1fe
...
@@ -211,7 +211,6 @@ def test_forward_context_interface():
...
@@ -211,7 +211,6 @@ def test_forward_context_interface():
from
vllm.forward_context
import
ForwardContext
from
vllm.forward_context
import
ForwardContext
assumes
(
ForwardContext
,
"no_compile_layers"
,
is_instance_of
=
dict
)
assumes
(
ForwardContext
,
"no_compile_layers"
,
is_instance_of
=
dict
)
assumes
(
ForwardContext
,
"virtual_engine"
)
assumes
(
ForwardContext
,
"attn_metadata"
)
assumes
(
ForwardContext
,
"attn_metadata"
)
...
...
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
0d81a1fe
...
@@ -599,7 +599,6 @@ class TestNixlHandshake:
...
@@ -599,7 +599,6 @@ class TestNixlHandshake:
dummy_ctx
=
ForwardContext
(
dummy_ctx
=
ForwardContext
(
no_compile_layers
=
{},
no_compile_layers
=
{},
attn_metadata
=
{},
attn_metadata
=
{},
virtual_engine
=
0
,
slot_mapping
=
{},
slot_mapping
=
{},
)
)
_before_load
=
time
.
perf_counter
()
_before_load
=
time
.
perf_counter
()
...
@@ -672,7 +671,6 @@ class TestNixlHandshake:
...
@@ -672,7 +671,6 @@ class TestNixlHandshake:
dummy_ctx
=
ForwardContext
(
dummy_ctx
=
ForwardContext
(
no_compile_layers
=
{},
no_compile_layers
=
{},
attn_metadata
=
{},
attn_metadata
=
{},
virtual_engine
=
0
,
slot_mapping
=
{},
slot_mapping
=
{},
)
)
_before_load
=
time
.
perf_counter
()
_before_load
=
time
.
perf_counter
()
...
@@ -908,7 +906,6 @@ class TestNixlHandshake:
...
@@ -908,7 +906,6 @@ class TestNixlHandshake:
dummy_ctx
=
ForwardContext
(
dummy_ctx
=
ForwardContext
(
no_compile_layers
=
{},
no_compile_layers
=
{},
attn_metadata
=
{},
attn_metadata
=
{},
virtual_engine
=
0
,
slot_mapping
=
{},
slot_mapping
=
{},
)
)
_before_load
=
time
.
perf_counter
()
_before_load
=
time
.
perf_counter
()
...
@@ -1079,7 +1076,6 @@ def test_kv_connector_stats(default_vllm_config, dist_init):
...
@@ -1079,7 +1076,6 @@ def test_kv_connector_stats(default_vllm_config, dist_init):
dummy_ctx
=
ForwardContext
(
dummy_ctx
=
ForwardContext
(
no_compile_layers
=
{},
no_compile_layers
=
{},
attn_metadata
=
{},
attn_metadata
=
{},
virtual_engine
=
0
,
slot_mapping
=
{},
slot_mapping
=
{},
)
)
connector
.
start_load_kv
(
dummy_ctx
)
connector
.
start_load_kv
(
dummy_ctx
)
...
@@ -1890,7 +1886,6 @@ def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_
...
@@ -1890,7 +1886,6 @@ def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_
dummy_ctx
=
ForwardContext
(
dummy_ctx
=
ForwardContext
(
no_compile_layers
=
{},
no_compile_layers
=
{},
attn_metadata
=
{},
attn_metadata
=
{},
virtual_engine
=
0
,
slot_mapping
=
{},
slot_mapping
=
{},
)
)
connector
.
start_load_kv
(
dummy_ctx
)
connector
.
start_load_kv
(
dummy_ctx
)
...
@@ -2059,7 +2054,6 @@ def test_transfer_failure_logging(
...
@@ -2059,7 +2054,6 @@ def test_transfer_failure_logging(
dummy_ctx
=
ForwardContext
(
dummy_ctx
=
ForwardContext
(
no_compile_layers
=
{},
no_compile_layers
=
{},
attn_metadata
=
{},
attn_metadata
=
{},
virtual_engine
=
0
,
slot_mapping
=
{},
slot_mapping
=
{},
)
)
...
@@ -2162,7 +2156,6 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
...
@@ -2162,7 +2156,6 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
dummy_ctx
=
ForwardContext
(
dummy_ctx
=
ForwardContext
(
no_compile_layers
=
{},
no_compile_layers
=
{},
attn_metadata
=
{},
attn_metadata
=
{},
virtual_engine
=
0
,
slot_mapping
=
{},
slot_mapping
=
{},
)
)
connector
.
start_load_kv
(
dummy_ctx
)
connector
.
start_load_kv
(
dummy_ctx
)
...
@@ -2215,7 +2208,6 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
...
@@ -2215,7 +2208,6 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
dummy_ctx
=
ForwardContext
(
dummy_ctx
=
ForwardContext
(
no_compile_layers
=
{},
no_compile_layers
=
{},
attn_metadata
=
{},
attn_metadata
=
{},
virtual_engine
=
0
,
slot_mapping
=
{},
slot_mapping
=
{},
)
)
connector
.
start_load_kv
(
dummy_ctx
)
connector
.
start_load_kv
(
dummy_ctx
)
...
...
tests/v1/kv_connector/unit/test_offloading_connector.py
View file @
0d81a1fe
...
@@ -261,7 +261,6 @@ class RequestRunner:
...
@@ -261,7 +261,6 @@ class RequestRunner:
self
.
_dummy_ctx
:
ForwardContext
=
ForwardContext
(
self
.
_dummy_ctx
:
ForwardContext
=
ForwardContext
(
no_compile_layers
=
{},
no_compile_layers
=
{},
attn_metadata
=
{},
attn_metadata
=
{},
virtual_engine
=
0
,
slot_mapping
=
{},
slot_mapping
=
{},
)
)
...
...
vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py
View file @
0d81a1fe
...
@@ -185,7 +185,7 @@ class ExampleConnector(KVConnectorBase_V1):
...
@@ -185,7 +185,7 @@ class ExampleConnector(KVConnectorBase_V1):
if
kv_cache_attr
is
None
:
if
kv_cache_attr
is
None
:
continue
continue
kv_cache_layer
=
kv_cache_attr
[
forward_context
.
virtual_engine
]
kv_cache_layer
=
kv_cache_attr
[
0
]
filename
=
self
.
_generate_filename_debug
(
filename
=
self
.
_generate_filename_debug
(
layer_name
,
request
.
token_ids
,
request
.
mm_hashes
layer_name
,
request
.
token_ids
,
request
.
mm_hashes
...
...
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
View file @
0d81a1fe
...
@@ -778,9 +778,7 @@ class LMCacheConnectorV1Impl:
...
@@ -778,9 +778,7 @@ class LMCacheConnectorV1Impl:
continue
continue
if
layer_name
not
in
self
.
kv_caches
:
if
layer_name
not
in
self
.
kv_caches
:
self
.
kv_caches
[
layer_name
]
=
attn_layer
.
kv_cache
[
self
.
kv_caches
[
layer_name
]
=
attn_layer
.
kv_cache
[
0
]
forward_context
.
virtual_engine
]
####################
####################
# Worker side APIs
# Worker side APIs
...
...
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
View file @
0d81a1fe
...
@@ -214,7 +214,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
...
@@ -214,7 +214,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
if
kv_cache
is
None
:
if
kv_cache
is
None
:
continue
continue
layer
=
kv_cache
[
forward_context
.
virtual_engine
]
layer
=
kv_cache
[
0
]
kv_cache
=
self
.
p2p_nccl_engine
.
recv_tensor
(
kv_cache
=
self
.
p2p_nccl_engine
.
recv_tensor
(
request
.
request_id
+
"#"
+
layer_name
,
remote_address
request
.
request_id
+
"#"
+
layer_name
,
remote_address
...
...
vllm/forward_context.py
View file @
0d81a1fe
...
@@ -197,8 +197,6 @@ class ForwardContext:
...
@@ -197,8 +197,6 @@ class ForwardContext:
for each microbatch.
for each microbatch.
Set dynamically for each forward pass
Set dynamically for each forward pass
"""
"""
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine
:
int
# set dynamically for each forward pass
# set dynamically for each forward pass
# set dynamically for each forward pass
dp_metadata
:
DPMetadata
|
None
=
None
dp_metadata
:
DPMetadata
|
None
=
None
# determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
# determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
...
@@ -265,7 +263,6 @@ def is_forward_context_available() -> bool:
...
@@ -265,7 +263,6 @@ def is_forward_context_available() -> bool:
def
create_forward_context
(
def
create_forward_context
(
attn_metadata
:
Any
,
attn_metadata
:
Any
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
virtual_engine
:
int
=
0
,
dp_metadata
:
DPMetadata
|
None
=
None
,
dp_metadata
:
DPMetadata
|
None
=
None
,
cudagraph_runtime_mode
:
CUDAGraphMode
=
CUDAGraphMode
.
NONE
,
cudagraph_runtime_mode
:
CUDAGraphMode
=
CUDAGraphMode
.
NONE
,
batch_descriptor
:
BatchDescriptor
|
None
=
None
,
batch_descriptor
:
BatchDescriptor
|
None
=
None
,
...
@@ -282,7 +279,6 @@ def create_forward_context(
...
@@ -282,7 +279,6 @@ def create_forward_context(
return
ForwardContext
(
return
ForwardContext
(
no_compile_layers
=
vllm_config
.
compilation_config
.
static_forward_context
,
no_compile_layers
=
vllm_config
.
compilation_config
.
static_forward_context
,
all_moe_layers
=
all_moe_layers
,
all_moe_layers
=
all_moe_layers
,
virtual_engine
=
virtual_engine
,
attn_metadata
=
attn_metadata
,
attn_metadata
=
attn_metadata
,
slot_mapping
=
slot_mapping
or
{},
slot_mapping
=
slot_mapping
or
{},
dp_metadata
=
dp_metadata
,
dp_metadata
=
dp_metadata
,
...
@@ -313,7 +309,6 @@ def override_forward_context(forward_context: ForwardContext | None):
...
@@ -313,7 +309,6 @@ def override_forward_context(forward_context: ForwardContext | None):
def
set_forward_context
(
def
set_forward_context
(
attn_metadata
:
Any
,
attn_metadata
:
Any
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
virtual_engine
:
int
=
0
,
num_tokens
:
int
|
None
=
None
,
num_tokens
:
int
|
None
=
None
,
num_tokens_across_dp
:
torch
.
Tensor
|
None
=
None
,
num_tokens_across_dp
:
torch
.
Tensor
|
None
=
None
,
cudagraph_runtime_mode
:
CUDAGraphMode
=
CUDAGraphMode
.
NONE
,
cudagraph_runtime_mode
:
CUDAGraphMode
=
CUDAGraphMode
.
NONE
,
...
@@ -362,7 +357,6 @@ def set_forward_context(
...
@@ -362,7 +357,6 @@ def set_forward_context(
additional_kwargs
=
current_platform
.
set_additional_forward_context
(
additional_kwargs
=
current_platform
.
set_additional_forward_context
(
attn_metadata
=
attn_metadata
,
attn_metadata
=
attn_metadata
,
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
virtual_engine
=
virtual_engine
,
dp_metadata
=
dp_metadata
,
dp_metadata
=
dp_metadata
,
num_tokens
=
num_tokens
,
num_tokens
=
num_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
,
num_tokens_across_dp
=
num_tokens_across_dp
,
...
@@ -374,7 +368,6 @@ def set_forward_context(
...
@@ -374,7 +368,6 @@ def set_forward_context(
forward_context
=
create_forward_context
(
forward_context
=
create_forward_context
(
attn_metadata
,
attn_metadata
,
vllm_config
,
vllm_config
,
virtual_engine
,
dp_metadata
,
dp_metadata
,
cudagraph_runtime_mode
,
cudagraph_runtime_mode
,
batch_descriptor
,
batch_descriptor
,
...
...
vllm/model_executor/layers/attention/attention.py
View file @
0d81a1fe
...
@@ -589,7 +589,7 @@ def get_attention_context(
...
@@ -589,7 +589,7 @@ def get_attention_context(
- attn_metadata: Attention metadata for this specific layer, or None if
- attn_metadata: Attention metadata for this specific layer, or None if
no metadata available
no metadata available
- attn_layer: The attention layer instance (Attention or MLAAttention)
- attn_layer: The attention layer instance (Attention or MLAAttention)
- kv_cache: The KV cache tensor for current
virtual engine
- kv_cache: The KV cache tensor for current
forward pass
- slot_mapping: The slot mapping for this specific layer
- slot_mapping: The slot mapping for this specific layer
Note: attn_metadata may be None, but attn_layer and kv_cache are always
Note: attn_metadata may be None, but attn_layer and kv_cache are always
...
@@ -600,7 +600,7 @@ def get_attention_context(
...
@@ -600,7 +600,7 @@ def get_attention_context(
if
isinstance
(
attn_metadata
,
dict
):
if
isinstance
(
attn_metadata
,
dict
):
attn_metadata
=
attn_metadata
[
layer_name
]
attn_metadata
=
attn_metadata
[
layer_name
]
attn_layer
:
Attention
|
MLAAttention
=
forward_context
.
no_compile_layers
[
layer_name
]
attn_layer
:
Attention
|
MLAAttention
=
forward_context
.
no_compile_layers
[
layer_name
]
kv_cache
=
attn_layer
.
kv_cache
[
forward_context
.
virtual_engine
]
kv_cache
=
attn_layer
.
kv_cache
[
0
]
slot_mapping
=
forward_context
.
slot_mapping
slot_mapping
=
forward_context
.
slot_mapping
assert
isinstance
(
slot_mapping
,
dict
),
(
assert
isinstance
(
slot_mapping
,
dict
),
(
f
"Expected slot_mapping to be a dict, got
{
type
(
slot_mapping
)
}
. "
f
"Expected slot_mapping to be a dict, got
{
type
(
slot_mapping
)
}
. "
...
...
vllm/model_executor/layers/attention/mla_attention.py
View file @
0d81a1fe
...
@@ -480,7 +480,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
...
@@ -480,7 +480,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
attn_metadata
=
forward_context
.
attn_metadata
attn_metadata
=
forward_context
.
attn_metadata
if
isinstance
(
attn_metadata
,
dict
):
if
isinstance
(
attn_metadata
,
dict
):
attn_metadata
=
attn_metadata
[
self
.
layer_name
]
attn_metadata
=
attn_metadata
[
self
.
layer_name
]
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
self_kv_cache
=
self
.
kv_cache
[
0
]
slot_mapping
=
forward_context
.
slot_mapping
slot_mapping
=
forward_context
.
slot_mapping
assert
isinstance
(
slot_mapping
,
dict
),
(
assert
isinstance
(
slot_mapping
,
dict
),
(
...
@@ -940,7 +940,7 @@ def unified_mla_kv_cache_update(
...
@@ -940,7 +940,7 @@ def unified_mla_kv_cache_update(
return
torch
.
empty
(
0
,
device
=
kv_c_normed
.
device
,
dtype
=
kv_c_normed
.
dtype
)
return
torch
.
empty
(
0
,
device
=
kv_c_normed
.
device
,
dtype
=
kv_c_normed
.
dtype
)
attn_layer
=
forward_context
.
no_compile_layers
[
layer_name
]
attn_layer
=
forward_context
.
no_compile_layers
[
layer_name
]
kv_cache
=
attn_layer
.
kv_cache
[
forward_context
.
virtual_engine
]
kv_cache
=
attn_layer
.
kv_cache
[
0
]
slot_mapping
=
forward_context
.
slot_mapping
slot_mapping
=
forward_context
.
slot_mapping
assert
isinstance
(
slot_mapping
,
dict
),
(
assert
isinstance
(
slot_mapping
,
dict
),
(
...
...
vllm/model_executor/layers/attention/static_sink_attention.py
View file @
0d81a1fe
...
@@ -168,8 +168,7 @@ class StaticSinkAttention(Attention, CustomOp):
...
@@ -168,8 +168,7 @@ class StaticSinkAttention(Attention, CustomOp):
"sink_key and sink_value have not been prepared"
"sink_key and sink_value have not been prepared"
)
)
if
not
self
.
sink_populated
:
if
not
self
.
sink_populated
:
forward_context
:
ForwardContext
=
get_forward_context
()
self_kv_cache
=
self
.
kv_cache
[
0
]
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
torch
.
ops
.
vllm
.
maybe_populate_sink
(
self_kv_cache
,
self
.
layer_name
)
torch
.
ops
.
vllm
.
maybe_populate_sink
(
self_kv_cache
,
self
.
layer_name
)
return
super
().
forward
(
query
,
key
,
value
,
output_shape
)
return
super
().
forward
(
query
,
key
,
value
,
output_shape
)
...
...
vllm/model_executor/layers/kda.py
View file @
0d81a1fe
...
@@ -306,7 +306,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
...
@@ -306,7 +306,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
non_spec_query_start_loc
=
attn_metadata
.
non_spec_query_start_loc
non_spec_query_start_loc
=
attn_metadata
.
non_spec_query_start_loc
non_spec_state_indices_tensor
=
attn_metadata
.
non_spec_state_indices_tensor
# noqa: E501
non_spec_state_indices_tensor
=
attn_metadata
.
non_spec_state_indices_tensor
# noqa: E501
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
constant_caches
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
constant_caches
=
self
.
kv_cache
[
0
]
q_proj_states
=
q_proj_states
[:
num_actual_tokens
]
q_proj_states
=
q_proj_states
[:
num_actual_tokens
]
k_proj_states
=
k_proj_states
[:
num_actual_tokens
]
k_proj_states
=
k_proj_states
[:
num_actual_tokens
]
...
...
vllm/model_executor/layers/mamba/linear_attn.py
View file @
0d81a1fe
...
@@ -413,7 +413,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
...
@@ -413,7 +413,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
qkvact
=
qkvact
.
view
((
qkv
.
shape
[
0
],
self
.
tp_heads
,
-
1
))
qkvact
=
qkvact
.
view
((
qkv
.
shape
[
0
],
self
.
tp_heads
,
-
1
))
q
,
k
,
v
=
torch
.
split
(
qkvact
,
[
self
.
head_dim
]
*
3
,
dim
=-
1
)
q
,
k
,
v
=
torch
.
split
(
qkvact
,
[
self
.
head_dim
]
*
3
,
dim
=-
1
)
if
attn_metadata
is
not
None
:
if
attn_metadata
is
not
None
:
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
][
0
]
kv_cache
=
self
.
kv_cache
[
0
][
0
]
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
clear_linear_attention_cache_for_new_sequences
(
clear_linear_attention_cache_for_new_sequences
(
kv_cache
,
state_indices_tensor
,
attn_metadata
kv_cache
,
state_indices_tensor
,
attn_metadata
...
...
vllm/model_executor/layers/mamba/mamba_mixer.py
View file @
0d81a1fe
...
@@ -267,7 +267,7 @@ class MambaMixer(MambaBase, PluggableLayer):
...
@@ -267,7 +267,7 @@ class MambaMixer(MambaBase, PluggableLayer):
query_start_loc_p
=
attn_metadata
.
query_start_loc_p
query_start_loc_p
=
attn_metadata
.
query_start_loc_p
state_indices_tensor_p
=
attn_metadata
.
state_indices_tensor_p
state_indices_tensor_p
=
attn_metadata
.
state_indices_tensor_p
state_indices_tensor_d
=
attn_metadata
.
state_indices_tensor_d
state_indices_tensor_d
=
attn_metadata
.
state_indices_tensor_d
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
self_kv_cache
=
self
.
kv_cache
[
0
]
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
ssm_state
=
self_kv_cache
[
1
]
ssm_state
=
self_kv_cache
[
1
]
has_initial_states_p
=
attn_metadata
.
has_initial_states_p
has_initial_states_p
=
attn_metadata
.
has_initial_states_p
...
...
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
0d81a1fe
...
@@ -575,7 +575,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
...
@@ -575,7 +575,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
assert
isinstance
(
attn_metadata
,
dict
)
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
attn_metadata
=
attn_metadata
[
self
.
prefix
]
assert
isinstance
(
attn_metadata
,
Mamba2AttentionMetadata
)
assert
isinstance
(
attn_metadata
,
Mamba2AttentionMetadata
)
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
self_kv_cache
=
self
.
kv_cache
[
0
]
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
ssm_state
=
self_kv_cache
[
1
]
ssm_state
=
self_kv_cache
[
1
]
...
...
vllm/model_executor/layers/mamba/short_conv.py
View file @
0d81a1fe
...
@@ -117,7 +117,7 @@ class ShortConv(MambaBase, CustomOp):
...
@@ -117,7 +117,7 @@ class ShortConv(MambaBase, CustomOp):
assert
isinstance
(
attn_metadata
,
dict
)
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
attn_metadata
=
attn_metadata
[
self
.
prefix
]
assert
isinstance
(
attn_metadata
,
ShortConvAttentionMetadata
)
assert
isinstance
(
attn_metadata
,
ShortConvAttentionMetadata
)
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
self_kv_cache
=
self
.
kv_cache
[
0
]
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
state_indices_tensor_p
=
attn_metadata
.
state_indices_tensor_p
state_indices_tensor_p
=
attn_metadata
.
state_indices_tensor_p
state_indices_tensor_d
=
attn_metadata
.
state_indices_tensor_d
state_indices_tensor_d
=
attn_metadata
.
state_indices_tensor_d
...
...
vllm/model_executor/models/bailing_moe_linear.py
View file @
0d81a1fe
...
@@ -709,7 +709,7 @@ class BailingMoELinearAttention(nn.Module, MambaBase):
...
@@ -709,7 +709,7 @@ class BailingMoELinearAttention(nn.Module, MambaBase):
# Get KV cache and state indices
# Get KV cache and state indices
if
attn_metadata
is
not
None
:
if
attn_metadata
is
not
None
:
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
][
0
]
kv_cache
=
self
.
kv_cache
[
0
][
0
]
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
clear_linear_attention_cache_for_new_sequences
(
clear_linear_attention_cache_for_new_sequences
(
kv_cache
,
state_indices_tensor
,
attn_metadata
kv_cache
,
state_indices_tensor
,
attn_metadata
...
...
vllm/model_executor/models/extract_hidden_states.py
View file @
0d81a1fe
...
@@ -51,7 +51,7 @@ def unified_kv_cache_update(
...
@@ -51,7 +51,7 @@ def unified_kv_cache_update(
"""
"""
forward_context
=
get_forward_context
()
forward_context
=
get_forward_context
()
attn_layer
=
forward_context
.
no_compile_layers
[
layer_name
]
attn_layer
=
forward_context
.
no_compile_layers
[
layer_name
]
kv_cache
=
attn_layer
.
kv_cache
[
forward_context
.
virtual_engine
]
kv_cache
=
attn_layer
.
kv_cache
[
0
]
slot_mapping
=
forward_context
.
slot_mapping
slot_mapping
=
forward_context
.
slot_mapping
assert
isinstance
(
slot_mapping
,
dict
),
(
assert
isinstance
(
slot_mapping
,
dict
),
(
...
...
vllm/model_executor/models/olmo_hybrid.py
View file @
0d81a1fe
...
@@ -428,7 +428,7 @@ class OlmoHybridGatedDeltaNet(nn.Module, MambaBase):
...
@@ -428,7 +428,7 @@ class OlmoHybridGatedDeltaNet(nn.Module, MambaBase):
non_spec_token_indx
=
attn_metadata
.
non_spec_token_indx
non_spec_token_indx
=
attn_metadata
.
non_spec_token_indx
spec_state_indices_tensor
=
attn_metadata
.
spec_state_indices_tensor
spec_state_indices_tensor
=
attn_metadata
.
spec_state_indices_tensor
non_spec_state_indices_tensor
=
attn_metadata
.
non_spec_state_indices_tensor
non_spec_state_indices_tensor
=
attn_metadata
.
non_spec_state_indices_tensor
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
self_kv_cache
=
self
.
kv_cache
[
0
]
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
ssm_state
=
self_kv_cache
[
1
]
ssm_state
=
self_kv_cache
[
1
]
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment