Unverified Commit 2cab0f7f authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

fix(perf): add embedding transfer implementation with NIXL WRITE initiation (#6651)


Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
parent da98f6a0
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from pydantic import BaseModel
from dynamo.common.constants import EmbeddingTransferMode
from dynamo.common.multimodal.embedding_transfer import TransferRequest
class TransferConfig(BaseModel):
use_gpu: bool = False
tensor_count_per_request: int = 30
# EmbeddingTransferMode.LOCAL: use local file implementation
# EmbeddingTransferMode.NIXL_WRITE: use NIXL writer as initiator (direct NIXL API calls)
# EmbeddingTransferMode.NIXL_READ: use NIXL reader as initiator (nixl_connect)
transfer_type: EmbeddingTransferMode = EmbeddingTransferMode.LOCAL
class BatchTransferRequest(BaseModel):
requests: list[TransferRequest]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
import uvloop
from protocol import BatchTransferRequest, EmbeddingTransferMode, TransferConfig
from dynamo.common.multimodal.embedding_transfer import (
LocalEmbeddingReceiver,
NixlReadEmbeddingReceiver,
NixlWriteEmbeddingReceiver,
)
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
logger = logging.getLogger(__name__)
configure_dynamo_logging()
class Receiver:
def __init__(self, runtime: DistributedRuntime):
self.runtime = runtime
self.local_receiver = LocalEmbeddingReceiver()
self.write_receiver = NixlWriteEmbeddingReceiver(2 * 8 * 1024 * 256 * 1024 * 3)
self.read_receiver = NixlReadEmbeddingReceiver(
embedding_hidden_size=8 * 1024, max_item_mm_token=1024
)
self.config = TransferConfig()
def get_run_config(self):
# Select the variant of sender/receiver based on config
if self.config.transfer_type == EmbeddingTransferMode.LOCAL:
receiver = self.local_receiver
elif self.config.transfer_type == EmbeddingTransferMode.NIXL_WRITE:
receiver = self.write_receiver
elif self.config.transfer_type == EmbeddingTransferMode.NIXL_READ:
receiver = self.read_receiver
else:
raise ValueError(f"Invalid transfer type: {self.config.transfer_type}")
# other fields in self.config are sender-side config, receiver only
# relies on BatchTransferRequest for completing the transfer.
return receiver
async def async_init(self):
self.sender_write_endpoint = self.runtime.endpoint(
"embedding_transfer.sender.write"
)
self.send_client = await self.sender_write_endpoint.client()
# await self.send_client.wait_for_instances()
async def batch_receive(self, batch_transfer_request: BatchTransferRequest):
receiver = self.get_run_config()
tasks = [
asyncio.create_task(receiver.receive_embeddings(tr))
for tr in batch_transfer_request.requests
]
responses = await asyncio.gather(*tasks, return_exceptions=True)
first_error = None
for result in responses:
if isinstance(result, Exception):
first_error = first_error or result
continue
tensor_id, _ = result
receiver.release_tensor(tensor_id)
if first_error:
raise first_error
async def generate(self, request):
stream = await self.send_client.round_robin("send_request")
async for response in stream:
await self.batch_receive(
BatchTransferRequest.model_validate_json(response.data())
)
yield "done"
async def read(self, request):
await self.batch_receive(BatchTransferRequest.model_validate_json(request))
yield "done"
async def update_config(self, request):
request = TransferConfig.model_validate_json(request)
self.config = request
yield "config updated"
@dynamo_worker()
async def worker(runtime: DistributedRuntime):
namespace_name = "embedding_transfer"
component_name = "receiver"
worker = Receiver(runtime)
await worker.async_init()
logger.info(f"Created service {namespace_name}/{component_name}")
logger.info(f"Serving endpoint {namespace_name}.{component_name}.generate")
logger.info(f"Serving endpoint {namespace_name}.{component_name}.read")
logger.info(f"Serving endpoint {namespace_name}.{component_name}.update_config")
generate_endpoint = runtime.endpoint(f"{namespace_name}.{component_name}.generate")
read_endpoint = runtime.endpoint(f"{namespace_name}.{component_name}.read")
update_config_endpoint = runtime.endpoint(
f"{namespace_name}.{component_name}.update_config"
)
await asyncio.gather(
*[
generate_endpoint.serve_endpoint(worker.generate),
read_endpoint.serve_endpoint(worker.read),
update_config_endpoint.serve_endpoint(worker.update_config),
]
)
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
import torch
import uvloop
from protocol import BatchTransferRequest, EmbeddingTransferMode, TransferConfig
from dynamo.common.multimodal.embedding_transfer import (
LocalEmbeddingSender,
NixlReadEmbeddingSender,
NixlWriteEmbeddingSender,
)
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
logger = logging.getLogger(__name__)
configure_dynamo_logging()
class Sender:
def __init__(self, runtime: DistributedRuntime):
self.runtime = runtime
self.local_sender = LocalEmbeddingSender()
self.read_sender = NixlReadEmbeddingSender()
self.write_sender = NixlWriteEmbeddingSender()
# GPU tensor to mimic encoder output
self.cpu_tensor = torch.randn([256, 8 * 1024], dtype=torch.float16)
self.gpu_tensor = (
torch.randn([256, 8 * 1024], dtype=torch.float16, device="cuda")
if torch.cuda.is_available()
else None
)
self.config = TransferConfig()
def get_run_config(self):
if self.config.use_gpu and self.gpu_tensor is None:
raise RuntimeError("GPU mode requested but CUDA is not available.")
# Select the variant of sender/receiver based on config
if self.config.transfer_type == EmbeddingTransferMode.LOCAL:
sender = self.local_sender
elif self.config.transfer_type == EmbeddingTransferMode.NIXL_WRITE:
sender = self.write_sender
elif self.config.transfer_type == EmbeddingTransferMode.NIXL_READ:
sender = self.read_sender
else:
raise ValueError(f"Invalid transfer type: {self.config.transfer_type}")
tensor = self.gpu_tensor if self.config.use_gpu else self.cpu_tensor
tensor_count = self.config.tensor_count_per_request
return sender, tensor, tensor_count
async def async_init(self):
self.receiver_read_endpoint = self.runtime.endpoint(
"embedding_transfer.receiver.read"
)
self.read_client = await self.receiver_read_endpoint.client()
# await self.read_client.wait_for_instances()
async def generate(self, request: str):
# Select the variant of sender/receiver based on config
sender, tensor, tensor_count = self.get_run_config()
request = BatchTransferRequest(requests=[])
futures = []
for _ in range(tensor_count):
transfer_request, send_future = await sender.send_embeddings(
tensor, stage_embeddings=True
)
request.requests.append(transfer_request)
futures.append(send_future)
stream = await self.read_client.round_robin(request.model_dump_json())
async for response in stream:
continue
await asyncio.gather(*futures)
yield "done"
async def write(self, request: str):
# Select the variant of sender/receiver based on config
sender, tensor, tensor_count = self.get_run_config()
response = BatchTransferRequest(requests=[])
futures = []
for _ in range(tensor_count):
transfer_request, send_future = await sender.send_embeddings(
tensor, stage_embeddings=True
)
response.requests.append(transfer_request)
futures.append(send_future)
yield response.model_dump_json()
await asyncio.gather(*futures)
async def update_config(self, request: str):
request = TransferConfig.model_validate_json(request)
self.config = request
yield "config updated"
@dynamo_worker()
async def worker(runtime: DistributedRuntime):
namespace_name = "embedding_transfer"
component_name = "sender"
worker = Sender(runtime)
await worker.async_init()
logger.info(f"Created service {namespace_name}/{component_name}")
logger.info(f"Serving endpoint {namespace_name}.{component_name}.generate")
logger.info(f"Serving endpoint {namespace_name}.{component_name}.write")
logger.info(f"Serving endpoint {namespace_name}.{component_name}.update_config")
generate_endpoint = runtime.endpoint(f"{namespace_name}.{component_name}.generate")
write_endpoint = runtime.endpoint(f"{namespace_name}.{component_name}.write")
update_config_endpoint = runtime.endpoint(
f"{namespace_name}.{component_name}.update_config"
)
await asyncio.gather(
*[
generate_endpoint.serve_endpoint(worker.generate),
write_endpoint.serve_endpoint(worker.write),
update_config_endpoint.serve_endpoint(worker.update_config),
]
)
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import time
import uvloop
from protocol import EmbeddingTransferMode, TransferConfig
from dynamo.runtime import DistributedRuntime, dynamo_worker
NUM_REQUESTS = 100
@dynamo_worker()
async def worker(runtime: DistributedRuntime):
# Get endpoint (sender -> receiver)
sender_endpoint = runtime.endpoint("embedding_transfer.sender.generate")
receiver_endpoint = runtime.endpoint("embedding_transfer.receiver.generate")
sender_update_config_endpoint = runtime.endpoint(
"embedding_transfer.sender.update_config"
)
receiver_update_config_endpoint = runtime.endpoint(
"embedding_transfer.receiver.update_config"
)
# Create client and wait for service to be ready
sender_client = await sender_endpoint.client()
await sender_client.wait_for_instances()
receiver_client = await receiver_endpoint.client()
await receiver_client.wait_for_instances()
sender_update_config_client = await sender_update_config_endpoint.client()
await sender_update_config_client.wait_for_instances()
receiver_update_config_client = await receiver_update_config_endpoint.client()
await receiver_update_config_client.wait_for_instances()
# NOTE From CPU is not the same as E/PD, E/PD originates from GPU and has
# GPU to CPU copy
for transfer_type in [
EmbeddingTransferMode.LOCAL,
EmbeddingTransferMode.NIXL_WRITE,
EmbeddingTransferMode.NIXL_READ,
]:
for workflow_string, client in [
("receiver-first", receiver_client),
("sender-first", sender_client),
]:
for use_gpu in [False, True]:
# Update sender/receiver config before each run
config = TransferConfig(
use_gpu=use_gpu,
tensor_count_per_request=30,
transfer_type=transfer_type,
)
async for res in await sender_update_config_client.round_robin(
config.model_dump_json()
):
pass
async for res in await receiver_update_config_client.round_robin(
config.model_dump_json()
):
pass
if transfer_type == EmbeddingTransferMode.NIXL_READ and use_gpu:
print(
f"Skipping: use_gpu={use_gpu} with transfer type: {transfer_type}"
)
print(
"Reason: nixl_connect errors out on GPU tensor, i.e. NIXL_ERR_NOT_ALLOWED"
)
continue
num_requests = NUM_REQUESTS
try:
print(
f"Workflow: {workflow_string}, From GPU: {use_gpu}, Transfer Type: {transfer_type}"
)
# warm up
async for response in await client.round_robin(
"world,sun,moon,star"
):
continue
start_time = time.perf_counter()
streams = [
await client.round_robin("world,sun,moon,star")
for _ in range(num_requests)
]
for stream in streams:
async for response in stream:
continue
end_time = time.perf_counter()
print(f"Time taken: {end_time - start_time:.2f} seconds")
except Exception as e:
# Log the exception with context
print(f"Error in worker: {type(e).__name__}: {e}")
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
...@@ -12,3 +12,11 @@ class DisaggregationMode(Enum): ...@@ -12,3 +12,11 @@ class DisaggregationMode(Enum):
AGGREGATED = "agg" AGGREGATED = "agg"
PREFILL = "prefill" PREFILL = "prefill"
DECODE = "decode" DECODE = "decode"
class EmbeddingTransferMode(Enum):
"""Embedding transfer mode for LLM workers."""
LOCAL = "local"
NIXL_WRITE = "nixl-write"
NIXL_READ = "nixl-read"
...@@ -7,8 +7,10 @@ from dynamo.common.multimodal.async_encoder_cache import AsyncEncoderCache ...@@ -7,8 +7,10 @@ from dynamo.common.multimodal.async_encoder_cache import AsyncEncoderCache
from dynamo.common.multimodal.embedding_transfer import ( from dynamo.common.multimodal.embedding_transfer import (
LocalEmbeddingReceiver, LocalEmbeddingReceiver,
LocalEmbeddingSender, LocalEmbeddingSender,
NixlPersistentEmbeddingReceiver, NixlReadEmbeddingReceiver,
NixlPersistentEmbeddingSender, NixlReadEmbeddingSender,
NixlWriteEmbeddingReceiver,
NixlWriteEmbeddingSender,
TransferRequest, TransferRequest,
) )
from dynamo.common.multimodal.image_loader import ImageLoader from dynamo.common.multimodal.image_loader import ImageLoader
...@@ -16,8 +18,10 @@ from dynamo.common.multimodal.image_loader import ImageLoader ...@@ -16,8 +18,10 @@ from dynamo.common.multimodal.image_loader import ImageLoader
__all__ = [ __all__ = [
"AsyncEncoderCache", "AsyncEncoderCache",
"ImageLoader", "ImageLoader",
"NixlPersistentEmbeddingReceiver", "NixlReadEmbeddingReceiver",
"NixlPersistentEmbeddingSender", "NixlReadEmbeddingSender",
"NixlWriteEmbeddingSender",
"NixlWriteEmbeddingReceiver",
"TransferRequest", "TransferRequest",
"LocalEmbeddingReceiver", "LocalEmbeddingReceiver",
"LocalEmbeddingSender", "LocalEmbeddingSender",
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import asyncio import asyncio
import logging import logging
import time import time
from random import randint
import pytest import pytest
import torch import torch
...@@ -13,18 +14,25 @@ import torch ...@@ -13,18 +14,25 @@ import torch
from dynamo.common.multimodal.embedding_transfer import ( from dynamo.common.multimodal.embedding_transfer import (
LocalEmbeddingReceiver, LocalEmbeddingReceiver,
LocalEmbeddingSender, LocalEmbeddingSender,
NixlEmbeddingReceiver, NixlReadEmbeddingReceiver,
NixlEmbeddingSender, NixlReadEmbeddingSender,
NixlPersistentEmbeddingReceiver, NixlWriteEmbeddingReceiver,
NixlPersistentEmbeddingSender, NixlWriteEmbeddingSender,
RingBuffer,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
EMBEDDING_SIZE = 8 * 1024
async def benchmark(sender, receiver, tensors=None):
async def benchmark(sender, receiver, tensors=None, from_cuda=False):
if tensors is None: if tensors is None:
tensors = [torch.randn(256, 8 * 1024) for _ in range(30)] tensors = [
torch.randn(256, EMBEDDING_SIZE, device="cuda" if from_cuda else "cpu")
for _ in range(30)
]
# warmup # warmup
request, send_future = await sender.send_embeddings(tensors[0]) request, send_future = await sender.send_embeddings(tensors[0])
tensor_id, response = await receiver.receive_embeddings(request) tensor_id, response = await receiver.receive_embeddings(request)
...@@ -45,6 +53,7 @@ async def benchmark(sender, receiver, tensors=None): ...@@ -45,6 +53,7 @@ async def benchmark(sender, receiver, tensors=None):
asyncio.create_task(receiver.receive_embeddings(request[0])) asyncio.create_task(receiver.receive_embeddings(request[0]))
for request in requests for request in requests
] ]
responses = await asyncio.gather(*receive_tasks) responses = await asyncio.gather(*receive_tasks)
receive_end = time.perf_counter() receive_end = time.perf_counter()
logger.info( logger.info(
...@@ -52,7 +61,7 @@ async def benchmark(sender, receiver, tensors=None): ...@@ -52,7 +61,7 @@ async def benchmark(sender, receiver, tensors=None):
) )
for tensor, request, response in zip(tensors, requests, responses): for tensor, request, response in zip(tensors, requests, responses):
tensor_id, received_tensor = response tensor_id, received_tensor = response
assert torch.equal(received_tensor, tensor) assert torch.equal(received_tensor, tensor.cpu())
receiver.release_tensor(tensor_id) receiver.release_tensor(tensor_id)
await request[1] await request[1]
...@@ -86,32 +95,226 @@ class TestLocalEmbeddingTransfer: ...@@ -86,32 +95,226 @@ class TestLocalEmbeddingTransfer:
receiver = LocalEmbeddingReceiver() receiver = LocalEmbeddingReceiver()
await benchmark(sender, receiver) await benchmark(sender, receiver)
@pytest.mark.asyncio
@pytest.mark.gpu_1
async def test_gpu_benchmark(self):
sender = LocalEmbeddingSender()
receiver = LocalEmbeddingReceiver()
await benchmark(sender, receiver, from_cuda=True)
@pytest.mark.xfail(run=False, reason="slow")
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.gpu_0 # Echo tensor worker is CPU-only (no GPU required) @pytest.mark.gpu_1 # NIXL init requires proper CUDA environment
class TestNixlEmbeddingTransfer: class TestNixlWriteEmbeddingTransfer:
async def test_correctness(self): async def test_correctness(self):
sender = NixlEmbeddingSender() sender = NixlWriteEmbeddingSender()
receiver = NixlEmbeddingReceiver() receiver = NixlWriteEmbeddingReceiver()
await correctness(sender, receiver) await correctness(sender, receiver)
async def test_benchmark(self): async def test_benchmark(self):
sender = NixlEmbeddingSender() sender = NixlWriteEmbeddingSender()
receiver = NixlEmbeddingReceiver() receiver = NixlWriteEmbeddingReceiver()
await benchmark(sender, receiver) await benchmark(sender, receiver)
async def test_gpu_benchmark(self):
sender = NixlWriteEmbeddingSender()
receiver = NixlWriteEmbeddingReceiver()
await benchmark(sender, receiver, from_cuda=True)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.gpu_0 # Echo tensor worker is CPU-only (no GPU required) @pytest.mark.gpu_1 # NIXL init requires proper CUDA environment
class TestNixlPersistentEmbeddingTransfer: class TestNixlReadEmbeddingTransfer:
async def test_correctness(self): async def test_correctness(self):
sender = NixlPersistentEmbeddingSender() sender = NixlReadEmbeddingSender()
receiver = NixlPersistentEmbeddingReceiver() receiver = NixlReadEmbeddingReceiver()
await correctness(sender, receiver) await correctness(sender, receiver)
async def test_benchmark(self): async def test_benchmark(self):
sender = NixlPersistentEmbeddingSender() sender = NixlReadEmbeddingSender()
receiver = NixlPersistentEmbeddingReceiver() receiver = NixlReadEmbeddingReceiver(embedding_hidden_size=EMBEDDING_SIZE)
await benchmark(sender, receiver) await benchmark(sender, receiver)
async def test_gpu_benchmark(self):
sender = NixlReadEmbeddingSender()
receiver = NixlReadEmbeddingReceiver(embedding_hidden_size=EMBEDDING_SIZE)
await benchmark(sender, receiver, from_cuda=True)
@pytest.mark.gpu_0 # Echo tensor worker is CPU-only (no GPU required)
class TestRingBuffer:
def test_simple(self):
buffer_size = 128
ring_buffer = RingBuffer(buffer_size)
# Fill buffer for debugging
for idx in range(buffer_size):
ring_buffer.buffer_tensor[idx] = idx
for byte_size in [32, 64, 128]:
id, tensor = ring_buffer.get_buffer(byte_size)
assert id is not None, f"Failed to get buffer for size {byte_size}"
assert tensor is not None, f"Failed to get tensor for size {byte_size}"
assert (
tensor.nbytes == byte_size
), f"Expected buffer of size {byte_size}, got {tensor.nbytes}"
ring_buffer.release_buffer(id)
# Test allocation that exceeds buffer size
id, tensor = ring_buffer.get_buffer(buffer_size + 1)
assert id is None, "Expected None when requesting buffer larger than capacity"
assert (
tensor is None
), "Expected None when requesting buffer larger than capacity"
def test_release(self):
buffer_size = 128
ring_buffer = RingBuffer(buffer_size)
# Fill buffer for debugging
for idx in range(buffer_size):
ring_buffer.buffer_tensor[idx] = idx
allocated_ids = []
for byte_size in [32, 32, 64]:
id, tensor = ring_buffer.get_buffer(byte_size)
assert id is not None, f"Failed to get buffer for size {byte_size}"
assert tensor is not None, f"Failed to get tensor for size {byte_size}"
assert (
tensor.nbytes == byte_size
), f"Expected buffer of size {byte_size}, got {tensor.nbytes}"
allocated_ids.append(id)
# Release buffers except the first one, ring buffer will not actually reuse the released space
# until the oldest allocated buffer is released, to maintain a simple implementation.
# |-32-|*32*|*64*| (released but not claimed space marked with *)
# | id1| | |
for id in allocated_ids[1:2]:
ring_buffer.release_buffer(id)
failed_id, failed_tensor = ring_buffer.get_buffer(64)
assert (
failed_id is None
), "Expected None when requesting buffer larger than remaining capacity"
assert (
failed_tensor is None
), "Expected None when requesting buffer larger than remaining capacity"
# Release the first allocated buffer to make sure the ring buffer can reuse the released space.
ring_buffer.release_buffer(allocated_ids[0])
# Now we should be able to allocate a buffer of size 64 again
id, tensor = ring_buffer.get_buffer(64)
assert id is not None, "Failed to get buffer after releasing space"
assert tensor is not None, "Failed to get tensor after releasing space"
assert tensor.nbytes == 64, f"Expected buffer of size 64, got {tensor.nbytes}"
def test_wrap_around(self):
buffer_size = 128
ring_buffer = RingBuffer(buffer_size)
# Fill buffer for debugging
for idx in range(buffer_size):
ring_buffer.buffer_tensor[idx] = idx
# 32 bytes remaining after allocating 96 bytes, so this should succeed
# |-32-|-32-|-32-| 32 |
# | id1| id2| id3| |
allocated_id1, tensor1 = ring_buffer.get_buffer(32)
allocated_id2, tensor2 = ring_buffer.get_buffer(32)
allocated_id3, tensor3 = ring_buffer.get_buffer(32)
assert (
allocated_id1 is not None
and allocated_id2 is not None
and allocated_id3 is not None
), "Failed to allocate initial buffers"
assert (
tensor1.nbytes == 32 and tensor2.nbytes == 32 and tensor3.nbytes == 32
), "Expected buffers of size 32"
# Out of space
failed_allocation_id, failed_allocation_tensor = ring_buffer.get_buffer(64)
assert (
failed_allocation_id is None
), "Expected None when requesting buffer larger than remaining capacity"
assert (
failed_allocation_tensor is None
), "Expected None when requesting buffer larger than remaining capacity"
# Release the first buffer to create free space at the beginning,
# but the 64 bytes allocation will fail as we don't allocate
# | 32 |-32-|-32-| 32 |
# | | id2| id3| |
ring_buffer.release_buffer(allocated_id1)
# small allocation okay, and should occupy part of the last 32 bytes
# | 32 |-32-|-32-|-16-| 16 |
# | | id2| id3| id4| |
allocated_id4, tensor4 = ring_buffer.get_buffer(16)
assert (
allocated_id4 is not None
), "Failed to allocate buffer after releasing space"
assert tensor4.nbytes == 16, f"Expected buffer of size 16, got {tensor4.nbytes}"
# Make room for large allocation
# Implementation detail: after wrap around, the tailing free space is marked allocated
# |-64-|-32-|-16-|*16*|
# | id5| id3| id4| |
ring_buffer.release_buffer(allocated_id2)
allocated_id5, tensor5 = ring_buffer.get_buffer(64)
assert (
allocated_id5 is not None
), "Failed to allocate buffer after releasing space"
assert tensor5.nbytes == 64, f"Expected buffer of size 64, got {tensor5.nbytes}"
failed_allocation_id, failed_allocation_tensor = ring_buffer.get_buffer(8)
assert (
failed_allocation_id is None
), "Expected None when requesting buffer larger than remaining capacity"
assert (
failed_allocation_tensor is None
), "Expected None when requesting buffer larger than remaining capacity"
# Release all and make sure we have full capacity again
ring_buffer.release_buffer(allocated_id3)
ring_buffer.release_buffer(allocated_id4)
ring_buffer.release_buffer(allocated_id5)
print(ring_buffer)
allocated_id6, tensor6 = ring_buffer.get_buffer(buffer_size)
assert (
allocated_id6 is not None
), "Failed to allocate buffer for full capacity after releasing all buffers"
assert (
tensor6.nbytes == buffer_size
), f"Expected buffer of size {buffer_size}, got {tensor6.nbytes}"
def test_looping(self):
buffer_size = 64 * 3
ring_buffer = RingBuffer(buffer_size)
# Fill buffer for debugging
for idx in range(buffer_size):
ring_buffer.buffer_tensor[idx] = idx % 128 # int8 max value
allocated_batches: list[int] = []
for _ in range(10):
# On each batch, allocate buffers with total size of 64, afterwards
# release previous batch if any.
# Implementation detail: Each batch takes 1/3 of the buffer to avoid not enough
# space with possible waste of tailing free space after wrap around.
current_batch_ids: list[int] = []
allocated_bytes = 0
while allocated_bytes < 64:
new_byte_size = min(randint(8, 64), 64 - allocated_bytes)
allocated_id, tensor = ring_buffer.get_buffer(new_byte_size)
assert (
allocated_id is not None
), "Failed to allocate buffer in looping test"
assert (
tensor.nbytes == new_byte_size
), f"Expected buffer of size {new_byte_size} in looping test"
allocated_bytes += new_byte_size
current_batch_ids.append(allocated_id)
# Release previous batch
for allocated_id in allocated_batches:
ring_buffer.release_buffer(allocated_id)
allocated_batches = current_batch_ids
...@@ -11,7 +11,7 @@ from dynamo.common.configuration.config_base import ConfigBase ...@@ -11,7 +11,7 @@ from dynamo.common.configuration.config_base import ConfigBase
from dynamo.common.configuration.utils import add_argument, add_negatable_bool_argument from dynamo.common.configuration.utils import add_argument, add_negatable_bool_argument
from . import __version__ from . import __version__
from .constants import DisaggregationMode from .constants import DisaggregationMode, EmbeddingTransferMode
class DynamoVllmArgGroup(ArgGroup): class DynamoVllmArgGroup(ArgGroup):
...@@ -136,6 +136,16 @@ class DynamoVllmArgGroup(ArgGroup): ...@@ -136,6 +136,16 @@ class DynamoVllmArgGroup(ArgGroup):
), ),
) )
add_argument(
g,
flag_name="--embedding-transfer-mode",
env_var="DYN_VLLM_EMBEDDING_TRANSFER_MODE",
default=EmbeddingTransferMode.NIXL_WRITE.value,
help="Worker embedding transfer mode: 'local' (default, local file system), "
"'nixl-write' (NIXL transfer with WRITE), or 'nixl-read' (NIXL transfer with READ).",
choices=[m.value for m in EmbeddingTransferMode],
)
# vLLM-Omni # vLLM-Omni
add_negatable_bool_argument( add_negatable_bool_argument(
g, g,
...@@ -325,6 +335,9 @@ class DynamoVllmConfig(ConfigBase): ...@@ -325,6 +335,9 @@ class DynamoVllmConfig(ConfigBase):
enable_multimodal: bool enable_multimodal: bool
mm_prompt_template: str mm_prompt_template: str
frontend_decoding: bool frontend_decoding: bool
embedding_transfer_mode: Union[
str, EmbeddingTransferMode
] # resolved to enum in validate()
# vLLM-Omni # vLLM-Omni
omni: bool omni: bool
...@@ -362,10 +375,18 @@ class DynamoVllmConfig(ConfigBase): ...@@ -362,10 +375,18 @@ class DynamoVllmConfig(ConfigBase):
def validate(self) -> None: def validate(self) -> None:
"""Validate vLLM wrapper configuration.""" """Validate vLLM wrapper configuration."""
self._resolve_disaggregation_mode() self._resolve_disaggregation_mode()
self._resolve_embedding_transfer_mode()
self._validate_multimodal_role_exclusivity() self._validate_multimodal_role_exclusivity()
self._validate_multimodal_requires_flag() self._validate_multimodal_requires_flag()
self._validate_omni_stage_config() self._validate_omni_stage_config()
def _resolve_embedding_transfer_mode(self) -> None:
"""Resolve embedding_transfer_mode from string to enum."""
if isinstance(self.embedding_transfer_mode, str):
self.embedding_transfer_mode = EmbeddingTransferMode(
self.embedding_transfer_mode
)
def _resolve_disaggregation_mode(self) -> None: def _resolve_disaggregation_mode(self) -> None:
"""Resolve disaggregation_mode from new enum or legacy boolean flags. """Resolve disaggregation_mode from new enum or legacy boolean flags.
......
...@@ -7,6 +7,6 @@ DisaggregationMode is defined in dynamo.common.constants and re-exported here ...@@ -7,6 +7,6 @@ DisaggregationMode is defined in dynamo.common.constants and re-exported here
so that existing imports from dynamo.vllm.constants continue to work. so that existing imports from dynamo.vllm.constants continue to work.
""" """
from dynamo.common.constants import DisaggregationMode from dynamo.common.constants import DisaggregationMode, EmbeddingTransferMode
__all__ = ["DisaggregationMode"] __all__ = ["DisaggregationMode", "EmbeddingTransferMode"]
...@@ -13,9 +13,14 @@ from transformers import AutoImageProcessor ...@@ -13,9 +13,14 @@ from transformers import AutoImageProcessor
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
import dynamo.nixl_connect as connect import dynamo.nixl_connect as connect
from dynamo.common.multimodal import LocalEmbeddingSender, NixlPersistentEmbeddingSender from dynamo.common.multimodal import (
LocalEmbeddingSender,
NixlReadEmbeddingSender,
NixlWriteEmbeddingSender,
)
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from ..constants import EmbeddingTransferMode
from ..multimodal_utils import ( from ..multimodal_utils import (
ImageLoader, ImageLoader,
encode_image_embeddings, encode_image_embeddings,
...@@ -30,10 +35,10 @@ logger = logging.getLogger(__name__) ...@@ -30,10 +35,10 @@ logger = logging.getLogger(__name__)
CACHE_SIZE_MAXIMUM = 8 CACHE_SIZE_MAXIMUM = 8
# Both embedding transmitter suffers from increasing latency as # [gluo WIP] now it's time to revisit
# number of concurrent requests increases, NixlPersistentEmbedding transmitters # Both embedding transfer suffers from increasing latency as
# number of concurrent requests increases, NixlPersistentEmbedding transfers
# scale worse than local. Need to investigate why. # scale worse than local. Need to investigate why.
TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1))
# [gluo NOTE] default off to benchmark standalone encoder # [gluo NOTE] default off to benchmark standalone encoder
ENABLE_ENCODER_CACHE = int(os.getenv("ENABLE_ENCODER_CACHE", 1)) ENABLE_ENCODER_CACHE = int(os.getenv("ENABLE_ENCODER_CACHE", 1))
...@@ -49,6 +54,7 @@ class EncodeWorkerHandler: ...@@ -49,6 +54,7 @@ class EncodeWorkerHandler:
def __init__( def __init__(
self, self,
engine_args: AsyncEngineArgs, engine_args: AsyncEngineArgs,
embedding_transfer_mode: EmbeddingTransferMode,
) -> None: ) -> None:
self.engine_args = engine_args self.engine_args = engine_args
self.model = self.engine_args.model self.model = self.engine_args.model
...@@ -75,11 +81,17 @@ class EncodeWorkerHandler: ...@@ -75,11 +81,17 @@ class EncodeWorkerHandler:
self._processed_requests = 0 self._processed_requests = 0
self.readables = [] self.readables = []
self.embedding_cache = EmbeddingCache() if ENABLE_ENCODER_CACHE else None self.embedding_cache = EmbeddingCache() if ENABLE_ENCODER_CACHE else None
self.embedding_sender = ( if embedding_transfer_mode == EmbeddingTransferMode.LOCAL:
LocalEmbeddingSender() self.embedding_sender = LocalEmbeddingSender()
if TRANSFER_LOCAL elif embedding_transfer_mode == EmbeddingTransferMode.NIXL_WRITE:
else NixlPersistentEmbeddingSender() self.embedding_sender = NixlWriteEmbeddingSender()
) elif embedding_transfer_mode == EmbeddingTransferMode.NIXL_READ:
self.embedding_sender = NixlReadEmbeddingSender()
else:
raise ValueError(
f"Invalid embedding transfer mode: {embedding_transfer_mode}"
)
self.send_complete_queue = asyncio.Queue() self.send_complete_queue = asyncio.Queue()
self.send_complete_checker_task = asyncio.create_task( self.send_complete_checker_task = asyncio.create_task(
self.check_complete(self.send_complete_queue) self.check_complete(self.send_complete_queue)
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import copy import copy
import logging import logging
import os
import uuid import uuid
from collections import defaultdict from collections import defaultdict
from typing import Any from typing import Any
...@@ -18,12 +17,13 @@ from dynamo.common.memory.multimodal_embedding_cache_manager import ( ...@@ -18,12 +17,13 @@ from dynamo.common.memory.multimodal_embedding_cache_manager import (
) )
from dynamo.common.multimodal.embedding_transfer import ( from dynamo.common.multimodal.embedding_transfer import (
LocalEmbeddingReceiver, LocalEmbeddingReceiver,
NixlPersistentEmbeddingReceiver, NixlReadEmbeddingReceiver,
NixlWriteEmbeddingReceiver,
) )
from dynamo.runtime import Client, DistributedRuntime from dynamo.runtime import Client, DistributedRuntime
from ..args import Config from ..args import Config
from ..constants import DisaggregationMode from ..constants import DisaggregationMode, EmbeddingTransferMode
from ..handlers import BaseWorkerHandler, build_sampling_params from ..handlers import BaseWorkerHandler, build_sampling_params
from ..multimodal_utils import ( from ..multimodal_utils import (
MyRequestOutput, MyRequestOutput,
...@@ -36,7 +36,6 @@ from ..multimodal_utils.prefill_worker_utils import load_multimodal_embeddings ...@@ -36,7 +36,6 @@ from ..multimodal_utils.prefill_worker_utils import load_multimodal_embeddings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
IMAGE_URL_KEY = "image_url" IMAGE_URL_KEY = "image_url"
TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1))
class MultimodalPDWorkerHandler(BaseWorkerHandler): class MultimodalPDWorkerHandler(BaseWorkerHandler):
...@@ -95,13 +94,18 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -95,13 +94,18 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
self._connector: connect.Connector | None = ( self._connector: connect.Connector | None = (
None # Will be initialized in async_init None # Will be initialized in async_init
) )
# [gluo FIXME] can't use pre-registered tensor as NIXL requires descriptors if config.embedding_transfer_mode == EmbeddingTransferMode.LOCAL:
# to be at matching size, need to overwrite nixl connect library self.embedding_receiver = LocalEmbeddingReceiver()
self.embedding_receiver = ( elif config.embedding_transfer_mode == EmbeddingTransferMode.NIXL_WRITE:
LocalEmbeddingReceiver() self.embedding_receiver = NixlWriteEmbeddingReceiver()
if TRANSFER_LOCAL elif config.embedding_transfer_mode == EmbeddingTransferMode.NIXL_READ:
else NixlPersistentEmbeddingReceiver(max_items=0) # [gluo FIXME] can't use pre-registered tensor as NIXL requires descriptors
) # to be at matching size, need to overwrite nixl connect library
self.embedding_receiver = NixlReadEmbeddingReceiver(max_items=0)
else:
raise ValueError(
f"Invalid embedding transfer mode: {config.embedding_transfer_mode}"
)
logger.info("Multimodal PD Worker has been initialized") logger.info("Multimodal PD Worker has been initialized")
......
...@@ -37,7 +37,7 @@ def _make_config( ...@@ -37,7 +37,7 @@ def _make_config(
multimodal_embedding_cache_capacity_gb: float = 0, multimodal_embedding_cache_capacity_gb: float = 0,
) -> MagicMock: ) -> MagicMock:
"""Create a mock Config with the fields used by MultimodalPDWorkerHandler.""" """Create a mock Config with the fields used by MultimodalPDWorkerHandler."""
from dynamo.vllm.constants import DisaggregationMode from dynamo.vllm.constants import DisaggregationMode, EmbeddingTransferMode
config = MagicMock() config = MagicMock()
config.model = model config.model = model
...@@ -47,6 +47,9 @@ def _make_config( ...@@ -47,6 +47,9 @@ def _make_config(
if is_prefill_worker if is_prefill_worker
else DisaggregationMode.AGGREGATED else DisaggregationMode.AGGREGATED
) )
# NIXL_WRITE / NIXL_READ modes require GPU, the tests may run in CPU-only environments,
# so set to LOCAL mode.
config.embedding_transfer_mode = EmbeddingTransferMode.LOCAL
config.enable_multimodal = enable_multimodal config.enable_multimodal = enable_multimodal
config.multimodal_embedding_cache_capacity_gb = ( config.multimodal_embedding_cache_capacity_gb = (
multimodal_embedding_cache_capacity_gb multimodal_embedding_cache_capacity_gb
......
...@@ -252,7 +252,9 @@ class WorkerFactory: ...@@ -252,7 +252,9 @@ class WorkerFactory:
) )
shutdown_endpoints[:] = [generate_endpoint] shutdown_endpoints[:] = [generate_endpoint]
handler = EncodeWorkerHandler(config.engine_args) handler = EncodeWorkerHandler(
config.engine_args, config.embedding_transfer_mode
)
await handler.async_init(runtime) await handler.async_init(runtime)
logger.info("Starting to serve the encode worker endpoint...") logger.info("Starting to serve the encode worker endpoint...")
......
...@@ -70,8 +70,11 @@ python -m dynamo.frontend & ...@@ -70,8 +70,11 @@ python -m dynamo.frontend &
EXTRA_ARGS="" EXTRA_ARGS=""
# Embedding transfer: 1 = local file (safetensors), 0 = NIXL RDMA # Embedding transfer:
export TRANSFER_LOCAL=${TRANSFER_LOCAL:-1} # "local" = local file (safetensors),
# "nixl-write" = NIXL WRITE transfer
# "nixl-read" = NIXL READ transfer (default: "local")
export DYN_VLLM_EMBEDDING_TRANSFER_MODE=${DYN_VLLM_EMBEDDING_TRANSFER_MODE:-"local"}
# GPU assignments (override via environment variables) # GPU assignments (override via environment variables)
if [[ "$SINGLE_GPU" == "true" ]]; then if [[ "$SINGLE_GPU" == "true" ]]; then
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment