Unverified Commit 4ccffe56 authored by Chenguang Zheng's avatar Chenguang Zheng Committed by GitHub
Browse files

[Core] Encoder separation for Encode-Prefill-Decode Disaggregation (#25233)


Signed-off-by: default avatarn00909098 <nguyen.kha.long@huawei.com>
Signed-off-by: default avatarknlnguyen1802 <knlnguyen1802@gmail.com>
Signed-off-by: default avatarherotai214 <herotai214@gmail.com>
Signed-off-by: default avatarKhuong Le <khuong.le.manh@huawei.com>
Signed-off-by: default avatarKhuong Le <lemanhkhuong2611@gmail.com>
Co-authored-by: default avatarn00909098 <nguyen.kha.long@huawei.com>
Co-authored-by: default avatarknlnguyen1802 <knlnguyen1802@gmail.com>
Co-authored-by: default avatarherotai214 <herotai214@gmail.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: default avatarKhuong Le <khuong.le.manh@huawei.com>
Co-authored-by: default avatarKhuong Le <lemanhkhuong2611@gmail.com>
parent cbb799e3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from collections.abc import Callable
from typing import TYPE_CHECKING
# yapf: disable
from vllm.distributed.ec_transfer.ec_connector.base import (
ECConnectorBase,
ECConnectorRole,
)
from vllm.logger import init_logger
# yapf: enable
if TYPE_CHECKING:
from vllm.config import ECTransferConfig, VllmConfig
logger = init_logger(__name__)
class ECConnectorFactory:
_registry: dict[str, Callable[[], type[ECConnectorBase]]] = {}
@classmethod
def register_connector(cls, name: str, module_path: str, class_name: str) -> None:
"""Register a connector with a lazy-loading module and class name."""
if name in cls._registry:
raise ValueError(f"Connector '{name}' is already registered.")
def loader() -> type[ECConnectorBase]:
module = importlib.import_module(module_path)
return getattr(module, class_name)
cls._registry[name] = loader
@classmethod
def create_connector(
cls,
config: "VllmConfig",
role: ECConnectorRole,
) -> ECConnectorBase:
ec_transfer_config = config.ec_transfer_config
if ec_transfer_config is None:
raise ValueError("ec_transfer_config must be set to create a connector")
connector_cls = cls.get_connector_class(ec_transfer_config)
logger.info(
"Creating connector with name: %s and engine_id: %s",
connector_cls.__name__,
ec_transfer_config.engine_id,
)
# Connector is explicitly separated into two roles.
# Scheduler connector:
# - Co-locate with scheduler process
# - Should only be used inside the Scheduler class
# Worker connector:
# - Co-locate with worker process
return connector_cls(config, role)
@classmethod
def get_connector_class(
cls, ec_transfer_config: "ECTransferConfig"
) -> type[ECConnectorBase]:
"""Get the connector class by name."""
connector_name = ec_transfer_config.ec_connector
if connector_name is None:
raise ValueError("EC connect must not be None")
elif connector_name in cls._registry:
connector_cls = cls._registry[connector_name]()
else:
connector_module_path = ec_transfer_config.ec_connector_module_path
if connector_module_path is None:
raise ValueError(f"Unsupported connector type: {connector_name}")
connector_module = importlib.import_module(connector_module_path)
connector_cls = getattr(connector_module, connector_name)
return connector_cls
# Register various connectors here.
# The registration should not be done in each individual file, as we want to
# only load the files corresponding to the current connector.
ECConnectorFactory.register_connector(
"ECSharedStorageConnector",
"vllm.distributed.ec_transfer.ec_connector.shared_storage_connector",
"ECSharedStorageConnector",
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING
import safetensors
from vllm.config import VllmConfig
from vllm.distributed.ec_transfer.ec_connector.base import (
ECConnectorBase,
ECConnectorMetadata,
ECConnectorRole,
)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.v1.request import Request
logger = init_logger(__name__)
@dataclass
class MMMeta:
mm_hash: str
num_token: int
@staticmethod
def make_meta(mm_hash, num_token) -> "MMMeta":
return MMMeta(mm_hash=mm_hash, num_token=num_token)
@dataclass
class ECSharedStorageConnectorMetadata(ECConnectorMetadata):
mm_datas: list[MMMeta]
def __init__(self):
self.mm_datas = []
def add_mm_data(self, mm_data: MMMeta):
self.mm_datas.append(mm_data)
class ECSharedStorageConnector(ECConnectorBase):
# NOTE: This is Simple debug implementation of the EC connector.
# It save / load the EC cache to / from the disk.
def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
# req_id -> index
self._mm_datas_need_loads: dict[str, int] = {}
transfer_config = vllm_config.ec_transfer_config
if transfer_config is not None:
self._storage_path = transfer_config.get_from_extra_config(
"shared_storage_path", "/tmp"
)
logger.debug(transfer_config)
logger.debug("Shared storage path is %s", self._storage_path)
else:
raise ValueError("ec_transfer_config must be set for ECConnectorBase")
def start_load_caches(self, encoder_cache, **kwargs) -> None:
"""
Start loading the cache from the connector into vLLM's encoder cache.
This method loads the encoder cache based on metadata provided by the scheduler.
It is called before `_gather_mm_embeddings` for the EC Connector. For EC,
the `encoder_cache` and `mm_hash` are stored in `kwargs`.
Args:
encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal
data hashes (`mm_hash`) to encoder cache tensors.
kwargs (dict): Additional keyword arguments for the connector.
"""
# Get the metadata
metadata: ECConnectorMetadata = self._get_connector_metadata()
assert isinstance(metadata, ECSharedStorageConnectorMetadata)
assert encoder_cache is not None
if metadata is None:
logger.warning(
(
"In connector.start_load_caches, ",
"but the connector metadata is None",
)
)
return
# Load the EC for each mm data
for mm_data in metadata.mm_datas:
if mm_data.mm_hash in encoder_cache:
continue
filename = self._generate_filename_debug(mm_data.mm_hash)
ec_cache = safetensors.torch.load_file(filename)["ec_cache"].cuda()
encoder_cache[mm_data.mm_hash] = ec_cache
logger.debug("Success load encoder cache for hash %s", mm_data.mm_hash)
def save_caches(self, encoder_cache, mm_hash, **kwargs) -> None:
"""
Save the encoder cache to the connector.
This method saves the encoder cache from the worker's local storage
to shared storage or another external connector.
Args:
encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal
data hashes (`mm_hash`) to encoder cache tensors.
mm_hash (str): The hash of the multimodal data whose cache is being saved.
kwargs (dict): Additional keyword arguments for the connector.
"""
# Return if it is PD Instance
if not self.is_producer:
return
filename = self._generate_filename_debug(mm_hash)
ec_cache = encoder_cache[mm_hash]
tensors = {"ec_cache": ec_cache.detach().cpu()}
safetensors.torch.save_file(tensors, filename)
logger.debug("Save cache successful for mm_hash %s", mm_hash)
def has_caches(
self,
request: "Request",
) -> list[bool]:
"""
Check if cache exist externally for each mm_data of request
Args:
request (Request): the request object.
Returns:
List of bool indicate that ith mm_data exist in cache or not
"""
result = []
for feature in request.mm_features:
result.append(self._found_match_for_mm_data(feature.identifier))
return result
def update_state_after_alloc(
self,
request: "Request",
index: int,
) -> None:
"""
Update ECConnector state after encoder cache allocation.
"""
mm_hash = request.mm_features[index].identifier
num_encoder_token = request.get_num_encoder_tokens(index)
# Insert mm_hash only if this block has not been recorded yet.
self._mm_datas_need_loads[mm_hash] = num_encoder_token
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> ECConnectorMetadata:
"""Build the connector metadata for this step.
This function should NOT modify any fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
This only build for load mm_data only
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta = ECSharedStorageConnectorMetadata()
for mm_hash, num_encoder_token in self._mm_datas_need_loads.items():
meta.add_mm_data(MMMeta.make_meta(mm_hash, num_encoder_token))
self._mm_datas_need_loads.clear()
return meta
# ==============================
# Helper functions
# ==============================
def _found_match_for_mm_data(self, mm_hash) -> bool:
"""Check if the cache is hit for the request."""
filename = self._generate_filename_debug(mm_hash)
return os.path.exists(filename)
def _generate_foldername_debug(
self,
mm_hash: str,
create_folder: bool = True, # <- now defaults to True
) -> str:
"""
Return the folder in which the cache for this mm_hash lives.
If `create_folder` is True (default) the directory is created
recursively the first time it is needed.
"""
foldername = os.path.join(self._storage_path, mm_hash)
if create_folder:
os.makedirs(foldername, exist_ok=True)
return foldername
def _generate_filename_debug(self, mm_hash: str) -> str:
"""
Return the full path of the safetensors file for this mm_hash.
Ensures the parent directory exists because
`_generate_foldername_debug` is called with its default
(`create_folder=True`).
"""
foldername = self._generate_foldername_debug(mm_hash) # <- folder auto-created
return os.path.join(foldername, "encoder_cache.safetensors")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
from vllm import envs
from vllm.distributed.ec_transfer.ec_connector.base import (
ECConnectorBase,
ECConnectorRole,
)
from vllm.distributed.ec_transfer.ec_connector.factory import ECConnectorFactory
if TYPE_CHECKING:
from vllm.config import VllmConfig
_EC_CONNECTOR_AGENT: ECConnectorBase | None = None
def get_ec_transfer() -> ECConnectorBase:
assert _EC_CONNECTOR_AGENT is not None, "disaggregated EC cache is not initialized"
return _EC_CONNECTOR_AGENT
def has_ec_transfer() -> bool:
return _EC_CONNECTOR_AGENT is not None
def ensure_ec_transfer_initialized(vllm_config: "VllmConfig") -> None:
"""
Initialize EC cache connector.
"""
global _EC_CONNECTOR_AGENT
if vllm_config.ec_transfer_config is None:
return
if (
vllm_config.ec_transfer_config.is_ec_transfer_instance
and _EC_CONNECTOR_AGENT is None
):
if envs.VLLM_USE_V1:
_EC_CONNECTOR_AGENT = ECConnectorFactory.create_connector(
config=vllm_config, role=ECConnectorRole.WORKER
)
else:
raise ValueError("V0 is no longer supported")
......@@ -38,6 +38,7 @@ from vllm.config import (
CompilationConfig,
ConfigType,
DeviceConfig,
ECTransferConfig,
EPLBConfig,
KVEventsConfig,
KVTransferConfig,
......@@ -527,6 +528,8 @@ class EngineArgs:
kv_transfer_config: KVTransferConfig | None = None
kv_events_config: KVEventsConfig | None = None
ec_transfer_config: ECTransferConfig | None = None
generation_config: str = ModelConfig.generation_config
enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
override_generation_config: dict[str, Any] = get_field(
......@@ -1105,6 +1108,9 @@ class EngineArgs:
"--kv-transfer-config", **vllm_kwargs["kv_transfer_config"]
)
vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"])
vllm_group.add_argument(
"--ec-transfer-config", **vllm_kwargs["ec_transfer_config"]
)
vllm_group.add_argument(
"--compilation-config", "-O", **vllm_kwargs["compilation_config"]
)
......@@ -1676,6 +1682,7 @@ class EngineArgs:
compilation_config=self.compilation_config,
kv_transfer_config=self.kv_transfer_config,
kv_events_config=self.kv_events_config,
ec_transfer_config=self.ec_transfer_config,
additional_config=self.additional_config,
)
......
......@@ -49,10 +49,18 @@ def kernel_warmup(worker: "Worker"):
except NotImplementedError:
return False
if not worker.model_runner.is_pooling_model and all(
# NOTE: we add check for empty attn_groups to avoid errors when
# deploying models such as E instances and encoder-only models.
# As for those models, worker.model_runner.attn_groups is empty.
# This change is made during EPD feature development.
if (
not worker.model_runner.is_pooling_model
and worker.model_runner.attn_groups
and all(
_is_flashinfer_backend(group.backend)
for groups in worker.model_runner.attn_groups
for group in groups
)
):
logger.info("Warming up FlashInfer attention.")
# Warmup with mixed batch containing both prefill and decode tokens
......
......@@ -14,6 +14,7 @@ if TYPE_CHECKING:
import numpy.typing as npt
import torch
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalFeatureSpec
......@@ -21,6 +22,7 @@ if TYPE_CHECKING:
from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request
else:
ECConnectorMetadata = object
KVConnectorMetadata = object
LoRARequest = object
MultiModalFeatureSpec = object
......@@ -188,6 +190,9 @@ class SchedulerOutput:
# KV Cache Connector metadata.
kv_connector_metadata: KVConnectorMetadata | None = None
# EC Cache Connector metadata
ec_connector_metadata: ECConnectorMetadata | None = None
@dataclass
class GrammarOutput:
......
......@@ -7,6 +7,11 @@ from collections.abc import Iterable
from typing import Any
from vllm.config import VllmConfig
from vllm.distributed.ec_transfer.ec_connector.base import (
ECConnectorMetadata,
ECConnectorRole,
)
from vllm.distributed.ec_transfer.ec_connector.factory import ECConnectorFactory
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import (
......@@ -14,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorRole,
SupportsHMA,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
......@@ -104,6 +110,11 @@ class Scheduler(SchedulerInterface):
self.kv_events_config,
self.parallel_config.data_parallel_rank,
)
self.ec_connector = None
if self.vllm_config.ec_transfer_config is not None:
self.ec_connector = ECConnectorFactory.create_connector(
config=self.vllm_config, role=ECConnectorRole.SCHEDULER
)
num_gpu_blocks = self.cache_config.num_gpu_blocks
assert num_gpu_blocks is not None and num_gpu_blocks > 0
......@@ -230,12 +241,14 @@ class Scheduler(SchedulerInterface):
# Schedule encoder inputs.
encoder_inputs_to_schedule = None
external_load_encoder_input: list[int] = []
new_encoder_compute_budget = encoder_compute_budget
if request.has_encoder_inputs:
(
encoder_inputs_to_schedule,
num_new_tokens,
new_encoder_compute_budget,
external_load_encoder_input,
) = self._try_schedule_encoder_inputs(
request,
request.num_computed_tokens,
......@@ -342,6 +355,11 @@ class Scheduler(SchedulerInterface):
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_compute_budget = new_encoder_compute_budget
if external_load_encoder_input:
for i in external_load_encoder_input:
self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(request, i)
# Record the LoRAs in scheduled_running_reqs
scheduled_loras: set[int] = set()
......@@ -445,6 +463,7 @@ class Scheduler(SchedulerInterface):
num_computed_tokens = request.num_computed_tokens
encoder_inputs_to_schedule = None
external_load_encoder_input = []
new_encoder_compute_budget = encoder_compute_budget
# KVTransfer: loading remote KV, do not allocate for new work.
......@@ -480,6 +499,7 @@ class Scheduler(SchedulerInterface):
encoder_inputs_to_schedule,
num_new_tokens,
new_encoder_compute_budget,
external_load_encoder_input,
) = self._try_schedule_encoder_inputs(
request,
num_computed_tokens,
......@@ -583,7 +603,12 @@ class Scheduler(SchedulerInterface):
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_compute_budget = new_encoder_compute_budget
# Allocate for external load encoder cache
if external_load_encoder_input:
for i in external_load_encoder_input:
self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(request, i)
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests)
......@@ -591,6 +616,7 @@ class Scheduler(SchedulerInterface):
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
assert token_budget >= 0
assert len(self.running) <= self.max_num_running_reqs
# Since some requests in the RUNNING queue may not be scheduled in
......@@ -653,8 +679,18 @@ class Scheduler(SchedulerInterface):
# 2. Wrap up all the KV cache load / save ops into an opaque object
# 3. Clear the internal states of the connector
if self.connector is not None:
meta = self.connector.build_connector_meta(scheduler_output)
meta: KVConnectorMetadata = self.connector.build_connector_meta(
scheduler_output
)
scheduler_output.kv_connector_metadata = meta
# Build the connector meta for ECConnector
if self.ec_connector is not None:
ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta(
scheduler_output
)
scheduler_output.ec_connector_metadata = ec_meta
with record_function_or_nullcontext("schedule: update_after_schedule"):
self._update_after_schedule(scheduler_output)
return scheduler_output
......@@ -755,7 +791,7 @@ class Scheduler(SchedulerInterface):
num_computed_tokens: int,
num_new_tokens: int,
encoder_compute_budget: int,
) -> tuple[list[int], int, int]:
) -> tuple[list[int], int, int, list[int]]:
"""
Determine which encoder inputs need to be scheduled in the current step,
and update `num_new_tokens` and encoder token budget accordingly.
......@@ -765,6 +801,7 @@ class Scheduler(SchedulerInterface):
in this step, i.e.,
[num_computed_tokens, num_computed_tokens + num_new_tokens).
- It is not already computed and stored in the encoder cache.
- It is not exist on remote encoder cache (via ECConnector)
- There is sufficient encoder token budget to process it.
- The encoder cache has space to store it.
......@@ -776,12 +813,16 @@ class Scheduler(SchedulerInterface):
blocks and externally cached blocks (via KVConnector).
"""
if num_new_tokens == 0 or not request.has_encoder_inputs:
return [], num_new_tokens, encoder_compute_budget
return [], num_new_tokens, encoder_compute_budget, []
encoder_inputs_to_schedule: list[int] = []
mm_features = request.mm_features
assert mm_features is not None
assert len(mm_features) > 0
external_load_encoder_input = []
# Check remote cache first
if self.ec_connector is not None:
remote_cache_has_item = self.ec_connector.has_caches(request)
# NOTE: since scheduler operates on the request level (possibly with
# multiple encoder inputs per request), we need to create temporary
# trackers for accounting at the encoder input level.
......@@ -862,6 +903,12 @@ class Scheduler(SchedulerInterface):
num_new_tokens = 0
break
if self.ec_connector is not None and remote_cache_has_item[i]:
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
external_load_encoder_input.append(i)
num_tokens_to_schedule += num_encoder_tokens
continue
num_tokens_to_schedule += num_encoder_tokens
encoder_compute_budget -= num_encoder_tokens
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
......@@ -871,6 +918,7 @@ class Scheduler(SchedulerInterface):
encoder_inputs_to_schedule,
num_new_tokens,
encoder_compute_budget,
external_load_encoder_input,
)
def get_grammar_bitmask(
......
......@@ -8,6 +8,8 @@ from typing import TYPE_CHECKING, NamedTuple
import numpy as np
import torch
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
else:
......@@ -136,6 +138,13 @@ class KVConnectorOutput:
)
@dataclass
class ECConnectorOutput:
# [mm_hash]
finished_sending: set[str] | None = None
finished_recving: set[str] | None = None
# ModelRunnerOutput is serialized and sent to the scheduler process.
# This is expensive for torch.Tensor so prefer to use list instead.
@dataclass
......@@ -167,6 +176,8 @@ class ModelRunnerOutput:
kv_connector_output: KVConnectorOutput | None = None
ec_connector_output: ECConnectorOutput | None = None
# req_id -> num_nans_in_logits
num_nans_in_logits: dict[str, int] | None = None
......@@ -192,6 +203,41 @@ class DraftTokenIds:
draft_token_ids: list[list[int]]
def make_empty_encoder_model_runner_output(
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
"""
Create a ModelRunnerOutput stub that contains the correct
per-request bookkeeping but no generated data yet.
"""
if not scheduler_output.num_scheduled_tokens:
return EMPTY_MODEL_RUNNER_OUTPUT
# Convert to list so we get a deterministic, indexable sequence
req_ids: list[str] = list(scheduler_output.num_scheduled_tokens.keys())
# Give every request its own contiguous index
req_id_to_index: dict[str, int] = {rid: idx for idx, rid in enumerate(req_ids)}
# No tokens generated yet ⇒ one empty list per request
sampled_token_ids: list[list[int]] = [[0] for _ in req_ids]
# Pooler outputs are not available yet ⇒ use None placeholders
pooler_output: list[torch.Tensor | None] = [None for _ in req_ids]
return ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=pooler_output,
kv_connector_output=None,
ec_connector_output=None,
num_nans_in_logits=None,
)
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=[],
req_id_to_index={},
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Define EC connector functionality mixin for model runners.
"""
from collections.abc import Generator
from contextlib import AbstractContextManager, contextmanager, nullcontext
from typing import (
TYPE_CHECKING, # noqa: UP035
)
import torch
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorBase
from vllm.logger import init_logger
from vllm.v1.outputs import ECConnectorOutput
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
logger = init_logger(__name__)
# Defined as a EC connector functionality mixin for ModelRunner (GPU, TPU)
class ECConnectorModelRunnerMixin:
@staticmethod
def maybe_save_ec_to_connector(
encoder_cache: dict[str, torch.Tensor],
mm_hash: str,
):
if not has_ec_transfer():
logger.debug("Not have ec transfer please check")
return
connector = get_ec_transfer()
connector.save_caches(encoder_cache=encoder_cache, mm_hash=mm_hash)
@staticmethod
def get_finished_ec_transfers(
scheduler_output: "SchedulerOutput",
) -> tuple[set[str] | None, set[str] | None]:
if has_ec_transfer():
return get_ec_transfer().get_finished(scheduler_output.finished_req_ids)
return None, None
@staticmethod
def maybe_get_ec_connector_output(
scheduler_output: "SchedulerOutput",
encoder_cache: dict[str, torch.Tensor],
**kwargs,
) -> AbstractContextManager[ECConnectorOutput | None]:
return (
ECConnectorModelRunnerMixin._get_ec_connector_output(
scheduler_output, encoder_cache, **kwargs
)
if has_ec_transfer()
else nullcontext()
)
# This context manager must be used within an active forward context.
# It encapsulates the entire EC conector lifecycle within execute_model
@staticmethod
@contextmanager
def _get_ec_connector_output(
scheduler_output: "SchedulerOutput",
encoder_cache: dict[str, torch.Tensor],
**kwargs,
) -> Generator[ECConnectorOutput, None, None]:
output = ECConnectorOutput()
ec_connector = get_ec_transfer()
assert isinstance(ec_connector, ECConnectorBase)
assert scheduler_output.ec_connector_metadata is not None
ec_connector.bind_connector_metadata(scheduler_output.ec_connector_metadata)
if not ec_connector.is_producer:
ec_connector.start_load_caches(encoder_cache, **kwargs)
try:
yield output
finally:
output.finished_sending, output.finished_recving = (
ec_connector.get_finished(scheduler_output.finished_req_ids)
)
ec_connector.clear_connector_metadata()
......@@ -35,6 +35,7 @@ from vllm.config import (
get_layers_from_vllm_config,
update_config,
)
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
......@@ -114,12 +115,14 @@ from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
AsyncModelRunnerOutput,
DraftTokenIds,
ECConnectorOutput,
KVConnectorOutput,
LogprobsLists,
LogprobsTensors,
ModelRunnerOutput,
PoolerOutput,
SamplerOutput,
make_empty_encoder_model_runner_output,
)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
......@@ -134,6 +137,7 @@ from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
from vllm.v1.structured_output.utils import apply_grammar_bitmask
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
......@@ -237,9 +241,12 @@ class ExecuteModelState(NamedTuple):
sample_hidden_states: torch.Tensor
aux_hidden_states: list[torch.Tensor] | None
kv_connector_output: KVConnectorOutput | None
ec_connector_output: ECConnectorOutput | None
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
class GPUModelRunner(
LoRAModelRunnerMixin, KVConnectorModelRunnerMixin, ECConnectorModelRunnerMixin
):
def __init__(
self,
vllm_config: VllmConfig,
......@@ -1873,6 +1880,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
output,
is_embed=pos_info.is_embed,
)
logger.debug("Finish execute for mm hash %s", mm_hash)
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
def _gather_mm_embeddings(
self,
......@@ -2191,18 +2200,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
torch.Tensor,
IntermediateTensors | None,
dict[str, Any],
ECConnectorOutput | None,
]:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
is_first_rank = get_pp_group().is_first_rank
# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
ec_connector_output = None
if (
self.supports_mm_inputs
and is_first_rank
and not self.model_config.is_encoder_decoder
):
# Run the multimodal encoder if any.
with self.maybe_get_ec_connector_output(
scheduler_output,
encoder_cache=self.encoder_cache,
) as ec_connector_output:
self._execute_mm_encoder(scheduler_output)
mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output)
......@@ -2284,6 +2300,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
positions,
intermediate_tensors,
model_kwargs,
ec_connector_output,
)
def _sample(
......@@ -2508,6 +2525,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Update persistent batch states.
self._update_states(scheduler_output)
if has_ec_transfer() and get_ec_transfer().is_producer:
with self.maybe_get_ec_connector_output(
scheduler_output,
encoder_cache=self.encoder_cache,
) as ec_connector_output:
self._execute_mm_encoder(scheduler_output)
return make_empty_encoder_model_runner_output(scheduler_output)
if not num_scheduled_tokens:
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if no work to do.
......@@ -2583,6 +2608,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
positions,
intermediate_tensors,
model_kwargs,
ec_connector_output,
) = self._preprocess(
scheduler_output, num_input_tokens, intermediate_tensors
)
......@@ -2699,6 +2725,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sample_hidden_states,
aux_hidden_states,
kv_connector_output,
ec_connector_output,
)
return None
......@@ -2720,6 +2747,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sample_hidden_states,
aux_hidden_states,
kv_connector_output,
ec_connector_output,
) = self.execute_model_state
# Clear ephemeral state.
self.execute_model_state = None
......@@ -2811,6 +2839,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
kv_connector_output=kv_connector_output,
ec_connector_output=ec_connector_output
if self.supports_mm_inputs
else None,
num_nans_in_logits=num_nans_in_logits,
)
......@@ -4797,7 +4828,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included.
"""
if has_ec_transfer() and get_ec_transfer().is_producer:
return {}
kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
for layer_name, attn_module in attn_layers.items():
......
......@@ -20,6 +20,7 @@ from vllm.distributed import (
init_distributed_environment,
set_custom_all_reduce,
)
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
from vllm.distributed.kv_transfer import (
ensure_kv_transfer_initialized,
get_kv_transfer_group,
......@@ -887,3 +888,7 @@ def init_worker_distributed_environment(
parallel_config.pipeline_parallel_size,
parallel_config.decode_context_parallel_size,
)
# Init ec connector here before KV caches caches init
# NOTE: We do not init KV caches for Encoder-only instance in EPD disagg mode
ensure_ec_transfer_initialized(vllm_config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment