Unverified Commit 4d302ab6 authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

feat: [vLLM] add remote embedding loader to prefill/decode worker. Port E/P/D usage (#7507)


Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
Signed-off-by: default avatarGuanLuo <41310872+GuanLuo@users.noreply.github.com>
parent 9ba5f828
...@@ -13,7 +13,7 @@ import time ...@@ -13,7 +13,7 @@ import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, AsyncIterator, Dict, Final, Generic, TypeVar from typing import Any, AsyncIterator, Dict, Final, Generic, Optional, TypeVar
import torch import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -24,6 +24,14 @@ from vllm.sampling_params import SamplingParams, StructuredOutputsParams ...@@ -24,6 +24,14 @@ from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.engine.exceptions import EngineDeadError
from dynamo._core import Context from dynamo._core import Context
from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
from dynamo.common.multimodal.embedding_transfer import (
LocalEmbeddingReceiver,
NixlReadEmbeddingReceiver,
NixlWriteEmbeddingReceiver,
)
from dynamo.common.multimodal.image_loader import ImageLoader from dynamo.common.multimodal.image_loader import ImageLoader
from dynamo.common.utils.engine_response import normalize_finish_reason from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.common.utils.input_params import InputParamManager from dynamo.common.utils.input_params import InputParamManager
...@@ -40,8 +48,12 @@ from dynamo.llm import ( ...@@ -40,8 +48,12 @@ from dynamo.llm import (
from dynamo.runtime import Client from dynamo.runtime import Client
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from .args import Config
from .constants import EmbeddingTransferMode
from .engine_monitor import VllmEngineMonitor from .engine_monitor import VllmEngineMonitor
from .multimodal_utils.hash_utils import compute_mm_uuids_from_images from .multimodal_utils.hash_utils import compute_mm_uuids_from_images
from .multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_model
from .multimodal_utils.prefill_worker_utils import MultiModalEmbeddingLoader
# Multimodal data dictionary keys # Multimodal data dictionary keys
IMAGE_URL_KEY: Final = "image_url" IMAGE_URL_KEY: Final = "image_url"
...@@ -110,6 +122,11 @@ def _compute_mm_uuids( ...@@ -110,6 +122,11 @@ def _compute_mm_uuids(
if not multi_modal_data or "image" not in multi_modal_data: if not multi_modal_data or "image" not in multi_modal_data:
return None return None
images = multi_modal_data["image"] images = multi_modal_data["image"]
# [gluo FIXME] Dict being returned when the mm data has been processed,
# in this case, we skip computing mm_uuids for now until we better understand
# what info should be hash on.
if isinstance(images, dict):
return None
if not isinstance(images, list): if not isinstance(images, list):
images = [images] images = [images]
if not images: if not images:
...@@ -338,15 +355,16 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]): ...@@ -338,15 +355,16 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
def __init__( def __init__(
self, self,
runtime, runtime,
config: Config,
engine, engine,
default_sampling_params, default_sampling_params,
model_max_len: int | None = None, model_max_len: int | None = None,
enable_multimodal: bool = False, enable_multimodal: bool = False,
generate_endpoint=None, generate_endpoint=None,
config=None,
use_vllm_tokenizer: bool = False, use_vllm_tokenizer: bool = False,
shutdown_event: asyncio.Event | None = None, shutdown_event: asyncio.Event | None = None,
enable_frontend_decoding: bool = False, enable_frontend_decoding: bool = False,
encode_worker_client: Optional[Client] = None,
): ):
self.runtime = runtime self.runtime = runtime
self.engine_client = engine self.engine_client = engine
...@@ -369,6 +387,7 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]): ...@@ -369,6 +387,7 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
self.image_loader = ImageLoader( self.image_loader = ImageLoader(
enable_frontend_decoding=enable_frontend_decoding enable_frontend_decoding=enable_frontend_decoding
) )
self.embedding_loader = self.init_embedding_loader(config, encode_worker_client)
self.use_vllm_tokenizer = use_vllm_tokenizer self.use_vllm_tokenizer = use_vllm_tokenizer
...@@ -385,6 +404,52 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]): ...@@ -385,6 +404,52 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
# Store shutdown event for graceful shutdown monitoring # Store shutdown event for graceful shutdown monitoring
self.shutdown_event = shutdown_event self.shutdown_event = shutdown_event
def init_embedding_loader(
self, config: Config, encode_worker_client: Optional[Client] = None
) -> Optional[MultiModalEmbeddingLoader]:
"""Initialize the embedding loader with the given encode worker client."""
# Without encode worker, the embedding will be generated internally by vLLM.
if encode_worker_client is None:
return None
# Embedding loader consist of two main components:
# 1) An remote encode worker client and matching embedding receiver,
# which can request remote encode and handle the transfer of embeddings
# from the encode worker to this prefill worker.
# 2) A local embedding cache manager, which can store previously fetched embeddings
# and used to determine whether remote encode is necessary for a given mm data.
self.encode_worker_client = encode_worker_client
if config.embedding_transfer_mode == EmbeddingTransferMode.LOCAL:
self.embedding_receiver = LocalEmbeddingReceiver() # type: ignore
elif config.embedding_transfer_mode == EmbeddingTransferMode.NIXL_WRITE:
self.embedding_receiver = NixlWriteEmbeddingReceiver() # type: ignore
elif config.embedding_transfer_mode == EmbeddingTransferMode.NIXL_READ:
# [gluo FIXME] can't use pre-registered tensor as NIXL requires descriptors
# to be at matching size, need to overwrite nixl connect library
self.embedding_receiver = NixlReadEmbeddingReceiver(max_items=0) # type: ignore
else:
raise ValueError(
f"Invalid embedding transfer mode: {config.embedding_transfer_mode}"
)
# [gluo FIXME/NOTE] This embedding cache manager is purely used for caching embedding
# results from encode worker, but 'config.multimodal_embedding_cache_capacity_gb' is
# also used to configure the DynamoMultimodalEmbeddingCacheConnector within the vLLM.
# This results in duplication of memory and ideally we should have single cache manager
# which can be used by vLLM internal and here. Then we can explore asynchrous embedding
# transfer as we can process and block until the embedding is actually used within vLLM.
self.embedding_cache_manager: MultimodalEmbeddingCacheManager | None = None
if config.multimodal_embedding_cache_capacity_gb > 0:
capacity_bytes = int(
config.multimodal_embedding_cache_capacity_gb * 1024**3
)
self.embedding_cache_manager = MultimodalEmbeddingCacheManager(
capacity_bytes
)
return MultiModalEmbeddingLoader(
encode_worker_client=self.encode_worker_client, # type: ignore
receiver=self.embedding_receiver,
embedding_cache_manager=self.embedding_cache_manager,
)
async def sleep(self, body: dict) -> dict: async def sleep(self, body: dict) -> dict:
"""Sleep the engine to release GPU memory and unregister from discovery. """Sleep the engine to release GPU memory and unregister from discovery.
...@@ -696,7 +761,7 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]): ...@@ -696,7 +761,7 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
# Publish LoRA as a ModelDeploymentCard with format: # Publish LoRA as a ModelDeploymentCard with format:
# v1/mdc/{namespace}/{component}/{endpoint}/{instance_id}/{lora_slug} # v1/mdc/{namespace}/{component}/{endpoint}/{instance_id}/{lora_slug}
# This allows the frontend to discover it and route correctly to the worker instance # This allows the frontend to discover it and route correctly to the worker instance
if self.generate_endpoint is not None and self.config is not None: if self.generate_endpoint is not None:
logger.debug( logger.debug(
f"Publishing LoRA '{lora_name}' ModelDeploymentCard to {self.generate_endpoint}" f"Publishing LoRA '{lora_name}' ModelDeploymentCard to {self.generate_endpoint}"
) )
...@@ -999,7 +1064,7 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]): ...@@ -999,7 +1064,7 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
return prompt, sequence_length, embeddings_tensor return prompt, sequence_length, embeddings_tensor
async def _extract_multimodal_data( async def _extract_multimodal_data(
self, request: Dict[str, Any] self, request: Dict[str, Any], request_id: str, context
) -> Dict[str, Any] | None: ) -> Dict[str, Any] | None:
""" """
Extract and decode multimodal data from PreprocessedRequest. Extract and decode multimodal data from PreprocessedRequest.
...@@ -1015,8 +1080,31 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]): ...@@ -1015,8 +1080,31 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
) )
mm_map = request["multi_modal_data"] mm_map = request["multi_modal_data"]
vllm_mm_data = {}
# [gluo NOTE] If embedding loader is configured, currently we unconditionally
# fetch from the embedding loader.
if self.embedding_loader is not None:
# [gluo FIXME] couldn't simply pass 'mm_map.get(IMAGE_URL_KEY, [])' like below
# as currently the encode worker is using 'ImageLoader.load_image()' which doesn't
# support 'Decoded' variant. Need to update the encode worker to unify handling
image_urls = []
supported = True
for item in mm_map.get(IMAGE_URL_KEY, []):
if isinstance(item, dict) and "Url" in item:
image_urls.append(item["Url"])
elif isinstance(item, dict) and "Decoded" in item:
supported = False
if supported:
vllm_mm_data = await self.embedding_loader.load_multimodal_embeddings(
image_urls, request_id, model=self.config.model, context=context
)
logger.debug(
f"Fetched multimodal embeddings for {len(vllm_mm_data)} items"
)
return vllm_mm_data if vllm_mm_data else None
# Fallback that the vLLM engine will perform encoding internally.
vllm_mm_data = {}
# Process image_url entries # Process image_url entries
images = await self.image_loader.load_image_batch( images = await self.image_loader.load_image_batch(
mm_map.get(IMAGE_URL_KEY, []), mm_map.get(IMAGE_URL_KEY, []),
...@@ -1327,33 +1415,29 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -1327,33 +1415,29 @@ class DecodeWorkerHandler(BaseWorkerHandler):
def __init__( def __init__(
self, self,
runtime, runtime,
config: Config,
engine, engine,
default_sampling_params, default_sampling_params,
model_max_len: int | None = None, model_max_len: int | None = None,
enable_multimodal: bool = False, enable_multimodal: bool = False,
generate_endpoint=None, generate_endpoint=None,
config=None,
use_vllm_tokenizer: bool = False, use_vllm_tokenizer: bool = False,
shutdown_event: asyncio.Event | None = None, shutdown_event: asyncio.Event | None = None,
enable_frontend_decoding: bool = False, enable_frontend_decoding: bool = False,
encode_worker_client: Client | None = None, encode_worker_client: Client | None = None,
): ):
if encode_worker_client is not None:
raise NotImplementedError(
"'encode_worker_client' is provided which indicates remote "
"multimodal encode is configured, this is not currently supported."
)
super().__init__( super().__init__(
runtime, runtime,
config,
engine, engine,
default_sampling_params, default_sampling_params,
model_max_len, model_max_len,
enable_multimodal, enable_multimodal,
generate_endpoint, generate_endpoint,
config,
use_vllm_tokenizer, use_vllm_tokenizer,
shutdown_event, shutdown_event,
enable_frontend_decoding, enable_frontend_decoding,
encode_worker_client,
) )
async def generate(self, request, context): async def generate(self, request, context):
...@@ -1379,8 +1463,33 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -1379,8 +1463,33 @@ class DecodeWorkerHandler(BaseWorkerHandler):
async def _generate_token_mode(self, request, context, request_id): async def _generate_token_mode(self, request, context, request_id):
"""Generate tokens using internal protocol format (token-in-token-out).""" """Generate tokens using internal protocol format (token-in-token-out)."""
# Extract and decode multimodal data if present # Firstly extract disaggregated params from prefill result if available
multi_modal_data = await self._extract_multimodal_data(request) prefill_result = request.get("prefill_result")
if prefill_result and isinstance(prefill_result, dict):
kv_params = prefill_result.get("disaggregated_params", {}).get(
"kv_transfer_params"
)
embedding_params = prefill_result.get("disaggregated_params", {}).get(
"embedding_params"
)
else:
kv_params = None
embedding_params = None
multi_modal_data = None
# The decode worker is handling disaggregated requests, the mm embedding will be synthetic
if prefill_result is not None and embedding_params is not None:
if is_qwen_vl_model(self.config.model):
multi_modal_data = construct_qwen_decode_mm_data(
embedding_params["image_grid_thw"],
embedding_params["embeddings_shape"],
request_id,
)
else:
# Extract and decode multimodal data if present
multi_modal_data = await self._extract_multimodal_data(
request, request_id, context
)
# Build prompt from request (handles both prompt_embeds and token_ids) # Build prompt from request (handles both prompt_embeds and token_ids)
prompt, embedding_sequence_length, error = self._build_prompt_from_request( prompt, embedding_sequence_length, error = self._build_prompt_from_request(
...@@ -1395,14 +1504,6 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -1395,14 +1504,6 @@ class DecodeWorkerHandler(BaseWorkerHandler):
request, self.default_sampling_params, self.model_max_len request, self.default_sampling_params, self.model_max_len
) )
prefill_result = request.get("prefill_result")
if prefill_result and isinstance(prefill_result, dict):
kv_params = prefill_result.get("disaggregated_params", {}).get(
"kv_transfer_params"
)
else:
kv_params = None
if kv_params is not None: if kv_params is not None:
if sampling_params.extra_args is None: if sampling_params.extra_args is None:
sampling_params.extra_args = {} sampling_params.extra_args = {}
...@@ -1548,33 +1649,29 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -1548,33 +1649,29 @@ class PrefillWorkerHandler(BaseWorkerHandler):
def __init__( def __init__(
self, self,
runtime, runtime,
config: Config,
engine, engine,
default_sampling_params, default_sampling_params,
model_max_len: int | None = None, model_max_len: int | None = None,
enable_multimodal: bool = False, enable_multimodal: bool = False,
generate_endpoint=None, generate_endpoint=None,
config=None,
use_vllm_tokenizer: bool = False, use_vllm_tokenizer: bool = False,
shutdown_event: asyncio.Event | None = None, shutdown_event: asyncio.Event | None = None,
enable_frontend_decoding: bool = False, enable_frontend_decoding: bool = False,
encode_worker_client: Client | None = None, encode_worker_client: Client | None = None,
): ):
if encode_worker_client is not None:
raise NotImplementedError(
"'encode_worker_client' is provided which indicates remote "
"multimodal encode is configured, this is not currently supported."
)
super().__init__( super().__init__(
runtime, runtime,
config,
engine, engine,
default_sampling_params, default_sampling_params,
model_max_len, model_max_len,
enable_multimodal, enable_multimodal,
generate_endpoint, generate_endpoint,
config,
use_vllm_tokenizer, use_vllm_tokenizer,
shutdown_event, shutdown_event,
enable_frontend_decoding, enable_frontend_decoding,
encode_worker_client,
) )
async def generate(self, request, context): async def generate(self, request, context):
...@@ -1590,7 +1687,10 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -1590,7 +1687,10 @@ class PrefillWorkerHandler(BaseWorkerHandler):
async def _generate_token_mode(self, request, context, request_id): async def _generate_token_mode(self, request, context, request_id):
"""Generate prefill using internal protocol format (token-in-token-out).""" """Generate prefill using internal protocol format (token-in-token-out)."""
# Extract and decode multimodal data if present # Extract and decode multimodal data if present
multi_modal_data = await self._extract_multimodal_data(request) multi_modal_data = await self._extract_multimodal_data(
request, request_id, context
)
embedding_params = self._build_embedding_params(multi_modal_data or {})
# Build prompt from request (handles both prompt_embeds and token_ids) # Build prompt from request (handles both prompt_embeds and token_ids)
prompt, embedding_sequence_length, error = self._build_prompt_from_request( prompt, embedding_sequence_length, error = self._build_prompt_from_request(
...@@ -1670,10 +1770,8 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -1670,10 +1770,8 @@ class PrefillWorkerHandler(BaseWorkerHandler):
output: Dict[str, Any] = { output: Dict[str, Any] = {
"token_ids": list(token_ids), "token_ids": list(token_ids),
"disaggregated_params": ( "disaggregated_params": self._build_disaggregated_params(
{"kv_transfer_params": res.kv_transfer_params} res.kv_transfer_params, embedding_params
if res.kv_transfer_params
else None
), ),
"completion_usage": BaseWorkerHandler._build_completion_usage( "completion_usage": BaseWorkerHandler._build_completion_usage(
request_output=res, request_output=res,
...@@ -1693,3 +1791,38 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -1693,3 +1791,38 @@ class PrefillWorkerHandler(BaseWorkerHandler):
) )
yield output yield output
def _build_disaggregated_params(self, kv_transfer_params, embedding_params=None):
disaggregated_params = {}
if kv_transfer_params is not None:
disaggregated_params["kv_transfer_params"] = kv_transfer_params
if embedding_params is not None:
disaggregated_params["embedding_params"] = embedding_params
return disaggregated_params if disaggregated_params else None
def _build_embedding_params(
self, multi_modal_data: dict[str, Any]
) -> Dict[str, Any] | None:
"""
Helper function to build mm embedding parameters to be consumed by the decode worker, typically
decode worker doesn't require any metadata for mm embedding as the content has been consumed by
prefill. However, especially found for Qwen models, vLLM's processor will try to expand image
tokens in the prompt which requires such a metadata to pass through the processor.
"""
embedding_params = {}
if is_qwen_vl_model(self.config.model) and isinstance(
multi_modal_data.get("image"), dict
):
image_data = multi_modal_data["image"]
image_grid_thw = image_data.get("image_grid_thw")
image_embeds = image_data.get("image_embeds")
if image_grid_thw is not None:
embedding_params["image_grid_thw"] = (
image_grid_thw.tolist()
if isinstance(image_grid_thw, torch.Tensor)
else image_grid_thw
)
if image_embeds is not None:
embedding_params["embeddings_shape"] = list(image_embeds.shape)
return embedding_params if embedding_params else None
...@@ -144,26 +144,21 @@ async def worker() -> None: ...@@ -144,26 +144,21 @@ async def worker() -> None:
# there # there
install_signal_handlers(loop, runtime, shutdown_endpoints, shutdown_event) install_signal_handlers(loop, runtime, shutdown_endpoints, shutdown_event)
# Route to appropriate initialization based on config flags # Use WorkerFactory to appropriate initialize worker based on config flags
if WorkerFactory.handles(config): factory = WorkerFactory(
# Create worker factory with setup functions setup_vllm_engine_fn=setup_vllm_engine,
factory = WorkerFactory( setup_kv_event_publisher_fn=setup_kv_event_publisher,
setup_vllm_engine_fn=setup_vllm_engine, register_vllm_model_fn=register_vllm_model,
setup_kv_event_publisher_fn=setup_kv_event_publisher, setup_fpm_relay_fn=setup_fpm_relay,
register_vllm_model_fn=register_vllm_model, setup_metrics_collection_fn=setup_metrics_collection,
setup_fpm_relay_fn=setup_fpm_relay, )
setup_metrics_collection_fn=setup_metrics_collection, await factory.create(
) runtime,
await factory.create( config,
runtime, shutdown_event,
config, shutdown_endpoints,
shutdown_event, snapshot_engine=snapshot_engine,
shutdown_endpoints, )
snapshot_engine=snapshot_engine,
)
logger.debug("worker init completed")
else:
raise ValueError("Unsupported worker configuration")
logger.debug("Worker function completed, exiting...") logger.debug("Worker function completed, exiting...")
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import copy import copy
import logging import logging
import uuid import uuid
from typing import Any from typing import Any, Optional
import torch import torch
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
...@@ -14,7 +14,6 @@ from dynamo.common.memory.multimodal_embedding_cache_manager import ( ...@@ -14,7 +14,6 @@ from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager, MultimodalEmbeddingCacheManager,
) )
from dynamo.common.multimodal.embedding_transfer import ( from dynamo.common.multimodal.embedding_transfer import (
AbstractEmbeddingReceiver,
LocalEmbeddingReceiver, LocalEmbeddingReceiver,
NixlReadEmbeddingReceiver, NixlReadEmbeddingReceiver,
NixlWriteEmbeddingReceiver, NixlWriteEmbeddingReceiver,
...@@ -48,8 +47,8 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler[dict, dict]): ...@@ -48,8 +47,8 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler[dict, dict]):
runtime, runtime,
engine_client: AsyncLLM, engine_client: AsyncLLM,
config: Config, config: Config,
encode_worker_client: Client | None = None, encode_worker_client: Optional[Client] = None,
decode_worker_client: Client | None = None, decode_worker_client: Optional[Client] = None,
shutdown_event=None, shutdown_event=None,
generate_endpoint=None, generate_endpoint=None,
): ):
...@@ -61,11 +60,11 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler[dict, dict]): ...@@ -61,11 +60,11 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler[dict, dict]):
# Call BaseWorkerHandler.__init__ with proper parameters # Call BaseWorkerHandler.__init__ with proper parameters
super().__init__( super().__init__(
runtime, runtime,
config,
engine_client, engine_client,
default_sampling_params, default_sampling_params,
enable_multimodal=config.enable_multimodal, enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint, generate_endpoint=generate_endpoint,
config=config,
shutdown_event=shutdown_event, shutdown_event=shutdown_event,
) )
...@@ -76,28 +75,21 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler[dict, dict]): ...@@ -76,28 +75,21 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler[dict, dict]):
# Initialize multimodal-specific components # Initialize multimodal-specific components
logger.info("Multimodal PD Worker startup started.") logger.info("Multimodal PD Worker startup started.")
if "video" in self.config.model.lower():
self.EMBEDDINGS_DTYPE = torch.uint8
else:
self.EMBEDDINGS_DTYPE = torch.float16
# Embedding loader consist of two main components: # Embedding loader consist of two main components:
# 1) An remote encode worker client and matching embedding receiver, # 1) An remote encode worker client and matching embedding receiver,
# which can request remote encode and handle the transfer of embeddings # which can request remote encode and handle the transfer of embeddings
# from the encode worker to this prefill worker. # from the encode worker to this prefill worker.
# 2) A local embedding cache manager, which can store previously fetched embeddings # 2) A local embedding cache manager, which can store previously fetched embeddings
# and used to determine whether remote encode is necessary for a given mm data. # and used to determine whether remote encode is necessary for a given mm data.
self.encode_worker_client = encode_worker_client self.encode_worker_client = encode_worker_client # type: ignore
if config.embedding_transfer_mode == EmbeddingTransferMode.LOCAL: if config.embedding_transfer_mode == EmbeddingTransferMode.LOCAL:
self.embedding_receiver: AbstractEmbeddingReceiver = ( self.embedding_receiver = LocalEmbeddingReceiver() # type: ignore
LocalEmbeddingReceiver()
)
elif config.embedding_transfer_mode == EmbeddingTransferMode.NIXL_WRITE: elif config.embedding_transfer_mode == EmbeddingTransferMode.NIXL_WRITE:
self.embedding_receiver = NixlWriteEmbeddingReceiver() self.embedding_receiver = NixlWriteEmbeddingReceiver() # type: ignore
elif config.embedding_transfer_mode == EmbeddingTransferMode.NIXL_READ: elif config.embedding_transfer_mode == EmbeddingTransferMode.NIXL_READ:
# [gluo FIXME] can't use pre-registered tensor as NIXL requires descriptors # [gluo FIXME] can't use pre-registered tensor as NIXL requires descriptors
# to be at matching size, need to overwrite nixl connect library # to be at matching size, need to overwrite nixl connect library
self.embedding_receiver = NixlReadEmbeddingReceiver(max_items=0) self.embedding_receiver = NixlReadEmbeddingReceiver(max_items=0) # type: ignore
else: else:
raise ValueError( raise ValueError(
f"Invalid embedding transfer mode: {config.embedding_transfer_mode}" f"Invalid embedding transfer mode: {config.embedding_transfer_mode}"
...@@ -110,7 +102,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler[dict, dict]): ...@@ -110,7 +102,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler[dict, dict]):
self.embedding_cache_manager = MultimodalEmbeddingCacheManager( self.embedding_cache_manager = MultimodalEmbeddingCacheManager(
capacity_bytes capacity_bytes
) )
self.embedding_loader = MultiModalEmbeddingLoader( self.embedding_loader: MultiModalEmbeddingLoader = MultiModalEmbeddingLoader(
encode_worker_client=self.encode_worker_client, # type: ignore encode_worker_client=self.encode_worker_client, # type: ignore
receiver=self.embedding_receiver, receiver=self.embedding_receiver,
embedding_cache_manager=self.embedding_cache_manager, embedding_cache_manager=self.embedding_cache_manager,
...@@ -173,7 +165,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler[dict, dict]): ...@@ -173,7 +165,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler[dict, dict]):
image_urls, image_urls,
request_id, request_id,
model=self.config.model, model=self.config.model,
embeddings_dtype=self.EMBEDDINGS_DTYPE,
context=context, context=context,
) )
......
...@@ -40,11 +40,11 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler[vLLMMultimodalRequest, str ...@@ -40,11 +40,11 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler[vLLMMultimodalRequest, str
# Call BaseWorkerHandler.__init__ with proper parameters # Call BaseWorkerHandler.__init__ with proper parameters
super().__init__( super().__init__(
runtime, runtime,
config,
engine_client, engine_client,
default_sampling_params, default_sampling_params,
enable_multimodal=config.enable_multimodal, enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint, generate_endpoint=generate_endpoint,
config=config,
shutdown_event=shutdown_event, shutdown_event=shutdown_event,
) )
......
...@@ -306,7 +306,6 @@ class MultiModalEmbeddingLoader: ...@@ -306,7 +306,6 @@ class MultiModalEmbeddingLoader:
request_id: str, request_id: str,
*, *,
model: str, model: str,
embeddings_dtype: torch.dtype,
context=None, context=None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Fetch embeddings and build engine-ready ``multi_modal_data``. """Fetch embeddings and build engine-ready ``multi_modal_data``.
...@@ -316,7 +315,7 @@ class MultiModalEmbeddingLoader: ...@@ -316,7 +315,7 @@ class MultiModalEmbeddingLoader:
Returns a dict suitable for passing to ``TokensPrompt(multi_modal_data=...)``. Returns a dict suitable for passing to ``TokensPrompt(multi_modal_data=...)``.
""" """
if not self._encode_worker_client or not image_urls: if self._encode_worker_client is None or not image_urls:
return {} return {}
groups, pending = await _fetch_embeddings( groups, pending = await _fetch_embeddings(
...@@ -337,7 +336,7 @@ class MultiModalEmbeddingLoader: ...@@ -337,7 +336,7 @@ class MultiModalEmbeddingLoader:
_accumulate_embeddings( _accumulate_embeddings(
multi_modal_data, multi_modal_data,
model, model,
embeddings_dtype, group.loaded_embedding.dtype,
group.loaded_embedding, group.loaded_embedding,
group.image_grid_thw, group.image_grid_thw,
) )
......
...@@ -193,8 +193,8 @@ class TestLoadMultimodalData: ...@@ -193,8 +193,8 @@ class TestLoadMultimodalData:
assert result is fake_mm_data assert result is fake_mm_data
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_passes_model_and_dtype(self): async def test_passes_model(self):
"""Model name and embeddings dtype are forwarded.""" """Model name is forwarded."""
mock_client = MagicMock() mock_client = MagicMock()
handler = _make_handler(encode_worker_client=mock_client) handler = _make_handler(encode_worker_client=mock_client)
...@@ -207,9 +207,6 @@ class TestLoadMultimodalData: ...@@ -207,9 +207,6 @@ class TestLoadMultimodalData:
await handler._load_multimodal_data(["http://img.png"], "req-1") await handler._load_multimodal_data(["http://img.png"], "req-1")
assert mock_load.call_args.kwargs["model"] == handler.config.model assert mock_load.call_args.kwargs["model"] == handler.config.model
assert (
mock_load.call_args.kwargs["embeddings_dtype"] == handler.EMBEDDINGS_DTYPE
)
class TestGenerateAgg: class TestGenerateAgg:
......
...@@ -47,7 +47,6 @@ class TestMultimodalEmbeddingsLoader: ...@@ -47,7 +47,6 @@ class TestMultimodalEmbeddingsLoader:
[url], [url],
"req-1", "req-1",
model=MODEL, model=MODEL,
embeddings_dtype=DTYPE,
) )
mock_fetch.assert_not_awaited() mock_fetch.assert_not_awaited()
...@@ -76,7 +75,6 @@ class TestMultimodalEmbeddingsLoader: ...@@ -76,7 +75,6 @@ class TestMultimodalEmbeddingsLoader:
[url], [url],
"req-1", "req-1",
model=MODEL, model=MODEL,
embeddings_dtype=DTYPE,
) )
mock_fetch.assert_awaited_once() mock_fetch.assert_awaited_once()
...@@ -108,7 +106,6 @@ class TestMultimodalEmbeddingsLoader: ...@@ -108,7 +106,6 @@ class TestMultimodalEmbeddingsLoader:
[url], [url],
"req-1", "req-1",
model=MODEL, model=MODEL,
embeddings_dtype=DTYPE,
) )
mock_fetch.assert_awaited_once() mock_fetch.assert_awaited_once()
...@@ -144,7 +141,6 @@ class TestMultimodalEmbeddingsLoader: ...@@ -144,7 +141,6 @@ class TestMultimodalEmbeddingsLoader:
[url_cached, url_miss], [url_cached, url_miss],
"req-1", "req-1",
model=MODEL, model=MODEL,
embeddings_dtype=DTYPE,
) )
mock_fetch.assert_awaited_once() mock_fetch.assert_awaited_once()
......
...@@ -69,11 +69,7 @@ class TestHandles: ...@@ -69,11 +69,7 @@ class TestHandles:
disaggregation_mode=DisaggregationMode.PREFILL, disaggregation_mode=DisaggregationMode.PREFILL,
route_to_encoder=route_to_encode, route_to_encoder=route_to_encode,
) )
# [gluo NOTE] due to current limitation, see 'WorkerFactory._validate_config()'. assert WorkerFactory.handles(config)
if route_to_encode:
assert not WorkerFactory.handles(config)
else:
assert WorkerFactory.handles(config)
@pytest.mark.parametrize("route_to_encode", [True, False]) @pytest.mark.parametrize("route_to_encode", [True, False])
def test_decode(self, route_to_encode: bool) -> None: def test_decode(self, route_to_encode: bool) -> None:
......
...@@ -64,7 +64,7 @@ class WorkerFactory: ...@@ -64,7 +64,7 @@ class WorkerFactory:
WorkerFactory._validate_config(config) WorkerFactory._validate_config(config)
return True return True
except (ValueError, NotImplementedError) as e: except (ValueError, NotImplementedError) as e:
logger.debug( logger.error(
f"WorkerFactory cannot handle config: {e}, provided config: {WorkerFactory._config_str(config)}" f"WorkerFactory cannot handle config: {e}, provided config: {WorkerFactory._config_str(config)}"
) )
return False return False
...@@ -115,21 +115,6 @@ class WorkerFactory: ...@@ -115,21 +115,6 @@ class WorkerFactory:
raise ValueError( raise ValueError(
"Multimodal worker with DECODE disaggregation mode is not supported." "Multimodal worker with DECODE disaggregation mode is not supported."
) )
# [gluo FIXME]
# 'route_to_encoder' hints standalone encode worker is used
# 'legacy_multimodal_llm_worker == False' hints Dynamo runtime will orchestrate
# P/D disagg and base P/D worker class should be used.
# In such a case, we can't use factory for P/D disaggregation modes because
# the current Dynamo runtime orchestrator is not aware of the extra mm data
# passing between P and D, P/D classes can't consume it properly untill
# the protocol is updated.
elif (
config.route_to_encoder
and config.disaggregation_mode == DisaggregationMode.PREFILL
):
raise NotImplementedError(
"Dynamo orchestrated disaggregated prefill worker, combined with remote encode worker is not supported."
)
async def create( async def create(
self, self,
...@@ -140,6 +125,7 @@ class WorkerFactory: ...@@ -140,6 +125,7 @@ class WorkerFactory:
snapshot_engine: Optional[EngineSetupResult] = None, snapshot_engine: Optional[EngineSetupResult] = None,
) -> None: ) -> None:
"""Create the appropriate multimodal worker based on config flags.""" """Create the appropriate multimodal worker based on config flags."""
WorkerFactory._validate_config(config)
# Standalone encode worker # Standalone encode worker
if config.multimodal_encode_worker: if config.multimodal_encode_worker:
...@@ -463,12 +449,12 @@ class WorkerFactory: ...@@ -463,12 +449,12 @@ class WorkerFactory:
handler = DecodeWorkerHandler( handler = DecodeWorkerHandler(
runtime, runtime,
config,
engine_client, engine_client,
default_sampling_params, default_sampling_params,
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None), getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
enable_multimodal=config.enable_multimodal, enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint, generate_endpoint=generate_endpoint,
config=config,
use_vllm_tokenizer=config.use_vllm_tokenizer, use_vllm_tokenizer=config.use_vllm_tokenizer,
shutdown_event=shutdown_event, shutdown_event=shutdown_event,
enable_frontend_decoding=config.frontend_decoding, enable_frontend_decoding=config.frontend_decoding,
...@@ -646,12 +632,12 @@ class WorkerFactory: ...@@ -646,12 +632,12 @@ class WorkerFactory:
handler = PrefillWorkerHandler( handler = PrefillWorkerHandler(
runtime, runtime,
config,
engine_client, engine_client,
default_sampling_params, default_sampling_params,
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None), getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
enable_multimodal=config.enable_multimodal, enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint, generate_endpoint=generate_endpoint,
config=config,
use_vllm_tokenizer=config.use_vllm_tokenizer, use_vllm_tokenizer=config.use_vllm_tokenizer,
shutdown_event=shutdown_event, shutdown_event=shutdown_event,
enable_frontend_decoding=config.frontend_decoding, enable_frontend_decoding=config.frontend_decoding,
......
...@@ -75,20 +75,14 @@ python -m dynamo.frontend & ...@@ -75,20 +75,14 @@ python -m dynamo.frontend &
EXTRA_ARGS="" EXTRA_ARGS=""
# Embedding transfer:
# "local" = local file (safetensors),
# "nixl-write" = NIXL WRITE transfer
# "nixl-read" = NIXL READ transfer (default: "local")
export DYN_VLLM_EMBEDDING_TRANSFER_MODE=${DYN_VLLM_EMBEDDING_TRANSFER_MODE:-"local"}
# GPU assignments (override via environment variables) # GPU assignments (override via environment variables)
# TODO: use build_gpu_mem_args to measure VRAM instead of hardcoded fractions # TODO: use build_gpu_mem_args to measure VRAM instead of hardcoded fractions
# In single-GPU mode both workers share the same GPU. # In single-GPU mode both workers share the same GPU.
if [[ "$SINGLE_GPU" == "true" ]]; then if [[ "$SINGLE_GPU" == "true" ]]; then
DYN_ENCODE_WORKER_GPU=${DYN_ENCODE_WORKER_GPU:-0} DYN_ENCODE_WORKER_GPU=${DYN_ENCODE_WORKER_GPU:-0}
DYN_PD_WORKER_GPU=${DYN_PD_WORKER_GPU:-0} DYN_PD_WORKER_GPU=${DYN_PD_WORKER_GPU:-0}
DYN_ENCODE_GPU_MEM=${DYN_ENCODE_GPU_MEM:-0.4} DYN_ENCODE_GPU_MEM=${DYN_ENCODE_GPU_MEM:-0.1}
DYN_PD_GPU_MEM=${DYN_PD_GPU_MEM:-0.4} DYN_PD_GPU_MEM=${DYN_PD_GPU_MEM:-0.7}
EXTRA_ARGS="--enforce-eager" EXTRA_ARGS="--enforce-eager"
else else
DYN_ENCODE_WORKER_GPU=${DYN_ENCODE_WORKER_GPU:-1} DYN_ENCODE_WORKER_GPU=${DYN_ENCODE_WORKER_GPU:-1}
...@@ -112,7 +106,6 @@ echo "Starting PD worker on GPU $DYN_PD_WORKER_GPU (GPU mem: $DYN_PD_GPU_MEM)... ...@@ -112,7 +106,6 @@ echo "Starting PD worker on GPU $DYN_PD_WORKER_GPU (GPU mem: $DYN_PD_GPU_MEM)...
CUDA_VISIBLE_DEVICES=$DYN_PD_WORKER_GPU \ CUDA_VISIBLE_DEVICES=$DYN_PD_WORKER_GPU \
python -m dynamo.vllm \ python -m dynamo.vllm \
--route-to-encoder \ --route-to-encoder \
--multimodal-worker \
--enable-multimodal \ --enable-multimodal \
--enable-mm-embeds \ --enable-mm-embeds \
--model "$MODEL_NAME" \ --model "$MODEL_NAME" \
......
...@@ -115,12 +115,12 @@ VLLM_NIXL_SIDE_CHANNEL_PORT=20097 CUDA_VISIBLE_DEVICES=$DYN_ENCODE_WORKER_GPU py ...@@ -115,12 +115,12 @@ VLLM_NIXL_SIDE_CHANNEL_PORT=20097 CUDA_VISIBLE_DEVICES=$DYN_ENCODE_WORKER_GPU py
# Start prefill worker (also handles encode routing via --route-to-encoder) # Start prefill worker (also handles encode routing via --route-to-encoder)
echo "Starting prefill worker on GPU $DYN_PREFILL_WORKER_GPU (GPU mem: $DYN_PREFILL_GPU_MEM)..." echo "Starting prefill worker on GPU $DYN_PREFILL_WORKER_GPU (GPU mem: $DYN_PREFILL_GPU_MEM)..."
VLLM_NIXL_SIDE_CHANNEL_PORT=20098 \ VLLM_NIXL_SIDE_CHANNEL_PORT=20098 \
CUDA_VISIBLE_DEVICES=$DYN_PREFILL_WORKER_GPU python -m dynamo.vllm --multimodal-worker --route-to-encoder --disaggregation-mode prefill --enable-multimodal --enable-mm-embeds --model $MODEL_NAME --gpu-memory-utilization $DYN_PREFILL_GPU_MEM $EXTRA_ARGS $PD_EXTRA_ARGS --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20081"}' & CUDA_VISIBLE_DEVICES=$DYN_PREFILL_WORKER_GPU python -m dynamo.vllm --route-to-encoder --disaggregation-mode prefill --enable-multimodal --enable-mm-embeds --model $MODEL_NAME --gpu-memory-utilization $DYN_PREFILL_GPU_MEM $EXTRA_ARGS $PD_EXTRA_ARGS --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20081"}' &
# Start decode worker # Start decode worker
echo "Starting decode worker on GPU $DYN_DECODE_WORKER_GPU (GPU mem: $DYN_DECODE_GPU_MEM)..." echo "Starting decode worker on GPU $DYN_DECODE_WORKER_GPU (GPU mem: $DYN_DECODE_GPU_MEM)..."
VLLM_NIXL_SIDE_CHANNEL_PORT=20099 \ VLLM_NIXL_SIDE_CHANNEL_PORT=20099 \
CUDA_VISIBLE_DEVICES=$DYN_DECODE_WORKER_GPU python -m dynamo.vllm --multimodal-decode-worker --enable-multimodal --enable-mm-embeds --model $MODEL_NAME --gpu-memory-utilization $DYN_DECODE_GPU_MEM $EXTRA_ARGS $PD_EXTRA_ARGS --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20082"}' & CUDA_VISIBLE_DEVICES=$DYN_DECODE_WORKER_GPU python -m dynamo.vllm --disaggregation-mode decode --enable-multimodal --enable-mm-embeds --model $MODEL_NAME --gpu-memory-utilization $DYN_DECODE_GPU_MEM $EXTRA_ARGS $PD_EXTRA_ARGS --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20082"}' &
echo "==================================================" echo "=================================================="
echo "All components started. Waiting for initialization..." echo "All components started. Waiting for initialization..."
......
...@@ -311,9 +311,6 @@ vllm_configs = { ...@@ -311,9 +311,6 @@ vllm_configs = {
], ],
model="Qwen/Qwen3-VL-2B-Instruct", model="Qwen/Qwen3-VL-2B-Instruct",
script_args=["--model", "Qwen/Qwen3-VL-2B-Instruct", "--single-gpu"], script_args=["--model", "Qwen/Qwen3-VL-2B-Instruct", "--single-gpu"],
env={
"DYN_VLLM_EMBEDDING_TRANSFER_MODE": "nixl-write",
},
request_payloads=[ request_payloads=[
chat_payload( chat_payload(
[ [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment