Unverified Commit 4ccffe56 authored by Chenguang Zheng's avatar Chenguang Zheng Committed by GitHub
Browse files

[Core] Encoder separation for Encode-Prefill-Decode Disaggregation (#25233)


Signed-off-by: default avatarn00909098 <nguyen.kha.long@huawei.com>
Signed-off-by: default avatarknlnguyen1802 <knlnguyen1802@gmail.com>
Signed-off-by: default avatarherotai214 <herotai214@gmail.com>
Signed-off-by: default avatarKhuong Le <khuong.le.manh@huawei.com>
Signed-off-by: default avatarKhuong Le <lemanhkhuong2611@gmail.com>
Co-authored-by: default avatarn00909098 <nguyen.kha.long@huawei.com>
Co-authored-by: default avatarknlnguyen1802 <knlnguyen1802@gmail.com>
Co-authored-by: default avatarherotai214 <herotai214@gmail.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: default avatarKhuong Le <khuong.le.manh@huawei.com>
Co-authored-by: default avatarKhuong Le <lemanhkhuong2611@gmail.com>
parent cbb799e3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from collections.abc import Callable
from typing import TYPE_CHECKING
# yapf: disable
from vllm.distributed.ec_transfer.ec_connector.base import (
ECConnectorBase,
ECConnectorRole,
)
from vllm.logger import init_logger
# yapf: enable
if TYPE_CHECKING:
from vllm.config import ECTransferConfig, VllmConfig
logger = init_logger(__name__)
class ECConnectorFactory:
_registry: dict[str, Callable[[], type[ECConnectorBase]]] = {}
@classmethod
def register_connector(cls, name: str, module_path: str, class_name: str) -> None:
"""Register a connector with a lazy-loading module and class name."""
if name in cls._registry:
raise ValueError(f"Connector '{name}' is already registered.")
def loader() -> type[ECConnectorBase]:
module = importlib.import_module(module_path)
return getattr(module, class_name)
cls._registry[name] = loader
@classmethod
def create_connector(
cls,
config: "VllmConfig",
role: ECConnectorRole,
) -> ECConnectorBase:
ec_transfer_config = config.ec_transfer_config
if ec_transfer_config is None:
raise ValueError("ec_transfer_config must be set to create a connector")
connector_cls = cls.get_connector_class(ec_transfer_config)
logger.info(
"Creating connector with name: %s and engine_id: %s",
connector_cls.__name__,
ec_transfer_config.engine_id,
)
# Connector is explicitly separated into two roles.
# Scheduler connector:
# - Co-locate with scheduler process
# - Should only be used inside the Scheduler class
# Worker connector:
# - Co-locate with worker process
return connector_cls(config, role)
@classmethod
def get_connector_class(
cls, ec_transfer_config: "ECTransferConfig"
) -> type[ECConnectorBase]:
"""Get the connector class by name."""
connector_name = ec_transfer_config.ec_connector
if connector_name is None:
raise ValueError("EC connect must not be None")
elif connector_name in cls._registry:
connector_cls = cls._registry[connector_name]()
else:
connector_module_path = ec_transfer_config.ec_connector_module_path
if connector_module_path is None:
raise ValueError(f"Unsupported connector type: {connector_name}")
connector_module = importlib.import_module(connector_module_path)
connector_cls = getattr(connector_module, connector_name)
return connector_cls
# Register various connectors here.
# The registration should not be done in each individual file, as we want to
# only load the files corresponding to the current connector.
ECConnectorFactory.register_connector(
"ECSharedStorageConnector",
"vllm.distributed.ec_transfer.ec_connector.shared_storage_connector",
"ECSharedStorageConnector",
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING
import safetensors
from vllm.config import VllmConfig
from vllm.distributed.ec_transfer.ec_connector.base import (
ECConnectorBase,
ECConnectorMetadata,
ECConnectorRole,
)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.v1.request import Request
logger = init_logger(__name__)
@dataclass
class MMMeta:
mm_hash: str
num_token: int
@staticmethod
def make_meta(mm_hash, num_token) -> "MMMeta":
return MMMeta(mm_hash=mm_hash, num_token=num_token)
@dataclass
class ECSharedStorageConnectorMetadata(ECConnectorMetadata):
mm_datas: list[MMMeta]
def __init__(self):
self.mm_datas = []
def add_mm_data(self, mm_data: MMMeta):
self.mm_datas.append(mm_data)
class ECSharedStorageConnector(ECConnectorBase):
# NOTE: This is Simple debug implementation of the EC connector.
# It save / load the EC cache to / from the disk.
def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
# req_id -> index
self._mm_datas_need_loads: dict[str, int] = {}
transfer_config = vllm_config.ec_transfer_config
if transfer_config is not None:
self._storage_path = transfer_config.get_from_extra_config(
"shared_storage_path", "/tmp"
)
logger.debug(transfer_config)
logger.debug("Shared storage path is %s", self._storage_path)
else:
raise ValueError("ec_transfer_config must be set for ECConnectorBase")
def start_load_caches(self, encoder_cache, **kwargs) -> None:
"""
Start loading the cache from the connector into vLLM's encoder cache.
This method loads the encoder cache based on metadata provided by the scheduler.
It is called before `_gather_mm_embeddings` for the EC Connector. For EC,
the `encoder_cache` and `mm_hash` are stored in `kwargs`.
Args:
encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal
data hashes (`mm_hash`) to encoder cache tensors.
kwargs (dict): Additional keyword arguments for the connector.
"""
# Get the metadata
metadata: ECConnectorMetadata = self._get_connector_metadata()
assert isinstance(metadata, ECSharedStorageConnectorMetadata)
assert encoder_cache is not None
if metadata is None:
logger.warning(
(
"In connector.start_load_caches, ",
"but the connector metadata is None",
)
)
return
# Load the EC for each mm data
for mm_data in metadata.mm_datas:
if mm_data.mm_hash in encoder_cache:
continue
filename = self._generate_filename_debug(mm_data.mm_hash)
ec_cache = safetensors.torch.load_file(filename)["ec_cache"].cuda()
encoder_cache[mm_data.mm_hash] = ec_cache
logger.debug("Success load encoder cache for hash %s", mm_data.mm_hash)
def save_caches(self, encoder_cache, mm_hash, **kwargs) -> None:
"""
Save the encoder cache to the connector.
This method saves the encoder cache from the worker's local storage
to shared storage or another external connector.
Args:
encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal
data hashes (`mm_hash`) to encoder cache tensors.
mm_hash (str): The hash of the multimodal data whose cache is being saved.
kwargs (dict): Additional keyword arguments for the connector.
"""
# Return if it is PD Instance
if not self.is_producer:
return
filename = self._generate_filename_debug(mm_hash)
ec_cache = encoder_cache[mm_hash]
tensors = {"ec_cache": ec_cache.detach().cpu()}
safetensors.torch.save_file(tensors, filename)
logger.debug("Save cache successful for mm_hash %s", mm_hash)
def has_caches(
self,
request: "Request",
) -> list[bool]:
"""
Check if cache exist externally for each mm_data of request
Args:
request (Request): the request object.
Returns:
List of bool indicate that ith mm_data exist in cache or not
"""
result = []
for feature in request.mm_features:
result.append(self._found_match_for_mm_data(feature.identifier))
return result
def update_state_after_alloc(
self,
request: "Request",
index: int,
) -> None:
"""
Update ECConnector state after encoder cache allocation.
"""
mm_hash = request.mm_features[index].identifier
num_encoder_token = request.get_num_encoder_tokens(index)
# Insert mm_hash only if this block has not been recorded yet.
self._mm_datas_need_loads[mm_hash] = num_encoder_token
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> ECConnectorMetadata:
"""Build the connector metadata for this step.
This function should NOT modify any fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
This only build for load mm_data only
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta = ECSharedStorageConnectorMetadata()
for mm_hash, num_encoder_token in self._mm_datas_need_loads.items():
meta.add_mm_data(MMMeta.make_meta(mm_hash, num_encoder_token))
self._mm_datas_need_loads.clear()
return meta
# ==============================
# Helper functions
# ==============================
def _found_match_for_mm_data(self, mm_hash) -> bool:
"""Check if the cache is hit for the request."""
filename = self._generate_filename_debug(mm_hash)
return os.path.exists(filename)
def _generate_foldername_debug(
self,
mm_hash: str,
create_folder: bool = True, # <- now defaults to True
) -> str:
"""
Return the folder in which the cache for this mm_hash lives.
If `create_folder` is True (default) the directory is created
recursively the first time it is needed.
"""
foldername = os.path.join(self._storage_path, mm_hash)
if create_folder:
os.makedirs(foldername, exist_ok=True)
return foldername
def _generate_filename_debug(self, mm_hash: str) -> str:
"""
Return the full path of the safetensors file for this mm_hash.
Ensures the parent directory exists because
`_generate_foldername_debug` is called with its default
(`create_folder=True`).
"""
foldername = self._generate_foldername_debug(mm_hash) # <- folder auto-created
return os.path.join(foldername, "encoder_cache.safetensors")
This diff is collapsed.
...@@ -38,6 +38,7 @@ from vllm.config import ( ...@@ -38,6 +38,7 @@ from vllm.config import (
CompilationConfig, CompilationConfig,
ConfigType, ConfigType,
DeviceConfig, DeviceConfig,
ECTransferConfig,
EPLBConfig, EPLBConfig,
KVEventsConfig, KVEventsConfig,
KVTransferConfig, KVTransferConfig,
...@@ -527,6 +528,8 @@ class EngineArgs: ...@@ -527,6 +528,8 @@ class EngineArgs:
kv_transfer_config: KVTransferConfig | None = None kv_transfer_config: KVTransferConfig | None = None
kv_events_config: KVEventsConfig | None = None kv_events_config: KVEventsConfig | None = None
ec_transfer_config: ECTransferConfig | None = None
generation_config: str = ModelConfig.generation_config generation_config: str = ModelConfig.generation_config
enable_sleep_mode: bool = ModelConfig.enable_sleep_mode enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
override_generation_config: dict[str, Any] = get_field( override_generation_config: dict[str, Any] = get_field(
...@@ -1105,6 +1108,9 @@ class EngineArgs: ...@@ -1105,6 +1108,9 @@ class EngineArgs:
"--kv-transfer-config", **vllm_kwargs["kv_transfer_config"] "--kv-transfer-config", **vllm_kwargs["kv_transfer_config"]
) )
vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"]) vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"])
vllm_group.add_argument(
"--ec-transfer-config", **vllm_kwargs["ec_transfer_config"]
)
vllm_group.add_argument( vllm_group.add_argument(
"--compilation-config", "-O", **vllm_kwargs["compilation_config"] "--compilation-config", "-O", **vllm_kwargs["compilation_config"]
) )
...@@ -1676,6 +1682,7 @@ class EngineArgs: ...@@ -1676,6 +1682,7 @@ class EngineArgs:
compilation_config=self.compilation_config, compilation_config=self.compilation_config,
kv_transfer_config=self.kv_transfer_config, kv_transfer_config=self.kv_transfer_config,
kv_events_config=self.kv_events_config, kv_events_config=self.kv_events_config,
ec_transfer_config=self.ec_transfer_config,
additional_config=self.additional_config, additional_config=self.additional_config,
) )
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment