Commit c6b7a44b authored by yangshj1's avatar yangshj1
Browse files

add env

parent 5db0e637
......@@ -380,9 +380,9 @@ def wait_for_kv_layer_from_connector(layer_name: str):
assert isinstance(attn_metadata, dict)
connector.wait_for_layer_load(layer_name)
get_lmcache_connector().wait_for_layer_load(layer_name)
lmcache_connector = get_lmcache_connector()
if lmcache_connector is not None:
lmcache_connector.wait_for_layer_load(layer_name)
def maybe_save_kv_layer_to_connector(
layer_name: str,
......@@ -400,10 +400,12 @@ def maybe_save_kv_layer_to_connector(
assert isinstance(attn_metadata, dict)
connector.save_kv_layer(layer_name, kv_cache_layer,
attn_metadata[layer_name])
get_lmcache_connector().save_kv_layer(layer_name, kv_cache_layer,
lmcache_connector = get_lmcache_connector()
if lmcache_connector is not None:
lmcache_connector.save_kv_layer(layer_name, kv_cache_layer,
attn_metadata[layer_name])
def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
......
......@@ -25,8 +25,6 @@ def get_kv_transfer_group() -> KVConnectorBaseType:
return _KV_CONNECTOR_AGENT
def get_lmcache_connector() -> KVConnectorBaseType:
assert _KV_LMCACHE_CONNECTOR_AGENT is not None, (
"LM cache transfer parallel group is not initialized")
return _KV_LMCACHE_CONNECTOR_AGENT
......@@ -64,7 +62,7 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
global _KV_CONNECTOR_AGENT
global _KV_LMCACHE_CONNECTOR_AGENT
if _KV_LMCACHE_CONNECTOR_AGENT is None:
if envs.VLLM_LMCACHE_ENABLE and _KV_LMCACHE_CONNECTOR_AGENT is None:
lmcache_config = copy.deepcopy(vllm_config)
from vllm.config import KVTransferConfig
lmcache_config.kv_transfer_config = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both")
......
......@@ -169,6 +169,7 @@ if TYPE_CHECKING:
USE_FUSED_RMS_QUANT: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_LMCACHE_ENABLE: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1114,6 +1115,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"USE_FUSED_SILU_MUL_QUANT":
lambda: (os.getenv('USE_FUSED_SILU_MUL_QUANT', '0').lower() in
("true", "1")),
"VLLM_LMCACHE_ENABLE":
lambda: (os.environ.get("VLLM_LMCACHE_ENABLE", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -10,6 +10,7 @@ from collections import defaultdict
from collections.abc import Iterable
from typing import Any, Optional, Union
import vllm.envs as envs
from vllm.config import KVTransferConfig,VllmConfig
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.factory import (
......@@ -88,11 +89,12 @@ class Scheduler(SchedulerInterface):
self.connector = KVConnectorFactory.create_connector_v1(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
lmcache_config = copy.deepcopy(self.vllm_config)
lmcache_config.kv_transfer_config = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both")
lmcache_config.kv_transfer_config.engine_id = "ed9e943a-e455-4ed6-b88c-09ae6263f0c9"
self.lmcache_connector: LMCacheConnectorV1 = LMCacheConnectorV1(
lmcache_config, role=KVConnectorRole.SCHEDULER)
if envs.VLLM_LMCACHE_ENABLE:
lmcache_config = copy.deepcopy(self.vllm_config)
lmcache_config.kv_transfer_config = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both")
lmcache_config.kv_transfer_config.engine_id = "ed9e943a-e455-4ed6-b88c-09ae6263f0c9"
self.lmcache_connector: LMCacheConnectorV1 = LMCacheConnectorV1(
lmcache_config, role=KVConnectorRole.SCHEDULER)
self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config,
......
......@@ -1577,7 +1577,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
get_lmcache_connector().clear_connector_metadata()
lmcache_connector = get_lmcache_connector()
if lmcache_connector is not None:
lmcache_connector.clear_connector_metadata()
self.eplb_step()
......@@ -1743,9 +1745,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_connector.start_load_kv(get_forward_context())
lmcache_connector = get_lmcache_connector()
lmcache_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata)
lmcache_connector.start_load_kv(get_forward_context())
if lmcache_connector is not None:
lmcache_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata)
lmcache_connector.start_load_kv(get_forward_context())
@staticmethod
def maybe_wait_for_kv_save() -> None:
......@@ -1753,7 +1756,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
get_kv_transfer_group().wait_for_save()
lmcache_connector = get_lmcache_connector()
lmcache_connector.wait_for_save()
if lmcache_connector is not None:
lmcache_connector.wait_for_save()
@staticmethod
def get_finished_kv_transfers(
......@@ -2704,7 +2708,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
get_lmcache_connector().register_kv_caches(kv_caches)
lmcache_connector = get_lmcache_connector()
if lmcache_connector is not None:
lmcache_connector.register_kv_caches(kv_caches)
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""
......
......@@ -727,7 +727,9 @@ class V1ZeroModelRunner(GPUModelRunner):
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
get_lmcache_connector().clear_connector_metadata()
lmcache_connector = get_lmcache_connector()
if lmcache_connector is not None:
lmcache_connector.clear_connector_metadata()
self.eplb_step()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment