Unverified Commit 0d81a1fe authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[V0 Deprecation] Deprecate virtual engine (#37195)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 6ae4c8d6
...@@ -295,7 +295,7 @@ def test_rope_kvcache_fusion( ...@@ -295,7 +295,7 @@ def test_rope_kvcache_fusion(
} }
q_unfused, k_unfused, v_unfused, dummy = model(qkv_unfused, pos_unfused) q_unfused, k_unfused, v_unfused, dummy = model(qkv_unfused, pos_unfused)
attn_layer = forward_context.no_compile_layers[model.layer_name] attn_layer = forward_context.no_compile_layers[model.layer_name]
kv_cache_unfused = attn_layer.kv_cache[forward_context.virtual_engine] kv_cache_unfused = attn_layer.kv_cache[0]
del dummy del dummy
torch._dynamo.mark_dynamic(qkv, 0) torch._dynamo.mark_dynamic(qkv, 0)
...@@ -309,7 +309,7 @@ def test_rope_kvcache_fusion( ...@@ -309,7 +309,7 @@ def test_rope_kvcache_fusion(
} }
q_fused, k_fused, v_fused, dummy = model_fused(qkv, pos) q_fused, k_fused, v_fused, dummy = model_fused(qkv, pos)
attn_layer = forward_context.no_compile_layers[model.layer_name] attn_layer = forward_context.no_compile_layers[model.layer_name]
kv_cache_fused = attn_layer.kv_cache[forward_context.virtual_engine] kv_cache_fused = attn_layer.kv_cache[0]
del dummy del dummy
assert fusion_pass.matched_count == 1 assert fusion_pass.matched_count == 1
......
...@@ -86,7 +86,7 @@ class DecodeBenchTestRunner: ...@@ -86,7 +86,7 @@ class DecodeBenchTestRunner:
self._block_hasher = get_request_block_hasher(block_size, sha256) self._block_hasher = get_request_block_hasher(block_size, sha256)
self._dummy_ctx: ForwardContext = ForwardContext( self._dummy_ctx: ForwardContext = ForwardContext(
no_compile_layers={}, attn_metadata={}, virtual_engine=0, slot_mapping={} no_compile_layers={}, attn_metadata={}, slot_mapping={}
) )
def new_request(self, token_ids: list[int]) -> Request: def new_request(self, token_ids: list[int]) -> Request:
......
...@@ -211,7 +211,6 @@ def test_forward_context_interface(): ...@@ -211,7 +211,6 @@ def test_forward_context_interface():
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
assumes(ForwardContext, "no_compile_layers", is_instance_of=dict) assumes(ForwardContext, "no_compile_layers", is_instance_of=dict)
assumes(ForwardContext, "virtual_engine")
assumes(ForwardContext, "attn_metadata") assumes(ForwardContext, "attn_metadata")
......
...@@ -599,7 +599,6 @@ class TestNixlHandshake: ...@@ -599,7 +599,6 @@ class TestNixlHandshake:
dummy_ctx = ForwardContext( dummy_ctx = ForwardContext(
no_compile_layers={}, no_compile_layers={},
attn_metadata={}, attn_metadata={},
virtual_engine=0,
slot_mapping={}, slot_mapping={},
) )
_before_load = time.perf_counter() _before_load = time.perf_counter()
...@@ -672,7 +671,6 @@ class TestNixlHandshake: ...@@ -672,7 +671,6 @@ class TestNixlHandshake:
dummy_ctx = ForwardContext( dummy_ctx = ForwardContext(
no_compile_layers={}, no_compile_layers={},
attn_metadata={}, attn_metadata={},
virtual_engine=0,
slot_mapping={}, slot_mapping={},
) )
_before_load = time.perf_counter() _before_load = time.perf_counter()
...@@ -908,7 +906,6 @@ class TestNixlHandshake: ...@@ -908,7 +906,6 @@ class TestNixlHandshake:
dummy_ctx = ForwardContext( dummy_ctx = ForwardContext(
no_compile_layers={}, no_compile_layers={},
attn_metadata={}, attn_metadata={},
virtual_engine=0,
slot_mapping={}, slot_mapping={},
) )
_before_load = time.perf_counter() _before_load = time.perf_counter()
...@@ -1079,7 +1076,6 @@ def test_kv_connector_stats(default_vllm_config, dist_init): ...@@ -1079,7 +1076,6 @@ def test_kv_connector_stats(default_vllm_config, dist_init):
dummy_ctx = ForwardContext( dummy_ctx = ForwardContext(
no_compile_layers={}, no_compile_layers={},
attn_metadata={}, attn_metadata={},
virtual_engine=0,
slot_mapping={}, slot_mapping={},
) )
connector.start_load_kv(dummy_ctx) connector.start_load_kv(dummy_ctx)
...@@ -1890,7 +1886,6 @@ def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_ ...@@ -1890,7 +1886,6 @@ def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_
dummy_ctx = ForwardContext( dummy_ctx = ForwardContext(
no_compile_layers={}, no_compile_layers={},
attn_metadata={}, attn_metadata={},
virtual_engine=0,
slot_mapping={}, slot_mapping={},
) )
connector.start_load_kv(dummy_ctx) connector.start_load_kv(dummy_ctx)
...@@ -2059,7 +2054,6 @@ def test_transfer_failure_logging( ...@@ -2059,7 +2054,6 @@ def test_transfer_failure_logging(
dummy_ctx = ForwardContext( dummy_ctx = ForwardContext(
no_compile_layers={}, no_compile_layers={},
attn_metadata={}, attn_metadata={},
virtual_engine=0,
slot_mapping={}, slot_mapping={},
) )
...@@ -2162,7 +2156,6 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init): ...@@ -2162,7 +2156,6 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
dummy_ctx = ForwardContext( dummy_ctx = ForwardContext(
no_compile_layers={}, no_compile_layers={},
attn_metadata={}, attn_metadata={},
virtual_engine=0,
slot_mapping={}, slot_mapping={},
) )
connector.start_load_kv(dummy_ctx) connector.start_load_kv(dummy_ctx)
...@@ -2215,7 +2208,6 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init) ...@@ -2215,7 +2208,6 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
dummy_ctx = ForwardContext( dummy_ctx = ForwardContext(
no_compile_layers={}, no_compile_layers={},
attn_metadata={}, attn_metadata={},
virtual_engine=0,
slot_mapping={}, slot_mapping={},
) )
connector.start_load_kv(dummy_ctx) connector.start_load_kv(dummy_ctx)
......
...@@ -261,7 +261,6 @@ class RequestRunner: ...@@ -261,7 +261,6 @@ class RequestRunner:
self._dummy_ctx: ForwardContext = ForwardContext( self._dummy_ctx: ForwardContext = ForwardContext(
no_compile_layers={}, no_compile_layers={},
attn_metadata={}, attn_metadata={},
virtual_engine=0,
slot_mapping={}, slot_mapping={},
) )
......
...@@ -185,7 +185,7 @@ class ExampleConnector(KVConnectorBase_V1): ...@@ -185,7 +185,7 @@ class ExampleConnector(KVConnectorBase_V1):
if kv_cache_attr is None: if kv_cache_attr is None:
continue continue
kv_cache_layer = kv_cache_attr[forward_context.virtual_engine] kv_cache_layer = kv_cache_attr[0]
filename = self._generate_filename_debug( filename = self._generate_filename_debug(
layer_name, request.token_ids, request.mm_hashes layer_name, request.token_ids, request.mm_hashes
......
...@@ -778,9 +778,7 @@ class LMCacheConnectorV1Impl: ...@@ -778,9 +778,7 @@ class LMCacheConnectorV1Impl:
continue continue
if layer_name not in self.kv_caches: if layer_name not in self.kv_caches:
self.kv_caches[layer_name] = attn_layer.kv_cache[ self.kv_caches[layer_name] = attn_layer.kv_cache[0]
forward_context.virtual_engine
]
#################### ####################
# Worker side APIs # Worker side APIs
......
...@@ -214,7 +214,7 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -214,7 +214,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
if kv_cache is None: if kv_cache is None:
continue continue
layer = kv_cache[forward_context.virtual_engine] layer = kv_cache[0]
kv_cache = self.p2p_nccl_engine.recv_tensor( kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name, remote_address request.request_id + "#" + layer_name, remote_address
......
...@@ -197,8 +197,6 @@ class ForwardContext: ...@@ -197,8 +197,6 @@ class ForwardContext:
for each microbatch. for each microbatch.
Set dynamically for each forward pass Set dynamically for each forward pass
""" """
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass # set dynamically for each forward pass
dp_metadata: DPMetadata | None = None dp_metadata: DPMetadata | None = None
# determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE. # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
...@@ -265,7 +263,6 @@ def is_forward_context_available() -> bool: ...@@ -265,7 +263,6 @@ def is_forward_context_available() -> bool:
def create_forward_context( def create_forward_context(
attn_metadata: Any, attn_metadata: Any,
vllm_config: VllmConfig, vllm_config: VllmConfig,
virtual_engine: int = 0,
dp_metadata: DPMetadata | None = None, dp_metadata: DPMetadata | None = None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None, batch_descriptor: BatchDescriptor | None = None,
...@@ -282,7 +279,6 @@ def create_forward_context( ...@@ -282,7 +279,6 @@ def create_forward_context(
return ForwardContext( return ForwardContext(
no_compile_layers=vllm_config.compilation_config.static_forward_context, no_compile_layers=vllm_config.compilation_config.static_forward_context,
all_moe_layers=all_moe_layers, all_moe_layers=all_moe_layers,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
slot_mapping=slot_mapping or {}, slot_mapping=slot_mapping or {},
dp_metadata=dp_metadata, dp_metadata=dp_metadata,
...@@ -313,7 +309,6 @@ def override_forward_context(forward_context: ForwardContext | None): ...@@ -313,7 +309,6 @@ def override_forward_context(forward_context: ForwardContext | None):
def set_forward_context( def set_forward_context(
attn_metadata: Any, attn_metadata: Any,
vllm_config: VllmConfig, vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: int | None = None, num_tokens: int | None = None,
num_tokens_across_dp: torch.Tensor | None = None, num_tokens_across_dp: torch.Tensor | None = None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
...@@ -362,7 +357,6 @@ def set_forward_context( ...@@ -362,7 +357,6 @@ def set_forward_context(
additional_kwargs = current_platform.set_additional_forward_context( additional_kwargs = current_platform.set_additional_forward_context(
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
vllm_config=vllm_config, vllm_config=vllm_config,
virtual_engine=virtual_engine,
dp_metadata=dp_metadata, dp_metadata=dp_metadata,
num_tokens=num_tokens, num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
...@@ -374,7 +368,6 @@ def set_forward_context( ...@@ -374,7 +368,6 @@ def set_forward_context(
forward_context = create_forward_context( forward_context = create_forward_context(
attn_metadata, attn_metadata,
vllm_config, vllm_config,
virtual_engine,
dp_metadata, dp_metadata,
cudagraph_runtime_mode, cudagraph_runtime_mode,
batch_descriptor, batch_descriptor,
......
...@@ -589,7 +589,7 @@ def get_attention_context( ...@@ -589,7 +589,7 @@ def get_attention_context(
- attn_metadata: Attention metadata for this specific layer, or None if - attn_metadata: Attention metadata for this specific layer, or None if
no metadata available no metadata available
- attn_layer: The attention layer instance (Attention or MLAAttention) - attn_layer: The attention layer instance (Attention or MLAAttention)
- kv_cache: The KV cache tensor for current virtual engine - kv_cache: The KV cache tensor for current forward pass
- slot_mapping: The slot mapping for this specific layer - slot_mapping: The slot mapping for this specific layer
Note: attn_metadata may be None, but attn_layer and kv_cache are always Note: attn_metadata may be None, but attn_layer and kv_cache are always
...@@ -600,7 +600,7 @@ def get_attention_context( ...@@ -600,7 +600,7 @@ def get_attention_context(
if isinstance(attn_metadata, dict): if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name] attn_metadata = attn_metadata[layer_name]
attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name] attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine] kv_cache = attn_layer.kv_cache[0]
slot_mapping = forward_context.slot_mapping slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), ( assert isinstance(slot_mapping, dict), (
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. " f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
......
...@@ -480,7 +480,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -480,7 +480,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict): if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name] attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine] self_kv_cache = self.kv_cache[0]
slot_mapping = forward_context.slot_mapping slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), ( assert isinstance(slot_mapping, dict), (
...@@ -940,7 +940,7 @@ def unified_mla_kv_cache_update( ...@@ -940,7 +940,7 @@ def unified_mla_kv_cache_update(
return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype) return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype)
attn_layer = forward_context.no_compile_layers[layer_name] attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine] kv_cache = attn_layer.kv_cache[0]
slot_mapping = forward_context.slot_mapping slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), ( assert isinstance(slot_mapping, dict), (
......
...@@ -168,8 +168,7 @@ class StaticSinkAttention(Attention, CustomOp): ...@@ -168,8 +168,7 @@ class StaticSinkAttention(Attention, CustomOp):
"sink_key and sink_value have not been prepared" "sink_key and sink_value have not been prepared"
) )
if not self.sink_populated: if not self.sink_populated:
forward_context: ForwardContext = get_forward_context() self_kv_cache = self.kv_cache[0]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name) torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name)
return super().forward(query, key, value, output_shape) return super().forward(query, key, value, output_shape)
......
...@@ -306,7 +306,7 @@ class KimiDeltaAttention(nn.Module, MambaBase): ...@@ -306,7 +306,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
constant_caches = self.kv_cache[forward_context.virtual_engine] constant_caches = self.kv_cache[0]
q_proj_states = q_proj_states[:num_actual_tokens] q_proj_states = q_proj_states[:num_actual_tokens]
k_proj_states = k_proj_states[:num_actual_tokens] k_proj_states = k_proj_states[:num_actual_tokens]
......
...@@ -413,7 +413,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): ...@@ -413,7 +413,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
if attn_metadata is not None: if attn_metadata is not None:
kv_cache = self.kv_cache[forward_context.virtual_engine][0] kv_cache = self.kv_cache[0][0]
state_indices_tensor = attn_metadata.state_indices_tensor state_indices_tensor = attn_metadata.state_indices_tensor
clear_linear_attention_cache_for_new_sequences( clear_linear_attention_cache_for_new_sequences(
kv_cache, state_indices_tensor, attn_metadata kv_cache, state_indices_tensor, attn_metadata
......
...@@ -267,7 +267,7 @@ class MambaMixer(MambaBase, PluggableLayer): ...@@ -267,7 +267,7 @@ class MambaMixer(MambaBase, PluggableLayer):
query_start_loc_p = attn_metadata.query_start_loc_p query_start_loc_p = attn_metadata.query_start_loc_p
state_indices_tensor_p = attn_metadata.state_indices_tensor_p state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d state_indices_tensor_d = attn_metadata.state_indices_tensor_d
self_kv_cache = self.kv_cache[forward_context.virtual_engine] self_kv_cache = self.kv_cache[0]
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p has_initial_states_p = attn_metadata.has_initial_states_p
......
...@@ -575,7 +575,7 @@ class MambaMixer2(MambaBase, PluggableLayer): ...@@ -575,7 +575,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata) assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine] self_kv_cache = self.kv_cache[0]
# conv_state = (..., dim, width-1) yet contiguous along 'dim' # conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
......
...@@ -117,7 +117,7 @@ class ShortConv(MambaBase, CustomOp): ...@@ -117,7 +117,7 @@ class ShortConv(MambaBase, CustomOp):
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, ShortConvAttentionMetadata) assert isinstance(attn_metadata, ShortConvAttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine] self_kv_cache = self.kv_cache[0]
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
state_indices_tensor_p = attn_metadata.state_indices_tensor_p state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d state_indices_tensor_d = attn_metadata.state_indices_tensor_d
......
...@@ -709,7 +709,7 @@ class BailingMoELinearAttention(nn.Module, MambaBase): ...@@ -709,7 +709,7 @@ class BailingMoELinearAttention(nn.Module, MambaBase):
# Get KV cache and state indices # Get KV cache and state indices
if attn_metadata is not None: if attn_metadata is not None:
kv_cache = self.kv_cache[forward_context.virtual_engine][0] kv_cache = self.kv_cache[0][0]
state_indices_tensor = attn_metadata.state_indices_tensor state_indices_tensor = attn_metadata.state_indices_tensor
clear_linear_attention_cache_for_new_sequences( clear_linear_attention_cache_for_new_sequences(
kv_cache, state_indices_tensor, attn_metadata kv_cache, state_indices_tensor, attn_metadata
......
...@@ -51,7 +51,7 @@ def unified_kv_cache_update( ...@@ -51,7 +51,7 @@ def unified_kv_cache_update(
""" """
forward_context = get_forward_context() forward_context = get_forward_context()
attn_layer = forward_context.no_compile_layers[layer_name] attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine] kv_cache = attn_layer.kv_cache[0]
slot_mapping = forward_context.slot_mapping slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), ( assert isinstance(slot_mapping, dict), (
......
...@@ -428,7 +428,7 @@ class OlmoHybridGatedDeltaNet(nn.Module, MambaBase): ...@@ -428,7 +428,7 @@ class OlmoHybridGatedDeltaNet(nn.Module, MambaBase):
non_spec_token_indx = attn_metadata.non_spec_token_indx non_spec_token_indx = attn_metadata.non_spec_token_indx
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor
self_kv_cache = self.kv_cache[forward_context.virtual_engine] self_kv_cache = self.kv_cache[0]
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment