Commit 8c10743e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-lmcache-pd' into 'v0.9.2-dev'

lmcache support pd

See merge request dcutoolkit/deeplearing/vllm!219
parents 8c0143db 4bb8c6af
......@@ -8,11 +8,13 @@ import torch.nn as nn
import torch.nn.functional as F
from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import tbo_maybe_save_kv_layer_to_connector
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole
import vllm.envs as envs
from vllm.attention import AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
get_lmcache_connector,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context
......@@ -23,7 +25,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import _Backend, current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.utils import validate_kv_sharing_target
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import LMCacheConnectorV1
class Attention(nn.Module):
"""Attention layer.
......@@ -378,6 +380,9 @@ def wait_for_kv_layer_from_connector(layer_name: str):
assert isinstance(attn_metadata, dict)
connector.wait_for_layer_load(layer_name)
lmcache_connector = get_lmcache_connector()
if lmcache_connector is not None:
lmcache_connector.wait_for_layer_load(layer_name)
def maybe_save_kv_layer_to_connector(
layer_name: str,
......@@ -395,7 +400,11 @@ def maybe_save_kv_layer_to_connector(
assert isinstance(attn_metadata, dict)
connector.save_kv_layer(layer_name, kv_cache_layer,
attn_metadata[layer_name])
lmcache_connector = get_lmcache_connector()
if lmcache_connector is not None:
lmcache_connector.save_kv_layer(layer_name, kv_cache_layer,
attn_metadata[layer_name])
def unified_attention(
query: torch.Tensor,
......
......@@ -3,10 +3,10 @@
from vllm.distributed.kv_transfer.kv_transfer_state import (
KVConnectorBaseType, ensure_kv_transfer_initialized, get_kv_transfer_group,
has_kv_transfer_group, is_v1_kv_transfer_group)
get_lmcache_connector, has_kv_transfer_group, is_v1_kv_transfer_group)
__all__ = [
"get_kv_transfer_group", "has_kv_transfer_group",
"get_kv_transfer_group", "get_lmcache_connector", "has_kv_transfer_group",
"is_v1_kv_transfer_group", "ensure_kv_transfer_initialized",
"KVConnectorBaseType"
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Optional
import copy
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
KVConnectorRole)
from vllm.distributed.parallel_state import get_world_group
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import LMCacheConnectorV1
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole
if TYPE_CHECKING:
from vllm.config import VllmConfig
_KV_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None
_KV_LMCACHE_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None
def get_kv_transfer_group() -> KVConnectorBaseType:
assert _KV_CONNECTOR_AGENT is not None, (
"disaggregated KV cache transfer parallel group is not initialized")
return _KV_CONNECTOR_AGENT
def get_lmcache_connector() -> KVConnectorBaseType:
return _KV_LMCACHE_CONNECTOR_AGENT
def has_kv_transfer_group() -> bool:
return _KV_CONNECTOR_AGENT is not None
......@@ -54,6 +60,16 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
"""
global _KV_CONNECTOR_AGENT
global _KV_LMCACHE_CONNECTOR_AGENT
if envs.VLLM_LMCACHE_ENABLE and _KV_LMCACHE_CONNECTOR_AGENT is None:
lmcache_config = copy.deepcopy(vllm_config)
from vllm.config import KVTransferConfig
lmcache_config.kv_transfer_config = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both")
lmcache_config.kv_transfer_config.engine_id = "lmcache_engine_id"
lmcache_connector: LMCacheConnectorV1 = LMCacheConnectorV1(lmcache_config, role=KVConnectorRole.WORKER)
_KV_LMCACHE_CONNECTOR_AGENT = lmcache_connector
if vllm_config.kv_transfer_config is None:
return
......
......@@ -169,6 +169,7 @@ if TYPE_CHECKING:
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_LMCACHE_ENABLE: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1114,6 +1115,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"USE_FUSED_SILU_MUL_QUANT":
lambda: (os.getenv('USE_FUSED_SILU_MUL_QUANT', '0').lower() in
("true", "1")),
"VLLM_LMCACHE_ENABLE":
lambda: (os.environ.get("VLLM_LMCACHE_ENABLE", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -155,3 +155,4 @@ class SchedulerOutput:
# KV Cache Connector metadata.
kv_connector_metadata: Optional[KVConnectorMetadata] = None
kv_lmcache_connector_metadata: Optional[KVConnectorMetadata] = None
......@@ -5,11 +5,13 @@ from __future__ import annotations
import itertools
import time
import copy
from collections import defaultdict
from collections.abc import Iterable
from typing import Any, Optional, Union
from vllm.config import VllmConfig
import vllm.envs as envs
from vllm.config import KVTransferConfig,VllmConfig
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
......@@ -34,6 +36,7 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import LMCacheConnectorV1
logger = init_logger(__name__)
......@@ -86,6 +89,14 @@ class Scheduler(SchedulerInterface):
self.connector = KVConnectorFactory.create_connector_v1(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
self.lmcache_connector = None
if envs.VLLM_LMCACHE_ENABLE:
lmcache_config = copy.deepcopy(self.vllm_config)
lmcache_config.kv_transfer_config = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both")
lmcache_config.kv_transfer_config.engine_id = "lmcache_engine_id"
self.lmcache_connector: LMCacheConnectorV1 = LMCacheConnectorV1(
lmcache_config, role=KVConnectorRole.SCHEDULER)
self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config,
self.parallel_config.data_parallel_rank,
......@@ -389,6 +400,11 @@ class Scheduler(SchedulerInterface):
self.connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens))
if self.lmcache_connector is not None:
num_external_computed_tokens, load_kv_async = (
self.lmcache_connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens))
# Total computed tokens (local + external).
num_computed_tokens = (num_new_local_computed_tokens +
num_external_computed_tokens)
......@@ -463,6 +479,13 @@ class Scheduler(SchedulerInterface):
num_external_computed_tokens,
)
if self.lmcache_connector is not None:
self.lmcache_connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,
num_external_computed_tokens,
)
# Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None.
request = self.waiting.pop_request()
......@@ -578,6 +601,10 @@ class Scheduler(SchedulerInterface):
meta = self.connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta
if self.lmcache_connector is not None:
meta = self.lmcache_connector.build_connector_meta(scheduler_output)
scheduler_output.kv_lmcache_connector_metadata = meta
events = self.kv_cache_manager.take_events()
if events:
batch = KVEventBatch(ts=time.time(), events=events)
......@@ -676,6 +703,11 @@ class Scheduler(SchedulerInterface):
self.connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens))
if self.lmcache_connector is not None:
num_external_computed_tokens, load_kv_async = (
self.lmcache_connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens))
# Total computed tokens (local + external).
num_computed_tokens = (num_new_local_computed_tokens +
num_external_computed_tokens)
......@@ -750,6 +782,13 @@ class Scheduler(SchedulerInterface):
num_external_computed_tokens,
)
if self.lmcache_connector is not None:
self.lmcache_connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,
num_external_computed_tokens,
)
# Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None.
request = self.waiting.pop_request()
......@@ -994,6 +1033,10 @@ class Scheduler(SchedulerInterface):
meta = self.connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta
if self.lmcache_connector is not None:
meta = self.lmcache_connector.build_connector_meta(scheduler_output)
scheduler_output.kv_lmcache_connector_metadata = meta
events = self.kv_cache_manager.take_events()
if events:
batch = KVEventBatch(ts=time.time(), events=events)
......
......@@ -864,7 +864,7 @@ class DPEngineCoreProc(EngineCoreProc):
vllm_config.kv_transfer_config.engine_id = (
f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
)
logger.debug("Setting kv_transfer_config.engine_id to %s",
logger.info("Setting kv_transfer_config.engine_id to %s",
vllm_config.kv_transfer_config.engine_id)
from vllm.platforms import current_platform
......
......@@ -23,6 +23,7 @@ from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
get_lmcache_connector,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import (
......@@ -73,6 +74,9 @@ from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute
from ..sample.logits_processor import LogitsProcessorManager
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import LMCacheConnectorV1
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole
if TYPE_CHECKING:
import xgrammar as xgr
......@@ -1573,6 +1577,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
lmcache_connector = get_lmcache_connector()
if lmcache_connector is not None:
lmcache_connector.clear_connector_metadata()
self.eplb_step()
return ModelRunnerOutput(
......@@ -1736,11 +1744,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context())
lmcache_connector = get_lmcache_connector()
if lmcache_connector is not None:
lmcache_connector.bind_connector_metadata(
scheduler_output.kv_lmcache_connector_metadata)
lmcache_connector.start_load_kv(get_forward_context())
@staticmethod
def maybe_wait_for_kv_save() -> None:
if has_kv_transfer_group():
get_kv_transfer_group().wait_for_save()
lmcache_connector = get_lmcache_connector()
if lmcache_connector is not None:
lmcache_connector.wait_for_save()
@staticmethod
def get_finished_kv_transfers(
scheduler_output: "SchedulerOutput",
......@@ -2690,6 +2708,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
lmcache_connector = get_lmcache_connector()
if lmcache_connector is not None:
lmcache_connector.register_kv_caches(kv_caches)
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
......
......@@ -3,7 +3,7 @@ from typing import Any, Optional, Union
import torch
import numpy as np
from vllm import envs
from vllm.distributed.kv_transfer.kv_transfer_state import get_kv_transfer_group, has_kv_transfer_group
from vllm.distributed.kv_transfer.kv_transfer_state import get_kv_transfer_group, has_kv_transfer_group, get_lmcache_connector
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.forward_context import set_forward_context
from vllm.sequence import IntermediateTensors
......@@ -727,6 +727,10 @@ class V1ZeroModelRunner(GPUModelRunner):
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
lmcache_connector = get_lmcache_connector()
if lmcache_connector is not None:
lmcache_connector.clear_connector_metadata()
self.eplb_step()
model_output = ZeroV1ModelRunnerOutput(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment