"vscode:/vscode.git/clone" did not exist on "1b1e8e05ff3c26b98e4161bd3c8671e86fb145f4"
Unverified Commit 0d81a1fe authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[V0 Deprecation] Deprecate virtual engine (#37195)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 6ae4c8d6
......@@ -262,7 +262,7 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer):
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self_kv_cache = self.kv_cache[0]
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
......
......@@ -842,7 +842,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
a=a,
core_attn_out=core_attn_out,
attn_metadata=attn_metadata,
virtual_engine=forward_context.virtual_engine,
)
has_initial_state = attn_metadata.has_initial_state
......@@ -853,7 +852,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
non_spec_token_indx = attn_metadata.non_spec_token_indx
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self_kv_cache = self.kv_cache[0]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens
......@@ -1036,13 +1035,12 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
a: torch.Tensor,
core_attn_out: torch.Tensor,
attn_metadata: GDNAttentionMetadata,
virtual_engine: int,
):
"""
Core attention computation with a packed non-spec decode fast path.
"""
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[virtual_engine]
self_kv_cache = self.kv_cache[0]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens
......
......@@ -510,7 +510,7 @@ def bind_kv_cache(
# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
# NOTE: Keep list wrapper for layers that index kv_cache by engine slot.
forward_context[layer_name].kv_cache = [kv_cache]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment