Commit f4fd3a96 authored by yangshj1's avatar yangshj1
Browse files

add pd lmcache

parent 4a80b456
......@@ -7,11 +7,13 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole
import vllm.envs as envs
from vllm.attention import AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
get_lmcache_connector,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context
......@@ -22,7 +24,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import _Backend, current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.utils import validate_kv_sharing_target
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import LMCacheConnectorV1
class Attention(nn.Module):
"""Attention layer.
......@@ -376,6 +378,9 @@ def wait_for_kv_layer_from_connector(layer_name: str):
assert isinstance(attn_metadata, dict)
connector.wait_for_layer_load(layer_name)
get_lmcache_connector().wait_for_layer_load(layer_name)
def maybe_save_kv_layer_to_connector(
layer_name: str,
......@@ -393,6 +398,9 @@ def maybe_save_kv_layer_to_connector(
assert isinstance(attn_metadata, dict)
connector.save_kv_layer(layer_name, kv_cache_layer,
attn_metadata[layer_name])
get_lmcache_connector().save_kv_layer(layer_name, kv_cache_layer,
attn_metadata[layer_name])
def unified_attention(
......
......@@ -3,24 +3,32 @@
from typing import TYPE_CHECKING, Optional
from vllm import envs
from vllm.config import get_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
KVConnectorRole)
from vllm.distributed.parallel_state import get_world_group
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import LMCacheConnectorV1
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole
if TYPE_CHECKING:
from vllm.config import VllmConfig
_KV_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None
_KV_LMCACHE_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None
def get_kv_transfer_group() -> KVConnectorBaseType:
assert _KV_CONNECTOR_AGENT is not None, (
"disaggregated KV cache transfer parallel group is not initialized")
return _KV_CONNECTOR_AGENT
def get_lmcache_connector() -> LMCacheConnectorV1:
assert _KV_LMCACHE_CONNECTOR_AGENT is not None, (
"LM cache transfer parallel group is not initialized")
return _KV_LMCACHE_CONNECTOR_AGENT
def has_kv_transfer_group() -> bool:
return _KV_CONNECTOR_AGENT is not None
......@@ -54,6 +62,7 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
"""
global _KV_CONNECTOR_AGENT
global _KV_LMCACHE_CONNECTOR_AGENT
if vllm_config.kv_transfer_config is None:
return
......@@ -63,6 +72,13 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
if envs.VLLM_USE_V1:
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1(
config=vllm_config, role=KVConnectorRole.WORKER)
lmcache_config = vllm_config
lmcache_config.kv_transfer_config.kv_role = "kv_both"
lmcache_connector: LMCacheConnectorV1 = LMCacheConnectorV1(
lmcache_config, role=KVConnectorRole.WORKER)
_KV_LMCACHE_CONNECTOR_AGENT = lmcache_connector
else:
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0(
rank=get_world_group().rank,
......
......@@ -34,6 +34,7 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import LMCacheConnectorV1
logger = init_logger(__name__)
......@@ -86,6 +87,11 @@ class Scheduler(SchedulerInterface):
self.connector = KVConnectorFactory.create_connector_v1(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
lmcache_config = self.vllm_config
lmcache_config.kv_transfer_config.kv_role = "kv_both"
self.lmcache_connector: LMCacheConnectorV1 = LMCacheConnectorV1(
lmcache_config, role=KVConnectorRole.SCHEDULER)
self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config,
self.parallel_config.data_parallel_rank,
......@@ -389,6 +395,12 @@ class Scheduler(SchedulerInterface):
self.connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens))
if self.lmcache_connector is not None:
num_external_computed_tokens, load_kv_async = (
self.lmcache_connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens))
# Total computed tokens (local + external).
num_computed_tokens = (num_new_local_computed_tokens +
num_external_computed_tokens)
......@@ -463,6 +475,13 @@ class Scheduler(SchedulerInterface):
num_external_computed_tokens,
)
if self.lmcache_connector is not None:
self.lmcache_connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,
num_external_computed_tokens,
)
# Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None.
request = self.waiting.pop_request()
......@@ -578,6 +597,10 @@ class Scheduler(SchedulerInterface):
meta = self.connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta
if self.lmcache_connector is not None:
meta = self.lmcache_connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta
events = self.kv_cache_manager.take_events()
if events:
batch = KVEventBatch(ts=time.time(), events=events)
......@@ -675,6 +698,11 @@ class Scheduler(SchedulerInterface):
num_external_computed_tokens, load_kv_async = (
self.connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens))
if self.lmcache_connector is not None:
num_external_computed_tokens, load_kv_async = (
self.lmcache_connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens))
# Total computed tokens (local + external).
num_computed_tokens = (num_new_local_computed_tokens +
......@@ -750,6 +778,13 @@ class Scheduler(SchedulerInterface):
num_external_computed_tokens,
)
if self.lmcache_connector is not None:
self.lmcache_connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,
num_external_computed_tokens,
)
# Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None.
request = self.waiting.pop_request()
......@@ -994,6 +1029,10 @@ class Scheduler(SchedulerInterface):
meta = self.connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta
if self.lmcache_connector is not None:
meta = self.lmcache_connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta
events = self.kv_cache_manager.take_events()
if events:
batch = KVEventBatch(ts=time.time(), events=events)
......
......@@ -23,6 +23,7 @@ from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
get_lmcache_connector,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import (
......@@ -73,6 +74,9 @@ from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute
from ..sample.logits_processor import LogitsProcessorManager
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import LMCacheConnectorV1
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole
if TYPE_CHECKING:
import xgrammar as xgr
......@@ -1573,6 +1577,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
get_lmcache_connector().clear_connector_metadata()
self.eplb_step()
return ModelRunnerOutput(
......@@ -1736,11 +1742,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context())
lmcache_connector = get_lmcache_connector()
lmcache_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata)
lmcache_connector.start_load_kv(get_forward_context())
@staticmethod
def maybe_wait_for_kv_save() -> None:
if has_kv_transfer_group():
get_kv_transfer_group().wait_for_save()
lmcache_connector = get_lmcache_connector()
lmcache_connector.wait_for_save()
@staticmethod
def get_finished_kv_transfers(
scheduler_output: "SchedulerOutput",
......@@ -2690,6 +2704,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
get_lmcache_connector().register_kv_caches(kv_caches)
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
......
......@@ -3,7 +3,7 @@ from typing import Any, Optional, Union
import torch
import numpy as np
from vllm import envs
from vllm.distributed.kv_transfer.kv_transfer_state import get_kv_transfer_group, has_kv_transfer_group
from vllm.distributed.kv_transfer.kv_transfer_state import get_kv_transfer_group, has_kv_transfer_group, get_lmcache_connector
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.forward_context import set_forward_context
from vllm.sequence import IntermediateTensors
......@@ -727,6 +727,8 @@ class V1ZeroModelRunner(GPUModelRunner):
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
get_lmcache_connector().clear_connector_metadata()
self.eplb_step()
model_output = ZeroV1ModelRunnerOutput(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment