Unverified Commit 75e774d4 authored by J Wyman's avatar J Wyman Committed by GitHub
Browse files

feat: NIXL Based RDMA Support w/ Multimodal Example (#1060)

parent 9acaa8d1
......@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
exclude: ^(src/grpc_generated|.*\.patch$)
exclude: ^(src/grpc_generated|.*\.patch$|.*/connect/.*\.py)
repos:
- repo: https://github.com/timothycrosley/isort
rev: 5.12.0
......@@ -82,4 +82,4 @@ repos:
# NOTE: pyright may be able to find other classes of errors not covered above,
# but would require some configuring and venv setup to properly eliminate noise
# and give it visiblity into all the local and third_party packages expected.
\ No newline at end of file
# and give it visiblity into all the local and third_party packages expected.
......@@ -19,10 +19,11 @@ import os
import signal
from typing import Optional
import connect
import torch
from components.disagg_router import PyDisaggregatedRouter
from components.encode_worker import EncodeWorker
from components.prefill_worker import PrefillWorker
from components.encode_worker import VllmEncodeWorker
from components.prefill_worker import VllmPrefillWorker
from transformers import LlavaForConditionalGeneration
from utils.logging import check_required_workers
from utils.nixl import NixlMetadataStore
......@@ -53,11 +54,11 @@ logger = logging.getLogger(__name__)
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class VllmWorker:
class VllmDecodeWorker:
# For disaggregated serving, we need to link the prefill worker to the vllm worker
prefill_worker = depends(PrefillWorker)
prefill_worker = depends(VllmPrefillWorker)
# For aggregated serving, we need to link the encode worker to the vllm worker.
encode_worker = depends(EncodeWorker)
encode_worker = depends(VllmEncodeWorker)
def __init__(self):
self.client = None
......@@ -141,7 +142,11 @@ class VllmWorker:
vision_tower.vision_model.embeddings.position_embedding.num_embeddings
)
else:
enc_comp_ns, enc_comp_name = EncodeWorker.dynamo_address() # type: ignore
EMBEDDINGS_SHAPE = (1, 577, 4096)
EMBEDDINGS_DTYPE = torch.float16
EMBEDDINGS_DEVICE = "cuda"
enc_comp_ns, enc_comp_name = VllmEncodeWorker.dynamo_address() # type: ignore
self.encode_worker_client = (
await runtime.namespace(enc_comp_ns)
.component(enc_comp_name)
......@@ -149,9 +154,22 @@ class VllmWorker:
.client()
)
self._connector = connect.Connector(runtime=runtime, namespace=enc_comp_ns)
await self._connector.initialize()
# Create a longer-lived buffer for receiving the image embeddings.
embeddings = torch.empty(
EMBEDDINGS_SHAPE, dtype=EMBEDDINGS_DTYPE, device=EMBEDDINGS_DEVICE
)
descriptor = connect.Descriptor(embeddings)
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
descriptor.register_memory(self._connector)
self._embeddings_descriptor = (embeddings, descriptor)
await check_required_workers(self.encode_worker_client, self.min_workers)
self.disaggregated_router = None
logger.info("VllmWorker has been initialized")
logger.info("Initialization complete.")
def shutdown_vllm_engine(self, signum, frame):
"""Shutdown the background loop"""
......@@ -159,7 +177,7 @@ class VllmWorker:
loop = asyncio.get_event_loop()
try:
self.engine_client.close()
logger.info("VllmWorker shutdown complete")
logger.info("Shutdown complete.")
except Exception as e:
logger.error(f"Error during shutdown: {e}")
finally:
......@@ -177,8 +195,18 @@ class VllmWorker:
@endpoint()
async def generate(self, request: vLLMMultimodalRequest):
image_features = None
request_id = request.request_id
image_url = request.image_url
logger.info(
f"Received multimodal request {{ id: {request_id}, image_url: '{image_url}' }}."
)
embeddings = None
if self.do_remote_prefill:
logger.debug(
f"Disaggregated: request {{ id: {request_id}, image_url: '{image_url}' }}"
" prefill worker will populate the decode model's key-value cache ahead of time;"
" no direct encode worker interaction required."
)
if self.disaggregated_router is not None:
async with PrefillQueue.get_instance(
nats_server=self._prefill_queue_nats_server,
......@@ -195,21 +223,21 @@ class VllmWorker:
disagg_router_decision = True
if self.do_remote_prefill and disagg_router_decision:
logger.debug(
f"Prefilling remotely for request {{ id: {request_id}, image_url: '{image_url}' }} with length {len(request.engine_prompt['prompt_token_ids'])}"
)
remote_prefill_params = RemotePrefillParams(
is_remote_prefill=True,
remote_prefill_request_callback=self.get_remote_prefill_request_callback(),
# Pass the image url as part of the RemotePrefillParams, which will be passed to the prefill worker via RemotePrefillRequest
multimodal_data_source={
"image_url": request.image_url,
"image_url": image_url,
},
)
logger.info(
f"Prefilling remotely for request {request.request_id} with length {len(request.engine_prompt['prompt_token_ids'])}"
)
else:
remote_prefill_params = None
logger.info(
f"Prefilling locally for request {request.request_id} with length {len(request.engine_prompt['prompt_token_ids'])}"
logger.debug(
f"Prefilling locally for request {{ id: {request_id}, image_url: '{image_url}' }} with length {len(request.engine_prompt['prompt_token_ids'])}"
)
# The decode worker will pre-allocate the memory based on the prompt token length for the prefill worker to transfer the kv cache.
......@@ -231,33 +259,61 @@ class VllmWorker:
)
else:
# For aggregated serving, the vllm worker will directly send the encode request to the encode worker.
encode_generator = await self.encode_worker_client.round_robin(
EncodeRequest(
image_url=request.image_url,
).model_dump_json()
logger.debug(
f"Aggregated: request {{ id: {request_id}, image_url: '{image_url}' }}"
" no prefill worker available, embeddings directly from encode worker."
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
async for encode_response in encode_generator:
encode_output = EncodeResponse.model_validate_json(
encode_response.data()
# Extract the pre-allocated, reusable image embeddings tensor and its descriptor.
# Doing this avoids unnessesary memory de/registration with NIXL.
embeddings, descriptor = self._embeddings_descriptor
with self._connector.create_writable(descriptor) as writable:
# Extract serialized metadata about the operation from the writable operation,
# and use it to create a new EncodeRequest.
encode_request = EncodeRequest(
request_id=request_id,
image_url=image_url,
serialized_request=writable.to_serialized(),
)
image_features = torch.tensor(
encode_output.image_features, device=device, dtype=torch.float16
logger.debug(f"Encode request: {encode_request.model_dump_json()}")
encode_generator = await self.encode_worker_client.round_robin(
encode_request.model_dump_json()
)
async for encode_response in encode_generator:
encode_output = EncodeResponse.model_validate_json(
encode_response.data()
)
logger.info(
f"Received response: {{ id: {encode_output.request_id} }}"
)
# Wait for the write operation to complete.
# This will block until the write operation is complete.
# This await should be a no-op since we've already received a response from the encode worker.
await writable.wait_for_completion()
# At this point, the `embeddings` tensor is filled with the image embeddings from the remote encode worker.
remote_prefill_params = None
logger.info(
f"Prefilling locally for request {request.request_id} with length {len(request.engine_prompt['prompt_token_ids'])}"
f"Prefilling locally for request {{ id: {request_id}, image_url: '{image_url}' }} with length {len(request.engine_prompt['prompt_token_ids'])}"
)
prompt_ids = request.engine_prompt["prompt_token_ids"]
# rust HTTP requires Delta streaming
request.sampling_params.output_kind = RequestOutputKind.DELTA
if image_features is not None:
multi_modal_data = {"image": image_features}
# When using aggregated serving, the encode worker will have provided the key-value cache updates via the prefill worker.
# When using disaggregated serving, the encode worker will have provided the key-value cache updates via the encode worker.
if embeddings is not None:
logger.debug(
"Aggregated: embedding data from encode worker provided via multi-modal data to decode model."
)
multi_modal_data = {"image": embeddings}
else:
logger.debug(
"Disaggregated: no embedding data required as prefill will have provided key-value cache updates via encode worker."
)
multi_modal_data = None
async for response in self.engine_client.generate(
......@@ -269,6 +325,9 @@ class VllmWorker:
request_id=request.request_id,
remote_prefill_params=remote_prefill_params,
):
logger.debug(
f"Yeilding response {{ id: {response.request_id}, prompt: '{response.prompt}' }}"
)
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
......
......@@ -15,8 +15,10 @@
import logging
from io import BytesIO
from queue import Queue
from typing import AsyncIterator
import connect
import requests
import torch
from PIL import Image
......@@ -24,10 +26,25 @@ from transformers import AutoImageProcessor, LlavaForConditionalGeneration
from utils.protocol import EncodeRequest, EncodeResponse
from utils.vllm import parse_vllm_args
from dynamo.sdk import endpoint, service
from dynamo.sdk import async_on_start, endpoint, service
logger = logging.getLogger(__name__)
try:
import cupy as array_module
if not array_module.cuda.is_available():
raise ImportError("CUDA is not available.")
DEVICE = "cuda"
logger.info("Using cupy for array operations (GPU mode).")
except ImportError as e:
logger.warning(f"Failed to import cupy, falling back to numpy: {e}.")
import numpy as array_module
DEVICE = "cpu"
CACHE_SIZE_MAXIMUM = 8
@service(
dynamo={
......@@ -36,7 +53,7 @@ logger = logging.getLogger(__name__)
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class EncodeWorker:
class VllmEncodeWorker:
def __init__(self) -> None:
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
......@@ -50,9 +67,50 @@ class EncodeWorker:
self.MODEL_ID, device_map="auto", torch_dtype=torch.float16
).eval()
self._image_cache: dict[str, Image.Image] = {}
self._cache_queue: Queue[str] = Queue(maxsize=CACHE_SIZE_MAXIMUM)
@endpoint()
async def encode(self, request: EncodeRequest) -> AsyncIterator[EncodeResponse]:
image = self.open_image(request.image_url)
logger.debug(
f"Received encode request: {{ id: {request.request_id}, image_url: '{request.image_url}' }}."
)
request_id = request.request_id
image_url = request.image_url.lower()
# The following steps encode the requested image and provided useful embeddings.
# 1. Open the image from the provided URL.
# 2. Process the image using the image processor.
# 3. Run the image through the vision model's vision tower.
# 4. Run the results of the vision tower through the multi-modal projector.
# 5. Create a descriptor for the embeddings.
# 6. Create a write operation using the serialized request and the descriptor.
# 7. Await for the write operation to complete.
# 8. Yield the encode response.
# Either retrieve the image from the cache or download it and then cache it.
if request.image_url in self._image_cache:
image = self._image_cache[image_url]
logger.debug(
f"Image found in cache for request: {{ id: {request_id}, image_url: '{image_url}' }}."
)
else:
image = self.open_image(image_url)
logger.debug(
f"Downloading/opening image for request: {{ id: {request_id}, image_url: '{image_url}' }}."
)
# Cache the image for future use, and evict the oldest image if the cache is full.
if self._cache_queue.full():
oldest_image_url = self._cache_queue.get()
del self._image_cache[oldest_image_url]
self._image_cache[request.image_url] = image
self._cache_queue.put(request.image_url)
logger.debug(
f"Processing image for request: {{ id: {request_id}, image_url: '{image_url}' }}"
)
image_embeds = self.image_processor(images=image, return_tensors="pt")
with torch.no_grad():
......@@ -60,22 +118,56 @@ class EncodeWorker:
vision_outputs = self.vision_model.vision_tower(
image_embeds["pixel_values"].to(self.vision_model.device)
)
logger.debug("Vision model completed.")
embeddings = vision_outputs.last_hidden_state
embeddings = self.vision_model.multi_modal_projector(embeddings)
logger.debug(
f"Embeddings: {{ shape: {embeddings.shape}, dtype: {embeddings.dtype}, device: {embeddings.device}, ptr: {embeddings.data_ptr()}, elements: {{ count: {embeddings.numel()}, size: {embeddings.element_size()} }} }}."
)
if request.serialized_request is None:
logger.error(
f"Request serialized_request is None for request: {{ id: {request_id}, image_url: '{image_url}' }}."
)
# Create a descriptor for the embeddings, this will register the memory with the connector (and the NIXL runtime).
descriptor = connect.Descriptor(embeddings)
# Create a write operation using the serialized request and the descriptor.
# This will begin the RDMA transfer of the embeddings to the remote worker.
write_op = await self._connector.begin_write(
descriptor,
request.serialized_request,
)
# Await for the write operation to complete.
# This will block until the data has been written to the remote worker or an error occurs.
await write_op.wait_for_completion()
image_features = vision_outputs.last_hidden_state
image_features = self.vision_model.multi_modal_projector(image_features)
yield EncodeResponse(
image_features=image_features.tolist()
request_id=request.request_id,
).model_dump_json()
@async_on_start()
async def on_start(self):
logger.info("Startup started.")
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Startup completed.")
def open_image(self, image: str) -> Image.Image:
# TODO: Have a seperate field for url and non url - and avoid auto detection
try:
# Acquire the image and convert it to the format (RGB) the image processor model expects.
if image.startswith("http") or image.startswith("https"):
response = requests.get(image)
image_data = Image.open(BytesIO(response.content)).convert("RGB")
else:
image_data = Image.open(image).convert("RGB")
return image_data
except Exception as e:
logger.error(f"Error opening image: {e}")
raise e
return image_data
......@@ -20,8 +20,9 @@ import os
import signal
import sys
import connect
import torch
from components.encode_worker import EncodeWorker
from components.encode_worker import VllmEncodeWorker
from pydantic import BaseModel
from utils.logging import check_required_workers
from utils.nixl import NixlMetadataStore
......@@ -38,6 +39,11 @@ from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, servic
logger = logging.getLogger(__name__)
# Constants for the shape and dtype of the embeddings tensor.
EMBEDDINGS_SHAPE = (1, 577, 4096)
EMBEDDINGS_DTYPE = torch.float16
EMBEDDINGS_DEVICE = "cuda"
class RequestType(BaseModel):
text: str
......@@ -50,8 +56,8 @@ class RequestType(BaseModel):
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class PrefillWorker:
encode_worker = depends(EncodeWorker)
class VllmPrefillWorker:
encode_worker = depends(VllmEncodeWorker)
def __init__(self):
class_name = self.__class__.__name__
......@@ -95,7 +101,7 @@ class PrefillWorker:
raise RuntimeError("Failed to initialize engine client")
runtime = dynamo_context["runtime"]
enc_comp_ns, enc_comp_name = EncodeWorker.dynamo_address() # type: ignore
enc_comp_ns, enc_comp_name = VllmEncodeWorker.dynamo_address() # type: ignore
self.encode_worker_client = (
await runtime.namespace(enc_comp_ns)
.component(enc_comp_name)
......@@ -103,6 +109,20 @@ class PrefillWorker:
.client()
)
self._connector = connect.Connector(runtime=runtime, namespace=enc_comp_ns)
await self._connector.initialize()
# Create a longer-lived buffer for receiving the image embeddings.
embeddings = torch.empty(
EMBEDDINGS_SHAPE,
dtype=EMBEDDINGS_DTYPE,
device=EMBEDDINGS_DEVICE,
)
descriptor = connect.Descriptor(embeddings)
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
descriptor.register_memory(self._connector)
self._embeddings_descriptor = (embeddings, descriptor)
await check_required_workers(self.encode_worker_client, self.min_workers)
metadata = self.engine_client.nixl_metadata
......@@ -119,19 +139,19 @@ class PrefillWorker:
sys.exit(1)
task.add_done_callback(prefill_queue_handler_cb)
logger.info("PrefillWorker initialized")
logger.info("Initialization complete.")
def shutdown_vllm_engine(self, signum, frame):
"""Shutdown the background loop"""
logger.info(f"Received signal {signum}, shutting down")
logger.info(f"Shutdown started, signal {signum} received.")
loop = asyncio.get_event_loop()
try:
self.engine_client.close()
logger.info("PrefillWorker shutdown complete")
except Exception as e:
logger.error(f"Error during shutdown: {e}")
finally:
loop.stop()
logger.info("Shutdown complete.")
async def prefill_queue_handler(self):
logger.info("Prefill queue handler entered")
......@@ -166,62 +186,88 @@ class PrefillWorker:
if request.multimodal_data_source["image_url"] is None:
raise ValueError("No image url provided for prefill request")
encode_generator = await self.encode_worker_client.round_robin(
EncodeRequest(
image_url=request.multimodal_data_source["image_url"],
).model_dump_json()
request_id = request.request_id
engine_id = request.engine_id
image_url = request.multimodal_data_source["image_url"]
logger.info(
f"Received prefill request {{ id: {request_id}, engine_id: {engine_id}, image_url: '{image_url}' }}."
)
async for encode_response in encode_generator:
encode_output = EncodeResponse.model_validate_json(encode_response.data())
image_features = torch.tensor(
encode_output.image_features, device="cpu", dtype=torch.float16
# Extract the pre-allocated, reusable image embeddings tensor and its descriptor.
# Doing this avoids unnessesary memory de/registration with NIXL.
embeddings, descriptor = self._embeddings_descriptor
# Create a new writable operation from the descriptor.
with self._connector.create_writable(descriptor) as writable:
# Extract serialized metadata about the operation from the writable operation,
# and use it to create a new EncodeRequest.
encode_generator = await self.encode_worker_client.round_robin(
EncodeRequest(
request_id=request_id,
image_url=image_url,
serialized_request=writable.to_serialized(),
).model_dump_json()
)
async for encode_response in encode_generator:
encode_output = EncodeResponse.model_validate_json(
encode_response.data(),
)
logger.debug(
f"Received response: {{ id: {encode_output.request_id} }}."
)
sampling_params = request.sampling_params
sampling_params.max_tokens = 1
sampling_params.min_tokens = 1
# Wait for the write operation to complete.
# This will block until the write operation is complete.
# This await should be a no-op since we've already received a response from the encode worker.
await writable.wait_for_completion()
# At this point, the `embeddings` tensor is filled with the image embeddings from the remote encode worker.
remote_prefill_params = RemotePrefillParams(
is_remote_decode=True,
decode_block_ids=request.block_ids,
decode_engine_id=request.engine_id,
decode_computed_block_ids=request.computed_block_ids,
)
sampling_params = request.sampling_params
sampling_params.max_tokens = 1
sampling_params.min_tokens = 1
# TODO check if metadata has changed
# and reload - currently only loading once
if request.engine_id not in self._loaded_metadata:
remote_metadata = await self._metadata_store.get(request.engine_id)
await self.engine_client.add_remote_nixl_metadata(remote_metadata)
logger.info(
f"Loaded nixl metadata from engine {request.engine_id} into "
f"engine {self.engine_client.nixl_metadata.engine_id}"
remote_prefill_params = RemotePrefillParams(
is_remote_decode=True,
decode_block_ids=request.block_ids,
decode_engine_id=engine_id,
decode_computed_block_ids=request.computed_block_ids,
)
# TODO check if metadata has changed
# and reload - currently only loading once
if engine_id not in self._loaded_metadata:
remote_metadata = await self._metadata_store.get(request.engine_id)
await self.engine_client.add_remote_nixl_metadata(remote_metadata)
logger.info(
f"Loaded nixl metadata from engine {engine_id} into "
f"engine {self.engine_client.nixl_metadata.engine_id}"
)
self._loaded_metadata.add(engine_id)
# To make sure the decode worker can pre-allocate the memory with the correct size for the prefill worker to transfer the kv cache,
# some placeholder dummy tokens were inserted based on the embedding size in the worker.py.
# The structure of the prompt is "\nUSER: <image> <dummy_tokens>\n<user_prompt>\nASSISTANT:", need to remove the dummy tokens after the image token.
IMAGE_TOKEN_ID = 32000
embedding_size = embeddings.shape[1]
padding_size = embedding_size - 1
image_token_index = request.prompt_token_ids.index(IMAGE_TOKEN_ID)
dummy_token_index = image_token_index + 1
prompt_token_ids = (
request.prompt_token_ids[:dummy_token_index]
+ request.prompt_token_ids[dummy_token_index + padding_size :]
)
self._loaded_metadata.add(request.engine_id)
# To make sure the decode worker can pre-allocate the memory with the correct size for the prefill worker to transfer the kv cache,
# some placeholder dummy tokens were inserted based on the embedding size in the worker.py.
# The structure of the prompt is "\nUSER: <image> <dummy_tokens>\n<user_prompt>\nASSISTANT:", need to remove the dummy tokens after the image token.
IMAGE_TOKEN_ID = 32000
embedding_size = image_features.shape[1]
padding_size = embedding_size - 1
image_token_index = request.prompt_token_ids.index(IMAGE_TOKEN_ID)
dummy_token_index = image_token_index + 1
prompt_token_ids = (
request.prompt_token_ids[:dummy_token_index]
+ request.prompt_token_ids[dummy_token_index + padding_size :]
)
async for _ in self.engine_client.generate(
request_id=request.request_id,
prompt=TokensPrompt(
prompt_token_ids=prompt_token_ids,
multi_modal_data={"image": image_features},
),
sampling_params=sampling_params,
remote_prefill_params=remote_prefill_params,
):
yield
async for _ in self.engine_client.generate(
request_id=request_id,
prompt=TokensPrompt(
prompt_token_ids=prompt_token_ids,
multi_modal_data={"image": embeddings},
),
sampling_params=sampling_params,
remote_prefill_params=remote_prefill_params,
):
yield
@endpoint()
async def mock(self, req: RequestType):
......
......@@ -19,7 +19,7 @@ import uuid
from enum import Enum
from typing import AsyncIterator, Tuple, Union
from components.worker import VllmWorker
from components.decode_worker import VllmDecodeWorker
from transformers import AutoTokenizer
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.logging import check_required_workers
......@@ -53,7 +53,7 @@ class Processor(ProcessMixIn):
vLLM pre and post processing
"""
worker = depends(VllmWorker)
worker = depends(VllmDecodeWorker)
def __init__(self):
class_name = self.__class__.__name__
......@@ -83,7 +83,7 @@ class Processor(ProcessMixIn):
@async_on_start
async def async_init(self):
runtime = dynamo_context["runtime"]
comp_ns, comp_name = VllmWorker.dynamo_address() # type: ignore
comp_ns, comp_name = VllmDecodeWorker.dynamo_address() # type: ignore
self.worker_client = (
await runtime.namespace(comp_ns)
.component(comp_name)
......
......@@ -21,7 +21,7 @@ Processor:
router: round-robin
common-configs: [model, block-size, max-model-len]
VllmWorker:
VllmDecodeWorker:
enforce-eager: true
max-num-batched-tokens: 16384
enable-prefix-caching: true
......@@ -33,7 +33,7 @@ VllmWorker:
gpu: 1
common-configs: [model, block-size, max-model-len]
EncodeWorker:
VllmEncodeWorker:
tensor-parallel-size: 1
router: random
ServiceArgs:
......
......@@ -22,7 +22,7 @@ Processor:
router: round-robin
common-configs: [model, block-size]
VllmWorker:
VllmDecodeWorker:
remote-prefill: true
conditional-disagg: true
max-local-prefill-length: 10
......@@ -33,7 +33,7 @@ VllmWorker:
gpu: 1
common-configs: [model, block-size, max-model-len, kv-transfer-config]
PrefillWorker:
VllmPrefillWorker:
max-num-batched-tokens: 16384
ServiceArgs:
workers: 1
......@@ -41,7 +41,7 @@ PrefillWorker:
gpu: 1
common-configs: [model, block-size, max-model-len, kv-transfer-config]
EncodeWorker:
VllmEncodeWorker:
tensor-parallel-size: 1
router: random
ServiceArgs:
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import asyncio
import logging
import socket
import uuid
import zlib
from abc import ABC, abstractmethod
from enum import IntEnum
from functools import cached_property
from typing import Any, List, Optional
import nixl._api as nixl_api
import nixl._bindings as nixl_bindings
import torch
from pydantic import BaseModel, ConfigDict, field_validator
from dynamo.runtime import DistributedRuntime
from dynamo.sdk import dynamo_context
logger = logging.getLogger(__name__)
try:
import cupy as array_module
from cupy_backends.cuda.api.runtime import CUDARuntimeError
logger.info("Utilizing cupy to enable GPU acceleration.")
except ImportError:
try:
import numpy as array_module
logger.warning("Failed to load cupy for GPU acceleration, utilizing numpy to provide CPU based operations.")
except ImportError as e:
raise ImportError("Numpy or cupy must be installed to use this module.") from e
class AbstractOperation(ABC):
"""
Abstract base class for awaitable NIXL based RDMA operations.
"""
def __init__(
self,
connector: Connector,
operation_kind: OperationKind,
local_descriptors: Descriptor | list[Descriptor],
remote_descriptors: Optional[Descriptor | list[Descriptor]],
notification_key: Optional[str],
) -> None:
if not isinstance(connector, Connector):
raise TypeError("Argument `connector` must be `dynamo.connect.Connector`.")
if operation_kind is not OperationKind.READ and operation_kind is not OperationKind.WRITE:
raise ValueError("Argument `operation_kind` must be either `READ` or `WRITE`.")
if not (
isinstance(local_descriptors, (Descriptor, list))
or (isinstance(local_descriptors, list) and all(isinstance(d, Descriptor) for d in local_descriptors))
):
raise TypeError("Argument `local_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.")
if (
remote_descriptors is not None
and not (
isinstance(remote_descriptors, Descriptor)
or (isinstance(remote_descriptors, list) and all(isinstance(d, Descriptor) for d in remote_descriptors))
)
):
raise TypeError("Argument `remote_descriptors` must be dynamo.connect.Descriptor`, `list[dynamo.connect.Descriptor]`, or `None`.")
if isinstance(local_descriptors, list) and len(local_descriptors) == 0:
raise ValueError("Argument `local_descriptors` must not be an empty list.")
if (
remote_descriptors is not None
and isinstance(remote_descriptors, list)
and len(remote_descriptors) == 0
):
raise ValueError("Argument `remote_descriptors` must not be an empty list.")
notification_key = str(uuid.uuid4()) if notification_key is None else notification_key
if not isinstance(notification_key, str):
raise TypeError("Argument `notification_key` must be `str` or `None`.")
if len(notification_key) == 0:
raise ValueError("Argument `notification_key` must not be an empty string.")
self._notification_key: str = "" if notification_key is None else notification_key
self._connector: Connector = connector
self._operation_kind: OperationKind = operation_kind
self._local_descriptors: Descriptor | list[Descriptor] = local_descriptors
self._local_dlist: Optional[list[tuple[int, int, int]]] = None
self._local_memtype: DeviceKind = DeviceKind.UNSPECIFIED
self._remote_descriptors: Optional[Descriptor | list[Descriptor]] = None if remote_descriptors is None else remote_descriptors
self._remote_dlist: Optional[list[tuple[int, int, int]]] = None
self._remote_memtype: DeviceKind = DeviceKind.UNSPECIFIED
# Register local descriptors with NIXL.
# Note: Only local descriptors should be registered with NIXL,
if isinstance(local_descriptors, list):
for d in local_descriptors:
d.register_memory(self._connector)
else:
local_descriptors.register_memory(self._connector)
# Record local descriptors.
memtype, dtlist = self._create_dlist(local_descriptors)
self._local_dlist = dtlist
self._local_memtype = memtype
# Record remote descriptors when provided.
if remote_descriptors is not None:
memtype, dtlist = self._create_dlist(remote_descriptors)
self._remote_dlist = dtlist
self._remote_memtype = memtype
def __del__(self) -> None:
self._release()
def __enter__(self) -> AbstractOperation:
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self._release()
def _release(self) -> None:
"""
Private method to release resources. Only to be called by `self`.
"""
pass
@property
def connector(self) -> Connector:
"""
Gets the local associated with this operation.
"""
return self._connector
@property
def operation_kind(self) -> OperationKind:
"""
Gets the kind of operation.
"""
return self._operation_kind
@abstractmethod
async def wait_for_completion(self) -> None:
"""
Blocks the caller asynchronously until the operation has completed.
"""
raise NotImplementedError("Abstract method not implemented by derived class.")
# Private Methods
def _create_dlist(
self,
descriptors: Descriptor | list[Descriptor],
) -> tuple[DeviceKind, list[tuple[int, int, int]]]:
"""
Helper function to create a list of tuples (ptr, size, device) from descriptors.
"""
dlist: list[tuple[int, int, int]] = []
memtype: DeviceKind = DeviceKind.UNSPECIFIED
if isinstance(descriptors, list):
memtype = descriptors[0].device.kind
for desc in descriptors:
if memtype != desc.device.kind:
raise ValueError("All local descriptors must have the same memory type.")
dlist.append((desc.ptr, desc.size, desc.device.id))
else:
memtype = descriptors.device.kind
dlist.append((descriptors.ptr, descriptors.size, descriptors.device.id))
return (memtype, dlist)
class ActiveOperation(AbstractOperation):
"""
Abstract class for active operations that initiates a NIXL based RDMA transfer based `SerializedRequest`
provided by the remote worker's corresponding `PassiveOperation`.
"""
def __init__(
self,
remote: Remote,
operation_kind: OperationKind,
local_descriptors: Descriptor | list[Descriptor],
remote_descriptors: Descriptor | list[Descriptor],
notification_key: str,
) -> None:
if not isinstance(remote, Remote) or remote._connector is None:
raise TypeError("Argument `remote` must be valid `dynamo.connect.RemoteAgent`.")
if not isinstance(operation_kind, OperationKind):
raise TypeError("Argument `operation_kind` must `dynamo.connect.OperationKind`.")
if operation_kind is not OperationKind.READ and operation_kind is not OperationKind.WRITE:
raise ValueError("Argument `operation_kind` must be either `READ` or `WRITE`.")
if not (
isinstance(local_descriptors, Descriptor)
or (isinstance(local_descriptors, list) and all(isinstance(d, Descriptor) for d in local_descriptors))
):
raise TypeError("Argument `local_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.")
if not (
isinstance(remote_descriptors, Descriptor)
or (isinstance(remote_descriptors, list) and all(isinstance(d, Descriptor) for d in remote_descriptors))
):
raise TypeError("Argument `remote_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.")
# Unpack single descriptors from lists if they are provided as single descriptors.
if isinstance(local_descriptors, list) and len(local_descriptors) == 1:
local_descriptors = local_descriptors[0]
if isinstance(remote_descriptors, list) and len(remote_descriptors) == 1:
remote_descriptors = remote_descriptors[0]
if (isinstance(local_descriptors, list) and isinstance(remote_descriptors, list) and len(local_descriptors) != len(remote_descriptors)):
raise ValueError("When `local_descriptors` and `remote_descriptors` are lists, they must have the same length.")
elif isinstance(local_descriptors, list) != isinstance(remote_descriptors, list):
raise ValueError("Both `local_descriptors` and `remote_descriptors` must be either lists or single descriptors.")
if not isinstance(notification_key, str):
raise TypeError("Argument `notification_key` must be `str`.")
if len(notification_key) == 0:
raise ValueError("Argument `notification_key` must not be an empty string.")
self._remote = remote
self._status = OperationStatus.UNINTIALIZED
super().__init__(remote.connector, operation_kind, local_descriptors, remote_descriptors, notification_key)
# Quick check to ensure remote descriptors are not None to make static analysis happy.
if self._local_dlist is None or self._remote_dlist is None:
raise RuntimeError("NIXL descriptor list(s) not bound to operation.")
self._local_xfer_descs: Optional[nixl_bindings.nixlXferDList] = None
self._remote_xfer_descs: Optional[nixl_bindings.nixlXferDList] = None
self._xfer_hndl: Optional[nixl_api.nixl_xfer_handle] = None
self._local_xfer_descs = self._connector._nixl.get_xfer_descs(
descs=self._local_dlist,
mem_type=str(self._local_memtype),
)
logger.debug(f"Created local NIXL xfer descs: {self._local_xfer_descs}")
self._remote_xfer_descs = self._connector._nixl.get_xfer_descs(
descs=self._remote_dlist,
mem_type=str(self._remote_memtype),
)
logger.debug(f"Created remote NIXL xfer descs: {self._remote_xfer_descs}")
self._xfer_hndl = self._connector._nixl.initialize_xfer(
operation=str(operation_kind),
local_descs=self._local_xfer_descs,
remote_descs=self._remote_xfer_descs,
remote_agent=self._remote.name,
notif_msg=self._notification_key.encode("utf-8"),
)
logger.debug(f"Created NIXL transfer handle: {self._xfer_hndl}")
def __del__(self) -> None:
super().__del__()
self._release()
def __enter__(self) -> ActiveOperation:
super().__enter__()
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
match self.status:
case OperationStatus.IN_PROGRESS | OperationStatus.INITIALIZED:
self._status = OperationStatus.CANCELLED
self._release()
def __repr__(self) -> str:
return str(
f"{self.__class__.__name__}("
f"operation_kind={self._operation_kind}, "
f"local_descriptors={self._local_descriptors}, "
f"remote_descriptors={self._remote_descriptors}, "
f"notification_key='{self._notification_key}', "
f"remote='{self._remote.name}', "
f"status='{self._status}'"
f")"
)
def _release(self) -> None:
"""
Private method to release resources.
"""
error: Optional[Exception] = None
if self._xfer_hndl is not None:
try:
logger.debug(f"NIXL transfer handle {self._xfer_hndl} released.")
self._connector._nixl.release_xfer_handle(self._xfer_hndl)
except Exception as e:
logger.error(f"Failed to release resources: {e}")
error = e
finally:
self._xfer_hndl = None
try:
super()._release()
except Exception as e:
logger.error(f"Failed to release WaitableOperation resources: {e}")
if error is not None:
e.__cause__ = error
error = e
if error is not None:
raise error
def _cancel_(self) -> None:
if self._xfer_hndl is None:
return
if self.status == OperationStatus.ERRORED:
raise RuntimeError("Operation is errored, unable to cancel the operation.")
logger.info(f"Cancellation requested for operation {{ kind={self._operation_kind}, remote='{self._remote.name}', status={self._status} }}.")
# NIXL will cancel the transfer if it is in progress when the handle is released.
self._connector._nixl.release_xfer_handle(self._xfer_hndl)
self._status = OperationStatus.CANCELLED
self._xfer_hndl = None
async def _wait_for_completion_(self) -> None:
# Loop until the operation is no longer in progress (or "initalized"),
# yielding control to the event loop to allow other operations to run.
iteration_count = 0
while True:
if iteration_count & 10 == 0:
logger.debug(f"Waiting for operation {{ kind={self._operation_kind}, remote='{self._remote.name}', duration={iteration_count / 10}s }}.")
match self.status:
# "in progress" or "initialized" means the operation is ongoing.
case OperationStatus.INITIALIZED:
await asyncio.sleep(0.1)
case OperationStatus.IN_PROGRESS:
await asyncio.sleep(0.1)
# Any other state indicates completion or error.
case _:
return
@abstractmethod
def cancel(self) -> None:
"""
Cancels the operation.
No affect if the operation has already completed or errored, or has been cancelled.
"""
raise NotImplementedError("Abstract method not implemented by derived class.")
@property
def remote(self) -> Remote:
"""
Gets the remote agent associated with this operation.
"""
return self._remote
@property
def status(self) -> OperationStatus:
"""
Gets the status of the operation.
"""
# Early return if the operation is already complete, errored, or cancelled.
match self._status:
case OperationStatus.COMPLETE | OperationStatus.ERRORED | OperationStatus.CANCELLED:
return self._status
if self._xfer_hndl is None:
raise RuntimeError("NIXL transfer handle is invalid.")
old_status = self._status
if self._status == OperationStatus.UNINTIALIZED:
state = self._connector._nixl.transfer(self._xfer_hndl, self._notification_key.encode("utf-8"))
logger.debug(f"NIXL reported transfer state: {state}")
if state == "ERR":
self._status = OperationStatus.ERRORED
elif state == "DONE":
self._status = OperationStatus.COMPLETE
else:
self._status = OperationStatus.INITIALIZED
else:
state = self._connector._nixl.check_xfer_state(self._xfer_hndl)
logger.debug(f"NIXL reported transfer state: {state}")
if state == "ERR":
self._status = OperationStatus.ERRORED
elif state == "DONE":
self._status = OperationStatus.COMPLETE
else:
self._status = OperationStatus.IN_PROGRESS
if self._status != old_status:
logger.debug(f"{self.__class__.__name__} {{ remote: '{self._remote.name}' status: '{old_status}' => '{self._status}' }}.")
return self._status
class Connector:
"""
Core class for managing the connection between agents in a distributed environment.
Use this class to create readable and writable operations, or read and write data to remote agents.
"""
def __init__(
self,
namespace: Optional[str] = None,
runtime: Optional[DistributedRuntime] = None,
worker_id: Optional[str] = None,
) -> None:
"""
Creates a new Connector instance.
Parameters
----------
namespace : Optional[str], optional
Dynamo namespace of the component, defaults to "dynamo" when `None`.
runtime : Optional[DistributedRuntime], optional
Reference the dynamo runtime used by the compenent, attempts to use the current runtime when `None`.
worker_id : Optional[str], optional
Unique identifier of the worker, defaults to a new UUID when `None`.
Raises
------
TypeError
When `namespace` is provied and not of type 'str'.
TypeError
When `runtime` iis provied and not of type `dynamo.runtime.DistributedRuntime`.
TypeError
When `worker_id` is provied and not of type `uuid.UUID`.
"""
namespace = "dynamo" if namespace is None else namespace
if not isinstance(namespace, str):
raise TypeError("Argument `namespace` must be `str` or `None`.")
if dynamo_context is not None and "runtime" in dynamo_context:
runtime = dynamo_context["runtime"] if runtime is None else runtime
if not isinstance(runtime, DistributedRuntime) or runtime is None:
raise TypeError("Argument `runtime` must be `dynamo.runtime.DistributedRuntime` or `None`.")
worker_id = worker_id if worker_id is not None else str(uuid.uuid4())
if not isinstance(worker_id, str) or len(worker_id) == 0:
raise TypeError("Argument `worker_id` must be a non-empty `str` or `None`.")
self._worker_id = worker_id
self._is_initialized = False
self._runtime = runtime
self._namespace = namespace
self._nixl = nixl_api.nixl_agent(self._worker_id)
self._hostname = socket.gethostname()
self._agent_metadata: Optional[bytes] = None
logger.debug(f"Created {self.__repr__()}.")
def __repr__(self) -> str:
return str(
f"{self.__class__.__name__}("
f"worker_id='{self._worker_id}', "
f"namespace={self._namespace}, "
f"hostname={self._hostname}, "
f"metadata=<{0 if self._agent_metadata is None else len(self._agent_metadata)} bytes>"
")"
)
def __str__(self) -> str:
return self._worker_id
@cached_property
def is_cuda_available(self) -> bool:
# Note: cuda.is_avalailable initializes cuda
# and can't be called when forking subprocesses
# care should be taken to only call it within
# subprocesses or use 'spawn'
try:
return array_module.cuda is not None and array_module.cuda.is_available()
except CUDARuntimeError:
return False
@property
def metadata(self) -> bytes:
"""
Get the metadata of the agent.
"""
return self._nixl.get_agent_metadata()
@property
def name(self) -> str | None:
"""
Get the name of the agent.
"""
return self._worker_id
@property
def namespace(self) -> str:
"""
Get the namespace of the local.
"""
return self._namespace
@property
def runtime(self) -> DistributedRuntime:
"""
Get the runtime of the local.
"""
if self._runtime is None:
raise RuntimeError("Runtime is not set. This Connector was not initialized with a runtime.")
return self._runtime
async def begin_read(
self,
remote_request: SerializedRequest,
local_descriptors: Descriptor | list[Descriptor],
) -> ReadOperation:
"""
Creates a read operation for fulfilling a remote readable operation.
Parameters
----------
remote_request : SerializedRequest
Serialized request from a remote worker that has created a readable operation.
local_descriptors : Descriptor | list[Descriptor]
Local descriptor(s) to receive data from the remote worker described by `remote_request`.
Returns
-------
ReadOperation
Awaitable read operation that can be used to transfer data from a remote agent.
Raises
------
TypeError
When `remote_request` is not of type `SerializedRequest`.
TypeError
When `local_descriptors` is not of type `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.
"""
if remote_request is None or not isinstance(remote_request, SerializedRequest):
raise TypeError("Argument `remote_request` must be `SerializedRequest`.")
if not (
isinstance(local_descriptors, Descriptor)
or (isinstance(local_descriptors, list) and all(isinstance(d, Descriptor) for d in local_descriptors))
):
raise TypeError("Argument `local_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.")
if remote_request.operation_kind != OperationKind.READ.value:
raise RuntimeError("Cannot create a `dynamo.connect.ReadOperation` to read from a remote `dynamo.connect.WritableOperation`.")
if not self._is_initialized:
raise RuntimeError("Connector not initialized. Call `initialize()` before calling this method.")
op = ReadOperation(self, remote_request, local_descriptors)
return op
async def begin_write(
self,
local_descriptors: Descriptor | list[Descriptor],
remote_request: SerializedRequest,
) -> WriteOperation:
"""
Creates a write operation for transferring data to a remote agent.
Parameters
----------
remote_request : SerializedRequest
Serialized request from a remote worker that has created a readable operation.
local_descriptors : Descriptor | list[Descriptor]
Local descriptors of one or more data objects to be transferred to the remote agent.
"""
if remote_request is None or not isinstance(remote_request, SerializedRequest):
raise TypeError("Argument `remote_request` must be `SerializedRequest`.")
if not (
isinstance(local_descriptors, Descriptor)
or (isinstance(local_descriptors, list) and all(isinstance(d, Descriptor) for d in local_descriptors))
):
raise TypeError("Argument `local_descriptors` must be `Descriptor` or `list[Descriptor]`.")
if remote_request.operation_kind != OperationKind.WRITE:
raise RuntimeError("Cannot create a `WriteOperation` to write to a remote `ReadableOperation`.")
if not isinstance(remote_request.nixl_metadata, str):
raise TypeError("Argument `remote_request.nixl_metadata` must be `str`.")
if not self._is_initialized:
raise RuntimeError("Connector not initialized. Call `initialize()` before calling this method.")
op = WriteOperation(self, local_descriptors, remote_request)
return op
def create_readable(
self,
local_descriptors: Descriptor | list[Descriptor],
) -> ReadableOperation:
"""
Creates a readable operation for transferring data from a remote agent.
Returns
-------
ReadableOperation
A readable operation that can be used to transfer data from a remote agent.
"""
if not self._is_initialized:
raise RuntimeError("Connector not initialized. Call `initialize()` before calling this method.")
op = ReadableOperation(self, local_descriptors)
return op
def create_writable(
self,
local_descriptors: Descriptor | list[Descriptor],
) -> WritableOperation:
"""
Creates a writable operation for transferring data to a remote agent.
Returns
-------
WritableOperation
A writable operation that can be used to transfer data to a remote agent.
"""
if not self._is_initialized:
raise RuntimeError("Connector not initialized. Call `initialize()` before calling this method.")
op = WritableOperation(self, local_descriptors)
return op
async def initialize(self) -> None:
# Only initialize the connector once.
if self._is_initialized:
return
self._is_initialized = True
# This method is a no-op for now, in the future it may be used to initialize the connector.
logger.debug(f"Initialized Connector {{ name: '{self._worker_id}', namespace '{self._namespace}' }} completed.")
class Descriptor:
"""
Memory descriptor that ensures memory is registered w/ NIXL, used for transferring data between workers.
"""
def __init__(
self,
data: torch.Tensor | tuple[array_module.ndarray, Device|str] | bytes | tuple[int, int, Device|str, Any],
) -> None:
"""
Memory descriptor for transferring data between agents.
Parameters
----------
data : torch.Tensor | tuple[ndarray, Device|str] | bytes | tuple[int, int, Device|str, Any]
The data to be transferred.
When `torch.Tensor` is provided, the attributes of the tensor will be used to create the descriptor.
When `tuple[ndarray, Device]` is provided, the tuple must contain:
- `ndarray`: The CuPy or NumPy array to be transferred.
- `Device`: Either a `dynamo.connect.Device` or a string representing the device type (e.g., "cuda" or "cpu").
When `bytes` is provided, the pointer and size derived from the bytes object and memory type will be assumed to be CPU.
When `tuple[int, int, Device|str, Any]` is provided, the tuple must contain the following elements:
- `int`: Pointer to the data in memory.
- `int`: Size of the data in bytes.
- `Device`: Either a `dynamo.connect.Device` or a string representing the device type (e.g., "cuda" or "cpu").
- `Any`: Optional reference to the data (e.g., the original tensor or bytes object).
This is useful for keeping a reference to the data in memory, but it is not required.
Raises
------
ValueError
When `data` is `None`.
TypeError
When `data` is not a valid type (i.e., not `torch.Tensor`, `bytes`, or a valid tuple).
TypeError
When `data` is a tuple but the elements are not of the expected types (i.e., [`ndarray`, `Device|str`] OR [`int`, `int`, `Device|str`, `Any`]).
"""
TYPE_ERROR_MESSAGE = "Argument `data` must be `torch.Tensor`, `tuple[ndarray, Device|str]`, `bytes`, or `tuple[int, int, Device|str, Any]`."
if data is None:
raise ValueError("Argument `data` cannot be `None`.")
if not (isinstance(data, torch.Tensor) or isinstance(data, bytes) or isinstance(data, tuple)):
raise TypeError(TYPE_ERROR_MESSAGE)
self._data_device: Device = Device("cpu")
self._data_ptr: int = 0
self._data_ref: Optional[Any] = None
self._data_size: int = 0
# Member fields for managing NIXL memory registration.
# Note: ONLY local descriptors should be registered with NIXL,
# remote descriptors do not have a valid memory address and registration will fault.
self._connector: Optional[Connector] = None
self._nixl_hndl: Optional[nixl_bindings.nixlRegDList] = None
# Initially `None` cached serialized descriptor reference, populated when `to_serialized()` is called.
self._serialized: Optional[SerializedDescriptor] = None
# Data is `torch.Tensor`.
if isinstance(data, torch.Tensor):
self._data_ptr = data.data_ptr()
self._data_size = data.numel() * data.element_size()
if data.is_cuda:
self._data_device = Device((DeviceKind.CUDA, data.get_device()))
self._data_ref = data
logger.debug(f"Created {self.__repr__()} from `torch.Tensor`.")
# Data is `tuple[ndarray, Device]`.
elif (
isinstance(data, tuple)
and len(data) == 2
and isinstance(data[0], array_module.ndarray)
and (isinstance(data[1], Device) or isinstance(data[1], str))
):
if hasattr(data[0], "__array_interface__"):
self._data_ptr = data[0].__array_interface__["data"][0]
elif hasattr(data[0], "__cuda_array_interface__"):
self._data_ptr = data[0].__cuda_array_interface__["data"][0]
else:
raise TypeError("Argument `data[0]` must be a `ndarray` with a valid array interface.")
self._data_size = data[0].nbytes
self._data_device = data[1] if isinstance(data[1], Device) else Device(data[1])
self._data_ref = data[0]
logger.debug(f"Created {self.__repr__()} from `tuple[ndarray, Device|str]`.")
# Data is `bytes`.
elif isinstance(data, bytes):
self._data_ptr = id(data)
self._data_size = len(data)
self._data_ref = data
logger.debug(f"Created {self.__repr__()} from `bytes`.")
# Data is `tuple[int, int, Device, dtype, tuple, Any]`.
elif isinstance(data, tuple) and len(data) >= 2 and isinstance(data[0], int) and isinstance(data[1], int):
if len(data) >= 3 and not (isinstance(data[2], Device) or isinstance(data[2], str)):
raise TypeError("Argument `data` must be a `tuple[int, int, Device|str, Any]`.")
self._data_ptr = data[0]
self._data_size = data[1]
if len(data) >= 3:
self._data_device = data[2] if isinstance(data[2], Device) else Device(data[2])
self._data_ref = data[3] if len(data) >=4 else None
logger.debug(f"Created {self.__repr__()} from `tuple[int, int, Device|str, Any]`.")
else:
raise TypeError(TYPE_ERROR_MESSAGE)
def __del__(self) -> None:
if self._nixl_hndl is not None and self._connector is not None:
# Unregister the memory with NIXL.
self._connector._nixl.deregister_memory(self._nixl_hndl)
self._nixl_hndl = None
if self._data_ref is not None:
# Release the reference to the data.
del self._data_ref
logger.debug(f"Deleted {self.__repr__()}.")
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self})"
def __str__(self) -> str:
return f"ptr={hex(self._data_ptr)}, size={self._data_size}, device={self._data_device}"
@property
def device(self) -> Device:
"""
Gets the device the of the descriptor.
"""
return self._data_device
@property
def ptr(self) -> int:
"""
Gets the pointer of the descriptor.
"""
return self._data_ptr
@property
def size(self) -> int:
"""
Gets the size of the descriptor.
"""
return self._data_size
@staticmethod
def from_serialized(
serialized: SerializedDescriptor,
) -> Descriptor:
"""
Deserializes a `SerializedDescriptor` into a `Descriptor` object.
Parameters
----------
serialized : SerializedDescriptor
The serialized descriptor to deserialize.
Returns
-------
Descriptor
The deserialized descriptor.
"""
if not isinstance(serialized, SerializedDescriptor):
raise TypeError("Argument `serialized` must be `SerializedDescriptor`.")
return serialized.to_descriptor()
def register_memory(
self,
connector: Connector,
) -> None:
"""
Registers the memory of the descriptor with NIXL.
"""
if not isinstance(connector, Connector):
raise TypeError("Argument `connector` must be `dynamo.connect.Connector`.")
if self._data_ptr == 0:
raise ValueError("Cannot register memory with a null pointer.")
if not (self._nixl_hndl is None and self._connector is None):
return
# Register the memory with NIXL.
self._connector = connector
if isinstance(self._data_ref, torch.Tensor):
self._nixl_hndl = connector._nixl.register_memory(self._data_ref)
else:
mem_type = str(self._data_device.kind)
reg_list = [(self._data_ptr, self._data_size, self._data_device.id, mem_type)]
self._nixl_hndl = connector._nixl.register_memory(reg_list, mem_type)
logger.debug(f"Registered {self.__repr__()} with NIXL.")
def to_serialized(self) -> SerializedDescriptor:
"""
Serializes the descriptor into a `SerializedDescriptor` object.
"""
if self._serialized is None:
self._serialized = SerializedDescriptor(
device=f"{self._data_device}",
ptr=self._data_ptr,
size=self._data_size,
)
return self._serialized
class Device:
"""
Represents a device in the system.
"""
def __init__(
self,
metadata: str | tuple[DeviceKind, int],
) -> None:
if metadata is None:
raise ValueError("Argument `metadata` cannot be `None`.")
if isinstance(metadata, tuple) and len(metadata) == 2 and isinstance(metadata[0], DeviceKind) and isinstance(metadata[1], int):
kind, device_id = metadata
elif isinstance(metadata, str):
metadata = metadata.strip().lower()
if metadata.startswith("cuda") or metadata.startswith("gpu"):
kind = DeviceKind.CUDA
device_id = 0 if metadata.find(":") == -1 else int(metadata.split(":")[1])
elif metadata.startswith("cpu") or metadata.startswith("host"):
kind = DeviceKind.HOST
device_id = 0
else:
raise ValueError("Argument `metadata` must be in the format 'cuda:<device_id>' or 'cpu'.")
else:
raise TypeError("Argument `metadata` must be a `tuple[MemoryKind, int]` or a `str`.")
self._device_id = device_id
self._kind = kind
def __repr__(self) -> str:
return f"{self.__class__.__name__}(kind={self._kind}, id={self._device_id})"
def __str__(self) -> str:
return f"{self._kind}:{self._device_id}" if self._kind is DeviceKind.CUDA else f"{self._kind}"
@property
def id(self) -> int:
"""
Gets the device ID of the device.
"""
return self._device_id
@property
def kind(self) -> DeviceKind:
"""
Gets the memory kind of the device.
"""
return self._kind
class DeviceKind(IntEnum):
"""
Type of memory a descriptor has been allocated to.
"""
UNSPECIFIED = 0
HOST = 1
CUDA = 2
def __str__(self) -> str:
if self == DeviceKind.HOST:
return "cpu"
elif self == DeviceKind.CUDA:
return "cuda"
else:
return "<invalid>"
class OperationKind(IntEnum):
"""
Kind of an operation.
"""
UNSPECIFIED = 0
READ = 1
WRITE = 2
def __str__(self) -> str:
if self == OperationKind.READ:
return "READ"
elif self == OperationKind.WRITE:
return "WRITE"
else:
return "<invalid>"
class OperationStatus(IntEnum):
"""
Status of an operation.
"""
UNINTIALIZED = 0
INITIALIZED = 1
IN_PROGRESS = 2
COMPLETE = 3
CANCELLED = 4
ERRORED = 5
def __str__(self) -> str:
if self == OperationStatus.INITIALIZED:
return "INIT"
elif self == OperationStatus.IN_PROGRESS:
return "PROC"
elif self == OperationStatus.COMPLETE:
return "DONE"
elif self == OperationStatus.ERRORED:
return "ERR"
elif self == OperationStatus.CANCELLED:
return "STOP"
else:
return "<invalid>"
class PassiveOperation(AbstractOperation):
"""
Abstract class for common functionality of passive operations.
"""
def __init__(
self,
connector: Connector,
operation_kind: OperationKind,
local_descriptors: Descriptor | list[Descriptor],
) -> None:
if operation_kind is not OperationKind.READ and operation_kind is not OperationKind.WRITE:
raise ValueError("Argument `operation_kind` must be either `READ` or `WRITE`.")
self._status = OperationStatus.UNINTIALIZED
super().__init__(connector, operation_kind, local_descriptors, None, None)
self._serialized_request: Optional[SerializedRequest] = None
self._status = OperationStatus.INITIALIZED
def __del__(self) -> None:
super().__del__()
def __enter__(self) -> AbstractOperation:
super().__enter__()
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
super().__exit__(exc_type, exc_value, traceback)
def __repr__(self) -> str:
return str(
f"{self.__class__.__name__}("
f"operation_kind={self._operation_kind}, "
f"local_descriptors={self._local_descriptors}, "
f"notification_key='{self._notification_key}', "
f"status='{self._status}'"
f")"
)
async def _wait_for_completion_(self) -> None:
# Loop until the operation is no longer in progress (or "initalized"),
# yielding control to the event loop to allow other operations to run.
while True:
match self.status:
# "in progress" or "initialized" means the operation is ongoing.
case OperationStatus.INITIALIZED:
await asyncio.sleep(0.1)
case OperationStatus.IN_PROGRESS:
await asyncio.sleep(0.1)
# Any other state indicates completion or error.
case _:
return
@property
def status(self) -> OperationStatus:
"""
Gets the status of the operation.
"""
# Early return if the operation is already complete, errored, or cancelled.
match self._status:
case OperationStatus.COMPLETE | OperationStatus.ERRORED | OperationStatus.CANCELLED:
return self._status
old_status = self._status
# Query NIXL for any notifications.
notifications = self._connector._nixl.update_notifs()
if isinstance(notifications, dict):
remote_state = OperationStatus.IN_PROGRESS
logger.debug(f"NIXL reported notifications: {len(notifications)}.")
for key, values in notifications.items():
if not isinstance(values, list):
raise TypeError(f"Expected `dict[str, list[bytes]]` from NIXL notification query; got {type(notifications)}.")
for value in values:
if not isinstance(value, bytes):
continue
notification_key = value.decode("utf-8")
# Once we've found the notification key, we know the operation is complete.
if notification_key == self._notification_key:
remote_state = OperationStatus.COMPLETE
break
if remote_state == OperationStatus.COMPLETE:
self._status = remote_state
logger.debug(f"{self.__class__.__name__} {{ remote: '{self._connector.name}' status: '{old_status}' => '{self._status}' }}.")
return self._status
def to_serialized(self) -> SerializedRequest:
"""
Gets the request descriptor for the operation.
"""
if self._serialized_request is None:
# When we've not yet cached the serialized request, we need to generate one before returning it.
# Handle both cases: multiple and single descriptors.
if isinstance(self._local_descriptors, list):
descriptors = [desc.to_serialized() for desc in self._local_descriptors]
else:
descriptors = [self._local_descriptors.to_serialized()]
original_len = len(self._connector.metadata)
nixl_metadata = self._connector.metadata
nixl_metadata = zlib.compress(nixl_metadata, level=6)
compressed_len = len(nixl_metadata)
logger.debug(f"Compressed NIXL metadata from {original_len} bytes to {compressed_len} bytes.")
if compressed_len > original_len:
logger.warning(f"Compressed NIXL metadata is larger than original ({compressed_len} > {original_len}).")
self._serialized_request = SerializedRequest(
descriptors=descriptors,
nixl_metadata=nixl_metadata.hex(),
notification_key=self._notification_key,
operation_kind=int(self._operation_kind),
)
return self._serialized_request
@abstractmethod
async def wait_for_completion(self) -> None:
"""
Blocks the caller asynchronously until the operation has completed.
"""
raise NotImplementedError("Abstract method not implemented by derived class.")
class ReadOperation(ActiveOperation):
"""
Operation that initiates an RDMA read operation to transfer data from a remote worker's `ReadableOperation`,
as described by `remote_request`, to local buffers.
"""
def __init__(
self,
connector: Connector,
remote_request: SerializedRequest,
local_descriptors: Descriptor | list[Descriptor],
) -> None:
"""
Creates a new instance of `ReadOperation`, registers `local_descriptors` with NIXL,
and begins an RDMA read operation which will transfer data described by `remote_request`
to `local_descriptors`.
Parameters
----------
connector : Connector
Connector instance to use for the operation.
remote_request : SerializedRequest
Serialized request from the remote worker.
local_descriptors : Descriptor | list[Descriptor]
Local descriptor(s) to to receive the data from the remote agent.
"""
if not isinstance(connector, Connector):
raise TypeError("Argument `connector` must be `dynamo.connect.Connector`.")
if not isinstance(remote_request, SerializedRequest):
raise TypeError("Argument `remote_request` must be `dynamo.connect.RequestDescriptor`.")
if remote_request.operation_kind != OperationKind.READ.value:
raise ValueError("Argument `remote_request` must be of kind `READ`.")
remote = Remote(connector, remote_request.nixl_metadata)
remote_descriptors = remote_request.to_descriptors()
if not (
isinstance(local_descriptors, Descriptor)
or (isinstance(local_descriptors, list) and all(isinstance(d, Descriptor) for d in local_descriptors))
):
raise TypeError("Argument `local_descriptors` must be `dynamo.connect.Descriptor`, `list[dynamo.connect.Descriptor]`.")
super().__init__(remote, OperationKind.READ, local_descriptors, remote_descriptors, remote_request.notification_key)
logger.debug(f"Created {self.__repr__()}")
def __del__(self) -> None:
super().__del__()
logger.debug(f"Deleted {self.__repr__()}")
def __enter__(self) -> ReadOperation:
super().__enter__()
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
super().__exit__(exc_type, exc_value, traceback)
def __repr__(self) -> str:
return super().__repr__()
def cancel(self) -> None:
"""
Cancels the operation.
No affect if the operation has already completed or errored, or been cancelled.
"""
super()._cancel_()
def results(self) -> list[Descriptor]:
"""
Gets the results of the operation.
Returns a single descriptor if only one was requested, or a list of descriptors if multiple were requested.
"""
if self._status != OperationStatus.COMPLETE:
raise RuntimeError("Operation has not completed yet, cannot get results.")
return self._local_descriptors if isinstance(self._local_descriptors, list) else [self._local_descriptors]
async def wait_for_completion(self) -> None:
"""
Blocks the caller asynchronously until the operation has completed.
"""
await super()._wait_for_completion_()
class ReadableOperation(PassiveOperation):
"""
Operation that can be awaited until a remote worker has completed a `ReadOperation`.
"""
def __init__(
self,
connector: Connector,
local_descriptors: Descriptor | list[Descriptor],
) -> None:
super().__init__(connector, OperationKind.READ, local_descriptors)
logger.debug(f"Created {self.__repr__()}")
def __del__(self) -> None:
super().__del__()
logger.debug(f"Deleted {self.__repr__()}")
def __enter__(self) -> ReadableOperation:
super().__enter__()
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
super().__exit__(exc_type, exc_value, traceback)
def __repr__(self) -> str:
return super().__repr__()
async def wait_for_completion(self) -> None:
"""
Blocks the caller asynchronously until the operation has completed.
"""
await super()._wait_for_completion_()
class Remote:
"""
Identifies a remote NIXL agent relative to a local NIXL agent.
"""
def __init__(
self,
connector: Connector,
nixl_metadata: bytes | str,
) -> None:
if not isinstance(connector, Connector):
raise TypeError("Argument `local` must be `dynamo.connect.Connector`.")
if not (isinstance(nixl_metadata, bytes) or isinstance(nixl_metadata, str)):
raise TypeError("Argument `nixl_metadata` must be `bytes` or `str`.")
if len(nixl_metadata) == 0:
raise ValueError("Argument `nixl_metadata` cannot be empty.")
self._connector = connector
# When `nixl_metadata` is a string, it is assumed to have come from a remote worker
# via a `SerializedRequest` object and therefore can assumed be a hex-encoded, compressed
# representation of the NIXL metadata.
if isinstance(nixl_metadata, str):
# Decode the hex-encoded string into bytes.
nixl_metadata = bytes.fromhex(nixl_metadata)
# Decompress the NIXL metadata.
nixl_metadata = zlib.decompress(nixl_metadata)
self._name = connector._nixl.add_remote_agent(nixl_metadata)
if isinstance(self._name, bytes):
self._name = self._name.decode("utf-8")
logger.debug(f"Created {self.__repr__()}.")
def __del__(self) -> None:
self._release()
def __enter__(self) -> Remote:
"""
Context manager entry method. Returns the current instance.
"""
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
"""
Context manager exit method. Cleans up the instance.
"""
self._release()
def __repr__(self) -> str:
return f"RemoteAgent(name={self._name}, connector={self._connector.name})"
def __str__(self) -> str:
return self._name
def _release(self) -> None:
"""
Private method for releasing NIXL resources. Not intended for public use.
"""
pass
@property
def connector(self) -> Connector:
"""
Gets the local connector associated with this remote agent.
"""
return self._connector
@property
def name(self) -> str:
"""
Gets the name of the remote agent.
"""
return self._name
class SerializedDescriptor(BaseModel):
"""
Pydantic serialization type for memory descriptors.
"""
model_config = ConfigDict(
extra="forbid",
frozen=True,
arbitrary_types_allowed=True,
)
device: str = "cpu"
ptr: int = 0
size: int = 0
def to_descriptor(self) -> Descriptor:
"""
Deserialize the serialized descriptor into a `Descriptor` object.
"""
return Descriptor(data=(self.ptr, self.size, self.device, None))
@field_validator("device")
def validate_memtype(cls, v: str) -> str:
if not isinstance(v, str):
raise TypeError("Argument `device` must be `str`.")
v = v.strip().lower()
if not (v.startswith("cuda") or v == "cpu"):
raise ValueError("Argument `device` must be one of 'cpu' or 'cuda:<device_id>'.")
return v
@field_validator("ptr")
def validate_ptr(cls, v: int) -> int:
if v == 0:
raise ValueError("Argument `ptr` cannot be zero (aka `null` or `None`).")
return v
@field_validator("size")
def validate_size(cls, v: int) -> int:
if v < 0:
raise ValueError("Argument `size` must be an integer greater than or equal to zero.")
return v
class SerializedRequest(BaseModel):
"""
Pydantic serialization type for describing the passive side of a transfer.
"""
model_config = ConfigDict(
extra="forbid",
frozen=True,
arbitrary_types_allowed=True,
)
descriptors: List[SerializedDescriptor] = []
nixl_metadata: str = ""
notification_key: str = ""
operation_kind: int = 0
def to_descriptors(self) -> Descriptor | list[Descriptor]:
"""
Deserializes the request descriptor into a `dynamo.connect.Descriptor` or list of `dynamo.connect.Descriptor` objects.
"""
if len(self.descriptors) == 0:
raise ValueError("Request descriptor must contain at least one serialized descriptor.")
if len(self.descriptors) == 1:
return self.descriptors[0].to_descriptor()
return [item.to_descriptor() for item in self.descriptors]
@field_validator("operation_kind")
def validate_operation_kind(cls, v: int) -> int:
if v < 1 or v > 3:
raise TypeError("Argument `operation_kind` must be an integer value of `dynamo.connect.OperationKind`.")
return v
class WritableOperation(PassiveOperation):
"""
Operation which can be awaited until written to by a `WriteOperation` from a remote worker.
"""
def __init__(
self,
connector: Connector,
local_descriptors: Descriptor | list[Descriptor],
) -> None:
"""
Creates a new instance of `WritableOperation`, registers the operation and descriptors w/ NIXL,
and enables an RDMA write operation to occur.
Parameters
----------
connector : Connector
Connector instance to use for the operation.
local_descriptors : Descriptor | list[Descriptor]
Descriptors to receive data from a remote worker.
Raises
TypeError
When `local` is not a `dynamo.connect.Connector`.
TypeError
When `local_descriptors` is not a `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.
"""
super().__init__(connector, OperationKind.WRITE, local_descriptors)
logger.debug(f"Created {self.__repr__()}")
def __del__(self) -> None:
super().__del__()
logger.debug(f"Deleted {self.__repr__()}")
def __enter__(self) -> WritableOperation:
super().__enter__()
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
super().__exit__(exc_type, exc_value, traceback)
def __repr__(self) -> str:
return super().__repr__()
async def wait_for_completion(self) -> None:
"""
Blocks the caller asynchronously until the operation has completed.
"""
await super()._wait_for_completion_()
class WriteOperation(ActiveOperation):
"""
Awaitable write operation which initiates an RDMA write operation to a remote worker
which provided a `SerializedRequest` object from a `WritableOperation`.
"""
def __init__(
self,
connector: Connector,
local_descriptors: Descriptor | list[Descriptor],
remote_request: SerializedRequest,
) -> None:
"""
Creates a new instance of `WriteOperation`, registers `local_descriptors` with NIXL,
and begins an RDMA write operation which will transfer from `local_descriptors` to
remote target(s) described by `remote_request`
Parameters
----------
connector : Connector
Connector instance to use for the operation.
local_descriptors : Descriptor | list[Descriptor]
Local descriptor(s) to send from, to the remote agent.
remote_request : SerializedRequest
Serialized request from the remote worker that describes the target(s) to send to.
Raises
TypeError
When `connector` is not a `dynamo.connect.Connector`.
TypeError
When `remote_request` is not a `dynamo.connect.RequestDescriptor`.
ValueError
When `remote_request` is not of kind `WRITE`.
ValueError
When `remote_request.nixl_metadata` is not a non-empty `str`.
TypeError
When `local_descriptors` is not a `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.
"""
if not isinstance(connector, Connector):
raise TypeError("Argument `connector` must be `dynamo.connect.Connector`.")
if not isinstance(remote_request, SerializedRequest):
raise TypeError("Argument `remote_request` must be `dynamo.connect.RequestDescriptor`.")
if remote_request.operation_kind != OperationKind.WRITE.value:
raise ValueError("Argument `remote_request` must be of kind `WRITE`.")
remote = Remote(connector, remote_request.nixl_metadata)
remote_descriptors = remote_request.to_descriptors()
super().__init__(remote, OperationKind.WRITE, local_descriptors, remote_descriptors, remote_request.notification_key)
logger.debug(f"Created {self.__repr__()}")
def __del__(self) -> None:
super().__del__()
logger.debug(f"Deleted {self.__repr__()}")
def __enter__(self) -> WriteOperation:
super().__enter__()
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
super().__exit__(exc_type, exc_value, traceback)
def __repr__(self) -> str:
return super().__repr__()
def cancel(self) -> None:
"""
Cancels the operation.
No affect if the operation has already completed or errored, or has been cancelled.
"""
super()._cancel_()
async def wait_for_completion(self) -> None:
"""
Blocks the caller asynchronously until the operation has completed.
"""
await super()._wait_for_completion_()
......@@ -6,6 +6,7 @@
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......@@ -13,9 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from components.encode_worker import EncodeWorker
from components.decode_worker import VllmDecodeWorker
from components.encode_worker import VllmEncodeWorker
from components.frontend import Frontend
from components.processor import Processor
from components.worker import VllmWorker
Frontend.link(Processor).link(VllmWorker).link(EncodeWorker)
Frontend.link(Processor).link(VllmDecodeWorker).link(VllmEncodeWorker)
......@@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from components.encode_worker import EncodeWorker
from components.decode_worker import VllmDecodeWorker
from components.encode_worker import VllmEncodeWorker
from components.frontend import Frontend
from components.prefill_worker import PrefillWorker
from components.prefill_worker import VllmPrefillWorker
from components.processor import Processor
from components.worker import VllmWorker
Frontend.link(Processor).link(VllmWorker).link(PrefillWorker).link(EncodeWorker)
Frontend.link(Processor).link(VllmDecodeWorker).link(VllmPrefillWorker).link(
VllmEncodeWorker
)
......@@ -17,6 +17,7 @@
import json
from typing import Any, List, Optional
import connect
import msgspec
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic_core import core_schema
......@@ -111,12 +112,13 @@ class EncodeRequest(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
image_url: str
request_id: str
serialized_request: Optional[connect.SerializedRequest] = None
class EncodeResponse(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
image_features: List[List[List[float]]]
request_id: str
class MyRequestOutput(BaseModel):
......@@ -129,7 +131,6 @@ class MyRequestOutput(BaseModel):
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
request_id: str
prompt: Optional[str] = None
prompt_token_ids: Optional[List[int]] = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment