Commit f4fd3a96 authored by yangshj1's avatar yangshj1
Browse files

add pd lmcache

parent 4a80b456
...@@ -7,11 +7,13 @@ import torch ...@@ -7,11 +7,13 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionType from vllm.attention import AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import CacheConfig, get_current_vllm_config from vllm.config import CacheConfig, get_current_vllm_config
from vllm.distributed.kv_transfer import (get_kv_transfer_group, from vllm.distributed.kv_transfer import (get_kv_transfer_group,
get_lmcache_connector,
has_kv_transfer_group, has_kv_transfer_group,
is_v1_kv_transfer_group) is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
...@@ -22,7 +24,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod ...@@ -22,7 +24,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import _Backend, current_platform from vllm.platforms import _Backend, current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.utils import validate_kv_sharing_target from vllm.v1.attention.backends.utils import validate_kv_sharing_target
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import LMCacheConnectorV1
class Attention(nn.Module): class Attention(nn.Module):
"""Attention layer. """Attention layer.
...@@ -376,6 +378,9 @@ def wait_for_kv_layer_from_connector(layer_name: str): ...@@ -376,6 +378,9 @@ def wait_for_kv_layer_from_connector(layer_name: str):
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
connector.wait_for_layer_load(layer_name) connector.wait_for_layer_load(layer_name)
get_lmcache_connector().wait_for_layer_load(layer_name)
def maybe_save_kv_layer_to_connector( def maybe_save_kv_layer_to_connector(
layer_name: str, layer_name: str,
...@@ -394,6 +399,9 @@ def maybe_save_kv_layer_to_connector( ...@@ -394,6 +399,9 @@ def maybe_save_kv_layer_to_connector(
connector.save_kv_layer(layer_name, kv_cache_layer, connector.save_kv_layer(layer_name, kv_cache_layer,
attn_metadata[layer_name]) attn_metadata[layer_name])
get_lmcache_connector().save_kv_layer(layer_name, kv_cache_layer,
attn_metadata[layer_name])
def unified_attention( def unified_attention(
query: torch.Tensor, query: torch.Tensor,
......
...@@ -3,24 +3,32 @@ ...@@ -3,24 +3,32 @@
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from vllm import envs from vllm import envs
from vllm.config import get_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
from vllm.distributed.kv_transfer.kv_connector.factory import ( from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory) KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
KVConnectorRole) KVConnectorRole)
from vllm.distributed.parallel_state import get_world_group from vllm.distributed.parallel_state import get_world_group
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import LMCacheConnectorV1
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
_KV_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None _KV_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None
_KV_LMCACHE_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None
def get_kv_transfer_group() -> KVConnectorBaseType: def get_kv_transfer_group() -> KVConnectorBaseType:
assert _KV_CONNECTOR_AGENT is not None, ( assert _KV_CONNECTOR_AGENT is not None, (
"disaggregated KV cache transfer parallel group is not initialized") "disaggregated KV cache transfer parallel group is not initialized")
return _KV_CONNECTOR_AGENT return _KV_CONNECTOR_AGENT
def get_lmcache_connector() -> LMCacheConnectorV1:
assert _KV_LMCACHE_CONNECTOR_AGENT is not None, (
"LM cache transfer parallel group is not initialized")
return _KV_LMCACHE_CONNECTOR_AGENT
def has_kv_transfer_group() -> bool: def has_kv_transfer_group() -> bool:
return _KV_CONNECTOR_AGENT is not None return _KV_CONNECTOR_AGENT is not None
...@@ -54,6 +62,7 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: ...@@ -54,6 +62,7 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
""" """
global _KV_CONNECTOR_AGENT global _KV_CONNECTOR_AGENT
global _KV_LMCACHE_CONNECTOR_AGENT
if vllm_config.kv_transfer_config is None: if vllm_config.kv_transfer_config is None:
return return
...@@ -63,6 +72,13 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: ...@@ -63,6 +72,13 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1( _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1(
config=vllm_config, role=KVConnectorRole.WORKER) config=vllm_config, role=KVConnectorRole.WORKER)
lmcache_config = vllm_config
lmcache_config.kv_transfer_config.kv_role = "kv_both"
lmcache_connector: LMCacheConnectorV1 = LMCacheConnectorV1(
lmcache_config, role=KVConnectorRole.WORKER)
_KV_LMCACHE_CONNECTOR_AGENT = lmcache_connector
else: else:
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0( _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0(
rank=get_world_group().rank, rank=get_world_group().rank,
......
...@@ -34,6 +34,7 @@ from vllm.v1.outputs import ModelRunnerOutput ...@@ -34,6 +34,7 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import LMCacheConnectorV1
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -86,6 +87,11 @@ class Scheduler(SchedulerInterface): ...@@ -86,6 +87,11 @@ class Scheduler(SchedulerInterface):
self.connector = KVConnectorFactory.create_connector_v1( self.connector = KVConnectorFactory.create_connector_v1(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER) config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
lmcache_config = self.vllm_config
lmcache_config.kv_transfer_config.kv_role = "kv_both"
self.lmcache_connector: LMCacheConnectorV1 = LMCacheConnectorV1(
lmcache_config, role=KVConnectorRole.SCHEDULER)
self.kv_event_publisher = EventPublisherFactory.create( self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config, self.kv_events_config,
self.parallel_config.data_parallel_rank, self.parallel_config.data_parallel_rank,
...@@ -389,6 +395,12 @@ class Scheduler(SchedulerInterface): ...@@ -389,6 +395,12 @@ class Scheduler(SchedulerInterface):
self.connector.get_num_new_matched_tokens( self.connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens)) request, num_new_local_computed_tokens))
if self.lmcache_connector is not None:
num_external_computed_tokens, load_kv_async = (
self.lmcache_connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens))
# Total computed tokens (local + external). # Total computed tokens (local + external).
num_computed_tokens = (num_new_local_computed_tokens + num_computed_tokens = (num_new_local_computed_tokens +
num_external_computed_tokens) num_external_computed_tokens)
...@@ -463,6 +475,13 @@ class Scheduler(SchedulerInterface): ...@@ -463,6 +475,13 @@ class Scheduler(SchedulerInterface):
num_external_computed_tokens, num_external_computed_tokens,
) )
if self.lmcache_connector is not None:
self.lmcache_connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,
num_external_computed_tokens,
)
# Request was already popped from self.waiting # Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None. # unless it was re-added above due to new_blocks being None.
request = self.waiting.pop_request() request = self.waiting.pop_request()
...@@ -578,6 +597,10 @@ class Scheduler(SchedulerInterface): ...@@ -578,6 +597,10 @@ class Scheduler(SchedulerInterface):
meta = self.connector.build_connector_meta(scheduler_output) meta = self.connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta scheduler_output.kv_connector_metadata = meta
if self.lmcache_connector is not None:
meta = self.lmcache_connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta
events = self.kv_cache_manager.take_events() events = self.kv_cache_manager.take_events()
if events: if events:
batch = KVEventBatch(ts=time.time(), events=events) batch = KVEventBatch(ts=time.time(), events=events)
...@@ -676,6 +699,11 @@ class Scheduler(SchedulerInterface): ...@@ -676,6 +699,11 @@ class Scheduler(SchedulerInterface):
self.connector.get_num_new_matched_tokens( self.connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens)) request, num_new_local_computed_tokens))
if self.lmcache_connector is not None:
num_external_computed_tokens, load_kv_async = (
self.lmcache_connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens))
# Total computed tokens (local + external). # Total computed tokens (local + external).
num_computed_tokens = (num_new_local_computed_tokens + num_computed_tokens = (num_new_local_computed_tokens +
num_external_computed_tokens) num_external_computed_tokens)
...@@ -750,6 +778,13 @@ class Scheduler(SchedulerInterface): ...@@ -750,6 +778,13 @@ class Scheduler(SchedulerInterface):
num_external_computed_tokens, num_external_computed_tokens,
) )
if self.lmcache_connector is not None:
self.lmcache_connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,
num_external_computed_tokens,
)
# Request was already popped from self.waiting # Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None. # unless it was re-added above due to new_blocks being None.
request = self.waiting.pop_request() request = self.waiting.pop_request()
...@@ -994,6 +1029,10 @@ class Scheduler(SchedulerInterface): ...@@ -994,6 +1029,10 @@ class Scheduler(SchedulerInterface):
meta = self.connector.build_connector_meta(scheduler_output) meta = self.connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta scheduler_output.kv_connector_metadata = meta
if self.lmcache_connector is not None:
meta = self.lmcache_connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta
events = self.kv_cache_manager.take_events() events = self.kv_cache_manager.take_events()
if events: if events:
batch = KVEventBatch(ts=time.time(), events=events) batch = KVEventBatch(ts=time.time(), events=events)
......
...@@ -23,6 +23,7 @@ from vllm.config import (CompilationLevel, VllmConfig, ...@@ -23,6 +23,7 @@ from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config) get_layers_from_vllm_config)
from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import (get_kv_transfer_group, from vllm.distributed.kv_transfer import (get_kv_transfer_group,
get_lmcache_connector,
has_kv_transfer_group) has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
...@@ -73,6 +74,9 @@ from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute ...@@ -73,6 +74,9 @@ from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute
from ..sample.logits_processor import LogitsProcessorManager from ..sample.logits_processor import LogitsProcessorManager
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders) sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import LMCacheConnectorV1
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole
if TYPE_CHECKING: if TYPE_CHECKING:
import xgrammar as xgr import xgrammar as xgr
...@@ -1573,6 +1577,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1573,6 +1577,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if has_kv_transfer_group(): if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata() get_kv_transfer_group().clear_connector_metadata()
get_lmcache_connector().clear_connector_metadata()
self.eplb_step() self.eplb_step()
return ModelRunnerOutput( return ModelRunnerOutput(
...@@ -1736,11 +1742,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1736,11 +1742,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Do this here to save a collective_rpc. # Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context()) kv_connector.start_load_kv(get_forward_context())
lmcache_connector = get_lmcache_connector()
lmcache_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata)
lmcache_connector.start_load_kv(get_forward_context())
@staticmethod @staticmethod
def maybe_wait_for_kv_save() -> None: def maybe_wait_for_kv_save() -> None:
if has_kv_transfer_group(): if has_kv_transfer_group():
get_kv_transfer_group().wait_for_save() get_kv_transfer_group().wait_for_save()
lmcache_connector = get_lmcache_connector()
lmcache_connector.wait_for_save()
@staticmethod @staticmethod
def get_finished_kv_transfers( def get_finished_kv_transfers(
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
...@@ -2690,6 +2704,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2690,6 +2704,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if has_kv_transfer_group(): if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches) get_kv_transfer_group().register_kv_caches(kv_caches)
get_lmcache_connector().register_kv_caches(kv_caches)
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
""" """
Generates the KVCacheSpec by parsing the kv cache format from each Generates the KVCacheSpec by parsing the kv cache format from each
......
...@@ -3,7 +3,7 @@ from typing import Any, Optional, Union ...@@ -3,7 +3,7 @@ from typing import Any, Optional, Union
import torch import torch
import numpy as np import numpy as np
from vllm import envs from vllm import envs
from vllm.distributed.kv_transfer.kv_transfer_state import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.kv_transfer.kv_transfer_state import get_kv_transfer_group, has_kv_transfer_group, get_lmcache_connector
from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -727,6 +727,8 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -727,6 +727,8 @@ class V1ZeroModelRunner(GPUModelRunner):
if has_kv_transfer_group(): if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata() get_kv_transfer_group().clear_connector_metadata()
get_lmcache_connector().clear_connector_metadata()
self.eplb_step() self.eplb_step()
model_output = ZeroV1ModelRunnerOutput( model_output = ZeroV1ModelRunnerOutput(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment