"lib/vscode:/vscode.git/clone" did not exist on "63fbf4988ec4f19e06bc480521462fe61a5dc95d"
Unverified Commit 4d302ab6 authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

feat: [vLLM] add remote embedding loader to prefill/decode worker. Port E/P/D usage (#7507)


Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
Signed-off-by: default avatarGuanLuo <41310872+GuanLuo@users.noreply.github.com>
parent 9ba5f828
......@@ -13,7 +13,7 @@ import time
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, AsyncIterator, Dict, Final, Generic, TypeVar
from typing import Any, AsyncIterator, Dict, Final, Generic, Optional, TypeVar
import torch
from vllm.config import VllmConfig
......@@ -24,6 +24,14 @@ from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.engine.exceptions import EngineDeadError
from dynamo._core import Context
from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
from dynamo.common.multimodal.embedding_transfer import (
LocalEmbeddingReceiver,
NixlReadEmbeddingReceiver,
NixlWriteEmbeddingReceiver,
)
from dynamo.common.multimodal.image_loader import ImageLoader
from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.common.utils.input_params import InputParamManager
......@@ -40,8 +48,12 @@ from dynamo.llm import (
from dynamo.runtime import Client
from dynamo.runtime.logging import configure_dynamo_logging
from .args import Config
from .constants import EmbeddingTransferMode
from .engine_monitor import VllmEngineMonitor
from .multimodal_utils.hash_utils import compute_mm_uuids_from_images
from .multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_model
from .multimodal_utils.prefill_worker_utils import MultiModalEmbeddingLoader
# Multimodal data dictionary keys
IMAGE_URL_KEY: Final = "image_url"
......@@ -110,6 +122,11 @@ def _compute_mm_uuids(
if not multi_modal_data or "image" not in multi_modal_data:
return None
images = multi_modal_data["image"]
# [gluo FIXME] Dict being returned when the mm data has been processed,
# in this case, we skip computing mm_uuids for now until we better understand
# what info should be hash on.
if isinstance(images, dict):
return None
if not isinstance(images, list):
images = [images]
if not images:
......@@ -338,15 +355,16 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
def __init__(
self,
runtime,
config: Config,
engine,
default_sampling_params,
model_max_len: int | None = None,
enable_multimodal: bool = False,
generate_endpoint=None,
config=None,
use_vllm_tokenizer: bool = False,
shutdown_event: asyncio.Event | None = None,
enable_frontend_decoding: bool = False,
encode_worker_client: Optional[Client] = None,
):
self.runtime = runtime
self.engine_client = engine
......@@ -369,6 +387,7 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
self.image_loader = ImageLoader(
enable_frontend_decoding=enable_frontend_decoding
)
self.embedding_loader = self.init_embedding_loader(config, encode_worker_client)
self.use_vllm_tokenizer = use_vllm_tokenizer
......@@ -385,6 +404,52 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
# Store shutdown event for graceful shutdown monitoring
self.shutdown_event = shutdown_event
def init_embedding_loader(
self, config: Config, encode_worker_client: Optional[Client] = None
) -> Optional[MultiModalEmbeddingLoader]:
"""Initialize the embedding loader with the given encode worker client."""
# Without encode worker, the embedding will be generated internally by vLLM.
if encode_worker_client is None:
return None
# Embedding loader consist of two main components:
# 1) An remote encode worker client and matching embedding receiver,
# which can request remote encode and handle the transfer of embeddings
# from the encode worker to this prefill worker.
# 2) A local embedding cache manager, which can store previously fetched embeddings
# and used to determine whether remote encode is necessary for a given mm data.
self.encode_worker_client = encode_worker_client
if config.embedding_transfer_mode == EmbeddingTransferMode.LOCAL:
self.embedding_receiver = LocalEmbeddingReceiver() # type: ignore
elif config.embedding_transfer_mode == EmbeddingTransferMode.NIXL_WRITE:
self.embedding_receiver = NixlWriteEmbeddingReceiver() # type: ignore
elif config.embedding_transfer_mode == EmbeddingTransferMode.NIXL_READ:
# [gluo FIXME] can't use pre-registered tensor as NIXL requires descriptors
# to be at matching size, need to overwrite nixl connect library
self.embedding_receiver = NixlReadEmbeddingReceiver(max_items=0) # type: ignore
else:
raise ValueError(
f"Invalid embedding transfer mode: {config.embedding_transfer_mode}"
)
# [gluo FIXME/NOTE] This embedding cache manager is purely used for caching embedding
# results from encode worker, but 'config.multimodal_embedding_cache_capacity_gb' is
# also used to configure the DynamoMultimodalEmbeddingCacheConnector within the vLLM.
# This results in duplication of memory and ideally we should have single cache manager
# which can be used by vLLM internal and here. Then we can explore asynchrous embedding
# transfer as we can process and block until the embedding is actually used within vLLM.
self.embedding_cache_manager: MultimodalEmbeddingCacheManager | None = None
if config.multimodal_embedding_cache_capacity_gb > 0:
capacity_bytes = int(
config.multimodal_embedding_cache_capacity_gb * 1024**3
)
self.embedding_cache_manager = MultimodalEmbeddingCacheManager(
capacity_bytes
)
return MultiModalEmbeddingLoader(
encode_worker_client=self.encode_worker_client, # type: ignore
receiver=self.embedding_receiver,
embedding_cache_manager=self.embedding_cache_manager,
)
async def sleep(self, body: dict) -> dict:
"""Sleep the engine to release GPU memory and unregister from discovery.
......@@ -696,7 +761,7 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
# Publish LoRA as a ModelDeploymentCard with format:
# v1/mdc/{namespace}/{component}/{endpoint}/{instance_id}/{lora_slug}
# This allows the frontend to discover it and route correctly to the worker instance
if self.generate_endpoint is not None and self.config is not None:
if self.generate_endpoint is not None:
logger.debug(
f"Publishing LoRA '{lora_name}' ModelDeploymentCard to {self.generate_endpoint}"
)
......@@ -999,7 +1064,7 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
return prompt, sequence_length, embeddings_tensor
async def _extract_multimodal_data(
self, request: Dict[str, Any]
self, request: Dict[str, Any], request_id: str, context
) -> Dict[str, Any] | None:
"""
Extract and decode multimodal data from PreprocessedRequest.
......@@ -1015,8 +1080,31 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
)
mm_map = request["multi_modal_data"]
vllm_mm_data = {}
# [gluo NOTE] If embedding loader is configured, currently we unconditionally
# fetch from the embedding loader.
if self.embedding_loader is not None:
# [gluo FIXME] couldn't simply pass 'mm_map.get(IMAGE_URL_KEY, [])' like below
# as currently the encode worker is using 'ImageLoader.load_image()' which doesn't
# support 'Decoded' variant. Need to update the encode worker to unify handling
image_urls = []
supported = True
for item in mm_map.get(IMAGE_URL_KEY, []):
if isinstance(item, dict) and "Url" in item:
image_urls.append(item["Url"])
elif isinstance(item, dict) and "Decoded" in item:
supported = False
if supported:
vllm_mm_data = await self.embedding_loader.load_multimodal_embeddings(
image_urls, request_id, model=self.config.model, context=context
)
logger.debug(
f"Fetched multimodal embeddings for {len(vllm_mm_data)} items"
)
return vllm_mm_data if vllm_mm_data else None
# Fallback that the vLLM engine will perform encoding internally.
vllm_mm_data = {}
# Process image_url entries
images = await self.image_loader.load_image_batch(
mm_map.get(IMAGE_URL_KEY, []),
......@@ -1327,33 +1415,29 @@ class DecodeWorkerHandler(BaseWorkerHandler):
def __init__(
self,
runtime,
config: Config,
engine,
default_sampling_params,
model_max_len: int | None = None,
enable_multimodal: bool = False,
generate_endpoint=None,
config=None,
use_vllm_tokenizer: bool = False,
shutdown_event: asyncio.Event | None = None,
enable_frontend_decoding: bool = False,
encode_worker_client: Client | None = None,
):
if encode_worker_client is not None:
raise NotImplementedError(
"'encode_worker_client' is provided which indicates remote "
"multimodal encode is configured, this is not currently supported."
)
super().__init__(
runtime,
config,
engine,
default_sampling_params,
model_max_len,
enable_multimodal,
generate_endpoint,
config,
use_vllm_tokenizer,
shutdown_event,
enable_frontend_decoding,
encode_worker_client,
)
async def generate(self, request, context):
......@@ -1379,8 +1463,33 @@ class DecodeWorkerHandler(BaseWorkerHandler):
async def _generate_token_mode(self, request, context, request_id):
"""Generate tokens using internal protocol format (token-in-token-out)."""
# Firstly extract disaggregated params from prefill result if available
prefill_result = request.get("prefill_result")
if prefill_result and isinstance(prefill_result, dict):
kv_params = prefill_result.get("disaggregated_params", {}).get(
"kv_transfer_params"
)
embedding_params = prefill_result.get("disaggregated_params", {}).get(
"embedding_params"
)
else:
kv_params = None
embedding_params = None
multi_modal_data = None
# The decode worker is handling disaggregated requests, the mm embedding will be synthetic
if prefill_result is not None and embedding_params is not None:
if is_qwen_vl_model(self.config.model):
multi_modal_data = construct_qwen_decode_mm_data(
embedding_params["image_grid_thw"],
embedding_params["embeddings_shape"],
request_id,
)
else:
# Extract and decode multimodal data if present
multi_modal_data = await self._extract_multimodal_data(request)
multi_modal_data = await self._extract_multimodal_data(
request, request_id, context
)
# Build prompt from request (handles both prompt_embeds and token_ids)
prompt, embedding_sequence_length, error = self._build_prompt_from_request(
......@@ -1395,14 +1504,6 @@ class DecodeWorkerHandler(BaseWorkerHandler):
request, self.default_sampling_params, self.model_max_len
)
prefill_result = request.get("prefill_result")
if prefill_result and isinstance(prefill_result, dict):
kv_params = prefill_result.get("disaggregated_params", {}).get(
"kv_transfer_params"
)
else:
kv_params = None
if kv_params is not None:
if sampling_params.extra_args is None:
sampling_params.extra_args = {}
......@@ -1548,33 +1649,29 @@ class PrefillWorkerHandler(BaseWorkerHandler):
def __init__(
self,
runtime,
config: Config,
engine,
default_sampling_params,
model_max_len: int | None = None,
enable_multimodal: bool = False,
generate_endpoint=None,
config=None,
use_vllm_tokenizer: bool = False,
shutdown_event: asyncio.Event | None = None,
enable_frontend_decoding: bool = False,
encode_worker_client: Client | None = None,
):
if encode_worker_client is not None:
raise NotImplementedError(
"'encode_worker_client' is provided which indicates remote "
"multimodal encode is configured, this is not currently supported."
)
super().__init__(
runtime,
config,
engine,
default_sampling_params,
model_max_len,
enable_multimodal,
generate_endpoint,
config,
use_vllm_tokenizer,
shutdown_event,
enable_frontend_decoding,
encode_worker_client,
)
async def generate(self, request, context):
......@@ -1590,7 +1687,10 @@ class PrefillWorkerHandler(BaseWorkerHandler):
async def _generate_token_mode(self, request, context, request_id):
"""Generate prefill using internal protocol format (token-in-token-out)."""
# Extract and decode multimodal data if present
multi_modal_data = await self._extract_multimodal_data(request)
multi_modal_data = await self._extract_multimodal_data(
request, request_id, context
)
embedding_params = self._build_embedding_params(multi_modal_data or {})
# Build prompt from request (handles both prompt_embeds and token_ids)
prompt, embedding_sequence_length, error = self._build_prompt_from_request(
......@@ -1670,10 +1770,8 @@ class PrefillWorkerHandler(BaseWorkerHandler):
output: Dict[str, Any] = {
"token_ids": list(token_ids),
"disaggregated_params": (
{"kv_transfer_params": res.kv_transfer_params}
if res.kv_transfer_params
else None
"disaggregated_params": self._build_disaggregated_params(
res.kv_transfer_params, embedding_params
),
"completion_usage": BaseWorkerHandler._build_completion_usage(
request_output=res,
......@@ -1693,3 +1791,38 @@ class PrefillWorkerHandler(BaseWorkerHandler):
)
yield output
def _build_disaggregated_params(self, kv_transfer_params, embedding_params=None):
disaggregated_params = {}
if kv_transfer_params is not None:
disaggregated_params["kv_transfer_params"] = kv_transfer_params
if embedding_params is not None:
disaggregated_params["embedding_params"] = embedding_params
return disaggregated_params if disaggregated_params else None
def _build_embedding_params(
self, multi_modal_data: dict[str, Any]
) -> Dict[str, Any] | None:
"""
Helper function to build mm embedding parameters to be consumed by the decode worker, typically
decode worker doesn't require any metadata for mm embedding as the content has been consumed by
prefill. However, especially found for Qwen models, vLLM's processor will try to expand image
tokens in the prompt which requires such a metadata to pass through the processor.
"""
embedding_params = {}
if is_qwen_vl_model(self.config.model) and isinstance(
multi_modal_data.get("image"), dict
):
image_data = multi_modal_data["image"]
image_grid_thw = image_data.get("image_grid_thw")
image_embeds = image_data.get("image_embeds")
if image_grid_thw is not None:
embedding_params["image_grid_thw"] = (
image_grid_thw.tolist()
if isinstance(image_grid_thw, torch.Tensor)
else image_grid_thw
)
if image_embeds is not None:
embedding_params["embeddings_shape"] = list(image_embeds.shape)
return embedding_params if embedding_params else None
......@@ -144,9 +144,7 @@ async def worker() -> None:
# there
install_signal_handlers(loop, runtime, shutdown_endpoints, shutdown_event)
# Route to appropriate initialization based on config flags
if WorkerFactory.handles(config):
# Create worker factory with setup functions
# Use WorkerFactory to appropriate initialize worker based on config flags
factory = WorkerFactory(
setup_vllm_engine_fn=setup_vllm_engine,
setup_kv_event_publisher_fn=setup_kv_event_publisher,
......@@ -161,9 +159,6 @@ async def worker() -> None:
shutdown_endpoints,
snapshot_engine=snapshot_engine,
)
logger.debug("worker init completed")
else:
raise ValueError("Unsupported worker configuration")
logger.debug("Worker function completed, exiting...")
......
......@@ -4,7 +4,7 @@
import copy
import logging
import uuid
from typing import Any
from typing import Any, Optional
import torch
from vllm.inputs.data import TokensPrompt
......@@ -14,7 +14,6 @@ from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
from dynamo.common.multimodal.embedding_transfer import (
AbstractEmbeddingReceiver,
LocalEmbeddingReceiver,
NixlReadEmbeddingReceiver,
NixlWriteEmbeddingReceiver,
......@@ -48,8 +47,8 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler[dict, dict]):
runtime,
engine_client: AsyncLLM,
config: Config,
encode_worker_client: Client | None = None,
decode_worker_client: Client | None = None,
encode_worker_client: Optional[Client] = None,
decode_worker_client: Optional[Client] = None,
shutdown_event=None,
generate_endpoint=None,
):
......@@ -61,11 +60,11 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler[dict, dict]):
# Call BaseWorkerHandler.__init__ with proper parameters
super().__init__(
runtime,
config,
engine_client,
default_sampling_params,
enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint,
config=config,
shutdown_event=shutdown_event,
)
......@@ -76,28 +75,21 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler[dict, dict]):
# Initialize multimodal-specific components
logger.info("Multimodal PD Worker startup started.")
if "video" in self.config.model.lower():
self.EMBEDDINGS_DTYPE = torch.uint8
else:
self.EMBEDDINGS_DTYPE = torch.float16
# Embedding loader consist of two main components:
# 1) An remote encode worker client and matching embedding receiver,
# which can request remote encode and handle the transfer of embeddings
# from the encode worker to this prefill worker.
# 2) A local embedding cache manager, which can store previously fetched embeddings
# and used to determine whether remote encode is necessary for a given mm data.
self.encode_worker_client = encode_worker_client
self.encode_worker_client = encode_worker_client # type: ignore
if config.embedding_transfer_mode == EmbeddingTransferMode.LOCAL:
self.embedding_receiver: AbstractEmbeddingReceiver = (
LocalEmbeddingReceiver()
)
self.embedding_receiver = LocalEmbeddingReceiver() # type: ignore
elif config.embedding_transfer_mode == EmbeddingTransferMode.NIXL_WRITE:
self.embedding_receiver = NixlWriteEmbeddingReceiver()
self.embedding_receiver = NixlWriteEmbeddingReceiver() # type: ignore
elif config.embedding_transfer_mode == EmbeddingTransferMode.NIXL_READ:
# [gluo FIXME] can't use pre-registered tensor as NIXL requires descriptors
# to be at matching size, need to overwrite nixl connect library
self.embedding_receiver = NixlReadEmbeddingReceiver(max_items=0)
self.embedding_receiver = NixlReadEmbeddingReceiver(max_items=0) # type: ignore
else:
raise ValueError(
f"Invalid embedding transfer mode: {config.embedding_transfer_mode}"
......@@ -110,7 +102,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler[dict, dict]):
self.embedding_cache_manager = MultimodalEmbeddingCacheManager(
capacity_bytes
)
self.embedding_loader = MultiModalEmbeddingLoader(
self.embedding_loader: MultiModalEmbeddingLoader = MultiModalEmbeddingLoader(
encode_worker_client=self.encode_worker_client, # type: ignore
receiver=self.embedding_receiver,
embedding_cache_manager=self.embedding_cache_manager,
......@@ -173,7 +165,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler[dict, dict]):
image_urls,
request_id,
model=self.config.model,
embeddings_dtype=self.EMBEDDINGS_DTYPE,
context=context,
)
......
......@@ -40,11 +40,11 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler[vLLMMultimodalRequest, str
# Call BaseWorkerHandler.__init__ with proper parameters
super().__init__(
runtime,
config,
engine_client,
default_sampling_params,
enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint,
config=config,
shutdown_event=shutdown_event,
)
......
......@@ -306,7 +306,6 @@ class MultiModalEmbeddingLoader:
request_id: str,
*,
model: str,
embeddings_dtype: torch.dtype,
context=None,
) -> Dict[str, Any]:
"""Fetch embeddings and build engine-ready ``multi_modal_data``.
......@@ -316,7 +315,7 @@ class MultiModalEmbeddingLoader:
Returns a dict suitable for passing to ``TokensPrompt(multi_modal_data=...)``.
"""
if not self._encode_worker_client or not image_urls:
if self._encode_worker_client is None or not image_urls:
return {}
groups, pending = await _fetch_embeddings(
......@@ -337,7 +336,7 @@ class MultiModalEmbeddingLoader:
_accumulate_embeddings(
multi_modal_data,
model,
embeddings_dtype,
group.loaded_embedding.dtype,
group.loaded_embedding,
group.image_grid_thw,
)
......
......@@ -193,8 +193,8 @@ class TestLoadMultimodalData:
assert result is fake_mm_data
@pytest.mark.asyncio
async def test_passes_model_and_dtype(self):
"""Model name and embeddings dtype are forwarded."""
async def test_passes_model(self):
"""Model name is forwarded."""
mock_client = MagicMock()
handler = _make_handler(encode_worker_client=mock_client)
......@@ -207,9 +207,6 @@ class TestLoadMultimodalData:
await handler._load_multimodal_data(["http://img.png"], "req-1")
assert mock_load.call_args.kwargs["model"] == handler.config.model
assert (
mock_load.call_args.kwargs["embeddings_dtype"] == handler.EMBEDDINGS_DTYPE
)
class TestGenerateAgg:
......
......@@ -47,7 +47,6 @@ class TestMultimodalEmbeddingsLoader:
[url],
"req-1",
model=MODEL,
embeddings_dtype=DTYPE,
)
mock_fetch.assert_not_awaited()
......@@ -76,7 +75,6 @@ class TestMultimodalEmbeddingsLoader:
[url],
"req-1",
model=MODEL,
embeddings_dtype=DTYPE,
)
mock_fetch.assert_awaited_once()
......@@ -108,7 +106,6 @@ class TestMultimodalEmbeddingsLoader:
[url],
"req-1",
model=MODEL,
embeddings_dtype=DTYPE,
)
mock_fetch.assert_awaited_once()
......@@ -144,7 +141,6 @@ class TestMultimodalEmbeddingsLoader:
[url_cached, url_miss],
"req-1",
model=MODEL,
embeddings_dtype=DTYPE,
)
mock_fetch.assert_awaited_once()
......
......@@ -69,10 +69,6 @@ class TestHandles:
disaggregation_mode=DisaggregationMode.PREFILL,
route_to_encoder=route_to_encode,
)
# [gluo NOTE] due to current limitation, see 'WorkerFactory._validate_config()'.
if route_to_encode:
assert not WorkerFactory.handles(config)
else:
assert WorkerFactory.handles(config)
@pytest.mark.parametrize("route_to_encode", [True, False])
......
......@@ -64,7 +64,7 @@ class WorkerFactory:
WorkerFactory._validate_config(config)
return True
except (ValueError, NotImplementedError) as e:
logger.debug(
logger.error(
f"WorkerFactory cannot handle config: {e}, provided config: {WorkerFactory._config_str(config)}"
)
return False
......@@ -115,21 +115,6 @@ class WorkerFactory:
raise ValueError(
"Multimodal worker with DECODE disaggregation mode is not supported."
)
# [gluo FIXME]
# 'route_to_encoder' hints standalone encode worker is used
# 'legacy_multimodal_llm_worker == False' hints Dynamo runtime will orchestrate
# P/D disagg and base P/D worker class should be used.
# In such a case, we can't use factory for P/D disaggregation modes because
# the current Dynamo runtime orchestrator is not aware of the extra mm data
# passing between P and D, P/D classes can't consume it properly untill
# the protocol is updated.
elif (
config.route_to_encoder
and config.disaggregation_mode == DisaggregationMode.PREFILL
):
raise NotImplementedError(
"Dynamo orchestrated disaggregated prefill worker, combined with remote encode worker is not supported."
)
async def create(
self,
......@@ -140,6 +125,7 @@ class WorkerFactory:
snapshot_engine: Optional[EngineSetupResult] = None,
) -> None:
"""Create the appropriate multimodal worker based on config flags."""
WorkerFactory._validate_config(config)
# Standalone encode worker
if config.multimodal_encode_worker:
......@@ -463,12 +449,12 @@ class WorkerFactory:
handler = DecodeWorkerHandler(
runtime,
config,
engine_client,
default_sampling_params,
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint,
config=config,
use_vllm_tokenizer=config.use_vllm_tokenizer,
shutdown_event=shutdown_event,
enable_frontend_decoding=config.frontend_decoding,
......@@ -646,12 +632,12 @@ class WorkerFactory:
handler = PrefillWorkerHandler(
runtime,
config,
engine_client,
default_sampling_params,
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint,
config=config,
use_vllm_tokenizer=config.use_vllm_tokenizer,
shutdown_event=shutdown_event,
enable_frontend_decoding=config.frontend_decoding,
......
......@@ -75,20 +75,14 @@ python -m dynamo.frontend &
EXTRA_ARGS=""
# Embedding transfer:
# "local" = local file (safetensors),
# "nixl-write" = NIXL WRITE transfer
# "nixl-read" = NIXL READ transfer (default: "local")
export DYN_VLLM_EMBEDDING_TRANSFER_MODE=${DYN_VLLM_EMBEDDING_TRANSFER_MODE:-"local"}
# GPU assignments (override via environment variables)
# TODO: use build_gpu_mem_args to measure VRAM instead of hardcoded fractions
# In single-GPU mode both workers share the same GPU.
if [[ "$SINGLE_GPU" == "true" ]]; then
DYN_ENCODE_WORKER_GPU=${DYN_ENCODE_WORKER_GPU:-0}
DYN_PD_WORKER_GPU=${DYN_PD_WORKER_GPU:-0}
DYN_ENCODE_GPU_MEM=${DYN_ENCODE_GPU_MEM:-0.4}
DYN_PD_GPU_MEM=${DYN_PD_GPU_MEM:-0.4}
DYN_ENCODE_GPU_MEM=${DYN_ENCODE_GPU_MEM:-0.1}
DYN_PD_GPU_MEM=${DYN_PD_GPU_MEM:-0.7}
EXTRA_ARGS="--enforce-eager"
else
DYN_ENCODE_WORKER_GPU=${DYN_ENCODE_WORKER_GPU:-1}
......@@ -112,7 +106,6 @@ echo "Starting PD worker on GPU $DYN_PD_WORKER_GPU (GPU mem: $DYN_PD_GPU_MEM)...
CUDA_VISIBLE_DEVICES=$DYN_PD_WORKER_GPU \
python -m dynamo.vllm \
--route-to-encoder \
--multimodal-worker \
--enable-multimodal \
--enable-mm-embeds \
--model "$MODEL_NAME" \
......
......@@ -115,12 +115,12 @@ VLLM_NIXL_SIDE_CHANNEL_PORT=20097 CUDA_VISIBLE_DEVICES=$DYN_ENCODE_WORKER_GPU py
# Start prefill worker (also handles encode routing via --route-to-encoder)
echo "Starting prefill worker on GPU $DYN_PREFILL_WORKER_GPU (GPU mem: $DYN_PREFILL_GPU_MEM)..."
VLLM_NIXL_SIDE_CHANNEL_PORT=20098 \
CUDA_VISIBLE_DEVICES=$DYN_PREFILL_WORKER_GPU python -m dynamo.vllm --multimodal-worker --route-to-encoder --disaggregation-mode prefill --enable-multimodal --enable-mm-embeds --model $MODEL_NAME --gpu-memory-utilization $DYN_PREFILL_GPU_MEM $EXTRA_ARGS $PD_EXTRA_ARGS --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20081"}' &
CUDA_VISIBLE_DEVICES=$DYN_PREFILL_WORKER_GPU python -m dynamo.vllm --route-to-encoder --disaggregation-mode prefill --enable-multimodal --enable-mm-embeds --model $MODEL_NAME --gpu-memory-utilization $DYN_PREFILL_GPU_MEM $EXTRA_ARGS $PD_EXTRA_ARGS --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20081"}' &
# Start decode worker
echo "Starting decode worker on GPU $DYN_DECODE_WORKER_GPU (GPU mem: $DYN_DECODE_GPU_MEM)..."
VLLM_NIXL_SIDE_CHANNEL_PORT=20099 \
CUDA_VISIBLE_DEVICES=$DYN_DECODE_WORKER_GPU python -m dynamo.vllm --multimodal-decode-worker --enable-multimodal --enable-mm-embeds --model $MODEL_NAME --gpu-memory-utilization $DYN_DECODE_GPU_MEM $EXTRA_ARGS $PD_EXTRA_ARGS --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20082"}' &
CUDA_VISIBLE_DEVICES=$DYN_DECODE_WORKER_GPU python -m dynamo.vllm --disaggregation-mode decode --enable-multimodal --enable-mm-embeds --model $MODEL_NAME --gpu-memory-utilization $DYN_DECODE_GPU_MEM $EXTRA_ARGS $PD_EXTRA_ARGS --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20082"}' &
echo "=================================================="
echo "All components started. Waiting for initialization..."
......
......@@ -311,9 +311,6 @@ vllm_configs = {
],
model="Qwen/Qwen3-VL-2B-Instruct",
script_args=["--model", "Qwen/Qwen3-VL-2B-Instruct", "--single-gpu"],
env={
"DYN_VLLM_EMBEDDING_TRANSFER_MODE": "nixl-write",
},
request_payloads=[
chat_payload(
[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment