Unverified Commit a2a6917f authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

feat: use embedding transfer classes for EPD (#6223)


Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
parent 1549c338
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
MODEL_NAME="Qwen/Qwen2.5-VL-7B-Instruct" MODEL_NAME="Qwen/Qwen3-VL-30B-A3B-Instruct-FP8"
CONCURRENCY=1 CONCURRENCY=1
OSL=150
# 500 * 333 pixels image # 500 * 333 pixels image
IMG_URL="http://images.cocodataset.org/test2017/000000000183.jpg" IMG_URL="http://images.cocodataset.org/test2017/000000000183.jpg"
...@@ -16,7 +17,12 @@ DUMMY_PROMPT="This is a prompt to describe the image content briefly." ...@@ -16,7 +17,12 @@ DUMMY_PROMPT="This is a prompt to describe the image content briefly."
for i in {1..1500}; do for i in {1..1500}; do
DUMMY_PROMPT+=" This is a prompt to describe the image content briefly." DUMMY_PROMPT+=" This is a prompt to describe the image content briefly."
done done
echo '{"texts": ["'"$DUMMY_PROMPT"'"], "images": ["'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'"]}' \ IMAGE_LIST='"'"$IMG_URL"'"'
for i in {2..30}; do
IMAGE_LIST+=',"'$IMG_URL'"'
done
echo '{"texts": ["'"$DUMMY_PROMPT"'"], "images": ['"$IMAGE_LIST"']}' \
> data_small.jsonl > data_small.jsonl
echo "This benchmark uses duplicate image urls, so any kind of caching can significantly affect the benchmark results, please make sure the caching setting is properly configured for your experiment." echo "This benchmark uses duplicate image urls, so any kind of caching can significantly affect the benchmark results, please make sure the caching setting is properly configured for your experiment."
...@@ -31,11 +37,16 @@ while [[ $# -gt 0 ]]; do ...@@ -31,11 +37,16 @@ while [[ $# -gt 0 ]]; do
CONCURRENCY=$2 CONCURRENCY=$2
shift 2 shift 2
;; ;;
--osl)
OSL=$2
shift 2
;;
-h|--help) -h|--help)
echo "Usage: $0 [OPTIONS]" echo "Usage: $0 [OPTIONS]"
echo "Options:" echo "Options:"
echo " --model <model_name> Specify the model to use (default: $MODEL_NAME)" echo " --model <model_name> Specify the model to use (default: $MODEL_NAME)"
echo " --concurrency <level> Specify the concurrency level to use (default: $CONCURRENCY)" echo " --concurrency <level> Specify the concurrency level to use (default: $CONCURRENCY)"
echo " --osl <level> Specify the OSL to use (default: $OSL)"
echo " -h, --help Show this help message" echo " -h, --help Show this help message"
exit 0 exit 0
;; ;;
...@@ -49,7 +60,8 @@ done ...@@ -49,7 +60,8 @@ done
aiperf profile -m $MODEL_NAME --endpoint-type chat \ aiperf profile -m $MODEL_NAME --endpoint-type chat \
--streaming --request-count 100 --warmup-request-count 5 \ --streaming --request-count 100 --warmup-request-count 5 \
--concurrency $CONCURRENCY --osl 1 \ --concurrency $CONCURRENCY --osl $OSL \
--input-file data_small.jsonl \ --input-file data_small.jsonl \
--custom-dataset-type single_turn --ui none \ --custom-dataset-type single_turn --ui none \
--extra-inputs 'ignore_eos:true' \
--no-server-metrics --no-server-metrics
...@@ -4,5 +4,19 @@ ...@@ -4,5 +4,19 @@
"""Multimodal utilities for Dynamo components.""" """Multimodal utilities for Dynamo components."""
from dynamo.common.multimodal.async_encoder_cache import AsyncEncoderCache from dynamo.common.multimodal.async_encoder_cache import AsyncEncoderCache
from dynamo.common.multimodal.embedding_transfer import (
LocalEmbeddingReceiver,
LocalEmbeddingSender,
NixlPersistentEmbeddingReceiver,
NixlPersistentEmbeddingSender,
TransferRequest,
)
__all__ = ["AsyncEncoderCache"] __all__ = [
"AsyncEncoderCache",
"NixlPersistentEmbeddingReceiver",
"NixlPersistentEmbeddingSender",
"TransferRequest",
"LocalEmbeddingReceiver",
"LocalEmbeddingSender",
]
...@@ -114,6 +114,29 @@ class LocalEmbeddingSender(AbstractEmbeddingSender): ...@@ -114,6 +114,29 @@ class LocalEmbeddingSender(AbstractEmbeddingSender):
self.sender_id = uuid.uuid4().hex self.sender_id = uuid.uuid4().hex
self.embedding_counter = 0 self.embedding_counter = 0
def save_embeddings_to_file(
self, embedding_key: str, embeddings: torch.Tensor
) -> str:
"""
Save the embeddings to a local file and return the file path.
Args:
embedding_key: A unique key for the embeddings.
embeddings: A torch.Tensor of the embeddings to save.
Returns:
The file path where the embeddings are saved.
"""
fd, tensor_path = tempfile.mkstemp(
prefix=f"encoder_cache.{embedding_key}.", suffix=".safetensors"
)
os.close(fd)
tensors = {"ec_cache": embeddings.cpu()}
safetensors_torch.save_file(
tensors,
tensor_path,
)
return tensor_path
async def send_embeddings( async def send_embeddings(
self, embeddings: torch.Tensor, stage_embeddings: bool = False self, embeddings: torch.Tensor, stage_embeddings: bool = False
) -> tuple[TransferRequest, asyncio.Future]: ) -> tuple[TransferRequest, asyncio.Future]:
...@@ -131,15 +154,10 @@ class LocalEmbeddingSender(AbstractEmbeddingSender): ...@@ -131,15 +154,10 @@ class LocalEmbeddingSender(AbstractEmbeddingSender):
# This could involve publishing to a message queue or making an API call # This could involve publishing to a message queue or making an API call
embedding_key = f"{self.sender_id}_{self.embedding_counter}" embedding_key = f"{self.sender_id}_{self.embedding_counter}"
self.embedding_counter += 1 self.embedding_counter += 1
tensor_path = f"/tmp/encoder_cache.{embedding_key}.safetensors" tensor_path = await asyncio.to_thread(
fd, tensor_path = tempfile.mkstemp( self.save_embeddings_to_file,
prefix=f"encoder_cache.{embedding_key}.", suffix=".safetensors" embedding_key,
) embeddings,
os.close(fd)
tensors = {"ec_cache": embeddings.cpu()}
safetensors_torch.save_file(
tensors,
tensor_path,
) )
fut = asyncio.get_event_loop().create_future() fut = asyncio.get_event_loop().create_future()
fut.set_result(None) fut.set_result(None)
...@@ -177,7 +195,7 @@ class LocalEmbeddingReceiver(AbstractEmbeddingReceiver): ...@@ -177,7 +195,7 @@ class LocalEmbeddingReceiver(AbstractEmbeddingReceiver):
Caller should invoke release_tensor(tensor_id) when the tensor is no longer needed to free up resources. Caller should invoke release_tensor(tensor_id) when the tensor is no longer needed to free up resources.
""" """
tensor_path = request.serialized_request tensor_path = request.serialized_request
tensors = safetensors_torch.load_file(tensor_path) tensors = await asyncio.to_thread(safetensors_torch.load_file, tensor_path)
embedding_tensor = tensors["ec_cache"] embedding_tensor = tensors["ec_cache"]
tensor_id = self.tensor_id_counter tensor_id = self.tensor_id_counter
self.tensor_id_counter += 1 self.tensor_id_counter += 1
...@@ -339,10 +357,17 @@ class NixlPersistentEmbeddingSender(AbstractEmbeddingSender): ...@@ -339,10 +357,17 @@ class NixlPersistentEmbeddingSender(AbstractEmbeddingSender):
embeddings: A torch.Tensor of the embeddings to send. embeddings: A torch.Tensor of the embeddings to send.
stage_embeddings: A boolean indicating whether the embeddings should be staged for the transfer, stage_embeddings: A boolean indicating whether the embeddings should be staged for the transfer,
if True, the embeddings may be used as transfer buffer and must not be released until the return future is completed. if True, the embeddings may be used as transfer buffer and must not be released until the return future is completed.
if False, the sender will copy the embeddings.
Returns: Returns:
A tuple containing the TransferRequest object and a future that can be awaited to indicate the send is completed. A tuple containing the TransferRequest object and a future that can be awaited to indicate the send is completed.
""" """
descriptor = nixl_connect.Descriptor(embeddings.cpu()) # If not staging embedding and embedding is on CPU, we explicitly copy
# the tensor as torch.Tensor.cpu() will return original tensor if it's already on CPU
if not stage_embeddings and not embeddings.is_cuda:
embeddings_cpu = embeddings.clone().detach()
else:
embeddings_cpu = embeddings.cpu()
descriptor = nixl_connect.Descriptor(embeddings_cpu)
readable_op = await self.connector.create_readable(descriptor) readable_op = await self.connector.create_readable(descriptor)
request = TransferRequest( request = TransferRequest(
...@@ -418,7 +443,7 @@ class NixlPersistentEmbeddingReceiver(AbstractEmbeddingReceiver): ...@@ -418,7 +443,7 @@ class NixlPersistentEmbeddingReceiver(AbstractEmbeddingReceiver):
) )
if self.warmedup_descriptors.empty(): if self.warmedup_descriptors.empty():
logger.warning( logger.debug(
"No warmed up descriptors available, creating a temporary one for transfer." "No warmed up descriptors available, creating a temporary one for transfer."
) )
encodings_tensor = torch.zeros(*embeddings_shape, dtype=embeddings_dtype) encodings_tensor = torch.zeros(*embeddings_shape, dtype=embeddings_dtype)
......
...@@ -25,9 +25,17 @@ logger = logging.getLogger(__name__) ...@@ -25,9 +25,17 @@ logger = logging.getLogger(__name__)
async def benchmark(sender, receiver, tensors=None): async def benchmark(sender, receiver, tensors=None):
if tensors is None: if tensors is None:
tensors = [torch.randn(256, 8 * 1024) for _ in range(30)] tensors = [torch.randn(256, 8 * 1024) for _ in range(30)]
# warmup
request, send_future = await sender.send_embeddings(tensors[0])
tensor_id, response = await receiver.receive_embeddings(request)
receiver.release_tensor(tensor_id)
await send_future
# benchmark
send_start = time.perf_counter() send_start = time.perf_counter()
sender_tasks = [ sender_tasks = [
asyncio.create_task(sender.send_embeddings(tensor)) for tensor in tensors asyncio.create_task(sender.send_embeddings(tensor, stage_embeddings=True))
for tensor in tensors
] ]
requests = await asyncio.gather(*sender_tasks) requests = await asyncio.gather(*sender_tasks)
send_end = time.perf_counter() send_end = time.perf_counter()
......
...@@ -4,17 +4,16 @@ ...@@ -4,17 +4,16 @@
import asyncio import asyncio
import logging import logging
import os import os
import tempfile
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import AsyncIterator from typing import AsyncIterator
import safetensors
import torch import torch
from transformers import AutoImageProcessor from transformers import AutoImageProcessor
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
import dynamo.nixl_connect as connect import dynamo.nixl_connect as connect
from dynamo.common.multimodal import LocalEmbeddingSender, NixlPersistentEmbeddingSender
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from ..multimodal_utils import ( from ..multimodal_utils import (
...@@ -31,7 +30,12 @@ logger = logging.getLogger(__name__) ...@@ -31,7 +30,12 @@ logger = logging.getLogger(__name__)
CACHE_SIZE_MAXIMUM = 8 CACHE_SIZE_MAXIMUM = 8
# Both embedding transmitter suffers from increasing latency as
# number of concurrent requests increases, NixlPersistentEmbedding transmitters
# scale worse than local. Need to investigate why.
TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1)) TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1))
# [gluo NOTE] default off to benchmark standalone encoder
ENABLE_ENCODER_CACHE = int(os.getenv("ENABLE_ENCODER_CACHE", 1))
@dataclass @dataclass
...@@ -54,6 +58,7 @@ class EncodeWorkerHandler: ...@@ -54,6 +58,7 @@ class EncodeWorkerHandler:
self.model, trust_remote_code=True self.model, trust_remote_code=True
) )
self.vision_model = load_vision_model(self.model) self.vision_model = load_vision_model(self.model)
logger.debug(f"embedding hidden dim: {self.vision_model.out_hidden_size}")
self.min_workers = 1 self.min_workers = 1
# Get encoder components for the model # Get encoder components for the model
...@@ -64,14 +69,30 @@ class EncodeWorkerHandler: ...@@ -64,14 +69,30 @@ class EncodeWorkerHandler:
self._accumulated_time = 0.0 self._accumulated_time = 0.0
self._processed_requests = 0 self._processed_requests = 0
self.readables = [] self.readables = []
self.embedding_cache = EmbeddingCache() self.embedding_cache = EmbeddingCache() if ENABLE_ENCODER_CACHE else None
self.embedding_sender = (
LocalEmbeddingSender()
if TRANSFER_LOCAL
else NixlPersistentEmbeddingSender()
)
self.send_complete_queue = asyncio.Queue()
self.send_complete_checker_task = asyncio.create_task(
self.check_complete(self.send_complete_queue)
)
# Use system temp directory for encoder cache files async def check_complete(self, queue):
self._cache_dir = os.path.join(tempfile.gettempdir(), "encoder_cache") while True:
os.makedirs(self._cache_dir, exist_ok=True) transfer_future, embedding = await queue.get()
if transfer_future is None: # Sentinel value to stop the checker
queue.task_done()
break
await transfer_future
queue.task_done()
def cleanup(self): def cleanup(self):
pass self.send_complete_queue.put_nowait(
(None, None)
) # Send sentinel value to stop the checker
async def async_init(self, runtime: DistributedRuntime): async def async_init(self, runtime: DistributedRuntime):
"""Initialize the connector for RDMA transfers""" """Initialize the connector for RDMA transfers"""
...@@ -115,8 +136,10 @@ class EncodeWorkerHandler: ...@@ -115,8 +136,10 @@ class EncodeWorkerHandler:
image_url = request.multimodal_inputs[idx].multimodal_input.image_url image_url = request.multimodal_inputs[idx].multimodal_input.image_url
# see if we have local cache # see if we have local cache
embedding_key = self.embedding_cache.generate_hash_key(image_url) embedding_key = EmbeddingCache.generate_hash_key(image_url)
if self.embedding_cache.has_key(embedding_key): if self.embedding_cache is not None and self.embedding_cache.has_key(
embedding_key
):
(image_grid_thw, embeddings_cpu) = self.embedding_cache.get( (image_grid_thw, embeddings_cpu) = self.embedding_cache.get(
embedding_key embedding_key
) )
...@@ -129,13 +152,15 @@ class EncodeWorkerHandler: ...@@ -129,13 +152,15 @@ class EncodeWorkerHandler:
need_encode_indexes.append((idx, embedding_key)) need_encode_indexes.append((idx, embedding_key))
# Load and generate image tensors # Load and generate image tensors
image_futures = [] image_tasks = []
image_to_load = [] image_to_load = []
for idx, _ in need_encode_indexes: for idx, _ in need_encode_indexes:
url = request.multimodal_inputs[idx].multimodal_input.image_url url = request.multimodal_inputs[idx].multimodal_input.image_url
image_futures.append(self.image_loader.load_image(url)) image_tasks.append(
asyncio.create_task(self.image_loader.load_image(url))
)
image_to_load.append(url) image_to_load.append(url)
results = await asyncio.gather(*image_futures, return_exceptions=True) results = await asyncio.gather(*image_tasks, return_exceptions=True)
loaded_images = [] loaded_images = []
collective_exceptions = "" collective_exceptions = ""
for i, result in enumerate(results): for i, result in enumerate(results):
...@@ -153,8 +178,8 @@ class EncodeWorkerHandler: ...@@ -153,8 +178,8 @@ class EncodeWorkerHandler:
) )
if loaded_images: if loaded_images:
image_embeds = self.image_processor( image_embeds = await asyncio.to_thread(
images=loaded_images, return_tensors="pt" self.image_processor, images=loaded_images, return_tensors="pt"
) )
# Encode the image embeddings using model-specific encoder # Encode the image embeddings using model-specific encoder
...@@ -200,15 +225,35 @@ class EncodeWorkerHandler: ...@@ -200,15 +225,35 @@ class EncodeWorkerHandler:
splitted_embeddings[split_idx].unsqueeze(0), splitted_embeddings[split_idx].unsqueeze(0),
) )
# Cache the computed value for future use # Cache the computed value for future use
self.embedding_cache.set( if self.embedding_cache is not None:
embedding_lists[list_idx].key, self.embedding_cache.set(
( embedding_lists[list_idx].key,
embedding_lists[list_idx].image_grid_thw, (
embedding_lists[list_idx].embeddings_cpu, embedding_lists[list_idx].image_grid_thw,
), embedding_lists[list_idx].embeddings_cpu,
),
)
before_transfer_time = time.perf_counter()
# Prepare transfer
send_tasks = [
asyncio.create_task(
self.embedding_sender.send_embeddings(
embedding_item.embeddings_cpu, stage_embeddings=True
)
) )
for embedding_item in embedding_lists
]
transfer_requests = await asyncio.gather(*send_tasks)
for idx, embedding_item in enumerate(embedding_lists): after_transfer_time = time.perf_counter()
for idx, item in enumerate(zip(embedding_lists, transfer_requests)):
embedding_item, transfer_request = item
logger.debug(
f"{embedding_item.embeddings_cpu.shape} prepared for transfer."
)
# Update request for transfer metadata # Update request for transfer metadata
request.multimodal_inputs[idx].multimodal_input.image_url = None request.multimodal_inputs[idx].multimodal_input.image_url = None
request.multimodal_inputs[ request.multimodal_inputs[
...@@ -217,36 +262,21 @@ class EncodeWorkerHandler: ...@@ -217,36 +262,21 @@ class EncodeWorkerHandler:
request.multimodal_inputs[idx].embeddings_shape = tuple( request.multimodal_inputs[idx].embeddings_shape = tuple(
embedding_item.embeddings_cpu.shape embedding_item.embeddings_cpu.shape
) )
request.multimodal_inputs[idx].serialized_request = transfer_request[0]
# Prepare transfer # Keep a reference of the embedding_cpu and only drop reference when the transfer is done
if TRANSFER_LOCAL: self.send_complete_queue.put_nowait(
logger.debug( (transfer_request[1], embedding_item.embeddings_cpu)
f"ENCODER: saving local safetensors file with key {embedding_item.key}, {embedding_item.embeddings_cpu.numel()} * {embedding_item.embeddings_cpu.element_size()} bytes" )
)
tensors = {"ec_cache": embedding_item.embeddings_cpu}
cache_path = os.path.join(
self._cache_dir, f"{embedding_item.key}.safetensors"
)
safetensors.torch.save_file(tensors, cache_path)
# [gluo FIXME] need mechanism to clean up local files
request.multimodal_inputs[idx].serialized_request = cache_path
else:
descriptor = connect.Descriptor(embedding_item.embeddings_cpu)
assert (
self._connector is not None
), "Connector not initialized; call async_init() first"
self.readables.append(
await self._connector.create_readable(descriptor)
)
request.multimodal_inputs[idx].serialized_request = self.readables[
-1
].metadata()
logger.debug(f"Request: {request.model_dump_json()}") logger.debug(f"Request: {request.model_dump_json()}")
time_end = time.perf_counter() time_end = time.perf_counter()
self._accumulated_time += time_end - time_start self._accumulated_time += time_end - time_start
self._processed_requests += 1 self._processed_requests += 1
logger.debug(
f"received request {{ id: {request_id} }} at time {time_start:.4f}, processed in {time_end - time_start:.4f} seconds, break down: image loading and encoding time {(before_transfer_time - time_start):.4f} seconds, transfer preparation time {(after_transfer_time - before_transfer_time):.4f} seconds, after transfer time {(time_end - after_transfer_time):.4f} seconds."
)
logger.debug( logger.debug(
f"Encoded image(s) for request {{ id: {request_id} }} in {time_end - time_start:.4f} seconds. " f"Encoded image(s) for request {{ id: {request_id} }} in {time_end - time_start:.4f} seconds. "
f"Average encoding time: {self._accumulated_time / self._processed_requests:.4f} seconds over {self._processed_requests} requests." f"Average encoding time: {self._accumulated_time / self._processed_requests:.4f} seconds over {self._processed_requests} requests."
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio
import copy import copy
import logging import logging
import os
import uuid import uuid
from collections import defaultdict from collections import defaultdict
from typing import Any from typing import Any
...@@ -15,6 +17,10 @@ import dynamo.nixl_connect as connect ...@@ -15,6 +17,10 @@ import dynamo.nixl_connect as connect
from dynamo.common.memory.multimodal_embedding_cache_manager import ( from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager, MultimodalEmbeddingCacheManager,
) )
from dynamo.common.multimodal.embedding_transfer import (
LocalEmbeddingReceiver,
NixlPersistentEmbeddingReceiver,
)
from dynamo.runtime import Client, Component, DistributedRuntime from dynamo.runtime import Client, Component, DistributedRuntime
from ..args import Config from ..args import Config
...@@ -35,6 +41,8 @@ from ..multimodal_utils.prefill_worker_utils import ( ...@@ -35,6 +41,8 @@ from ..multimodal_utils.prefill_worker_utils import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1))
class MultimodalPDWorkerHandler(BaseWorkerHandler): class MultimodalPDWorkerHandler(BaseWorkerHandler):
"""Prefill/Decode or Prefill-only worker for multimodal serving""" """Prefill/Decode or Prefill-only worker for multimodal serving"""
...@@ -93,6 +101,13 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -93,6 +101,13 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
self._connector: connect.Connector | None = ( self._connector: connect.Connector | None = (
None # Will be initialized in async_init None # Will be initialized in async_init
) )
# [gluo FIXME] can't use pre-registered tensor as NIXL requires descriptors
# to be at matching size, need to overwrite nixl connect library
self.embedding_receiver = (
LocalEmbeddingReceiver()
if TRANSFER_LOCAL
else NixlPersistentEmbeddingReceiver(max_items=0)
)
logger.info("Multimodal PD Worker has been initialized") logger.info("Multimodal PD Worker has been initialized")
...@@ -168,7 +183,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -168,7 +183,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
async def _load_multimodal_data( async def _load_multimodal_data(
self, request: vLLMMultimodalRequest self, request: vLLMMultimodalRequest
) -> dict[str, Any]: ) -> tuple[dict[str, Any], list[int]]:
"""Load pre-computed embeddings into an engine-ready dict. """Load pre-computed embeddings into an engine-ready dict.
Each ``MultiModalGroup`` carries embeddings from encode workers, Each ``MultiModalGroup`` carries embeddings from encode workers,
...@@ -179,13 +194,21 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -179,13 +194,21 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
multimodal_inputs: list[MultiModalGroup] = request.multimodal_inputs or [] multimodal_inputs: list[MultiModalGroup] = request.multimodal_inputs or []
multi_modal_data: dict[str, Any] = defaultdict(list) multi_modal_data: dict[str, Any] = defaultdict(list)
for mi in multimodal_inputs: task_lists = [
embeddings = await load_embeddings( asyncio.create_task(
mi, load_embeddings(
self.EMBEDDINGS_DTYPE, mi,
self.EMBEDDINGS_DEVICE, self.EMBEDDINGS_DTYPE,
self._connector, self.EMBEDDINGS_DEVICE,
self.embedding_receiver,
)
) )
for mi in multimodal_inputs
]
receiver_tensor_ids: list[int] = []
for task, mi in zip(task_lists, multimodal_inputs):
tensor_id, embeddings = await task
receiver_tensor_ids.append(tensor_id)
accumulate_embeddings( accumulate_embeddings(
multi_modal_data, multi_modal_data,
self.config.model, self.config.model,
...@@ -194,7 +217,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -194,7 +217,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
mi.image_grid_thw, mi.image_grid_thw,
) )
return multi_modal_data return multi_modal_data, receiver_tensor_ids
# ── Request metadata finalization ──────────────────────────────── # ── Request metadata finalization ────────────────────────────────
...@@ -291,6 +314,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -291,6 +314,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
self, self,
request: vLLMMultimodalRequest, request: vLLMMultimodalRequest,
multi_modal_data: dict[str, Any], multi_modal_data: dict[str, Any],
received_tensor_ids: list[int],
): ):
"""Run prefill and decode on this worker (aggregated mode).""" """Run prefill and decode on this worker (aggregated mode)."""
gen = self.engine_client.generate( gen = self.engine_client.generate(
...@@ -302,6 +326,9 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -302,6 +326,9 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
request_id=request.request_id, request_id=request.request_id,
) )
for tensor_id in received_tensor_ids:
self.embedding_receiver.release_tensor(tensor_id)
num_output_tokens_so_far = 0 num_output_tokens_so_far = 0
async for response in gen: async for response in gen:
logger.debug(f"Response kv_transfer_params: {response.kv_transfer_params}") logger.debug(f"Response kv_transfer_params: {response.kv_transfer_params}")
...@@ -318,6 +345,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -318,6 +345,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
self, self,
request: vLLMMultimodalRequest, request: vLLMMultimodalRequest,
multi_modal_data: dict[str, Any], multi_modal_data: dict[str, Any],
received_tensor_ids: list[int],
): ):
"""Prefill locally, then forward to a remote decode worker.""" """Prefill locally, then forward to a remote decode worker."""
# Prepare prefill-only request # Prepare prefill-only request
...@@ -338,6 +366,9 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -338,6 +366,9 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
request_id=prefill_only_request.request_id, request_id=prefill_only_request.request_id,
) )
for tensor_id in received_tensor_ids:
self.embedding_receiver.release_tensor(tensor_id)
# Drain prefill generator (max_tokens=1, expect a single response) # Drain prefill generator (max_tokens=1, expect a single response)
async for prefill_response in gen: async for prefill_response in gen:
pass pass
...@@ -385,7 +416,9 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -385,7 +416,9 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
request = await self._parse_request(request) request = await self._parse_request(request)
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.") logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")
multi_modal_data = await self._load_multimodal_data(request) multi_modal_data, received_tensor_ids = await self._load_multimodal_data(
request
)
self._finalize_request_metadata(request, multi_modal_data) self._finalize_request_metadata(request, multi_modal_data)
logger.info( logger.info(
...@@ -394,8 +427,12 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -394,8 +427,12 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
logger.debug(f"{multi_modal_data}") logger.debug(f"{multi_modal_data}")
if self.enable_disagg and self.decode_worker_client: if self.enable_disagg and self.decode_worker_client:
async for chunk in self._generate_disagg(request, multi_modal_data): async for chunk in self._generate_disagg(
request, multi_modal_data, received_tensor_ids
):
yield chunk yield chunk
else: else:
async for chunk in self._generate_agg(request, multi_modal_data): async for chunk in self._generate_agg(
request, multi_modal_data, received_tensor_ids
):
yield chunk yield chunk
...@@ -9,7 +9,8 @@ class EmbeddingCache: ...@@ -9,7 +9,8 @@ class EmbeddingCache:
# Initialize an empty dictionary to store key-value pairs # Initialize an empty dictionary to store key-value pairs
self.cache = {} self.cache = {}
def generate_hash_key(self, *args): @classmethod
def generate_hash_key(cls, *args):
""" """
Generate a hashable key based on the provided arguments. Generate a hashable key based on the provided arguments.
......
...@@ -39,6 +39,7 @@ class SupportedModels: ...@@ -39,6 +39,7 @@ class SupportedModels:
QWEN_2_5_VL_7B = "Qwen/Qwen2.5-VL-7B-Instruct" QWEN_2_5_VL_7B = "Qwen/Qwen2.5-VL-7B-Instruct"
QWEN_2_5_VL_32B = "Qwen/Qwen2.5-VL-32B-Instruct" QWEN_2_5_VL_32B = "Qwen/Qwen2.5-VL-32B-Instruct"
QWEN_3_VL_30B_A3B_FP8 = "Qwen/Qwen3-VL-30B-A3B-Instruct-FP8" QWEN_3_VL_30B_A3B_FP8 = "Qwen/Qwen3-VL-30B-A3B-Instruct-FP8"
QWEN_3_VL_8B_FP8 = "Qwen/Qwen3-VL-8B-Instruct-FP8"
LLAVA_NEXT_VIDEO_7B = "llava-hf/LLaVA-NeXT-Video-7B-hf" LLAVA_NEXT_VIDEO_7B = "llava-hf/LLaVA-NeXT-Video-7B-hf"
...@@ -118,6 +119,7 @@ QWEN_VL_MODELS = [ ...@@ -118,6 +119,7 @@ QWEN_VL_MODELS = [
SupportedModels.QWEN_2_5_VL_7B, SupportedModels.QWEN_2_5_VL_7B,
SupportedModels.QWEN_2_5_VL_32B, SupportedModels.QWEN_2_5_VL_32B,
SupportedModels.QWEN_3_VL_30B_A3B_FP8, SupportedModels.QWEN_3_VL_30B_A3B_FP8,
SupportedModels.QWEN_3_VL_8B_FP8,
] ]
...@@ -147,49 +149,18 @@ def load_vision_model(model_id: str) -> torch.nn.Module: ...@@ -147,49 +149,18 @@ def load_vision_model(model_id: str) -> torch.nn.Module:
"VLLM_ENABLE_V1_MULTIPROCESSING": "0", "VLLM_ENABLE_V1_MULTIPROCESSING": "0",
} }
) )
# [NOTE] For vLLM pre-0.15.0, see https://github.com/vllm-project/vllm/pull/32605 for enhancement after 0.15.0
#
# Load only the vision model via vLLM on encoder workers to avoid loading the full LLM weights, significantly reducing memory usage.
# Uses native vLLM encoder only model loading added in https://github.com/vllm-project/vllm/pull/30242.
# Model needs the class method get_language_model_spec to be defined for this to work.
# TODO(gluo/dsocek): Remove this monkey patch once vLLM upstream adds
# get_language_model_spec to Qwen VL model classes.
# Monkey patch to vLLM's Qwen 2 VL and Qwen 2.5 VL classes to add get_language_model_spec
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
from vllm.model_executor.models.qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
)
from vllm.model_executor.models.qwen2_vl import Qwen2VLForConditionalGeneration
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
from vllm.model_executor.models.qwen3_vl import Qwen3VLForConditionalGeneration
@classmethod
def get_language_model_spec(cls):
return (Qwen2ForCausalLM, "language_model")
Qwen2_5_VLForConditionalGeneration.get_language_model_spec = (
get_language_model_spec
)
Qwen2VLForConditionalGeneration.get_language_model_spec = (
get_language_model_spec
)
@classmethod
def get_language_model_spec(cls):
return (Qwen3ForCausalLM, "language_model")
Qwen3VLForConditionalGeneration.get_language_model_spec = (
get_language_model_spec
)
# Load only the vision model via vLLM on encoder workers to avoid loading the full LLM weights, significantly reducing memory usage.
# Uses native vLLM encoder only model loading added in https://github.com/vllm-project/vllm/pull/32605.
# Load only the vision model via vLLM # Load only the vision model via vLLM
vllm_model = LLM( vllm_model = LLM(
model=model_id, model=model_id,
enforce_eager=True, enforce_eager=True,
gpu_memory_utilization=0.4, kv_cache_memory_bytes=1024
max_model_len=10, * 1024
convert="mm_encoder_only", * 8, # 8MB KV cache for vLLM to complete the init lifecycle, encoder-only doesn't require KV cache.
max_model_len=1,
mm_encoder_only=True,
enable_prefix_caching=False, enable_prefix_caching=False,
) )
return ( return (
......
...@@ -5,11 +5,10 @@ import logging ...@@ -5,11 +5,10 @@ import logging
import os import os
from typing import Any, Dict, List from typing import Any, Dict, List
import safetensors
import torch import torch
from vllm.sampling_params import SamplingParams as VllmSamplingParams from vllm.sampling_params import SamplingParams as VllmSamplingParams
import dynamo.nixl_connect as connect from dynamo.common.multimodal.embedding_transfer import AbstractEmbeddingReceiver
from dynamo.runtime import Client from dynamo.runtime import Client
from .model import construct_mm_data from .model import construct_mm_data
...@@ -25,15 +24,18 @@ logger = logging.getLogger(__name__) ...@@ -25,15 +24,18 @@ logger = logging.getLogger(__name__)
IMAGE_URL_KEY = "image_url" IMAGE_URL_KEY = "image_url"
VIDEO_URL_KEY = "video_url" VIDEO_URL_KEY = "video_url"
TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1)) # Whether to split the multimodal items into smaller batches for encoding. This can help if multimodal items can be speed up
# by separately encodeded with multiple workers.
# Need to experiment with this setting to see if it brings benefits when concurrency > encoder count.
SPLIT_ENCODE = int(os.getenv("SPLIT_ENCODE", 1))
async def load_embeddings( async def load_embeddings(
mi: MultiModalGroup, mi: MultiModalGroup,
embeddings_dtype: torch.dtype, _embeddings_dtype: torch.dtype,
embeddings_device: str, _embeddings_device: str,
connector: connect.Connector | None, receiver: AbstractEmbeddingReceiver,
) -> torch.Tensor: ) -> tuple[int, torch.Tensor]:
"""Load pre-computed embedding tensor via local safetensors or NIXL RDMA. """Load pre-computed embedding tensor via local safetensors or NIXL RDMA.
Args: Args:
...@@ -41,30 +43,14 @@ async def load_embeddings( ...@@ -41,30 +43,14 @@ async def load_embeddings(
contains either a local file path or NIXL RDMA metadata. contains either a local file path or NIXL RDMA metadata.
embeddings_dtype: Torch dtype for the tensor (used for RDMA path). embeddings_dtype: Torch dtype for the tensor (used for RDMA path).
embeddings_device: Device string for the tensor (used for RDMA path). embeddings_device: Device string for the tensor (used for RDMA path).
connector: NIXL Connector for RDMA reads (required when TRANSFER_LOCAL=0). receiver: AbstractEmbeddingReceiver for tensor reads.
Returns: Returns:
The embedding tensor loaded into CPU memory. A tuple of (tensor_id, embeddings), where tensor_id is an integer identifier for the loaded tensor (used for later release),
and the embeddings tensor loaded into CPU memory.
""" """
if TRANSFER_LOCAL: tensor_id, embeddings = await receiver.receive_embeddings(mi.serialized_request)
logger.info("PD: Loading local safetensors file") return tensor_id, embeddings
return safetensors.torch.load_file(mi.serialized_request)["ec_cache"]
embeddings = torch.empty(
mi.embeddings_shape,
dtype=embeddings_dtype,
device=embeddings_device,
)
descriptor = connect.Descriptor(embeddings)
if descriptor is None:
raise RuntimeError(
"Descriptor is None in PD worker - cannot process embeddings"
)
read_op = await connector.begin_read(mi.serialized_request, descriptor)
await read_op.wait_for_completion()
return embeddings
def accumulate_embeddings( def accumulate_embeddings(
...@@ -142,7 +128,11 @@ async def fetch_embeddings_from_encode_workers( ...@@ -142,7 +128,11 @@ async def fetch_embeddings_from_encode_workers(
if encode_worker_count == 0: if encode_worker_count == 0:
raise RuntimeError("No encode workers available to process multimodal input") raise RuntimeError("No encode workers available to process multimodal input")
encode_batch_size = max(1, len(image_urls) // encode_worker_count) encode_batch_size = (
max(1, len(image_urls) // encode_worker_count)
if SPLIT_ENCODE
else len(image_urls)
)
encode_request = vLLMMultimodalRequest( encode_request = vLLMMultimodalRequest(
engine_prompt=PatchedTokensPrompt(prompt_token_ids=[]), engine_prompt=PatchedTokensPrompt(prompt_token_ids=[]),
......
...@@ -28,7 +28,7 @@ from vllm.outputs import CompletionOutput ...@@ -28,7 +28,7 @@ from vllm.outputs import CompletionOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.stats import RequestStateStats from vllm.v1.metrics.stats import RequestStateStats
import dynamo.nixl_connect as connect from dynamo.common.multimodal.embedding_transfer import TransferRequest
class Request(BaseModel): class Request(BaseModel):
...@@ -170,7 +170,7 @@ class MultiModalGroup(BaseModel): ...@@ -170,7 +170,7 @@ class MultiModalGroup(BaseModel):
embeddings_shape: Optional[ embeddings_shape: Optional[
Union[Tuple[int, int, int], Tuple[int, int, int, int]] Union[Tuple[int, int, int], Tuple[int, int, int, int]]
] = None ] = None
serialized_request: Optional[connect.RdmaMetadata | str] = None serialized_request: Optional[TransferRequest] = None
class vLLMMultimodalRequest(vLLMGenerateRequest): class vLLMMultimodalRequest(vLLMGenerateRequest):
......
...@@ -172,7 +172,7 @@ class TestGenerateAgg: ...@@ -172,7 +172,7 @@ class TestGenerateAgg:
handler.engine_client.generate = fake_generate handler.engine_client.generate = fake_generate
chunks = [] chunks = []
async for chunk in handler._generate_agg(request, {"image": []}): async for chunk in handler._generate_agg(request, {"image": []}, []):
chunks.append(chunk) chunks.append(chunk)
assert len(chunks) == 1 assert len(chunks) == 1
...@@ -220,7 +220,7 @@ class TestGenerateDisagg: ...@@ -220,7 +220,7 @@ class TestGenerateDisagg:
request = _make_vllm_request() request = _make_vllm_request()
chunks = [] chunks = []
async for chunk in handler._generate_disagg(request, {"image": []}): async for chunk in handler._generate_disagg(request, {"image": []}, []):
chunks.append(chunk) chunks.append(chunk)
assert len(chunks) == 1 assert len(chunks) == 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment