Unverified Commit a2a6917f authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

feat: use embedding transfer classes for EPD (#6223)


Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
parent 1549c338
......@@ -2,8 +2,9 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
MODEL_NAME="Qwen/Qwen2.5-VL-7B-Instruct"
MODEL_NAME="Qwen/Qwen3-VL-30B-A3B-Instruct-FP8"
CONCURRENCY=1
OSL=150
# 500 * 333 pixels image
IMG_URL="http://images.cocodataset.org/test2017/000000000183.jpg"
......@@ -16,7 +17,12 @@ DUMMY_PROMPT="This is a prompt to describe the image content briefly."
for i in {1..1500}; do
DUMMY_PROMPT+=" This is a prompt to describe the image content briefly."
done
echo '{"texts": ["'"$DUMMY_PROMPT"'"], "images": ["'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'","'"$IMG_URL"'"]}' \
IMAGE_LIST='"'"$IMG_URL"'"'
for i in {2..30}; do
IMAGE_LIST+=',"'$IMG_URL'"'
done
echo '{"texts": ["'"$DUMMY_PROMPT"'"], "images": ['"$IMAGE_LIST"']}' \
> data_small.jsonl
echo "This benchmark uses duplicate image urls, so any kind of caching can significantly affect the benchmark results, please make sure the caching setting is properly configured for your experiment."
......@@ -31,11 +37,16 @@ while [[ $# -gt 0 ]]; do
CONCURRENCY=$2
shift 2
;;
--osl)
OSL=$2
shift 2
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --model <model_name> Specify the model to use (default: $MODEL_NAME)"
echo " --concurrency <level> Specify the concurrency level to use (default: $CONCURRENCY)"
echo " --osl <level> Specify the OSL to use (default: $OSL)"
echo " -h, --help Show this help message"
exit 0
;;
......@@ -49,7 +60,8 @@ done
aiperf profile -m $MODEL_NAME --endpoint-type chat \
--streaming --request-count 100 --warmup-request-count 5 \
--concurrency $CONCURRENCY --osl 1 \
--concurrency $CONCURRENCY --osl $OSL \
--input-file data_small.jsonl \
--custom-dataset-type single_turn --ui none \
--extra-inputs 'ignore_eos:true' \
--no-server-metrics
......@@ -4,5 +4,19 @@
"""Multimodal utilities for Dynamo components."""
from dynamo.common.multimodal.async_encoder_cache import AsyncEncoderCache
from dynamo.common.multimodal.embedding_transfer import (
LocalEmbeddingReceiver,
LocalEmbeddingSender,
NixlPersistentEmbeddingReceiver,
NixlPersistentEmbeddingSender,
TransferRequest,
)
__all__ = ["AsyncEncoderCache"]
__all__ = [
"AsyncEncoderCache",
"NixlPersistentEmbeddingReceiver",
"NixlPersistentEmbeddingSender",
"TransferRequest",
"LocalEmbeddingReceiver",
"LocalEmbeddingSender",
]
......@@ -114,6 +114,29 @@ class LocalEmbeddingSender(AbstractEmbeddingSender):
self.sender_id = uuid.uuid4().hex
self.embedding_counter = 0
def save_embeddings_to_file(
self, embedding_key: str, embeddings: torch.Tensor
) -> str:
"""
Save the embeddings to a local file and return the file path.
Args:
embedding_key: A unique key for the embeddings.
embeddings: A torch.Tensor of the embeddings to save.
Returns:
The file path where the embeddings are saved.
"""
fd, tensor_path = tempfile.mkstemp(
prefix=f"encoder_cache.{embedding_key}.", suffix=".safetensors"
)
os.close(fd)
tensors = {"ec_cache": embeddings.cpu()}
safetensors_torch.save_file(
tensors,
tensor_path,
)
return tensor_path
async def send_embeddings(
self, embeddings: torch.Tensor, stage_embeddings: bool = False
) -> tuple[TransferRequest, asyncio.Future]:
......@@ -131,15 +154,10 @@ class LocalEmbeddingSender(AbstractEmbeddingSender):
# This could involve publishing to a message queue or making an API call
embedding_key = f"{self.sender_id}_{self.embedding_counter}"
self.embedding_counter += 1
tensor_path = f"/tmp/encoder_cache.{embedding_key}.safetensors"
fd, tensor_path = tempfile.mkstemp(
prefix=f"encoder_cache.{embedding_key}.", suffix=".safetensors"
)
os.close(fd)
tensors = {"ec_cache": embeddings.cpu()}
safetensors_torch.save_file(
tensors,
tensor_path,
tensor_path = await asyncio.to_thread(
self.save_embeddings_to_file,
embedding_key,
embeddings,
)
fut = asyncio.get_event_loop().create_future()
fut.set_result(None)
......@@ -177,7 +195,7 @@ class LocalEmbeddingReceiver(AbstractEmbeddingReceiver):
Caller should invoke release_tensor(tensor_id) when the tensor is no longer needed to free up resources.
"""
tensor_path = request.serialized_request
tensors = safetensors_torch.load_file(tensor_path)
tensors = await asyncio.to_thread(safetensors_torch.load_file, tensor_path)
embedding_tensor = tensors["ec_cache"]
tensor_id = self.tensor_id_counter
self.tensor_id_counter += 1
......@@ -339,10 +357,17 @@ class NixlPersistentEmbeddingSender(AbstractEmbeddingSender):
embeddings: A torch.Tensor of the embeddings to send.
stage_embeddings: A boolean indicating whether the embeddings should be staged for the transfer,
if True, the embeddings may be used as transfer buffer and must not be released until the return future is completed.
if False, the sender will copy the embeddings.
Returns:
A tuple containing the TransferRequest object and a future that can be awaited to indicate the send is completed.
"""
descriptor = nixl_connect.Descriptor(embeddings.cpu())
# If not staging embedding and embedding is on CPU, we explicitly copy
# the tensor as torch.Tensor.cpu() will return original tensor if it's already on CPU
if not stage_embeddings and not embeddings.is_cuda:
embeddings_cpu = embeddings.clone().detach()
else:
embeddings_cpu = embeddings.cpu()
descriptor = nixl_connect.Descriptor(embeddings_cpu)
readable_op = await self.connector.create_readable(descriptor)
request = TransferRequest(
......@@ -418,7 +443,7 @@ class NixlPersistentEmbeddingReceiver(AbstractEmbeddingReceiver):
)
if self.warmedup_descriptors.empty():
logger.warning(
logger.debug(
"No warmed up descriptors available, creating a temporary one for transfer."
)
encodings_tensor = torch.zeros(*embeddings_shape, dtype=embeddings_dtype)
......
......@@ -25,9 +25,17 @@ logger = logging.getLogger(__name__)
async def benchmark(sender, receiver, tensors=None):
if tensors is None:
tensors = [torch.randn(256, 8 * 1024) for _ in range(30)]
# warmup
request, send_future = await sender.send_embeddings(tensors[0])
tensor_id, response = await receiver.receive_embeddings(request)
receiver.release_tensor(tensor_id)
await send_future
# benchmark
send_start = time.perf_counter()
sender_tasks = [
asyncio.create_task(sender.send_embeddings(tensor)) for tensor in tensors
asyncio.create_task(sender.send_embeddings(tensor, stage_embeddings=True))
for tensor in tensors
]
requests = await asyncio.gather(*sender_tasks)
send_end = time.perf_counter()
......
......@@ -4,17 +4,16 @@
import asyncio
import logging
import os
import tempfile
import time
from dataclasses import dataclass
from typing import AsyncIterator
import safetensors
import torch
from transformers import AutoImageProcessor
from vllm.engine.arg_utils import AsyncEngineArgs
import dynamo.nixl_connect as connect
from dynamo.common.multimodal import LocalEmbeddingSender, NixlPersistentEmbeddingSender
from dynamo.runtime import DistributedRuntime
from ..multimodal_utils import (
......@@ -31,7 +30,12 @@ logger = logging.getLogger(__name__)
CACHE_SIZE_MAXIMUM = 8
# Both embedding transmitter suffers from increasing latency as
# number of concurrent requests increases, NixlPersistentEmbedding transmitters
# scale worse than local. Need to investigate why.
TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1))
# [gluo NOTE] default off to benchmark standalone encoder
ENABLE_ENCODER_CACHE = int(os.getenv("ENABLE_ENCODER_CACHE", 1))
@dataclass
......@@ -54,6 +58,7 @@ class EncodeWorkerHandler:
self.model, trust_remote_code=True
)
self.vision_model = load_vision_model(self.model)
logger.debug(f"embedding hidden dim: {self.vision_model.out_hidden_size}")
self.min_workers = 1
# Get encoder components for the model
......@@ -64,14 +69,30 @@ class EncodeWorkerHandler:
self._accumulated_time = 0.0
self._processed_requests = 0
self.readables = []
self.embedding_cache = EmbeddingCache()
self.embedding_cache = EmbeddingCache() if ENABLE_ENCODER_CACHE else None
self.embedding_sender = (
LocalEmbeddingSender()
if TRANSFER_LOCAL
else NixlPersistentEmbeddingSender()
)
self.send_complete_queue = asyncio.Queue()
self.send_complete_checker_task = asyncio.create_task(
self.check_complete(self.send_complete_queue)
)
# Use system temp directory for encoder cache files
self._cache_dir = os.path.join(tempfile.gettempdir(), "encoder_cache")
os.makedirs(self._cache_dir, exist_ok=True)
async def check_complete(self, queue):
while True:
transfer_future, embedding = await queue.get()
if transfer_future is None: # Sentinel value to stop the checker
queue.task_done()
break
await transfer_future
queue.task_done()
def cleanup(self):
pass
self.send_complete_queue.put_nowait(
(None, None)
) # Send sentinel value to stop the checker
async def async_init(self, runtime: DistributedRuntime):
"""Initialize the connector for RDMA transfers"""
......@@ -115,8 +136,10 @@ class EncodeWorkerHandler:
image_url = request.multimodal_inputs[idx].multimodal_input.image_url
# see if we have local cache
embedding_key = self.embedding_cache.generate_hash_key(image_url)
if self.embedding_cache.has_key(embedding_key):
embedding_key = EmbeddingCache.generate_hash_key(image_url)
if self.embedding_cache is not None and self.embedding_cache.has_key(
embedding_key
):
(image_grid_thw, embeddings_cpu) = self.embedding_cache.get(
embedding_key
)
......@@ -129,13 +152,15 @@ class EncodeWorkerHandler:
need_encode_indexes.append((idx, embedding_key))
# Load and generate image tensors
image_futures = []
image_tasks = []
image_to_load = []
for idx, _ in need_encode_indexes:
url = request.multimodal_inputs[idx].multimodal_input.image_url
image_futures.append(self.image_loader.load_image(url))
image_tasks.append(
asyncio.create_task(self.image_loader.load_image(url))
)
image_to_load.append(url)
results = await asyncio.gather(*image_futures, return_exceptions=True)
results = await asyncio.gather(*image_tasks, return_exceptions=True)
loaded_images = []
collective_exceptions = ""
for i, result in enumerate(results):
......@@ -153,8 +178,8 @@ class EncodeWorkerHandler:
)
if loaded_images:
image_embeds = self.image_processor(
images=loaded_images, return_tensors="pt"
image_embeds = await asyncio.to_thread(
self.image_processor, images=loaded_images, return_tensors="pt"
)
# Encode the image embeddings using model-specific encoder
......@@ -200,15 +225,35 @@ class EncodeWorkerHandler:
splitted_embeddings[split_idx].unsqueeze(0),
)
# Cache the computed value for future use
self.embedding_cache.set(
embedding_lists[list_idx].key,
(
embedding_lists[list_idx].image_grid_thw,
embedding_lists[list_idx].embeddings_cpu,
),
if self.embedding_cache is not None:
self.embedding_cache.set(
embedding_lists[list_idx].key,
(
embedding_lists[list_idx].image_grid_thw,
embedding_lists[list_idx].embeddings_cpu,
),
)
before_transfer_time = time.perf_counter()
# Prepare transfer
send_tasks = [
asyncio.create_task(
self.embedding_sender.send_embeddings(
embedding_item.embeddings_cpu, stage_embeddings=True
)
)
for embedding_item in embedding_lists
]
transfer_requests = await asyncio.gather(*send_tasks)
for idx, embedding_item in enumerate(embedding_lists):
after_transfer_time = time.perf_counter()
for idx, item in enumerate(zip(embedding_lists, transfer_requests)):
embedding_item, transfer_request = item
logger.debug(
f"{embedding_item.embeddings_cpu.shape} prepared for transfer."
)
# Update request for transfer metadata
request.multimodal_inputs[idx].multimodal_input.image_url = None
request.multimodal_inputs[
......@@ -217,36 +262,21 @@ class EncodeWorkerHandler:
request.multimodal_inputs[idx].embeddings_shape = tuple(
embedding_item.embeddings_cpu.shape
)
request.multimodal_inputs[idx].serialized_request = transfer_request[0]
# Prepare transfer
if TRANSFER_LOCAL:
logger.debug(
f"ENCODER: saving local safetensors file with key {embedding_item.key}, {embedding_item.embeddings_cpu.numel()} * {embedding_item.embeddings_cpu.element_size()} bytes"
)
tensors = {"ec_cache": embedding_item.embeddings_cpu}
cache_path = os.path.join(
self._cache_dir, f"{embedding_item.key}.safetensors"
)
safetensors.torch.save_file(tensors, cache_path)
# [gluo FIXME] need mechanism to clean up local files
request.multimodal_inputs[idx].serialized_request = cache_path
else:
descriptor = connect.Descriptor(embedding_item.embeddings_cpu)
assert (
self._connector is not None
), "Connector not initialized; call async_init() first"
self.readables.append(
await self._connector.create_readable(descriptor)
)
request.multimodal_inputs[idx].serialized_request = self.readables[
-1
].metadata()
# Keep a reference of the embedding_cpu and only drop reference when the transfer is done
self.send_complete_queue.put_nowait(
(transfer_request[1], embedding_item.embeddings_cpu)
)
logger.debug(f"Request: {request.model_dump_json()}")
time_end = time.perf_counter()
self._accumulated_time += time_end - time_start
self._processed_requests += 1
logger.debug(
f"received request {{ id: {request_id} }} at time {time_start:.4f}, processed in {time_end - time_start:.4f} seconds, break down: image loading and encoding time {(before_transfer_time - time_start):.4f} seconds, transfer preparation time {(after_transfer_time - before_transfer_time):.4f} seconds, after transfer time {(time_end - after_transfer_time):.4f} seconds."
)
logger.debug(
f"Encoded image(s) for request {{ id: {request_id} }} in {time_end - time_start:.4f} seconds. "
f"Average encoding time: {self._accumulated_time / self._processed_requests:.4f} seconds over {self._processed_requests} requests."
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import copy
import logging
import os
import uuid
from collections import defaultdict
from typing import Any
......@@ -15,6 +17,10 @@ import dynamo.nixl_connect as connect
from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
from dynamo.common.multimodal.embedding_transfer import (
LocalEmbeddingReceiver,
NixlPersistentEmbeddingReceiver,
)
from dynamo.runtime import Client, Component, DistributedRuntime
from ..args import Config
......@@ -35,6 +41,8 @@ from ..multimodal_utils.prefill_worker_utils import (
logger = logging.getLogger(__name__)
TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1))
class MultimodalPDWorkerHandler(BaseWorkerHandler):
"""Prefill/Decode or Prefill-only worker for multimodal serving"""
......@@ -93,6 +101,13 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
self._connector: connect.Connector | None = (
None # Will be initialized in async_init
)
# [gluo FIXME] can't use pre-registered tensor as NIXL requires descriptors
# to be at matching size, need to overwrite nixl connect library
self.embedding_receiver = (
LocalEmbeddingReceiver()
if TRANSFER_LOCAL
else NixlPersistentEmbeddingReceiver(max_items=0)
)
logger.info("Multimodal PD Worker has been initialized")
......@@ -168,7 +183,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
async def _load_multimodal_data(
self, request: vLLMMultimodalRequest
) -> dict[str, Any]:
) -> tuple[dict[str, Any], list[int]]:
"""Load pre-computed embeddings into an engine-ready dict.
Each ``MultiModalGroup`` carries embeddings from encode workers,
......@@ -179,13 +194,21 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
multimodal_inputs: list[MultiModalGroup] = request.multimodal_inputs or []
multi_modal_data: dict[str, Any] = defaultdict(list)
for mi in multimodal_inputs:
embeddings = await load_embeddings(
mi,
self.EMBEDDINGS_DTYPE,
self.EMBEDDINGS_DEVICE,
self._connector,
task_lists = [
asyncio.create_task(
load_embeddings(
mi,
self.EMBEDDINGS_DTYPE,
self.EMBEDDINGS_DEVICE,
self.embedding_receiver,
)
)
for mi in multimodal_inputs
]
receiver_tensor_ids: list[int] = []
for task, mi in zip(task_lists, multimodal_inputs):
tensor_id, embeddings = await task
receiver_tensor_ids.append(tensor_id)
accumulate_embeddings(
multi_modal_data,
self.config.model,
......@@ -194,7 +217,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
mi.image_grid_thw,
)
return multi_modal_data
return multi_modal_data, receiver_tensor_ids
# ── Request metadata finalization ────────────────────────────────
......@@ -291,6 +314,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
self,
request: vLLMMultimodalRequest,
multi_modal_data: dict[str, Any],
received_tensor_ids: list[int],
):
"""Run prefill and decode on this worker (aggregated mode)."""
gen = self.engine_client.generate(
......@@ -302,6 +326,9 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
request_id=request.request_id,
)
for tensor_id in received_tensor_ids:
self.embedding_receiver.release_tensor(tensor_id)
num_output_tokens_so_far = 0
async for response in gen:
logger.debug(f"Response kv_transfer_params: {response.kv_transfer_params}")
......@@ -318,6 +345,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
self,
request: vLLMMultimodalRequest,
multi_modal_data: dict[str, Any],
received_tensor_ids: list[int],
):
"""Prefill locally, then forward to a remote decode worker."""
# Prepare prefill-only request
......@@ -338,6 +366,9 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
request_id=prefill_only_request.request_id,
)
for tensor_id in received_tensor_ids:
self.embedding_receiver.release_tensor(tensor_id)
# Drain prefill generator (max_tokens=1, expect a single response)
async for prefill_response in gen:
pass
......@@ -385,7 +416,9 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
request = await self._parse_request(request)
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")
multi_modal_data = await self._load_multimodal_data(request)
multi_modal_data, received_tensor_ids = await self._load_multimodal_data(
request
)
self._finalize_request_metadata(request, multi_modal_data)
logger.info(
......@@ -394,8 +427,12 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
logger.debug(f"{multi_modal_data}")
if self.enable_disagg and self.decode_worker_client:
async for chunk in self._generate_disagg(request, multi_modal_data):
async for chunk in self._generate_disagg(
request, multi_modal_data, received_tensor_ids
):
yield chunk
else:
async for chunk in self._generate_agg(request, multi_modal_data):
async for chunk in self._generate_agg(
request, multi_modal_data, received_tensor_ids
):
yield chunk
......@@ -9,7 +9,8 @@ class EmbeddingCache:
# Initialize an empty dictionary to store key-value pairs
self.cache = {}
def generate_hash_key(self, *args):
@classmethod
def generate_hash_key(cls, *args):
"""
Generate a hashable key based on the provided arguments.
......
......@@ -39,6 +39,7 @@ class SupportedModels:
QWEN_2_5_VL_7B = "Qwen/Qwen2.5-VL-7B-Instruct"
QWEN_2_5_VL_32B = "Qwen/Qwen2.5-VL-32B-Instruct"
QWEN_3_VL_30B_A3B_FP8 = "Qwen/Qwen3-VL-30B-A3B-Instruct-FP8"
QWEN_3_VL_8B_FP8 = "Qwen/Qwen3-VL-8B-Instruct-FP8"
LLAVA_NEXT_VIDEO_7B = "llava-hf/LLaVA-NeXT-Video-7B-hf"
......@@ -118,6 +119,7 @@ QWEN_VL_MODELS = [
SupportedModels.QWEN_2_5_VL_7B,
SupportedModels.QWEN_2_5_VL_32B,
SupportedModels.QWEN_3_VL_30B_A3B_FP8,
SupportedModels.QWEN_3_VL_8B_FP8,
]
......@@ -147,49 +149,18 @@ def load_vision_model(model_id: str) -> torch.nn.Module:
"VLLM_ENABLE_V1_MULTIPROCESSING": "0",
}
)
# [NOTE] For vLLM pre-0.15.0, see https://github.com/vllm-project/vllm/pull/32605 for enhancement after 0.15.0
#
# Load only the vision model via vLLM on encoder workers to avoid loading the full LLM weights, significantly reducing memory usage.
# Uses native vLLM encoder only model loading added in https://github.com/vllm-project/vllm/pull/30242.
# Model needs the class method get_language_model_spec to be defined for this to work.
# TODO(gluo/dsocek): Remove this monkey patch once vLLM upstream adds
# get_language_model_spec to Qwen VL model classes.
# Monkey patch to vLLM's Qwen 2 VL and Qwen 2.5 VL classes to add get_language_model_spec
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
from vllm.model_executor.models.qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
)
from vllm.model_executor.models.qwen2_vl import Qwen2VLForConditionalGeneration
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
from vllm.model_executor.models.qwen3_vl import Qwen3VLForConditionalGeneration
@classmethod
def get_language_model_spec(cls):
return (Qwen2ForCausalLM, "language_model")
Qwen2_5_VLForConditionalGeneration.get_language_model_spec = (
get_language_model_spec
)
Qwen2VLForConditionalGeneration.get_language_model_spec = (
get_language_model_spec
)
@classmethod
def get_language_model_spec(cls):
return (Qwen3ForCausalLM, "language_model")
Qwen3VLForConditionalGeneration.get_language_model_spec = (
get_language_model_spec
)
# Load only the vision model via vLLM on encoder workers to avoid loading the full LLM weights, significantly reducing memory usage.
# Uses native vLLM encoder only model loading added in https://github.com/vllm-project/vllm/pull/32605.
# Load only the vision model via vLLM
vllm_model = LLM(
model=model_id,
enforce_eager=True,
gpu_memory_utilization=0.4,
max_model_len=10,
convert="mm_encoder_only",
kv_cache_memory_bytes=1024
* 1024
* 8, # 8MB KV cache for vLLM to complete the init lifecycle, encoder-only doesn't require KV cache.
max_model_len=1,
mm_encoder_only=True,
enable_prefix_caching=False,
)
return (
......
......@@ -5,11 +5,10 @@ import logging
import os
from typing import Any, Dict, List
import safetensors
import torch
from vllm.sampling_params import SamplingParams as VllmSamplingParams
import dynamo.nixl_connect as connect
from dynamo.common.multimodal.embedding_transfer import AbstractEmbeddingReceiver
from dynamo.runtime import Client
from .model import construct_mm_data
......@@ -25,15 +24,18 @@ logger = logging.getLogger(__name__)
IMAGE_URL_KEY = "image_url"
VIDEO_URL_KEY = "video_url"
TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1))
# Whether to split the multimodal items into smaller batches for encoding. This can help if multimodal items can be speed up
# by separately encodeded with multiple workers.
# Need to experiment with this setting to see if it brings benefits when concurrency > encoder count.
SPLIT_ENCODE = int(os.getenv("SPLIT_ENCODE", 1))
async def load_embeddings(
mi: MultiModalGroup,
embeddings_dtype: torch.dtype,
embeddings_device: str,
connector: connect.Connector | None,
) -> torch.Tensor:
_embeddings_dtype: torch.dtype,
_embeddings_device: str,
receiver: AbstractEmbeddingReceiver,
) -> tuple[int, torch.Tensor]:
"""Load pre-computed embedding tensor via local safetensors or NIXL RDMA.
Args:
......@@ -41,30 +43,14 @@ async def load_embeddings(
contains either a local file path or NIXL RDMA metadata.
embeddings_dtype: Torch dtype for the tensor (used for RDMA path).
embeddings_device: Device string for the tensor (used for RDMA path).
connector: NIXL Connector for RDMA reads (required when TRANSFER_LOCAL=0).
receiver: AbstractEmbeddingReceiver for tensor reads.
Returns:
The embedding tensor loaded into CPU memory.
A tuple of (tensor_id, embeddings), where tensor_id is an integer identifier for the loaded tensor (used for later release),
and the embeddings tensor loaded into CPU memory.
"""
if TRANSFER_LOCAL:
logger.info("PD: Loading local safetensors file")
return safetensors.torch.load_file(mi.serialized_request)["ec_cache"]
embeddings = torch.empty(
mi.embeddings_shape,
dtype=embeddings_dtype,
device=embeddings_device,
)
descriptor = connect.Descriptor(embeddings)
if descriptor is None:
raise RuntimeError(
"Descriptor is None in PD worker - cannot process embeddings"
)
read_op = await connector.begin_read(mi.serialized_request, descriptor)
await read_op.wait_for_completion()
return embeddings
tensor_id, embeddings = await receiver.receive_embeddings(mi.serialized_request)
return tensor_id, embeddings
def accumulate_embeddings(
......@@ -142,7 +128,11 @@ async def fetch_embeddings_from_encode_workers(
if encode_worker_count == 0:
raise RuntimeError("No encode workers available to process multimodal input")
encode_batch_size = max(1, len(image_urls) // encode_worker_count)
encode_batch_size = (
max(1, len(image_urls) // encode_worker_count)
if SPLIT_ENCODE
else len(image_urls)
)
encode_request = vLLMMultimodalRequest(
engine_prompt=PatchedTokensPrompt(prompt_token_ids=[]),
......
......@@ -28,7 +28,7 @@ from vllm.outputs import CompletionOutput
from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.stats import RequestStateStats
import dynamo.nixl_connect as connect
from dynamo.common.multimodal.embedding_transfer import TransferRequest
class Request(BaseModel):
......@@ -170,7 +170,7 @@ class MultiModalGroup(BaseModel):
embeddings_shape: Optional[
Union[Tuple[int, int, int], Tuple[int, int, int, int]]
] = None
serialized_request: Optional[connect.RdmaMetadata | str] = None
serialized_request: Optional[TransferRequest] = None
class vLLMMultimodalRequest(vLLMGenerateRequest):
......
......@@ -172,7 +172,7 @@ class TestGenerateAgg:
handler.engine_client.generate = fake_generate
chunks = []
async for chunk in handler._generate_agg(request, {"image": []}):
async for chunk in handler._generate_agg(request, {"image": []}, []):
chunks.append(chunk)
assert len(chunks) == 1
......@@ -220,7 +220,7 @@ class TestGenerateDisagg:
request = _make_vllm_request()
chunks = []
async for chunk in handler._generate_disagg(request, {"image": []}):
async for chunk in handler._generate_disagg(request, {"image": []}, []):
chunks.append(chunk)
assert len(chunks) == 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment