Commit c6b7a44b authored by yangshj1's avatar yangshj1
Browse files

add env

parent 5db0e637
...@@ -380,9 +380,9 @@ def wait_for_kv_layer_from_connector(layer_name: str): ...@@ -380,9 +380,9 @@ def wait_for_kv_layer_from_connector(layer_name: str):
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
connector.wait_for_layer_load(layer_name) connector.wait_for_layer_load(layer_name)
get_lmcache_connector().wait_for_layer_load(layer_name) lmcache_connector = get_lmcache_connector()
if lmcache_connector is not None:
lmcache_connector.wait_for_layer_load(layer_name)
def maybe_save_kv_layer_to_connector( def maybe_save_kv_layer_to_connector(
layer_name: str, layer_name: str,
...@@ -400,9 +400,11 @@ def maybe_save_kv_layer_to_connector( ...@@ -400,9 +400,11 @@ def maybe_save_kv_layer_to_connector(
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
connector.save_kv_layer(layer_name, kv_cache_layer, connector.save_kv_layer(layer_name, kv_cache_layer,
attn_metadata[layer_name]) attn_metadata[layer_name])
get_lmcache_connector().save_kv_layer(layer_name, kv_cache_layer,
attn_metadata[layer_name])
lmcache_connector = get_lmcache_connector()
if lmcache_connector is not None:
lmcache_connector.save_kv_layer(layer_name, kv_cache_layer,
attn_metadata[layer_name])
def unified_attention( def unified_attention(
query: torch.Tensor, query: torch.Tensor,
......
...@@ -25,8 +25,6 @@ def get_kv_transfer_group() -> KVConnectorBaseType: ...@@ -25,8 +25,6 @@ def get_kv_transfer_group() -> KVConnectorBaseType:
return _KV_CONNECTOR_AGENT return _KV_CONNECTOR_AGENT
def get_lmcache_connector() -> KVConnectorBaseType: def get_lmcache_connector() -> KVConnectorBaseType:
assert _KV_LMCACHE_CONNECTOR_AGENT is not None, (
"LM cache transfer parallel group is not initialized")
return _KV_LMCACHE_CONNECTOR_AGENT return _KV_LMCACHE_CONNECTOR_AGENT
...@@ -64,7 +62,7 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: ...@@ -64,7 +62,7 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
global _KV_CONNECTOR_AGENT global _KV_CONNECTOR_AGENT
global _KV_LMCACHE_CONNECTOR_AGENT global _KV_LMCACHE_CONNECTOR_AGENT
if _KV_LMCACHE_CONNECTOR_AGENT is None: if envs.VLLM_LMCACHE_ENABLE and _KV_LMCACHE_CONNECTOR_AGENT is None:
lmcache_config = copy.deepcopy(vllm_config) lmcache_config = copy.deepcopy(vllm_config)
from vllm.config import KVTransferConfig from vllm.config import KVTransferConfig
lmcache_config.kv_transfer_config = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both") lmcache_config.kv_transfer_config = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both")
......
...@@ -169,6 +169,7 @@ if TYPE_CHECKING: ...@@ -169,6 +169,7 @@ if TYPE_CHECKING:
USE_FUSED_RMS_QUANT: bool = False USE_FUSED_RMS_QUANT: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_LMCACHE_ENABLE: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1114,6 +1115,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1114,6 +1115,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"USE_FUSED_SILU_MUL_QUANT": "USE_FUSED_SILU_MUL_QUANT":
lambda: (os.getenv('USE_FUSED_SILU_MUL_QUANT', '0').lower() in lambda: (os.getenv('USE_FUSED_SILU_MUL_QUANT', '0').lower() in
("true", "1")), ("true", "1")),
"VLLM_LMCACHE_ENABLE":
lambda: (os.environ.get("VLLM_LMCACHE_ENABLE", "False").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -10,6 +10,7 @@ from collections import defaultdict ...@@ -10,6 +10,7 @@ from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any, Optional, Union from typing import Any, Optional, Union
import vllm.envs as envs
from vllm.config import KVTransferConfig,VllmConfig from vllm.config import KVTransferConfig,VllmConfig
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.factory import ( from vllm.distributed.kv_transfer.kv_connector.factory import (
...@@ -88,6 +89,7 @@ class Scheduler(SchedulerInterface): ...@@ -88,6 +89,7 @@ class Scheduler(SchedulerInterface):
self.connector = KVConnectorFactory.create_connector_v1( self.connector = KVConnectorFactory.create_connector_v1(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER) config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
if envs.VLLM_LMCACHE_ENABLE:
lmcache_config = copy.deepcopy(self.vllm_config) lmcache_config = copy.deepcopy(self.vllm_config)
lmcache_config.kv_transfer_config = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both") lmcache_config.kv_transfer_config = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both")
lmcache_config.kv_transfer_config.engine_id = "ed9e943a-e455-4ed6-b88c-09ae6263f0c9" lmcache_config.kv_transfer_config.engine_id = "ed9e943a-e455-4ed6-b88c-09ae6263f0c9"
......
...@@ -1577,7 +1577,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1577,7 +1577,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if has_kv_transfer_group(): if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata() get_kv_transfer_group().clear_connector_metadata()
get_lmcache_connector().clear_connector_metadata() lmcache_connector = get_lmcache_connector()
if lmcache_connector is not None:
lmcache_connector.clear_connector_metadata()
self.eplb_step() self.eplb_step()
...@@ -1743,6 +1745,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1743,6 +1745,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_connector.start_load_kv(get_forward_context()) kv_connector.start_load_kv(get_forward_context())
lmcache_connector = get_lmcache_connector() lmcache_connector = get_lmcache_connector()
if lmcache_connector is not None:
lmcache_connector.bind_connector_metadata( lmcache_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata) scheduler_output.kv_connector_metadata)
lmcache_connector.start_load_kv(get_forward_context()) lmcache_connector.start_load_kv(get_forward_context())
...@@ -1753,6 +1756,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1753,6 +1756,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
get_kv_transfer_group().wait_for_save() get_kv_transfer_group().wait_for_save()
lmcache_connector = get_lmcache_connector() lmcache_connector = get_lmcache_connector()
if lmcache_connector is not None:
lmcache_connector.wait_for_save() lmcache_connector.wait_for_save()
@staticmethod @staticmethod
...@@ -2704,7 +2708,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2704,7 +2708,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if has_kv_transfer_group(): if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches) get_kv_transfer_group().register_kv_caches(kv_caches)
get_lmcache_connector().register_kv_caches(kv_caches) lmcache_connector = get_lmcache_connector()
if lmcache_connector is not None:
lmcache_connector.register_kv_caches(kv_caches)
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
""" """
......
...@@ -727,7 +727,9 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -727,7 +727,9 @@ class V1ZeroModelRunner(GPUModelRunner):
if has_kv_transfer_group(): if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata() get_kv_transfer_group().clear_connector_metadata()
get_lmcache_connector().clear_connector_metadata() lmcache_connector = get_lmcache_connector()
if lmcache_connector is not None:
lmcache_connector.clear_connector_metadata()
self.eplb_step() self.eplb_step()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment