Unverified Commit 0d81a1fe authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[V0 Deprecation] Deprecate virtual engine (#37195)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 6ae4c8d6
...@@ -262,7 +262,7 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer): ...@@ -262,7 +262,7 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer):
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata) assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine] self_kv_cache = self.kv_cache[0]
# conv_state = (..., dim, width-1) yet contiguous along 'dim' # conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
......
...@@ -842,7 +842,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ...@@ -842,7 +842,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
a=a, a=a,
core_attn_out=core_attn_out, core_attn_out=core_attn_out,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
virtual_engine=forward_context.virtual_engine,
) )
has_initial_state = attn_metadata.has_initial_state has_initial_state = attn_metadata.has_initial_state
...@@ -853,7 +852,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ...@@ -853,7 +852,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
non_spec_token_indx = attn_metadata.non_spec_token_indx non_spec_token_indx = attn_metadata.non_spec_token_indx
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[forward_context.virtual_engine] self_kv_cache = self.kv_cache[0]
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
...@@ -1036,13 +1035,12 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ...@@ -1036,13 +1035,12 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
a: torch.Tensor, a: torch.Tensor,
core_attn_out: torch.Tensor, core_attn_out: torch.Tensor,
attn_metadata: GDNAttentionMetadata, attn_metadata: GDNAttentionMetadata,
virtual_engine: int,
): ):
""" """
Core attention computation with a packed non-spec decode fast path. Core attention computation with a packed non-spec decode fast path.
""" """
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[virtual_engine] self_kv_cache = self.kv_cache[0]
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
......
...@@ -510,7 +510,7 @@ def bind_kv_cache( ...@@ -510,7 +510,7 @@ def bind_kv_cache(
# Bind kv_caches to forward context # Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items(): for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine. # NOTE: Keep list wrapper for layers that index kv_cache by engine slot.
forward_context[layer_name].kv_cache = [kv_cache] forward_context[layer_name].kv_cache = [kv_cache]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment