Unverified Commit 4ccffe56 authored by Chenguang Zheng's avatar Chenguang Zheng Committed by GitHub
Browse files

[Core] Encoder separation for Encode-Prefill-Decode Disaggregation (#25233)


Signed-off-by: default avatarn00909098 <nguyen.kha.long@huawei.com>
Signed-off-by: default avatarknlnguyen1802 <knlnguyen1802@gmail.com>
Signed-off-by: default avatarherotai214 <herotai214@gmail.com>
Signed-off-by: default avatarKhuong Le <khuong.le.manh@huawei.com>
Signed-off-by: default avatarKhuong Le <lemanhkhuong2611@gmail.com>
Co-authored-by: default avatarn00909098 <nguyen.kha.long@huawei.com>
Co-authored-by: default avatarknlnguyen1802 <knlnguyen1802@gmail.com>
Co-authored-by: default avatarherotai214 <herotai214@gmail.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: default avatarKhuong Le <khuong.le.manh@huawei.com>
Co-authored-by: default avatarKhuong Le <lemanhkhuong2611@gmail.com>
parent cbb799e3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from collections.abc import Callable
from typing import TYPE_CHECKING
# yapf: disable
from vllm.distributed.ec_transfer.ec_connector.base import (
ECConnectorBase,
ECConnectorRole,
)
from vllm.logger import init_logger
# yapf: enable
if TYPE_CHECKING:
from vllm.config import ECTransferConfig, VllmConfig
logger = init_logger(__name__)
class ECConnectorFactory:
_registry: dict[str, Callable[[], type[ECConnectorBase]]] = {}
@classmethod
def register_connector(cls, name: str, module_path: str, class_name: str) -> None:
"""Register a connector with a lazy-loading module and class name."""
if name in cls._registry:
raise ValueError(f"Connector '{name}' is already registered.")
def loader() -> type[ECConnectorBase]:
module = importlib.import_module(module_path)
return getattr(module, class_name)
cls._registry[name] = loader
@classmethod
def create_connector(
cls,
config: "VllmConfig",
role: ECConnectorRole,
) -> ECConnectorBase:
ec_transfer_config = config.ec_transfer_config
if ec_transfer_config is None:
raise ValueError("ec_transfer_config must be set to create a connector")
connector_cls = cls.get_connector_class(ec_transfer_config)
logger.info(
"Creating connector with name: %s and engine_id: %s",
connector_cls.__name__,
ec_transfer_config.engine_id,
)
# Connector is explicitly separated into two roles.
# Scheduler connector:
# - Co-locate with scheduler process
# - Should only be used inside the Scheduler class
# Worker connector:
# - Co-locate with worker process
return connector_cls(config, role)
@classmethod
def get_connector_class(
cls, ec_transfer_config: "ECTransferConfig"
) -> type[ECConnectorBase]:
"""Get the connector class by name."""
connector_name = ec_transfer_config.ec_connector
if connector_name is None:
raise ValueError("EC connect must not be None")
elif connector_name in cls._registry:
connector_cls = cls._registry[connector_name]()
else:
connector_module_path = ec_transfer_config.ec_connector_module_path
if connector_module_path is None:
raise ValueError(f"Unsupported connector type: {connector_name}")
connector_module = importlib.import_module(connector_module_path)
connector_cls = getattr(connector_module, connector_name)
return connector_cls
# Register various connectors here.
# The registration should not be done in each individual file, as we want to
# only load the files corresponding to the current connector.
ECConnectorFactory.register_connector(
"ECSharedStorageConnector",
"vllm.distributed.ec_transfer.ec_connector.shared_storage_connector",
"ECSharedStorageConnector",
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING
import safetensors
from vllm.config import VllmConfig
from vllm.distributed.ec_transfer.ec_connector.base import (
ECConnectorBase,
ECConnectorMetadata,
ECConnectorRole,
)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.v1.request import Request
logger = init_logger(__name__)
@dataclass
class MMMeta:
mm_hash: str
num_token: int
@staticmethod
def make_meta(mm_hash, num_token) -> "MMMeta":
return MMMeta(mm_hash=mm_hash, num_token=num_token)
@dataclass
class ECSharedStorageConnectorMetadata(ECConnectorMetadata):
mm_datas: list[MMMeta]
def __init__(self):
self.mm_datas = []
def add_mm_data(self, mm_data: MMMeta):
self.mm_datas.append(mm_data)
class ECSharedStorageConnector(ECConnectorBase):
# NOTE: This is Simple debug implementation of the EC connector.
# It save / load the EC cache to / from the disk.
def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
# req_id -> index
self._mm_datas_need_loads: dict[str, int] = {}
transfer_config = vllm_config.ec_transfer_config
if transfer_config is not None:
self._storage_path = transfer_config.get_from_extra_config(
"shared_storage_path", "/tmp"
)
logger.debug(transfer_config)
logger.debug("Shared storage path is %s", self._storage_path)
else:
raise ValueError("ec_transfer_config must be set for ECConnectorBase")
def start_load_caches(self, encoder_cache, **kwargs) -> None:
"""
Start loading the cache from the connector into vLLM's encoder cache.
This method loads the encoder cache based on metadata provided by the scheduler.
It is called before `_gather_mm_embeddings` for the EC Connector. For EC,
the `encoder_cache` and `mm_hash` are stored in `kwargs`.
Args:
encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal
data hashes (`mm_hash`) to encoder cache tensors.
kwargs (dict): Additional keyword arguments for the connector.
"""
# Get the metadata
metadata: ECConnectorMetadata = self._get_connector_metadata()
assert isinstance(metadata, ECSharedStorageConnectorMetadata)
assert encoder_cache is not None
if metadata is None:
logger.warning(
(
"In connector.start_load_caches, ",
"but the connector metadata is None",
)
)
return
# Load the EC for each mm data
for mm_data in metadata.mm_datas:
if mm_data.mm_hash in encoder_cache:
continue
filename = self._generate_filename_debug(mm_data.mm_hash)
ec_cache = safetensors.torch.load_file(filename)["ec_cache"].cuda()
encoder_cache[mm_data.mm_hash] = ec_cache
logger.debug("Success load encoder cache for hash %s", mm_data.mm_hash)
def save_caches(self, encoder_cache, mm_hash, **kwargs) -> None:
"""
Save the encoder cache to the connector.
This method saves the encoder cache from the worker's local storage
to shared storage or another external connector.
Args:
encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal
data hashes (`mm_hash`) to encoder cache tensors.
mm_hash (str): The hash of the multimodal data whose cache is being saved.
kwargs (dict): Additional keyword arguments for the connector.
"""
# Return if it is PD Instance
if not self.is_producer:
return
filename = self._generate_filename_debug(mm_hash)
ec_cache = encoder_cache[mm_hash]
tensors = {"ec_cache": ec_cache.detach().cpu()}
safetensors.torch.save_file(tensors, filename)
logger.debug("Save cache successful for mm_hash %s", mm_hash)
def has_caches(
self,
request: "Request",
) -> list[bool]:
"""
Check if cache exist externally for each mm_data of request
Args:
request (Request): the request object.
Returns:
List of bool indicate that ith mm_data exist in cache or not
"""
result = []
for feature in request.mm_features:
result.append(self._found_match_for_mm_data(feature.identifier))
return result
def update_state_after_alloc(
self,
request: "Request",
index: int,
) -> None:
"""
Update ECConnector state after encoder cache allocation.
"""
mm_hash = request.mm_features[index].identifier
num_encoder_token = request.get_num_encoder_tokens(index)
# Insert mm_hash only if this block has not been recorded yet.
self._mm_datas_need_loads[mm_hash] = num_encoder_token
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> ECConnectorMetadata:
"""Build the connector metadata for this step.
This function should NOT modify any fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
This only build for load mm_data only
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta = ECSharedStorageConnectorMetadata()
for mm_hash, num_encoder_token in self._mm_datas_need_loads.items():
meta.add_mm_data(MMMeta.make_meta(mm_hash, num_encoder_token))
self._mm_datas_need_loads.clear()
return meta
# ==============================
# Helper functions
# ==============================
def _found_match_for_mm_data(self, mm_hash) -> bool:
"""Check if the cache is hit for the request."""
filename = self._generate_filename_debug(mm_hash)
return os.path.exists(filename)
def _generate_foldername_debug(
self,
mm_hash: str,
create_folder: bool = True, # <- now defaults to True
) -> str:
"""
Return the folder in which the cache for this mm_hash lives.
If `create_folder` is True (default) the directory is created
recursively the first time it is needed.
"""
foldername = os.path.join(self._storage_path, mm_hash)
if create_folder:
os.makedirs(foldername, exist_ok=True)
return foldername
def _generate_filename_debug(self, mm_hash: str) -> str:
"""
Return the full path of the safetensors file for this mm_hash.
Ensures the parent directory exists because
`_generate_foldername_debug` is called with its default
(`create_folder=True`).
"""
foldername = self._generate_foldername_debug(mm_hash) # <- folder auto-created
return os.path.join(foldername, "encoder_cache.safetensors")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
from vllm import envs
from vllm.distributed.ec_transfer.ec_connector.base import (
ECConnectorBase,
ECConnectorRole,
)
from vllm.distributed.ec_transfer.ec_connector.factory import ECConnectorFactory
if TYPE_CHECKING:
from vllm.config import VllmConfig
_EC_CONNECTOR_AGENT: ECConnectorBase | None = None
def get_ec_transfer() -> ECConnectorBase:
assert _EC_CONNECTOR_AGENT is not None, "disaggregated EC cache is not initialized"
return _EC_CONNECTOR_AGENT
def has_ec_transfer() -> bool:
return _EC_CONNECTOR_AGENT is not None
def ensure_ec_transfer_initialized(vllm_config: "VllmConfig") -> None:
"""
Initialize EC cache connector.
"""
global _EC_CONNECTOR_AGENT
if vllm_config.ec_transfer_config is None:
return
if (
vllm_config.ec_transfer_config.is_ec_transfer_instance
and _EC_CONNECTOR_AGENT is None
):
if envs.VLLM_USE_V1:
_EC_CONNECTOR_AGENT = ECConnectorFactory.create_connector(
config=vllm_config, role=ECConnectorRole.WORKER
)
else:
raise ValueError("V0 is no longer supported")
...@@ -38,6 +38,7 @@ from vllm.config import ( ...@@ -38,6 +38,7 @@ from vllm.config import (
CompilationConfig, CompilationConfig,
ConfigType, ConfigType,
DeviceConfig, DeviceConfig,
ECTransferConfig,
EPLBConfig, EPLBConfig,
KVEventsConfig, KVEventsConfig,
KVTransferConfig, KVTransferConfig,
...@@ -527,6 +528,8 @@ class EngineArgs: ...@@ -527,6 +528,8 @@ class EngineArgs:
kv_transfer_config: KVTransferConfig | None = None kv_transfer_config: KVTransferConfig | None = None
kv_events_config: KVEventsConfig | None = None kv_events_config: KVEventsConfig | None = None
ec_transfer_config: ECTransferConfig | None = None
generation_config: str = ModelConfig.generation_config generation_config: str = ModelConfig.generation_config
enable_sleep_mode: bool = ModelConfig.enable_sleep_mode enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
override_generation_config: dict[str, Any] = get_field( override_generation_config: dict[str, Any] = get_field(
...@@ -1105,6 +1108,9 @@ class EngineArgs: ...@@ -1105,6 +1108,9 @@ class EngineArgs:
"--kv-transfer-config", **vllm_kwargs["kv_transfer_config"] "--kv-transfer-config", **vllm_kwargs["kv_transfer_config"]
) )
vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"]) vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"])
vllm_group.add_argument(
"--ec-transfer-config", **vllm_kwargs["ec_transfer_config"]
)
vllm_group.add_argument( vllm_group.add_argument(
"--compilation-config", "-O", **vllm_kwargs["compilation_config"] "--compilation-config", "-O", **vllm_kwargs["compilation_config"]
) )
...@@ -1676,6 +1682,7 @@ class EngineArgs: ...@@ -1676,6 +1682,7 @@ class EngineArgs:
compilation_config=self.compilation_config, compilation_config=self.compilation_config,
kv_transfer_config=self.kv_transfer_config, kv_transfer_config=self.kv_transfer_config,
kv_events_config=self.kv_events_config, kv_events_config=self.kv_events_config,
ec_transfer_config=self.ec_transfer_config,
additional_config=self.additional_config, additional_config=self.additional_config,
) )
......
...@@ -49,10 +49,18 @@ def kernel_warmup(worker: "Worker"): ...@@ -49,10 +49,18 @@ def kernel_warmup(worker: "Worker"):
except NotImplementedError: except NotImplementedError:
return False return False
if not worker.model_runner.is_pooling_model and all( # NOTE: we add check for empty attn_groups to avoid errors when
_is_flashinfer_backend(group.backend) # deploying models such as E instances and encoder-only models.
for groups in worker.model_runner.attn_groups # As for those models, worker.model_runner.attn_groups is empty.
for group in groups # This change is made during EPD feature development.
if (
not worker.model_runner.is_pooling_model
and worker.model_runner.attn_groups
and all(
_is_flashinfer_backend(group.backend)
for groups in worker.model_runner.attn_groups
for group in groups
)
): ):
logger.info("Warming up FlashInfer attention.") logger.info("Warming up FlashInfer attention.")
# Warmup with mixed batch containing both prefill and decode tokens # Warmup with mixed batch containing both prefill and decode tokens
......
...@@ -14,6 +14,7 @@ if TYPE_CHECKING: ...@@ -14,6 +14,7 @@ if TYPE_CHECKING:
import numpy.typing as npt import numpy.typing as npt
import torch import torch
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.multimodal.inputs import MultiModalFeatureSpec
...@@ -21,6 +22,7 @@ if TYPE_CHECKING: ...@@ -21,6 +22,7 @@ if TYPE_CHECKING:
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request from vllm.v1.request import Request
else: else:
ECConnectorMetadata = object
KVConnectorMetadata = object KVConnectorMetadata = object
LoRARequest = object LoRARequest = object
MultiModalFeatureSpec = object MultiModalFeatureSpec = object
...@@ -188,6 +190,9 @@ class SchedulerOutput: ...@@ -188,6 +190,9 @@ class SchedulerOutput:
# KV Cache Connector metadata. # KV Cache Connector metadata.
kv_connector_metadata: KVConnectorMetadata | None = None kv_connector_metadata: KVConnectorMetadata | None = None
# EC Cache Connector metadata
ec_connector_metadata: ECConnectorMetadata | None = None
@dataclass @dataclass
class GrammarOutput: class GrammarOutput:
......
...@@ -7,6 +7,11 @@ from collections.abc import Iterable ...@@ -7,6 +7,11 @@ from collections.abc import Iterable
from typing import Any from typing import Any
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.ec_transfer.ec_connector.base import (
ECConnectorMetadata,
ECConnectorRole,
)
from vllm.distributed.ec_transfer.ec_connector.factory import ECConnectorFactory
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import ( from vllm.distributed.kv_transfer.kv_connector.v1 import (
...@@ -14,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import ( ...@@ -14,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorRole, KVConnectorRole,
SupportsHMA, SupportsHMA,
) )
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
...@@ -104,6 +110,11 @@ class Scheduler(SchedulerInterface): ...@@ -104,6 +110,11 @@ class Scheduler(SchedulerInterface):
self.kv_events_config, self.kv_events_config,
self.parallel_config.data_parallel_rank, self.parallel_config.data_parallel_rank,
) )
self.ec_connector = None
if self.vllm_config.ec_transfer_config is not None:
self.ec_connector = ECConnectorFactory.create_connector(
config=self.vllm_config, role=ECConnectorRole.SCHEDULER
)
num_gpu_blocks = self.cache_config.num_gpu_blocks num_gpu_blocks = self.cache_config.num_gpu_blocks
assert num_gpu_blocks is not None and num_gpu_blocks > 0 assert num_gpu_blocks is not None and num_gpu_blocks > 0
...@@ -230,12 +241,14 @@ class Scheduler(SchedulerInterface): ...@@ -230,12 +241,14 @@ class Scheduler(SchedulerInterface):
# Schedule encoder inputs. # Schedule encoder inputs.
encoder_inputs_to_schedule = None encoder_inputs_to_schedule = None
external_load_encoder_input: list[int] = []
new_encoder_compute_budget = encoder_compute_budget new_encoder_compute_budget = encoder_compute_budget
if request.has_encoder_inputs: if request.has_encoder_inputs:
( (
encoder_inputs_to_schedule, encoder_inputs_to_schedule,
num_new_tokens, num_new_tokens,
new_encoder_compute_budget, new_encoder_compute_budget,
external_load_encoder_input,
) = self._try_schedule_encoder_inputs( ) = self._try_schedule_encoder_inputs(
request, request,
request.num_computed_tokens, request.num_computed_tokens,
...@@ -342,6 +355,11 @@ class Scheduler(SchedulerInterface): ...@@ -342,6 +355,11 @@ class Scheduler(SchedulerInterface):
for i in encoder_inputs_to_schedule: for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i) self.encoder_cache_manager.allocate(request, i)
encoder_compute_budget = new_encoder_compute_budget encoder_compute_budget = new_encoder_compute_budget
if external_load_encoder_input:
for i in external_load_encoder_input:
self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(request, i)
# Record the LoRAs in scheduled_running_reqs # Record the LoRAs in scheduled_running_reqs
scheduled_loras: set[int] = set() scheduled_loras: set[int] = set()
...@@ -445,6 +463,7 @@ class Scheduler(SchedulerInterface): ...@@ -445,6 +463,7 @@ class Scheduler(SchedulerInterface):
num_computed_tokens = request.num_computed_tokens num_computed_tokens = request.num_computed_tokens
encoder_inputs_to_schedule = None encoder_inputs_to_schedule = None
external_load_encoder_input = []
new_encoder_compute_budget = encoder_compute_budget new_encoder_compute_budget = encoder_compute_budget
# KVTransfer: loading remote KV, do not allocate for new work. # KVTransfer: loading remote KV, do not allocate for new work.
...@@ -480,6 +499,7 @@ class Scheduler(SchedulerInterface): ...@@ -480,6 +499,7 @@ class Scheduler(SchedulerInterface):
encoder_inputs_to_schedule, encoder_inputs_to_schedule,
num_new_tokens, num_new_tokens,
new_encoder_compute_budget, new_encoder_compute_budget,
external_load_encoder_input,
) = self._try_schedule_encoder_inputs( ) = self._try_schedule_encoder_inputs(
request, request,
num_computed_tokens, num_computed_tokens,
...@@ -583,7 +603,12 @@ class Scheduler(SchedulerInterface): ...@@ -583,7 +603,12 @@ class Scheduler(SchedulerInterface):
for i in encoder_inputs_to_schedule: for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i) self.encoder_cache_manager.allocate(request, i)
encoder_compute_budget = new_encoder_compute_budget encoder_compute_budget = new_encoder_compute_budget
# Allocate for external load encoder cache
if external_load_encoder_input:
for i in external_load_encoder_input:
self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(request, i)
# Put back any skipped requests at the head of the waiting queue # Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests: if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests) self.waiting.prepend_requests(skipped_waiting_requests)
...@@ -591,6 +616,7 @@ class Scheduler(SchedulerInterface): ...@@ -591,6 +616,7 @@ class Scheduler(SchedulerInterface):
# Check if the scheduling constraints are satisfied. # Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
assert token_budget >= 0 assert token_budget >= 0
assert len(self.running) <= self.max_num_running_reqs assert len(self.running) <= self.max_num_running_reqs
# Since some requests in the RUNNING queue may not be scheduled in # Since some requests in the RUNNING queue may not be scheduled in
...@@ -653,8 +679,18 @@ class Scheduler(SchedulerInterface): ...@@ -653,8 +679,18 @@ class Scheduler(SchedulerInterface):
# 2. Wrap up all the KV cache load / save ops into an opaque object # 2. Wrap up all the KV cache load / save ops into an opaque object
# 3. Clear the internal states of the connector # 3. Clear the internal states of the connector
if self.connector is not None: if self.connector is not None:
meta = self.connector.build_connector_meta(scheduler_output) meta: KVConnectorMetadata = self.connector.build_connector_meta(
scheduler_output
)
scheduler_output.kv_connector_metadata = meta scheduler_output.kv_connector_metadata = meta
# Build the connector meta for ECConnector
if self.ec_connector is not None:
ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta(
scheduler_output
)
scheduler_output.ec_connector_metadata = ec_meta
with record_function_or_nullcontext("schedule: update_after_schedule"): with record_function_or_nullcontext("schedule: update_after_schedule"):
self._update_after_schedule(scheduler_output) self._update_after_schedule(scheduler_output)
return scheduler_output return scheduler_output
...@@ -755,7 +791,7 @@ class Scheduler(SchedulerInterface): ...@@ -755,7 +791,7 @@ class Scheduler(SchedulerInterface):
num_computed_tokens: int, num_computed_tokens: int,
num_new_tokens: int, num_new_tokens: int,
encoder_compute_budget: int, encoder_compute_budget: int,
) -> tuple[list[int], int, int]: ) -> tuple[list[int], int, int, list[int]]:
""" """
Determine which encoder inputs need to be scheduled in the current step, Determine which encoder inputs need to be scheduled in the current step,
and update `num_new_tokens` and encoder token budget accordingly. and update `num_new_tokens` and encoder token budget accordingly.
...@@ -765,6 +801,7 @@ class Scheduler(SchedulerInterface): ...@@ -765,6 +801,7 @@ class Scheduler(SchedulerInterface):
in this step, i.e., in this step, i.e.,
[num_computed_tokens, num_computed_tokens + num_new_tokens). [num_computed_tokens, num_computed_tokens + num_new_tokens).
- It is not already computed and stored in the encoder cache. - It is not already computed and stored in the encoder cache.
- It is not exist on remote encoder cache (via ECConnector)
- There is sufficient encoder token budget to process it. - There is sufficient encoder token budget to process it.
- The encoder cache has space to store it. - The encoder cache has space to store it.
...@@ -776,12 +813,16 @@ class Scheduler(SchedulerInterface): ...@@ -776,12 +813,16 @@ class Scheduler(SchedulerInterface):
blocks and externally cached blocks (via KVConnector). blocks and externally cached blocks (via KVConnector).
""" """
if num_new_tokens == 0 or not request.has_encoder_inputs: if num_new_tokens == 0 or not request.has_encoder_inputs:
return [], num_new_tokens, encoder_compute_budget return [], num_new_tokens, encoder_compute_budget, []
encoder_inputs_to_schedule: list[int] = [] encoder_inputs_to_schedule: list[int] = []
mm_features = request.mm_features mm_features = request.mm_features
assert mm_features is not None assert mm_features is not None
assert len(mm_features) > 0 assert len(mm_features) > 0
external_load_encoder_input = []
# Check remote cache first
if self.ec_connector is not None:
remote_cache_has_item = self.ec_connector.has_caches(request)
# NOTE: since scheduler operates on the request level (possibly with # NOTE: since scheduler operates on the request level (possibly with
# multiple encoder inputs per request), we need to create temporary # multiple encoder inputs per request), we need to create temporary
# trackers for accounting at the encoder input level. # trackers for accounting at the encoder input level.
...@@ -862,6 +903,12 @@ class Scheduler(SchedulerInterface): ...@@ -862,6 +903,12 @@ class Scheduler(SchedulerInterface):
num_new_tokens = 0 num_new_tokens = 0
break break
if self.ec_connector is not None and remote_cache_has_item[i]:
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
external_load_encoder_input.append(i)
num_tokens_to_schedule += num_encoder_tokens
continue
num_tokens_to_schedule += num_encoder_tokens num_tokens_to_schedule += num_encoder_tokens
encoder_compute_budget -= num_encoder_tokens encoder_compute_budget -= num_encoder_tokens
mm_hashes_to_schedule.add(request.mm_features[i].identifier) mm_hashes_to_schedule.add(request.mm_features[i].identifier)
...@@ -871,6 +918,7 @@ class Scheduler(SchedulerInterface): ...@@ -871,6 +918,7 @@ class Scheduler(SchedulerInterface):
encoder_inputs_to_schedule, encoder_inputs_to_schedule,
num_new_tokens, num_new_tokens,
encoder_compute_budget, encoder_compute_budget,
external_load_encoder_input,
) )
def get_grammar_bitmask( def get_grammar_bitmask(
......
...@@ -8,6 +8,8 @@ from typing import TYPE_CHECKING, NamedTuple ...@@ -8,6 +8,8 @@ from typing import TYPE_CHECKING, NamedTuple
import numpy as np import numpy as np
import torch import torch
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
else: else:
...@@ -136,6 +138,13 @@ class KVConnectorOutput: ...@@ -136,6 +138,13 @@ class KVConnectorOutput:
) )
@dataclass
class ECConnectorOutput:
# [mm_hash]
finished_sending: set[str] | None = None
finished_recving: set[str] | None = None
# ModelRunnerOutput is serialized and sent to the scheduler process. # ModelRunnerOutput is serialized and sent to the scheduler process.
# This is expensive for torch.Tensor so prefer to use list instead. # This is expensive for torch.Tensor so prefer to use list instead.
@dataclass @dataclass
...@@ -167,6 +176,8 @@ class ModelRunnerOutput: ...@@ -167,6 +176,8 @@ class ModelRunnerOutput:
kv_connector_output: KVConnectorOutput | None = None kv_connector_output: KVConnectorOutput | None = None
ec_connector_output: ECConnectorOutput | None = None
# req_id -> num_nans_in_logits # req_id -> num_nans_in_logits
num_nans_in_logits: dict[str, int] | None = None num_nans_in_logits: dict[str, int] | None = None
...@@ -192,6 +203,41 @@ class DraftTokenIds: ...@@ -192,6 +203,41 @@ class DraftTokenIds:
draft_token_ids: list[list[int]] draft_token_ids: list[list[int]]
def make_empty_encoder_model_runner_output(
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
"""
Create a ModelRunnerOutput stub that contains the correct
per-request bookkeeping but no generated data yet.
"""
if not scheduler_output.num_scheduled_tokens:
return EMPTY_MODEL_RUNNER_OUTPUT
# Convert to list so we get a deterministic, indexable sequence
req_ids: list[str] = list(scheduler_output.num_scheduled_tokens.keys())
# Give every request its own contiguous index
req_id_to_index: dict[str, int] = {rid: idx for idx, rid in enumerate(req_ids)}
# No tokens generated yet ⇒ one empty list per request
sampled_token_ids: list[list[int]] = [[0] for _ in req_ids]
# Pooler outputs are not available yet ⇒ use None placeholders
pooler_output: list[torch.Tensor | None] = [None for _ in req_ids]
return ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=pooler_output,
kv_connector_output=None,
ec_connector_output=None,
num_nans_in_logits=None,
)
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=[], req_ids=[],
req_id_to_index={}, req_id_to_index={},
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Define EC connector functionality mixin for model runners.
"""
from collections.abc import Generator
from contextlib import AbstractContextManager, contextmanager, nullcontext
from typing import (
TYPE_CHECKING, # noqa: UP035
)
import torch
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorBase
from vllm.logger import init_logger
from vllm.v1.outputs import ECConnectorOutput
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
logger = init_logger(__name__)
# Defined as a EC connector functionality mixin for ModelRunner (GPU, TPU)
class ECConnectorModelRunnerMixin:
@staticmethod
def maybe_save_ec_to_connector(
encoder_cache: dict[str, torch.Tensor],
mm_hash: str,
):
if not has_ec_transfer():
logger.debug("Not have ec transfer please check")
return
connector = get_ec_transfer()
connector.save_caches(encoder_cache=encoder_cache, mm_hash=mm_hash)
@staticmethod
def get_finished_ec_transfers(
scheduler_output: "SchedulerOutput",
) -> tuple[set[str] | None, set[str] | None]:
if has_ec_transfer():
return get_ec_transfer().get_finished(scheduler_output.finished_req_ids)
return None, None
@staticmethod
def maybe_get_ec_connector_output(
scheduler_output: "SchedulerOutput",
encoder_cache: dict[str, torch.Tensor],
**kwargs,
) -> AbstractContextManager[ECConnectorOutput | None]:
return (
ECConnectorModelRunnerMixin._get_ec_connector_output(
scheduler_output, encoder_cache, **kwargs
)
if has_ec_transfer()
else nullcontext()
)
# This context manager must be used within an active forward context.
# It encapsulates the entire EC conector lifecycle within execute_model
@staticmethod
@contextmanager
def _get_ec_connector_output(
scheduler_output: "SchedulerOutput",
encoder_cache: dict[str, torch.Tensor],
**kwargs,
) -> Generator[ECConnectorOutput, None, None]:
output = ECConnectorOutput()
ec_connector = get_ec_transfer()
assert isinstance(ec_connector, ECConnectorBase)
assert scheduler_output.ec_connector_metadata is not None
ec_connector.bind_connector_metadata(scheduler_output.ec_connector_metadata)
if not ec_connector.is_producer:
ec_connector.start_load_caches(encoder_cache, **kwargs)
try:
yield output
finally:
output.finished_sending, output.finished_recving = (
ec_connector.get_finished(scheduler_output.finished_req_ids)
)
ec_connector.clear_connector_metadata()
...@@ -35,6 +35,7 @@ from vllm.config import ( ...@@ -35,6 +35,7 @@ from vllm.config import (
get_layers_from_vllm_config, get_layers_from_vllm_config,
update_config, update_config,
) )
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
...@@ -114,12 +115,14 @@ from vllm.v1.outputs import ( ...@@ -114,12 +115,14 @@ from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT, EMPTY_MODEL_RUNNER_OUTPUT,
AsyncModelRunnerOutput, AsyncModelRunnerOutput,
DraftTokenIds, DraftTokenIds,
ECConnectorOutput,
KVConnectorOutput, KVConnectorOutput,
LogprobsLists, LogprobsLists,
LogprobsTensors, LogprobsTensors,
ModelRunnerOutput, ModelRunnerOutput,
PoolerOutput, PoolerOutput,
SamplerOutput, SamplerOutput,
make_empty_encoder_model_runner_output,
) )
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
...@@ -134,6 +137,7 @@ from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer ...@@ -134,6 +137,7 @@ from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.structured_output.utils import apply_grammar_bitmask
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
...@@ -237,9 +241,12 @@ class ExecuteModelState(NamedTuple): ...@@ -237,9 +241,12 @@ class ExecuteModelState(NamedTuple):
sample_hidden_states: torch.Tensor sample_hidden_states: torch.Tensor
aux_hidden_states: list[torch.Tensor] | None aux_hidden_states: list[torch.Tensor] | None
kv_connector_output: KVConnectorOutput | None kv_connector_output: KVConnectorOutput | None
ec_connector_output: ECConnectorOutput | None
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): class GPUModelRunner(
LoRAModelRunnerMixin, KVConnectorModelRunnerMixin, ECConnectorModelRunnerMixin
):
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
...@@ -1873,6 +1880,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1873,6 +1880,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
output, output,
is_embed=pos_info.is_embed, is_embed=pos_info.is_embed,
) )
logger.debug("Finish execute for mm hash %s", mm_hash)
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
def _gather_mm_embeddings( def _gather_mm_embeddings(
self, self,
...@@ -2191,20 +2200,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2191,20 +2200,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
torch.Tensor, torch.Tensor,
IntermediateTensors | None, IntermediateTensors | None,
dict[str, Any], dict[str, Any],
ECConnectorOutput | None,
]: ]:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
is_first_rank = get_pp_group().is_first_rank is_first_rank = get_pp_group().is_first_rank
# _prepare_inputs may reorder the batch, so we must gather multi # _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order # modal outputs after that to ensure the correct order
ec_connector_output = None
if ( if (
self.supports_mm_inputs self.supports_mm_inputs
and is_first_rank and is_first_rank
and not self.model_config.is_encoder_decoder and not self.model_config.is_encoder_decoder
): ):
# Run the multimodal encoder if any. # Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output) with self.maybe_get_ec_connector_output(
mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output) scheduler_output,
encoder_cache=self.encoder_cache,
) as ec_connector_output:
self._execute_mm_encoder(scheduler_output)
mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output)
# NOTE(woosuk): To unify token ids and soft tokens (vision # NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids) # embeddings), we always use embeddings (rather than token ids)
...@@ -2284,6 +2300,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2284,6 +2300,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
positions, positions,
intermediate_tensors, intermediate_tensors,
model_kwargs, model_kwargs,
ec_connector_output,
) )
def _sample( def _sample(
...@@ -2508,6 +2525,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2508,6 +2525,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Update persistent batch states. # Update persistent batch states.
self._update_states(scheduler_output) self._update_states(scheduler_output)
if has_ec_transfer() and get_ec_transfer().is_producer:
with self.maybe_get_ec_connector_output(
scheduler_output,
encoder_cache=self.encoder_cache,
) as ec_connector_output:
self._execute_mm_encoder(scheduler_output)
return make_empty_encoder_model_runner_output(scheduler_output)
if not num_scheduled_tokens: if not num_scheduled_tokens:
if not has_kv_transfer_group(): if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if no work to do. # Return empty ModelRunnerOutput if no work to do.
...@@ -2583,6 +2608,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2583,6 +2608,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
positions, positions,
intermediate_tensors, intermediate_tensors,
model_kwargs, model_kwargs,
ec_connector_output,
) = self._preprocess( ) = self._preprocess(
scheduler_output, num_input_tokens, intermediate_tensors scheduler_output, num_input_tokens, intermediate_tensors
) )
...@@ -2699,6 +2725,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2699,6 +2725,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sample_hidden_states, sample_hidden_states,
aux_hidden_states, aux_hidden_states,
kv_connector_output, kv_connector_output,
ec_connector_output,
) )
return None return None
...@@ -2720,6 +2747,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2720,6 +2747,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sample_hidden_states, sample_hidden_states,
aux_hidden_states, aux_hidden_states,
kv_connector_output, kv_connector_output,
ec_connector_output,
) = self.execute_model_state ) = self.execute_model_state
# Clear ephemeral state. # Clear ephemeral state.
self.execute_model_state = None self.execute_model_state = None
...@@ -2811,6 +2839,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2811,6 +2839,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
prompt_logprobs_dict=prompt_logprobs_dict, prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[], pooler_output=[],
kv_connector_output=kv_connector_output, kv_connector_output=kv_connector_output,
ec_connector_output=ec_connector_output
if self.supports_mm_inputs
else None,
num_nans_in_logits=num_nans_in_logits, num_nans_in_logits=num_nans_in_logits,
) )
...@@ -4797,7 +4828,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -4797,7 +4828,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
KVCacheSpec: A dictionary mapping layer names to their KV cache KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included. format. Layers that do not need KV cache are not included.
""" """
if has_ec_transfer() and get_ec_transfer().is_producer:
return {}
kv_cache_spec: dict[str, KVCacheSpec] = {} kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
for layer_name, attn_module in attn_layers.items(): for layer_name, attn_module in attn_layers.items():
......
...@@ -20,6 +20,7 @@ from vllm.distributed import ( ...@@ -20,6 +20,7 @@ from vllm.distributed import (
init_distributed_environment, init_distributed_environment,
set_custom_all_reduce, set_custom_all_reduce,
) )
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
from vllm.distributed.kv_transfer import ( from vllm.distributed.kv_transfer import (
ensure_kv_transfer_initialized, ensure_kv_transfer_initialized,
get_kv_transfer_group, get_kv_transfer_group,
...@@ -887,3 +888,7 @@ def init_worker_distributed_environment( ...@@ -887,3 +888,7 @@ def init_worker_distributed_environment(
parallel_config.pipeline_parallel_size, parallel_config.pipeline_parallel_size,
parallel_config.decode_context_parallel_size, parallel_config.decode_context_parallel_size,
) )
# Init ec connector here before KV caches caches init
# NOTE: We do not init KV caches for Encoder-only instance in EPD disagg mode
ensure_ec_transfer_initialized(vllm_config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment