Commit f6fcc8ff authored by xuxz's avatar xuxz
Browse files

add engine id

parent 9129c728
...@@ -398,7 +398,6 @@ def maybe_save_kv_layer_to_connector( ...@@ -398,7 +398,6 @@ def maybe_save_kv_layer_to_connector(
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
connector.save_kv_layer(layer_name, kv_cache_layer, connector.save_kv_layer(layer_name, kv_cache_layer,
attn_metadata[layer_name]) attn_metadata[layer_name])
get_lmcache_connector().save_kv_layer(layer_name, kv_cache_layer, get_lmcache_connector().save_kv_layer(layer_name, kv_cache_layer,
attn_metadata[layer_name]) attn_metadata[layer_name])
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
import copy
from vllm import envs from vllm import envs
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
...@@ -64,6 +64,15 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: ...@@ -64,6 +64,15 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
global _KV_CONNECTOR_AGENT global _KV_CONNECTOR_AGENT
global _KV_LMCACHE_CONNECTOR_AGENT global _KV_LMCACHE_CONNECTOR_AGENT
if _KV_LMCACHE_CONNECTOR_AGENT is None:
lmcache_config = copy.deepcopy(vllm_config)
from vllm.config import KVTransferConfig
lmcache_config.kv_transfer_config = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both")
lmcache_config.kv_transfer_config.engine_id = "ed9e943a-e455-4ed6-b88c-09ae6263f0c9"
lmcache_connector: LMCacheConnectorV1 = LMCacheConnectorV1(lmcache_config, role=KVConnectorRole.WORKER)
_KV_LMCACHE_CONNECTOR_AGENT = lmcache_connector
if vllm_config.kv_transfer_config is None: if vllm_config.kv_transfer_config is None:
return return
...@@ -72,13 +81,6 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: ...@@ -72,13 +81,6 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1( _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1(
config=vllm_config, role=KVConnectorRole.WORKER) config=vllm_config, role=KVConnectorRole.WORKER)
lmcache_config = vllm_config
lmcache_config.kv_transfer_config.kv_role = "kv_both"
lmcache_connector: LMCacheConnectorV1 = LMCacheConnectorV1(
lmcache_config, role=KVConnectorRole.WORKER)
_KV_LMCACHE_CONNECTOR_AGENT = lmcache_connector
else: else:
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0( _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0(
rank=get_world_group().rank, rank=get_world_group().rank,
......
...@@ -5,11 +5,12 @@ from __future__ import annotations ...@@ -5,11 +5,12 @@ from __future__ import annotations
import itertools import itertools
import time import time
import copy
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any, Optional, Union from typing import Any, Optional, Union
from vllm.config import VllmConfig from vllm.config import KVTransferConfig,VllmConfig
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.factory import ( from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory) KVConnectorFactory)
...@@ -87,8 +88,9 @@ class Scheduler(SchedulerInterface): ...@@ -87,8 +88,9 @@ class Scheduler(SchedulerInterface):
self.connector = KVConnectorFactory.create_connector_v1( self.connector = KVConnectorFactory.create_connector_v1(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER) config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
lmcache_config = self.vllm_config lmcache_config = copy.deepcopy(self.vllm_config)
lmcache_config.kv_transfer_config.kv_role = "kv_both" lmcache_config.kv_transfer_config = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both")
lmcache_config.kv_transfer_config.engine_id = "ed9e943a-e455-4ed6-b88c-09ae6263f0c9"
self.lmcache_connector: LMCacheConnectorV1 = LMCacheConnectorV1( self.lmcache_connector: LMCacheConnectorV1 = LMCacheConnectorV1(
lmcache_config, role=KVConnectorRole.SCHEDULER) lmcache_config, role=KVConnectorRole.SCHEDULER)
...@@ -400,7 +402,6 @@ class Scheduler(SchedulerInterface): ...@@ -400,7 +402,6 @@ class Scheduler(SchedulerInterface):
self.lmcache_connector.get_num_new_matched_tokens( self.lmcache_connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens)) request, num_new_local_computed_tokens))
# Total computed tokens (local + external). # Total computed tokens (local + external).
num_computed_tokens = (num_new_local_computed_tokens + num_computed_tokens = (num_new_local_computed_tokens +
num_external_computed_tokens) num_external_computed_tokens)
...@@ -698,7 +699,7 @@ class Scheduler(SchedulerInterface): ...@@ -698,7 +699,7 @@ class Scheduler(SchedulerInterface):
num_external_computed_tokens, load_kv_async = ( num_external_computed_tokens, load_kv_async = (
self.connector.get_num_new_matched_tokens( self.connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens)) request, num_new_local_computed_tokens))
if self.lmcache_connector is not None: if self.lmcache_connector is not None:
num_external_computed_tokens, load_kv_async = ( num_external_computed_tokens, load_kv_async = (
self.lmcache_connector.get_num_new_matched_tokens( self.lmcache_connector.get_num_new_matched_tokens(
......
...@@ -864,7 +864,7 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -864,7 +864,7 @@ class DPEngineCoreProc(EngineCoreProc):
vllm_config.kv_transfer_config.engine_id = ( vllm_config.kv_transfer_config.engine_id = (
f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}" f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
) )
logger.debug("Setting kv_transfer_config.engine_id to %s", logger.info("Setting kv_transfer_config.engine_id to %s",
vllm_config.kv_transfer_config.engine_id) vllm_config.kv_transfer_config.engine_id)
from vllm.platforms import current_platform from vllm.platforms import current_platform
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment