Unverified Commit 2cab0f7f authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

fix(perf): add embedding transfer implementation with NIXL WRITE initiation (#6651)


Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
parent da98f6a0
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from pydantic import BaseModel
from dynamo.common.constants import EmbeddingTransferMode
from dynamo.common.multimodal.embedding_transfer import TransferRequest
class TransferConfig(BaseModel):
use_gpu: bool = False
tensor_count_per_request: int = 30
# EmbeddingTransferMode.LOCAL: use local file implementation
# EmbeddingTransferMode.NIXL_WRITE: use NIXL writer as initiator (direct NIXL API calls)
# EmbeddingTransferMode.NIXL_READ: use NIXL reader as initiator (nixl_connect)
transfer_type: EmbeddingTransferMode = EmbeddingTransferMode.LOCAL
class BatchTransferRequest(BaseModel):
requests: list[TransferRequest]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
import uvloop
from protocol import BatchTransferRequest, EmbeddingTransferMode, TransferConfig
from dynamo.common.multimodal.embedding_transfer import (
LocalEmbeddingReceiver,
NixlReadEmbeddingReceiver,
NixlWriteEmbeddingReceiver,
)
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
logger = logging.getLogger(__name__)
configure_dynamo_logging()
class Receiver:
def __init__(self, runtime: DistributedRuntime):
self.runtime = runtime
self.local_receiver = LocalEmbeddingReceiver()
self.write_receiver = NixlWriteEmbeddingReceiver(2 * 8 * 1024 * 256 * 1024 * 3)
self.read_receiver = NixlReadEmbeddingReceiver(
embedding_hidden_size=8 * 1024, max_item_mm_token=1024
)
self.config = TransferConfig()
def get_run_config(self):
# Select the variant of sender/receiver based on config
if self.config.transfer_type == EmbeddingTransferMode.LOCAL:
receiver = self.local_receiver
elif self.config.transfer_type == EmbeddingTransferMode.NIXL_WRITE:
receiver = self.write_receiver
elif self.config.transfer_type == EmbeddingTransferMode.NIXL_READ:
receiver = self.read_receiver
else:
raise ValueError(f"Invalid transfer type: {self.config.transfer_type}")
# other fields in self.config are sender-side config, receiver only
# relies on BatchTransferRequest for completing the transfer.
return receiver
async def async_init(self):
self.sender_write_endpoint = self.runtime.endpoint(
"embedding_transfer.sender.write"
)
self.send_client = await self.sender_write_endpoint.client()
# await self.send_client.wait_for_instances()
async def batch_receive(self, batch_transfer_request: BatchTransferRequest):
receiver = self.get_run_config()
tasks = [
asyncio.create_task(receiver.receive_embeddings(tr))
for tr in batch_transfer_request.requests
]
responses = await asyncio.gather(*tasks, return_exceptions=True)
first_error = None
for result in responses:
if isinstance(result, Exception):
first_error = first_error or result
continue
tensor_id, _ = result
receiver.release_tensor(tensor_id)
if first_error:
raise first_error
async def generate(self, request):
stream = await self.send_client.round_robin("send_request")
async for response in stream:
await self.batch_receive(
BatchTransferRequest.model_validate_json(response.data())
)
yield "done"
async def read(self, request):
await self.batch_receive(BatchTransferRequest.model_validate_json(request))
yield "done"
async def update_config(self, request):
request = TransferConfig.model_validate_json(request)
self.config = request
yield "config updated"
@dynamo_worker()
async def worker(runtime: DistributedRuntime):
namespace_name = "embedding_transfer"
component_name = "receiver"
worker = Receiver(runtime)
await worker.async_init()
logger.info(f"Created service {namespace_name}/{component_name}")
logger.info(f"Serving endpoint {namespace_name}.{component_name}.generate")
logger.info(f"Serving endpoint {namespace_name}.{component_name}.read")
logger.info(f"Serving endpoint {namespace_name}.{component_name}.update_config")
generate_endpoint = runtime.endpoint(f"{namespace_name}.{component_name}.generate")
read_endpoint = runtime.endpoint(f"{namespace_name}.{component_name}.read")
update_config_endpoint = runtime.endpoint(
f"{namespace_name}.{component_name}.update_config"
)
await asyncio.gather(
*[
generate_endpoint.serve_endpoint(worker.generate),
read_endpoint.serve_endpoint(worker.read),
update_config_endpoint.serve_endpoint(worker.update_config),
]
)
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
import torch
import uvloop
from protocol import BatchTransferRequest, EmbeddingTransferMode, TransferConfig
from dynamo.common.multimodal.embedding_transfer import (
LocalEmbeddingSender,
NixlReadEmbeddingSender,
NixlWriteEmbeddingSender,
)
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
logger = logging.getLogger(__name__)
configure_dynamo_logging()
class Sender:
def __init__(self, runtime: DistributedRuntime):
self.runtime = runtime
self.local_sender = LocalEmbeddingSender()
self.read_sender = NixlReadEmbeddingSender()
self.write_sender = NixlWriteEmbeddingSender()
# GPU tensor to mimic encoder output
self.cpu_tensor = torch.randn([256, 8 * 1024], dtype=torch.float16)
self.gpu_tensor = (
torch.randn([256, 8 * 1024], dtype=torch.float16, device="cuda")
if torch.cuda.is_available()
else None
)
self.config = TransferConfig()
def get_run_config(self):
if self.config.use_gpu and self.gpu_tensor is None:
raise RuntimeError("GPU mode requested but CUDA is not available.")
# Select the variant of sender/receiver based on config
if self.config.transfer_type == EmbeddingTransferMode.LOCAL:
sender = self.local_sender
elif self.config.transfer_type == EmbeddingTransferMode.NIXL_WRITE:
sender = self.write_sender
elif self.config.transfer_type == EmbeddingTransferMode.NIXL_READ:
sender = self.read_sender
else:
raise ValueError(f"Invalid transfer type: {self.config.transfer_type}")
tensor = self.gpu_tensor if self.config.use_gpu else self.cpu_tensor
tensor_count = self.config.tensor_count_per_request
return sender, tensor, tensor_count
async def async_init(self):
self.receiver_read_endpoint = self.runtime.endpoint(
"embedding_transfer.receiver.read"
)
self.read_client = await self.receiver_read_endpoint.client()
# await self.read_client.wait_for_instances()
async def generate(self, request: str):
# Select the variant of sender/receiver based on config
sender, tensor, tensor_count = self.get_run_config()
request = BatchTransferRequest(requests=[])
futures = []
for _ in range(tensor_count):
transfer_request, send_future = await sender.send_embeddings(
tensor, stage_embeddings=True
)
request.requests.append(transfer_request)
futures.append(send_future)
stream = await self.read_client.round_robin(request.model_dump_json())
async for response in stream:
continue
await asyncio.gather(*futures)
yield "done"
async def write(self, request: str):
# Select the variant of sender/receiver based on config
sender, tensor, tensor_count = self.get_run_config()
response = BatchTransferRequest(requests=[])
futures = []
for _ in range(tensor_count):
transfer_request, send_future = await sender.send_embeddings(
tensor, stage_embeddings=True
)
response.requests.append(transfer_request)
futures.append(send_future)
yield response.model_dump_json()
await asyncio.gather(*futures)
async def update_config(self, request: str):
request = TransferConfig.model_validate_json(request)
self.config = request
yield "config updated"
@dynamo_worker()
async def worker(runtime: DistributedRuntime):
namespace_name = "embedding_transfer"
component_name = "sender"
worker = Sender(runtime)
await worker.async_init()
logger.info(f"Created service {namespace_name}/{component_name}")
logger.info(f"Serving endpoint {namespace_name}.{component_name}.generate")
logger.info(f"Serving endpoint {namespace_name}.{component_name}.write")
logger.info(f"Serving endpoint {namespace_name}.{component_name}.update_config")
generate_endpoint = runtime.endpoint(f"{namespace_name}.{component_name}.generate")
write_endpoint = runtime.endpoint(f"{namespace_name}.{component_name}.write")
update_config_endpoint = runtime.endpoint(
f"{namespace_name}.{component_name}.update_config"
)
await asyncio.gather(
*[
generate_endpoint.serve_endpoint(worker.generate),
write_endpoint.serve_endpoint(worker.write),
update_config_endpoint.serve_endpoint(worker.update_config),
]
)
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import time
import uvloop
from protocol import EmbeddingTransferMode, TransferConfig
from dynamo.runtime import DistributedRuntime, dynamo_worker
NUM_REQUESTS = 100
@dynamo_worker()
async def worker(runtime: DistributedRuntime):
# Get endpoint (sender -> receiver)
sender_endpoint = runtime.endpoint("embedding_transfer.sender.generate")
receiver_endpoint = runtime.endpoint("embedding_transfer.receiver.generate")
sender_update_config_endpoint = runtime.endpoint(
"embedding_transfer.sender.update_config"
)
receiver_update_config_endpoint = runtime.endpoint(
"embedding_transfer.receiver.update_config"
)
# Create client and wait for service to be ready
sender_client = await sender_endpoint.client()
await sender_client.wait_for_instances()
receiver_client = await receiver_endpoint.client()
await receiver_client.wait_for_instances()
sender_update_config_client = await sender_update_config_endpoint.client()
await sender_update_config_client.wait_for_instances()
receiver_update_config_client = await receiver_update_config_endpoint.client()
await receiver_update_config_client.wait_for_instances()
# NOTE From CPU is not the same as E/PD, E/PD originates from GPU and has
# GPU to CPU copy
for transfer_type in [
EmbeddingTransferMode.LOCAL,
EmbeddingTransferMode.NIXL_WRITE,
EmbeddingTransferMode.NIXL_READ,
]:
for workflow_string, client in [
("receiver-first", receiver_client),
("sender-first", sender_client),
]:
for use_gpu in [False, True]:
# Update sender/receiver config before each run
config = TransferConfig(
use_gpu=use_gpu,
tensor_count_per_request=30,
transfer_type=transfer_type,
)
async for res in await sender_update_config_client.round_robin(
config.model_dump_json()
):
pass
async for res in await receiver_update_config_client.round_robin(
config.model_dump_json()
):
pass
if transfer_type == EmbeddingTransferMode.NIXL_READ and use_gpu:
print(
f"Skipping: use_gpu={use_gpu} with transfer type: {transfer_type}"
)
print(
"Reason: nixl_connect errors out on GPU tensor, i.e. NIXL_ERR_NOT_ALLOWED"
)
continue
num_requests = NUM_REQUESTS
try:
print(
f"Workflow: {workflow_string}, From GPU: {use_gpu}, Transfer Type: {transfer_type}"
)
# warm up
async for response in await client.round_robin(
"world,sun,moon,star"
):
continue
start_time = time.perf_counter()
streams = [
await client.round_robin("world,sun,moon,star")
for _ in range(num_requests)
]
for stream in streams:
async for response in stream:
continue
end_time = time.perf_counter()
print(f"Time taken: {end_time - start_time:.2f} seconds")
except Exception as e:
# Log the exception with context
print(f"Error in worker: {type(e).__name__}: {e}")
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
......@@ -12,3 +12,11 @@ class DisaggregationMode(Enum):
AGGREGATED = "agg"
PREFILL = "prefill"
DECODE = "decode"
class EmbeddingTransferMode(Enum):
"""Embedding transfer mode for LLM workers."""
LOCAL = "local"
NIXL_WRITE = "nixl-write"
NIXL_READ = "nixl-read"
......@@ -7,8 +7,10 @@ from dynamo.common.multimodal.async_encoder_cache import AsyncEncoderCache
from dynamo.common.multimodal.embedding_transfer import (
LocalEmbeddingReceiver,
LocalEmbeddingSender,
NixlPersistentEmbeddingReceiver,
NixlPersistentEmbeddingSender,
NixlReadEmbeddingReceiver,
NixlReadEmbeddingSender,
NixlWriteEmbeddingReceiver,
NixlWriteEmbeddingSender,
TransferRequest,
)
from dynamo.common.multimodal.image_loader import ImageLoader
......@@ -16,8 +18,10 @@ from dynamo.common.multimodal.image_loader import ImageLoader
__all__ = [
"AsyncEncoderCache",
"ImageLoader",
"NixlPersistentEmbeddingReceiver",
"NixlPersistentEmbeddingSender",
"NixlReadEmbeddingReceiver",
"NixlReadEmbeddingSender",
"NixlWriteEmbeddingSender",
"NixlWriteEmbeddingReceiver",
"TransferRequest",
"LocalEmbeddingReceiver",
"LocalEmbeddingSender",
......
......@@ -2,16 +2,20 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import base64
import logging
import math
import os
import tempfile
import time
import uuid
from abc import ABC, abstractmethod
from queue import Queue
from typing import Any, List
from typing import Any, List, Optional
import msgpack
import torch
from nixl._api import nixl_agent, nixl_agent_config
from pydantic import BaseModel
from safetensors import torch as safetensors_torch
......@@ -215,18 +219,346 @@ class LocalEmbeddingReceiver(AbstractEmbeddingReceiver):
del self.received_tensors[tensor_id]
class NixlEmbeddingSender(AbstractEmbeddingSender):
class MonolithicCounter:
"""
The EmbeddingSender implementation of current usage of NIXL connect library,
which creates a new NIXL connection for each send operation. Only implemented here
for reference and should not be used due to overhead discovered in practice.
A simple counter implementation for generating unique IDs.
"""
def __init__(self):
self.connector = nixl_connect.Connector()
self.counter = 0
def get_next_id(self) -> int:
current_id = self.counter
self.counter += 1
return current_id
class RingBuffer:
"""
A ring buffer implementation for managing memory allocation.
Uses a circular buffer pattern to efficiently reuse memory without wrapped-around allocations.
When insufficient space remains at the end, allocation restarts from the beginning.
"""
BufferId = int
def __init__(self, buffer_size):
self.buffer_tensor = torch.zeros(buffer_size, dtype=torch.int8)
# Index tracking for the ring buffer, when
# free_start_idx < allocated_start_idx, the allocation has been wrapped around,
# so the allocation request should be rejected if the requested size is larger
# than the remaining space before allocated_start_idx.
self.free_start_idx = 0
self.allocated_start_idx = 0
self.buffer_size = buffer_size
self.end_idx = buffer_size
self.wrapped_around = False
# Track allocated buffers and their release state,
# keeping released range in 'freed_list' for simpler monotonical buffer release
self.freed_list = {}
self.allocated_buffer_id_to_range = {}
# For generate buffer IDs
self.id_counter = MonolithicCounter()
def __repr__(self):
return f"RingBuffer(size={self.buffer_size}, free_start_idx={self.free_start_idx}, allocated_start_idx={self.allocated_start_idx}, wrapped_around={self.wrapped_around}, freed_list={self.freed_list}, allocated_buffers={self.allocated_buffer_id_to_range})"
def _flush_freed_list(self):
allocated_end = self.freed_list.pop(self.allocated_start_idx, None)
while allocated_end is not None:
self.allocated_start_idx = allocated_end
if self.allocated_start_idx == self.end_idx:
self.allocated_start_idx = 0
self.wrapped_around = False
allocated_end = self.freed_list.pop(self.allocated_start_idx, None)
# No allocated buffer, reset indices. Important as the ring buffer doesn't
# support non-contiguous allocation, this make sure the next allocation can
# use the full buffer.
if not self.allocated_buffer_id_to_range:
self.free_start_idx = 0
self.allocated_start_idx = 0
self.wrapped_around = False
def get_buffer(self, size):
"""
Get a buffer of given size in the form of 1D tensor with dtype int8,
the buffer is owned by the RingBuffer instance.
The returned ID will be used for releasing the buffer after use, as
an indicator that the buffer can be reused for future allocation.
Args:
size: The size of the buffer to allocate.
Returns:
A tuple containing the buffer ID and the allocated tensor, or None if allocation fails.
"""
# [gluo TODO] raise exception as there is no way to satisfy the request.
# Can not allocate for sure
if size > self.buffer_size:
return None, None
# Sanity clean up freed list
self._flush_freed_list()
# If the allocation will go over end boundary, simply try allocate from the start
if self.free_start_idx + size > self.end_idx:
# Not enough space even after wrap around, reject the allocation early
# so we don't mark the remaining space "used"
if self.allocated_start_idx < size:
return None, None
# add artificial entry to freed_list to treat the remaining space to be
# allocated and released.
self.freed_list[self.free_start_idx] = self.end_idx
self.free_start_idx = 0
self.wrapped_around = True
start_idx = self.free_start_idx
end_idx = start_idx + size
# Check availability of the buffer, if the allocation overlaps with allocated buffer,
# return None for the caller to retry later after some buffers are released.
if self.wrapped_around and end_idx > self.allocated_start_idx:
return None, None
# book-keep allocations
buffer_id = self.id_counter.get_next_id()
self.allocated_buffer_id_to_range[buffer_id] = (start_idx, end_idx)
self.free_start_idx = end_idx
return buffer_id, self.buffer_tensor[start_idx:end_idx]
def release_buffer(self, buffer_id):
start_end = self.allocated_buffer_id_to_range.pop(buffer_id, None)
if start_end is not None:
self.freed_list[start_end[0]] = start_end[1]
self._flush_freed_list()
class NixlTransferRequest(BaseModel):
"""
A TransferRequest subclass that includes additional fields specific to NIXL-based embedding transfer.
"""
sender_agent_id: str
# metadata of the given agent ID, can be None if
# sender determines that the receiver already connected to the sender.
agent_metadata: Optional[str]
# The ID of the tensor to be written
tensor_id: int
tensor_size: int
class NixlWriteEmbeddingSender(AbstractEmbeddingSender):
"""NIXL WRITE-based implementation of the embedding sender interface.
Designed for scenarios where the sender transmits dynamically allocated
tensors. Because these tensors allocation is external to the sender,
NIXL memory registration will perform on each send request. The receiver
will manage a pre-allocated buffer, so its NIXL metadata is consistent once
initialized. In such acenarios, let sender initiate the WRITE operations requires
minimal metadata exchange.
Protocol:
1. Record the receiver NIXL metadata, this is done:
* Implicitly through the first transfer request as fallback if the metadata
hasn't been recorded.
* [REMOVED] Explicitly through add_agent() API before calling send_embeddings().
The receiver provides get_agent_metadata() API to return its NIXL metadata.
This complicates the implementation and add extra responsiblity on the caller side,
will revisit the necessity if metadata exchange overhead is significant.
2. The sender prepares the embeddings and produces a TransferRequest
containing sender contact and tensor metadata (shape, dtype, size, etc).
3. The receiver responds with (optional) receiver contact, target tensor
metadata (buffer address, device, etc) and done signal through NIXL notification.
4. The sender performs a NIXL WRITE to push the data into the
receiver's buffer.
"""
def __init__(self):
# NIXL agent setup
self.sender_id = f"sender_{str(uuid.uuid4())}"
self.nixl_agent = nixl_agent(
self.sender_id, nixl_agent_config(num_threads=8, capture_telemetry=True)
)
self.remote_agents = {}
self.agent_metadata = self.nixl_agent.get_agent_metadata()
self.agent_metadata_b64 = base64.b64encode(self.agent_metadata).decode("utf-8")
# tracker for the prepared embeddings
self.transfer_tracker = {}
# Track dynamically registered descriptors for cleanup,
# there can be case of the same tensor being requested to be transferred multiple times,
# we want to avoid duplicated registration or early deregistration while other transfer
# of the tensor is still in-flight, so we track the inflight transfer with respect to
# the actual tensor buffer and only deregister after all transfers of the same tensor is completed.
self.registered_descs = {}
self.id_counter = MonolithicCounter()
# Background transfer task..
# Create a queue hinting whether the sender is expecting future transfer
self.transfer_queue = asyncio.Queue()
self._state_update_task = asyncio.create_task(self._state_update())
self.transfer_timeout = 60 # seconds, can be tuned based on expected transfer time and network condition
def __del__(self):
self._state_update_task.cancel()
async def _state_update(self):
"""Long-running async task that processes transfer requests."""
inflight_transfers = {}
scheduled_transfer_task = None
while True:
try:
# If there is no scheduled transfer task, blocking wait for
# a new transfer request because no state needs to be updated.
if scheduled_transfer_task is None:
scheduled_transfer_task = await self.transfer_queue.get()
# check if write is requested, initiate the write
write_requests = self._get_receiver_handshakes()
for (
remote_agent_id,
remote_agent_metadata,
tensor_id,
(target_buffer, target_byte_size, target_device_id, target_mem_str),
write_done_id,
) in write_requests:
# Just in time add remote agent if not added
if remote_agent_id not in self.remote_agents:
if len(remote_agent_metadata) == 0:
logger.error(
f"Received transfer notification from unknown agent {remote_agent_id} without metadata, cannot add remote agent for transfer"
)
# Can't proceed with the transfer without receiver metadata,
# mark the transfer as completed to unblock the sender.
self._complete_transfer(tensor_id)
continue
self.remote_agents[
remote_agent_id
] = self.nixl_agent.add_remote_agent(remote_agent_metadata)
# initiate NIXL WRITE transfer
source_tensor, source_desc, _ = self.transfer_tracker[tensor_id]
target_desc = self.nixl_agent.get_xfer_descs(
[
(target_buffer, target_byte_size, target_device_id),
],
mem_type=target_mem_str,
)
done_signal = str(write_done_id).encode()
xfer_handle = self.nixl_agent.initialize_xfer(
"WRITE",
source_desc,
target_desc,
remote_agent_id,
done_signal,
)
self.nixl_agent.transfer(xfer_handle, done_signal)
inflight_transfers[tensor_id] = [
xfer_handle,
time.perf_counter(),
]
# check inflight transfer state, if completed, get another task to match
# remaining transfers count
# use list() to create a copy of the dict items since the dict will be modified in the loop
now_time = time.perf_counter()
for tensor_id, (
xfer_handle,
start_time,
) in list(inflight_transfers.items()):
state = self.nixl_agent.check_xfer_state(xfer_handle)
if state == "ERR":
logger.error(f"Transfer failed for tensor_id {tensor_id}")
elif state == "DONE":
logger.debug(
f"Send completed for tensor_id {tensor_id}, total wait time: {now_time - start_time:.2f} seconds"
)
else:
# still in-flight, check again later
if now_time - start_time > self.transfer_timeout:
logger.warning(
f"Transfer for tensor_id {tensor_id} has been in-flight for more than {self.transfer_timeout} seconds, reseting its timer"
)
inflight_transfers[tensor_id][1] = now_time
continue
# NOTE future is set with result None in "ERR" and "DONE", so the sender will not
# be able to distinguish failure with success, we can consider
# adding more explicit failure signal in the future if needed.
self._complete_transfer(tensor_id)
inflight_transfers.pop(tensor_id)
try:
scheduled_transfer_task = self.transfer_queue.get_nowait()
except asyncio.QueueEmpty:
if inflight_transfers:
logger.error(
f"Unexpected no scheduled transfer request, while there are still {len(inflight_transfers)} inflight transfers"
)
# Continue the loop to check the state of remaining inflight transfers
continue
logger.debug("No pending transfer task in the queue.")
scheduled_transfer_task = None
break
# short pause to yield control and allow cancellation
await asyncio.sleep(0.001)
except Exception as e:
logger.error(f"Error in state update loop: {e}")
await asyncio.sleep(1) # Backoff on error to prevent tight error loop
def _get_receiver_handshakes(self):
write_requests = []
notifs = self.nixl_agent.get_new_notifs()
for remote_agent_id, notifs in notifs.items():
for notif in notifs:
(
tensor_id,
(target_buffer, target_byte_size, target_device_id, target_mem_str),
write_done_id,
remote_agent_metadata,
) = msgpack.unpackb(notif)
write_requests.append(
(
# receiver contact
remote_agent_id,
remote_agent_metadata,
# source tensor
tensor_id,
# target tensor
# (note byte size can be retrieved from source tensor)
(
target_buffer,
target_byte_size,
target_device_id,
target_mem_str,
),
# done signal
write_done_id,
)
)
return write_requests
def _complete_transfer(self, tensor_id):
transfer_info = self.transfer_tracker.pop(tensor_id, None)
if transfer_info is not None:
# Clean up registered memory after transfer completion
embeddings, _, fut = transfer_info
desc_key = (embeddings.data_ptr(), embeddings.get_device())
self.registered_descs[desc_key][1] -= 1
if self.registered_descs[desc_key][1] == 0:
self.nixl_agent.deregister_memory(self.registered_descs[desc_key][0])
del self.registered_descs[desc_key]
# Future can be 'done' if the embeddings is not external
# (send_embeddings with stage_embeddings=False)
if not fut.done():
fut.set_result(None)
async def send_embeddings(
self, embeddings: torch.Tensor, stage_embeddings: bool = False
self,
embeddings: torch.Tensor,
stage_embeddings: bool = False,
) -> tuple[TransferRequest, asyncio.Future]:
"""
Send precomputed embeddings.
......@@ -238,66 +570,186 @@ class NixlEmbeddingSender(AbstractEmbeddingSender):
Returns:
A tuple containing the TransferRequest object and a future that can be awaited to indicate the send is completed.
"""
tensor_id = self.id_counter.get_next_id()
fut = asyncio.get_event_loop().create_future()
if not stage_embeddings:
embeddings = embeddings.clone().detach()
fut.set_result(None)
# In case the same embedding tensor is sent multiple times,
# we want to avoid potential issues with duplicated NIXL memory registration.
desc_key = (embeddings.data_ptr(), embeddings.get_device())
if desc_key not in self.registered_descs:
registered_desc = self.nixl_agent.register_memory(embeddings)
self.registered_descs[desc_key] = [registered_desc, 1]
else:
self.registered_descs[desc_key][1] += 1
descriptor = nixl_connect.Descriptor(embeddings.cpu())
readable_op = await self.connector.create_readable(descriptor)
desc = self.nixl_agent.get_xfer_descs(embeddings)
# use tracker to also extend lifecycle of transfer-related objects
self.transfer_tracker[tensor_id] = (embeddings, desc, fut)
self.transfer_queue.put_nowait("task_indicator")
request = TransferRequest(
embeddings_shape=list(embeddings.shape),
embedding_dtype_str=torch_dtype_to_string(embeddings.dtype),
serialized_request=readable_op.metadata().model_dump(),
serialized_request=NixlTransferRequest(
sender_agent_id=self.sender_id,
agent_metadata=self.agent_metadata_b64,
tensor_id=tensor_id,
tensor_size=embeddings.nbytes,
).model_dump_json(),
)
return request, readable_op.wait_for_completion()
return request, fut
class NixlEmbeddingReceiver(AbstractEmbeddingReceiver):
class NixlWriteEmbeddingReceiver(AbstractEmbeddingReceiver):
"""
The EmbeddingReceiver implementation of current usage of NIXL connect library,
which creates a new NIXL connection for each send operation. Only implemented here
for reference and should not be used due to overhead discovered in practice.
Counter part of 'NixlWriteEmbeddingSender', see 'NixlWriteEmbeddingSender' for details.
The receiver manages a ring buffer for sender to write the embeddings into, and respond
to the sender's transfer request with the buffer information for the WRITE transfer.
"""
def __init__(self):
super().__init__()
self.connector = nixl_connect.Connector()
self.tensor_id_counter = 0
def __init__(self, buffer_size=2 * 8 * 1024 * 1024 * 256 * 2):
# the default buffer_size is the product of:
# 2 (typical dtype size float16)
# 8 * 1024 (typical embedding hidden size for Qwen-VL)
# 256 * 1024 (1024 count of 256 mm token item)
# 2 (extra copies) = 8 GB memory
# ring buffer without wrapped around allocation, i.e. will allocate from
# start if the last remaining buffer is not enough
self.ring_buffer = RingBuffer(buffer_size)
self.transfer_tensor = self.ring_buffer.buffer_tensor
# NIXL agent setup
self.receiver_id = f"receiver_{str(uuid.uuid4())}"
self.nixl_agent = nixl_agent(
self.receiver_id, nixl_agent_config(num_threads=8, capture_telemetry=True)
)
self.remote_agents = {}
self.reg_descs = self.nixl_agent.register_memory(self.transfer_tensor)
self.agent_metadata = self.nixl_agent.get_agent_metadata()
self.id_counter = MonolithicCounter()
self.to_buffer_id = {}
async def receive_embeddings(
self, request: TransferRequest
self, request: TransferRequest, receive_timeout=60
) -> tuple[int, torch.Tensor]:
"""
Receive precomputed embeddings for a given request ID.
Args:
request: The TransferRequest object containing information to receive embeddings for.
receive_timeout: Maximum time to wait for the transfer to complete before raising a TimeoutError.
The timeout will be applied separately for waiting for available buffer and waiting for transfer completion.
Returns:
A tuple containing the tensor ID and the received embeddings as a torch.Tensor.
Caller should invoke release_tensor(tensor_id) when the tensor is no longer needed to free up resources.
"""
# Extract dynamic shape, metadata, and auxiliary data
embeddings_shape = request.embeddings_shape
embeddings_dtype = torch_dtype_from_string(request.embedding_dtype_str)
readable_metadata = nixl_connect.RdmaMetadata.model_validate(
nixl_request = NixlTransferRequest.model_validate_json(
request.serialized_request
)
if nixl_request.sender_agent_id not in self.remote_agents:
if nixl_request.agent_metadata is None:
raise ValueError(
f"Missing agent metadata for new sender {nixl_request.sender_agent_id}"
)
self.remote_agents[
nixl_request.sender_agent_id
] = self.nixl_agent.add_remote_agent(
base64.b64decode(nixl_request.agent_metadata)
)
encodings_tensor = torch.zeros(*embeddings_shape, dtype=embeddings_dtype)
# Create descriptor for our allocated tensor
descriptor = nixl_connect.Descriptor(encodings_tensor)
# Allocate tensor to be written into.
start_time = time.perf_counter()
while True:
buffer_id, transfer_tensor = self.ring_buffer.get_buffer(
nixl_request.tensor_size
)
if transfer_tensor is not None:
break
# No available buffer, wait for a short period and retry.
# The receiver side should have concurrent work on other
# allocated buffer and release them in a timely manner,
# so the wait time should not be long.
#
# NOTE This approach can result in deadlock due to
# the current usage of the receiver:
# The case of concurrent requests may request 2 buffer in order,
# if all request get the first buffer and exhaust the ring buffer,
# then no request can get the second buffer and proceed.
# On raising the timeout error from this function, the caller must
# release all previously allocated tensor of the request to unblock
# other requests, and retry the request after some delay to avoid
# repeated deadlock.
# [gluo WIP] provide an API for batch allocation so some requests can
# proceed.
if time.perf_counter() - start_time > receive_timeout:
raise TimeoutError("Timeout while waiting for available buffer.")
await asyncio.sleep(0.005)
# view as tensor matching the source tensor..
embeddings_shape = request.embeddings_shape
embeddings_dtype = torch_dtype_from_string(request.embedding_dtype_str)
embedding_tensor = transfer_tensor.view(dtype=embeddings_dtype).view(
embeddings_shape
)
# Create read operation to read from EncodeHandler
read_op = await self.connector.begin_read(readable_metadata, descriptor)
with read_op:
# Wait for the read operation to complete
await read_op.wait_for_completion()
logging.debug(
f"Successfully read embeddings via NIXL: {encodings_tensor.shape}"
# Request for transfer
tensor_id = self.id_counter.get_next_id()
notif_msg = msgpack.packb(
(
nixl_request.tensor_id,
(
transfer_tensor.data_ptr(),
nixl_request.tensor_size,
# torch returns -1 for CPU device, need to normalized there
max(transfer_tensor.get_device(), 0),
"cuda" if str(transfer_tensor.device).startswith("cuda") else "cpu",
),
tensor_id,
# side channel handshake fallback for receiver API consistency,
# this will increase message size for the first few transfers before handshake
self.agent_metadata if nixl_request.agent_metadata else b"",
)
tensor_id = self.tensor_id_counter
self.tensor_id_counter += 1
return tensor_id, encodings_tensor
)
self.nixl_agent.send_notif(nixl_request.sender_agent_id, notif_msg=notif_msg)
# await for write notification
start_time = time.perf_counter()
done_signal = str(tensor_id).encode()
found = False
while not found:
# parse notifications to find done signal, we can't use 'check_remote_xfer_done' API
# because it match requested string pattern in substring of the notifications instead
# of exact match, which is not what we want, i.e. for two done signal "1" and "11",
# 'check_remote_xfer_done("1")' will return True for both signal and "11" will be cleared
# as a result, leading the subsequent 'check_remote_xfer_done("1")' returns False.
notifs = self.nixl_agent.update_notifs()
if nixl_request.sender_agent_id in notifs:
for notif in notifs[nixl_request.sender_agent_id]:
if notif == done_signal:
self.nixl_agent.notifs[nixl_request.sender_agent_id].remove(
notif
)
found = True
break
await asyncio.sleep(0.001)
# Waited for too long without transfer completion, log for debugging
if (time.perf_counter() - start_time) > receive_timeout:
self.ring_buffer.release_buffer(buffer_id)
raise TimeoutError(
f"Timeout while waiting for transfer completion for tensor_id {tensor_id} for more than {receive_timeout} seconds"
)
logger.debug(
f"Transfer completed for tensor_id {tensor_id}, total wait time: {time.perf_counter() - start_time:.2f} seconds"
)
self.to_buffer_id[tensor_id] = buffer_id
return tensor_id, embedding_tensor
def release_tensor(self, tensor_id: int):
"""
......@@ -306,8 +758,8 @@ class NixlEmbeddingReceiver(AbstractEmbeddingReceiver):
Args:
tensor_id: The ID of the tensor to release.
"""
# receiver doesn't hold the embedding
pass
buffer_id = self.to_buffer_id.pop(tensor_id)
self.ring_buffer.release_buffer(buffer_id)
class PersistentConnector(nixl_connect.Connector):
......@@ -337,11 +789,15 @@ def remote_release_overwrite(self) -> None:
nixl_connect.Remote._release = remote_release_overwrite
class NixlPersistentEmbeddingSender(AbstractEmbeddingSender):
class NixlReadEmbeddingSender(AbstractEmbeddingSender):
"""
Initial implementation of another usage of NIXL connect library that persists
Initial implementation of NIXL READ based transfer. This implementation uses
a monkey-patched version of 'nixl_connect' wrapper library to persist
connection (agent registration) and descriptors across multiple send operations
to avoid the overhead of repeated connection setup and teardown.
NOTE This implementation or the use of 'nixl_connect' needs to be revisited as
the benchmarking result is unexpectedly slow. Keeping it now for completeness,
i.e. provide NIXL WRITE based and READ based transfer classes.
"""
def __init__(self):
......@@ -376,8 +832,9 @@ class NixlPersistentEmbeddingSender(AbstractEmbeddingSender):
return request, readable_op.wait_for_completion()
class NixlPersistentEmbeddingReceiver(AbstractEmbeddingReceiver):
class NixlReadEmbeddingReceiver(AbstractEmbeddingReceiver):
"""
Counter part of 'NixlReadEmbeddingSender', see 'NixlReadEmbeddingSender' for details.
Initial implementation of another usage of NIXL connect library that persists
connection (agent registration) and descriptors (memory registration) across multiple send operations
to avoid the overhead of repeated connection setup and teardown.
......@@ -389,7 +846,7 @@ class NixlPersistentEmbeddingReceiver(AbstractEmbeddingReceiver):
"""
def __init__(
self, embedding_hidden_size=8 * 1024, max_item_mm_token=1024, max_items=50
self, embedding_hidden_size=8 * 1024, max_item_mm_token=1024, max_items=1024
):
super().__init__()
self.connector = PersistentConnector()
......@@ -440,6 +897,7 @@ class NixlPersistentEmbeddingReceiver(AbstractEmbeddingReceiver):
request.serialized_request
)
original_descriptor_size = None
if self.warmedup_descriptors.empty():
logger.debug(
"No warmed up descriptors available, creating a temporary one for transfer."
......@@ -450,7 +908,9 @@ class NixlPersistentEmbeddingReceiver(AbstractEmbeddingReceiver):
else:
descriptor = self.warmedup_descriptors.get()
# Slide view of pre-allocated tensor
original_descriptor_size = descriptor._data_size
tensor_size_bytes = embeddings_dtype.itemsize * math.prod(embeddings_shape)
descriptor._data_size = tensor_size_bytes
encodings_tensor = (
descriptor._data_ref[:tensor_size_bytes]
.view(dtype=embeddings_dtype)
......@@ -465,6 +925,8 @@ class NixlPersistentEmbeddingReceiver(AbstractEmbeddingReceiver):
logging.debug(
f"Successfully read embeddings via NIXL: {encodings_tensor.shape}"
)
if original_descriptor_size is not None:
descriptor._data_size = original_descriptor_size
tensor_id = self.tensor_id_counter
self.tensor_id_counter += 1
self.inuse_descriptors[tensor_id] = (descriptor, dynamic_descriptor)
......
......@@ -6,6 +6,7 @@
import asyncio
import logging
import time
from random import randint
import pytest
import torch
......@@ -13,18 +14,25 @@ import torch
from dynamo.common.multimodal.embedding_transfer import (
LocalEmbeddingReceiver,
LocalEmbeddingSender,
NixlEmbeddingReceiver,
NixlEmbeddingSender,
NixlPersistentEmbeddingReceiver,
NixlPersistentEmbeddingSender,
NixlReadEmbeddingReceiver,
NixlReadEmbeddingSender,
NixlWriteEmbeddingReceiver,
NixlWriteEmbeddingSender,
RingBuffer,
)
logger = logging.getLogger(__name__)
EMBEDDING_SIZE = 8 * 1024
async def benchmark(sender, receiver, tensors=None):
async def benchmark(sender, receiver, tensors=None, from_cuda=False):
if tensors is None:
tensors = [torch.randn(256, 8 * 1024) for _ in range(30)]
tensors = [
torch.randn(256, EMBEDDING_SIZE, device="cuda" if from_cuda else "cpu")
for _ in range(30)
]
# warmup
request, send_future = await sender.send_embeddings(tensors[0])
tensor_id, response = await receiver.receive_embeddings(request)
......@@ -45,6 +53,7 @@ async def benchmark(sender, receiver, tensors=None):
asyncio.create_task(receiver.receive_embeddings(request[0]))
for request in requests
]
responses = await asyncio.gather(*receive_tasks)
receive_end = time.perf_counter()
logger.info(
......@@ -52,7 +61,7 @@ async def benchmark(sender, receiver, tensors=None):
)
for tensor, request, response in zip(tensors, requests, responses):
tensor_id, received_tensor = response
assert torch.equal(received_tensor, tensor)
assert torch.equal(received_tensor, tensor.cpu())
receiver.release_tensor(tensor_id)
await request[1]
......@@ -86,32 +95,226 @@ class TestLocalEmbeddingTransfer:
receiver = LocalEmbeddingReceiver()
await benchmark(sender, receiver)
@pytest.mark.asyncio
@pytest.mark.gpu_1
async def test_gpu_benchmark(self):
sender = LocalEmbeddingSender()
receiver = LocalEmbeddingReceiver()
await benchmark(sender, receiver, from_cuda=True)
@pytest.mark.xfail(run=False, reason="slow")
@pytest.mark.asyncio
@pytest.mark.gpu_0 # Echo tensor worker is CPU-only (no GPU required)
class TestNixlEmbeddingTransfer:
@pytest.mark.gpu_1 # NIXL init requires proper CUDA environment
class TestNixlWriteEmbeddingTransfer:
async def test_correctness(self):
sender = NixlEmbeddingSender()
receiver = NixlEmbeddingReceiver()
sender = NixlWriteEmbeddingSender()
receiver = NixlWriteEmbeddingReceiver()
await correctness(sender, receiver)
async def test_benchmark(self):
sender = NixlEmbeddingSender()
receiver = NixlEmbeddingReceiver()
sender = NixlWriteEmbeddingSender()
receiver = NixlWriteEmbeddingReceiver()
await benchmark(sender, receiver)
async def test_gpu_benchmark(self):
sender = NixlWriteEmbeddingSender()
receiver = NixlWriteEmbeddingReceiver()
await benchmark(sender, receiver, from_cuda=True)
@pytest.mark.asyncio
@pytest.mark.gpu_0 # Echo tensor worker is CPU-only (no GPU required)
class TestNixlPersistentEmbeddingTransfer:
@pytest.mark.gpu_1 # NIXL init requires proper CUDA environment
class TestNixlReadEmbeddingTransfer:
async def test_correctness(self):
sender = NixlPersistentEmbeddingSender()
receiver = NixlPersistentEmbeddingReceiver()
sender = NixlReadEmbeddingSender()
receiver = NixlReadEmbeddingReceiver()
await correctness(sender, receiver)
async def test_benchmark(self):
sender = NixlPersistentEmbeddingSender()
receiver = NixlPersistentEmbeddingReceiver()
sender = NixlReadEmbeddingSender()
receiver = NixlReadEmbeddingReceiver(embedding_hidden_size=EMBEDDING_SIZE)
await benchmark(sender, receiver)
async def test_gpu_benchmark(self):
sender = NixlReadEmbeddingSender()
receiver = NixlReadEmbeddingReceiver(embedding_hidden_size=EMBEDDING_SIZE)
await benchmark(sender, receiver, from_cuda=True)
@pytest.mark.gpu_0 # Echo tensor worker is CPU-only (no GPU required)
class TestRingBuffer:
def test_simple(self):
buffer_size = 128
ring_buffer = RingBuffer(buffer_size)
# Fill buffer for debugging
for idx in range(buffer_size):
ring_buffer.buffer_tensor[idx] = idx
for byte_size in [32, 64, 128]:
id, tensor = ring_buffer.get_buffer(byte_size)
assert id is not None, f"Failed to get buffer for size {byte_size}"
assert tensor is not None, f"Failed to get tensor for size {byte_size}"
assert (
tensor.nbytes == byte_size
), f"Expected buffer of size {byte_size}, got {tensor.nbytes}"
ring_buffer.release_buffer(id)
# Test allocation that exceeds buffer size
id, tensor = ring_buffer.get_buffer(buffer_size + 1)
assert id is None, "Expected None when requesting buffer larger than capacity"
assert (
tensor is None
), "Expected None when requesting buffer larger than capacity"
def test_release(self):
buffer_size = 128
ring_buffer = RingBuffer(buffer_size)
# Fill buffer for debugging
for idx in range(buffer_size):
ring_buffer.buffer_tensor[idx] = idx
allocated_ids = []
for byte_size in [32, 32, 64]:
id, tensor = ring_buffer.get_buffer(byte_size)
assert id is not None, f"Failed to get buffer for size {byte_size}"
assert tensor is not None, f"Failed to get tensor for size {byte_size}"
assert (
tensor.nbytes == byte_size
), f"Expected buffer of size {byte_size}, got {tensor.nbytes}"
allocated_ids.append(id)
# Release buffers except the first one, ring buffer will not actually reuse the released space
# until the oldest allocated buffer is released, to maintain a simple implementation.
# |-32-|*32*|*64*| (released but not claimed space marked with *)
# | id1| | |
for id in allocated_ids[1:2]:
ring_buffer.release_buffer(id)
failed_id, failed_tensor = ring_buffer.get_buffer(64)
assert (
failed_id is None
), "Expected None when requesting buffer larger than remaining capacity"
assert (
failed_tensor is None
), "Expected None when requesting buffer larger than remaining capacity"
# Release the first allocated buffer to make sure the ring buffer can reuse the released space.
ring_buffer.release_buffer(allocated_ids[0])
# Now we should be able to allocate a buffer of size 64 again
id, tensor = ring_buffer.get_buffer(64)
assert id is not None, "Failed to get buffer after releasing space"
assert tensor is not None, "Failed to get tensor after releasing space"
assert tensor.nbytes == 64, f"Expected buffer of size 64, got {tensor.nbytes}"
def test_wrap_around(self):
buffer_size = 128
ring_buffer = RingBuffer(buffer_size)
# Fill buffer for debugging
for idx in range(buffer_size):
ring_buffer.buffer_tensor[idx] = idx
# 32 bytes remaining after allocating 96 bytes, so this should succeed
# |-32-|-32-|-32-| 32 |
# | id1| id2| id3| |
allocated_id1, tensor1 = ring_buffer.get_buffer(32)
allocated_id2, tensor2 = ring_buffer.get_buffer(32)
allocated_id3, tensor3 = ring_buffer.get_buffer(32)
assert (
allocated_id1 is not None
and allocated_id2 is not None
and allocated_id3 is not None
), "Failed to allocate initial buffers"
assert (
tensor1.nbytes == 32 and tensor2.nbytes == 32 and tensor3.nbytes == 32
), "Expected buffers of size 32"
# Out of space
failed_allocation_id, failed_allocation_tensor = ring_buffer.get_buffer(64)
assert (
failed_allocation_id is None
), "Expected None when requesting buffer larger than remaining capacity"
assert (
failed_allocation_tensor is None
), "Expected None when requesting buffer larger than remaining capacity"
# Release the first buffer to create free space at the beginning,
# but the 64 bytes allocation will fail as we don't allocate
# | 32 |-32-|-32-| 32 |
# | | id2| id3| |
ring_buffer.release_buffer(allocated_id1)
# small allocation okay, and should occupy part of the last 32 bytes
# | 32 |-32-|-32-|-16-| 16 |
# | | id2| id3| id4| |
allocated_id4, tensor4 = ring_buffer.get_buffer(16)
assert (
allocated_id4 is not None
), "Failed to allocate buffer after releasing space"
assert tensor4.nbytes == 16, f"Expected buffer of size 16, got {tensor4.nbytes}"
# Make room for large allocation
# Implementation detail: after wrap around, the tailing free space is marked allocated
# |-64-|-32-|-16-|*16*|
# | id5| id3| id4| |
ring_buffer.release_buffer(allocated_id2)
allocated_id5, tensor5 = ring_buffer.get_buffer(64)
assert (
allocated_id5 is not None
), "Failed to allocate buffer after releasing space"
assert tensor5.nbytes == 64, f"Expected buffer of size 64, got {tensor5.nbytes}"
failed_allocation_id, failed_allocation_tensor = ring_buffer.get_buffer(8)
assert (
failed_allocation_id is None
), "Expected None when requesting buffer larger than remaining capacity"
assert (
failed_allocation_tensor is None
), "Expected None when requesting buffer larger than remaining capacity"
# Release all and make sure we have full capacity again
ring_buffer.release_buffer(allocated_id3)
ring_buffer.release_buffer(allocated_id4)
ring_buffer.release_buffer(allocated_id5)
print(ring_buffer)
allocated_id6, tensor6 = ring_buffer.get_buffer(buffer_size)
assert (
allocated_id6 is not None
), "Failed to allocate buffer for full capacity after releasing all buffers"
assert (
tensor6.nbytes == buffer_size
), f"Expected buffer of size {buffer_size}, got {tensor6.nbytes}"
def test_looping(self):
buffer_size = 64 * 3
ring_buffer = RingBuffer(buffer_size)
# Fill buffer for debugging
for idx in range(buffer_size):
ring_buffer.buffer_tensor[idx] = idx % 128 # int8 max value
allocated_batches: list[int] = []
for _ in range(10):
# On each batch, allocate buffers with total size of 64, afterwards
# release previous batch if any.
# Implementation detail: Each batch takes 1/3 of the buffer to avoid not enough
# space with possible waste of tailing free space after wrap around.
current_batch_ids: list[int] = []
allocated_bytes = 0
while allocated_bytes < 64:
new_byte_size = min(randint(8, 64), 64 - allocated_bytes)
allocated_id, tensor = ring_buffer.get_buffer(new_byte_size)
assert (
allocated_id is not None
), "Failed to allocate buffer in looping test"
assert (
tensor.nbytes == new_byte_size
), f"Expected buffer of size {new_byte_size} in looping test"
allocated_bytes += new_byte_size
current_batch_ids.append(allocated_id)
# Release previous batch
for allocated_id in allocated_batches:
ring_buffer.release_buffer(allocated_id)
allocated_batches = current_batch_ids
......@@ -11,7 +11,7 @@ from dynamo.common.configuration.config_base import ConfigBase
from dynamo.common.configuration.utils import add_argument, add_negatable_bool_argument
from . import __version__
from .constants import DisaggregationMode
from .constants import DisaggregationMode, EmbeddingTransferMode
class DynamoVllmArgGroup(ArgGroup):
......@@ -136,6 +136,16 @@ class DynamoVllmArgGroup(ArgGroup):
),
)
add_argument(
g,
flag_name="--embedding-transfer-mode",
env_var="DYN_VLLM_EMBEDDING_TRANSFER_MODE",
default=EmbeddingTransferMode.NIXL_WRITE.value,
help="Worker embedding transfer mode: 'local' (default, local file system), "
"'nixl-write' (NIXL transfer with WRITE), or 'nixl-read' (NIXL transfer with READ).",
choices=[m.value for m in EmbeddingTransferMode],
)
# vLLM-Omni
add_negatable_bool_argument(
g,
......@@ -325,6 +335,9 @@ class DynamoVllmConfig(ConfigBase):
enable_multimodal: bool
mm_prompt_template: str
frontend_decoding: bool
embedding_transfer_mode: Union[
str, EmbeddingTransferMode
] # resolved to enum in validate()
# vLLM-Omni
omni: bool
......@@ -362,10 +375,18 @@ class DynamoVllmConfig(ConfigBase):
def validate(self) -> None:
"""Validate vLLM wrapper configuration."""
self._resolve_disaggregation_mode()
self._resolve_embedding_transfer_mode()
self._validate_multimodal_role_exclusivity()
self._validate_multimodal_requires_flag()
self._validate_omni_stage_config()
def _resolve_embedding_transfer_mode(self) -> None:
"""Resolve embedding_transfer_mode from string to enum."""
if isinstance(self.embedding_transfer_mode, str):
self.embedding_transfer_mode = EmbeddingTransferMode(
self.embedding_transfer_mode
)
def _resolve_disaggregation_mode(self) -> None:
"""Resolve disaggregation_mode from new enum or legacy boolean flags.
......
......@@ -7,6 +7,6 @@ DisaggregationMode is defined in dynamo.common.constants and re-exported here
so that existing imports from dynamo.vllm.constants continue to work.
"""
from dynamo.common.constants import DisaggregationMode
from dynamo.common.constants import DisaggregationMode, EmbeddingTransferMode
__all__ = ["DisaggregationMode"]
__all__ = ["DisaggregationMode", "EmbeddingTransferMode"]
......@@ -13,9 +13,14 @@ from transformers import AutoImageProcessor
from vllm.engine.arg_utils import AsyncEngineArgs
import dynamo.nixl_connect as connect
from dynamo.common.multimodal import LocalEmbeddingSender, NixlPersistentEmbeddingSender
from dynamo.common.multimodal import (
LocalEmbeddingSender,
NixlReadEmbeddingSender,
NixlWriteEmbeddingSender,
)
from dynamo.runtime import DistributedRuntime
from ..constants import EmbeddingTransferMode
from ..multimodal_utils import (
ImageLoader,
encode_image_embeddings,
......@@ -30,10 +35,10 @@ logger = logging.getLogger(__name__)
CACHE_SIZE_MAXIMUM = 8
# Both embedding transmitter suffers from increasing latency as
# number of concurrent requests increases, NixlPersistentEmbedding transmitters
# [gluo WIP] now it's time to revisit
# Both embedding transfer suffers from increasing latency as
# number of concurrent requests increases, NixlPersistentEmbedding transfers
# scale worse than local. Need to investigate why.
TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1))
# [gluo NOTE] default off to benchmark standalone encoder
ENABLE_ENCODER_CACHE = int(os.getenv("ENABLE_ENCODER_CACHE", 1))
......@@ -49,6 +54,7 @@ class EncodeWorkerHandler:
def __init__(
self,
engine_args: AsyncEngineArgs,
embedding_transfer_mode: EmbeddingTransferMode,
) -> None:
self.engine_args = engine_args
self.model = self.engine_args.model
......@@ -75,11 +81,17 @@ class EncodeWorkerHandler:
self._processed_requests = 0
self.readables = []
self.embedding_cache = EmbeddingCache() if ENABLE_ENCODER_CACHE else None
self.embedding_sender = (
LocalEmbeddingSender()
if TRANSFER_LOCAL
else NixlPersistentEmbeddingSender()
)
if embedding_transfer_mode == EmbeddingTransferMode.LOCAL:
self.embedding_sender = LocalEmbeddingSender()
elif embedding_transfer_mode == EmbeddingTransferMode.NIXL_WRITE:
self.embedding_sender = NixlWriteEmbeddingSender()
elif embedding_transfer_mode == EmbeddingTransferMode.NIXL_READ:
self.embedding_sender = NixlReadEmbeddingSender()
else:
raise ValueError(
f"Invalid embedding transfer mode: {embedding_transfer_mode}"
)
self.send_complete_queue = asyncio.Queue()
self.send_complete_checker_task = asyncio.create_task(
self.check_complete(self.send_complete_queue)
......
......@@ -3,7 +3,6 @@
import copy
import logging
import os
import uuid
from collections import defaultdict
from typing import Any
......@@ -18,12 +17,13 @@ from dynamo.common.memory.multimodal_embedding_cache_manager import (
)
from dynamo.common.multimodal.embedding_transfer import (
LocalEmbeddingReceiver,
NixlPersistentEmbeddingReceiver,
NixlReadEmbeddingReceiver,
NixlWriteEmbeddingReceiver,
)
from dynamo.runtime import Client, DistributedRuntime
from ..args import Config
from ..constants import DisaggregationMode
from ..constants import DisaggregationMode, EmbeddingTransferMode
from ..handlers import BaseWorkerHandler, build_sampling_params
from ..multimodal_utils import (
MyRequestOutput,
......@@ -36,7 +36,6 @@ from ..multimodal_utils.prefill_worker_utils import load_multimodal_embeddings
logger = logging.getLogger(__name__)
IMAGE_URL_KEY = "image_url"
TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1))
class MultimodalPDWorkerHandler(BaseWorkerHandler):
......@@ -95,13 +94,18 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
self._connector: connect.Connector | None = (
None # Will be initialized in async_init
)
# [gluo FIXME] can't use pre-registered tensor as NIXL requires descriptors
# to be at matching size, need to overwrite nixl connect library
self.embedding_receiver = (
LocalEmbeddingReceiver()
if TRANSFER_LOCAL
else NixlPersistentEmbeddingReceiver(max_items=0)
)
if config.embedding_transfer_mode == EmbeddingTransferMode.LOCAL:
self.embedding_receiver = LocalEmbeddingReceiver()
elif config.embedding_transfer_mode == EmbeddingTransferMode.NIXL_WRITE:
self.embedding_receiver = NixlWriteEmbeddingReceiver()
elif config.embedding_transfer_mode == EmbeddingTransferMode.NIXL_READ:
# [gluo FIXME] can't use pre-registered tensor as NIXL requires descriptors
# to be at matching size, need to overwrite nixl connect library
self.embedding_receiver = NixlReadEmbeddingReceiver(max_items=0)
else:
raise ValueError(
f"Invalid embedding transfer mode: {config.embedding_transfer_mode}"
)
logger.info("Multimodal PD Worker has been initialized")
......
......@@ -37,7 +37,7 @@ def _make_config(
multimodal_embedding_cache_capacity_gb: float = 0,
) -> MagicMock:
"""Create a mock Config with the fields used by MultimodalPDWorkerHandler."""
from dynamo.vllm.constants import DisaggregationMode
from dynamo.vllm.constants import DisaggregationMode, EmbeddingTransferMode
config = MagicMock()
config.model = model
......@@ -47,6 +47,9 @@ def _make_config(
if is_prefill_worker
else DisaggregationMode.AGGREGATED
)
# NIXL_WRITE / NIXL_READ modes require GPU, the tests may run in CPU-only environments,
# so set to LOCAL mode.
config.embedding_transfer_mode = EmbeddingTransferMode.LOCAL
config.enable_multimodal = enable_multimodal
config.multimodal_embedding_cache_capacity_gb = (
multimodal_embedding_cache_capacity_gb
......
......@@ -252,7 +252,9 @@ class WorkerFactory:
)
shutdown_endpoints[:] = [generate_endpoint]
handler = EncodeWorkerHandler(config.engine_args)
handler = EncodeWorkerHandler(
config.engine_args, config.embedding_transfer_mode
)
await handler.async_init(runtime)
logger.info("Starting to serve the encode worker endpoint...")
......
......@@ -70,8 +70,11 @@ python -m dynamo.frontend &
EXTRA_ARGS=""
# Embedding transfer: 1 = local file (safetensors), 0 = NIXL RDMA
export TRANSFER_LOCAL=${TRANSFER_LOCAL:-1}
# Embedding transfer:
# "local" = local file (safetensors),
# "nixl-write" = NIXL WRITE transfer
# "nixl-read" = NIXL READ transfer (default: "local")
export DYN_VLLM_EMBEDDING_TRANSFER_MODE=${DYN_VLLM_EMBEDDING_TRANSFER_MODE:-"local"}
# GPU assignments (override via environment variables)
if [[ "$SINGLE_GPU" == "true" ]]; then
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment