Unverified Commit 75e774d4 authored by J Wyman's avatar J Wyman Committed by GitHub
Browse files

feat: NIXL Based RDMA Support w/ Multimodal Example (#1060)

parent 9acaa8d1
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
exclude: ^(src/grpc_generated|.*\.patch$) exclude: ^(src/grpc_generated|.*\.patch$|.*/connect/.*\.py)
repos: repos:
- repo: https://github.com/timothycrosley/isort - repo: https://github.com/timothycrosley/isort
rev: 5.12.0 rev: 5.12.0
...@@ -82,4 +82,4 @@ repos: ...@@ -82,4 +82,4 @@ repos:
# NOTE: pyright may be able to find other classes of errors not covered above, # NOTE: pyright may be able to find other classes of errors not covered above,
# but would require some configuring and venv setup to properly eliminate noise # but would require some configuring and venv setup to properly eliminate noise
# and give it visiblity into all the local and third_party packages expected. # and give it visiblity into all the local and third_party packages expected.
\ No newline at end of file
...@@ -19,10 +19,11 @@ import os ...@@ -19,10 +19,11 @@ import os
import signal import signal
from typing import Optional from typing import Optional
import connect
import torch import torch
from components.disagg_router import PyDisaggregatedRouter from components.disagg_router import PyDisaggregatedRouter
from components.encode_worker import EncodeWorker from components.encode_worker import VllmEncodeWorker
from components.prefill_worker import PrefillWorker from components.prefill_worker import VllmPrefillWorker
from transformers import LlavaForConditionalGeneration from transformers import LlavaForConditionalGeneration
from utils.logging import check_required_workers from utils.logging import check_required_workers
from utils.nixl import NixlMetadataStore from utils.nixl import NixlMetadataStore
...@@ -53,11 +54,11 @@ logger = logging.getLogger(__name__) ...@@ -53,11 +54,11 @@ logger = logging.getLogger(__name__)
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"}, resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1, workers=1,
) )
class VllmWorker: class VllmDecodeWorker:
# For disaggregated serving, we need to link the prefill worker to the vllm worker # For disaggregated serving, we need to link the prefill worker to the vllm worker
prefill_worker = depends(PrefillWorker) prefill_worker = depends(VllmPrefillWorker)
# For aggregated serving, we need to link the encode worker to the vllm worker. # For aggregated serving, we need to link the encode worker to the vllm worker.
encode_worker = depends(EncodeWorker) encode_worker = depends(VllmEncodeWorker)
def __init__(self): def __init__(self):
self.client = None self.client = None
...@@ -141,7 +142,11 @@ class VllmWorker: ...@@ -141,7 +142,11 @@ class VllmWorker:
vision_tower.vision_model.embeddings.position_embedding.num_embeddings vision_tower.vision_model.embeddings.position_embedding.num_embeddings
) )
else: else:
enc_comp_ns, enc_comp_name = EncodeWorker.dynamo_address() # type: ignore EMBEDDINGS_SHAPE = (1, 577, 4096)
EMBEDDINGS_DTYPE = torch.float16
EMBEDDINGS_DEVICE = "cuda"
enc_comp_ns, enc_comp_name = VllmEncodeWorker.dynamo_address() # type: ignore
self.encode_worker_client = ( self.encode_worker_client = (
await runtime.namespace(enc_comp_ns) await runtime.namespace(enc_comp_ns)
.component(enc_comp_name) .component(enc_comp_name)
...@@ -149,9 +154,22 @@ class VllmWorker: ...@@ -149,9 +154,22 @@ class VllmWorker:
.client() .client()
) )
self._connector = connect.Connector(runtime=runtime, namespace=enc_comp_ns)
await self._connector.initialize()
# Create a longer-lived buffer for receiving the image embeddings.
embeddings = torch.empty(
EMBEDDINGS_SHAPE, dtype=EMBEDDINGS_DTYPE, device=EMBEDDINGS_DEVICE
)
descriptor = connect.Descriptor(embeddings)
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
descriptor.register_memory(self._connector)
self._embeddings_descriptor = (embeddings, descriptor)
await check_required_workers(self.encode_worker_client, self.min_workers) await check_required_workers(self.encode_worker_client, self.min_workers)
self.disaggregated_router = None self.disaggregated_router = None
logger.info("VllmWorker has been initialized")
logger.info("Initialization complete.")
def shutdown_vllm_engine(self, signum, frame): def shutdown_vllm_engine(self, signum, frame):
"""Shutdown the background loop""" """Shutdown the background loop"""
...@@ -159,7 +177,7 @@ class VllmWorker: ...@@ -159,7 +177,7 @@ class VllmWorker:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
try: try:
self.engine_client.close() self.engine_client.close()
logger.info("VllmWorker shutdown complete") logger.info("Shutdown complete.")
except Exception as e: except Exception as e:
logger.error(f"Error during shutdown: {e}") logger.error(f"Error during shutdown: {e}")
finally: finally:
...@@ -177,8 +195,18 @@ class VllmWorker: ...@@ -177,8 +195,18 @@ class VllmWorker:
@endpoint() @endpoint()
async def generate(self, request: vLLMMultimodalRequest): async def generate(self, request: vLLMMultimodalRequest):
image_features = None request_id = request.request_id
image_url = request.image_url
logger.info(
f"Received multimodal request {{ id: {request_id}, image_url: '{image_url}' }}."
)
embeddings = None
if self.do_remote_prefill: if self.do_remote_prefill:
logger.debug(
f"Disaggregated: request {{ id: {request_id}, image_url: '{image_url}' }}"
" prefill worker will populate the decode model's key-value cache ahead of time;"
" no direct encode worker interaction required."
)
if self.disaggregated_router is not None: if self.disaggregated_router is not None:
async with PrefillQueue.get_instance( async with PrefillQueue.get_instance(
nats_server=self._prefill_queue_nats_server, nats_server=self._prefill_queue_nats_server,
...@@ -195,21 +223,21 @@ class VllmWorker: ...@@ -195,21 +223,21 @@ class VllmWorker:
disagg_router_decision = True disagg_router_decision = True
if self.do_remote_prefill and disagg_router_decision: if self.do_remote_prefill and disagg_router_decision:
logger.debug(
f"Prefilling remotely for request {{ id: {request_id}, image_url: '{image_url}' }} with length {len(request.engine_prompt['prompt_token_ids'])}"
)
remote_prefill_params = RemotePrefillParams( remote_prefill_params = RemotePrefillParams(
is_remote_prefill=True, is_remote_prefill=True,
remote_prefill_request_callback=self.get_remote_prefill_request_callback(), remote_prefill_request_callback=self.get_remote_prefill_request_callback(),
# Pass the image url as part of the RemotePrefillParams, which will be passed to the prefill worker via RemotePrefillRequest # Pass the image url as part of the RemotePrefillParams, which will be passed to the prefill worker via RemotePrefillRequest
multimodal_data_source={ multimodal_data_source={
"image_url": request.image_url, "image_url": image_url,
}, },
) )
logger.info(
f"Prefilling remotely for request {request.request_id} with length {len(request.engine_prompt['prompt_token_ids'])}"
)
else: else:
remote_prefill_params = None remote_prefill_params = None
logger.info( logger.debug(
f"Prefilling locally for request {request.request_id} with length {len(request.engine_prompt['prompt_token_ids'])}" f"Prefilling locally for request {{ id: {request_id}, image_url: '{image_url}' }} with length {len(request.engine_prompt['prompt_token_ids'])}"
) )
# The decode worker will pre-allocate the memory based on the prompt token length for the prefill worker to transfer the kv cache. # The decode worker will pre-allocate the memory based on the prompt token length for the prefill worker to transfer the kv cache.
...@@ -231,33 +259,61 @@ class VllmWorker: ...@@ -231,33 +259,61 @@ class VllmWorker:
) )
else: else:
# For aggregated serving, the vllm worker will directly send the encode request to the encode worker. logger.debug(
encode_generator = await self.encode_worker_client.round_robin( f"Aggregated: request {{ id: {request_id}, image_url: '{image_url}' }}"
EncodeRequest( " no prefill worker available, embeddings directly from encode worker."
image_url=request.image_url,
).model_dump_json()
) )
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Extract the pre-allocated, reusable image embeddings tensor and its descriptor.
async for encode_response in encode_generator: # Doing this avoids unnessesary memory de/registration with NIXL.
encode_output = EncodeResponse.model_validate_json( embeddings, descriptor = self._embeddings_descriptor
encode_response.data()
with self._connector.create_writable(descriptor) as writable:
# Extract serialized metadata about the operation from the writable operation,
# and use it to create a new EncodeRequest.
encode_request = EncodeRequest(
request_id=request_id,
image_url=image_url,
serialized_request=writable.to_serialized(),
) )
image_features = torch.tensor( logger.debug(f"Encode request: {encode_request.model_dump_json()}")
encode_output.image_features, device=device, dtype=torch.float16 encode_generator = await self.encode_worker_client.round_robin(
encode_request.model_dump_json()
) )
async for encode_response in encode_generator:
encode_output = EncodeResponse.model_validate_json(
encode_response.data()
)
logger.info(
f"Received response: {{ id: {encode_output.request_id} }}"
)
# Wait for the write operation to complete.
# This will block until the write operation is complete.
# This await should be a no-op since we've already received a response from the encode worker.
await writable.wait_for_completion()
# At this point, the `embeddings` tensor is filled with the image embeddings from the remote encode worker.
remote_prefill_params = None remote_prefill_params = None
logger.info( logger.info(
f"Prefilling locally for request {request.request_id} with length {len(request.engine_prompt['prompt_token_ids'])}" f"Prefilling locally for request {{ id: {request_id}, image_url: '{image_url}' }} with length {len(request.engine_prompt['prompt_token_ids'])}"
) )
prompt_ids = request.engine_prompt["prompt_token_ids"] prompt_ids = request.engine_prompt["prompt_token_ids"]
# rust HTTP requires Delta streaming # rust HTTP requires Delta streaming
request.sampling_params.output_kind = RequestOutputKind.DELTA request.sampling_params.output_kind = RequestOutputKind.DELTA
if image_features is not None: # When using aggregated serving, the encode worker will have provided the key-value cache updates via the prefill worker.
multi_modal_data = {"image": image_features} # When using disaggregated serving, the encode worker will have provided the key-value cache updates via the encode worker.
if embeddings is not None:
logger.debug(
"Aggregated: embedding data from encode worker provided via multi-modal data to decode model."
)
multi_modal_data = {"image": embeddings}
else: else:
logger.debug(
"Disaggregated: no embedding data required as prefill will have provided key-value cache updates via encode worker."
)
multi_modal_data = None multi_modal_data = None
async for response in self.engine_client.generate( async for response in self.engine_client.generate(
...@@ -269,6 +325,9 @@ class VllmWorker: ...@@ -269,6 +325,9 @@ class VllmWorker:
request_id=request.request_id, request_id=request.request_id,
remote_prefill_params=remote_prefill_params, remote_prefill_params=remote_prefill_params,
): ):
logger.debug(
f"Yeilding response {{ id: {response.request_id}, prompt: '{response.prompt}' }}"
)
yield MyRequestOutput( yield MyRequestOutput(
request_id=response.request_id, request_id=response.request_id,
prompt=response.prompt, prompt=response.prompt,
......
...@@ -15,8 +15,10 @@ ...@@ -15,8 +15,10 @@
import logging import logging
from io import BytesIO from io import BytesIO
from queue import Queue
from typing import AsyncIterator from typing import AsyncIterator
import connect
import requests import requests
import torch import torch
from PIL import Image from PIL import Image
...@@ -24,10 +26,25 @@ from transformers import AutoImageProcessor, LlavaForConditionalGeneration ...@@ -24,10 +26,25 @@ from transformers import AutoImageProcessor, LlavaForConditionalGeneration
from utils.protocol import EncodeRequest, EncodeResponse from utils.protocol import EncodeRequest, EncodeResponse
from utils.vllm import parse_vllm_args from utils.vllm import parse_vllm_args
from dynamo.sdk import endpoint, service from dynamo.sdk import async_on_start, endpoint, service
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try:
import cupy as array_module
if not array_module.cuda.is_available():
raise ImportError("CUDA is not available.")
DEVICE = "cuda"
logger.info("Using cupy for array operations (GPU mode).")
except ImportError as e:
logger.warning(f"Failed to import cupy, falling back to numpy: {e}.")
import numpy as array_module
DEVICE = "cpu"
CACHE_SIZE_MAXIMUM = 8
@service( @service(
dynamo={ dynamo={
...@@ -36,7 +53,7 @@ logger = logging.getLogger(__name__) ...@@ -36,7 +53,7 @@ logger = logging.getLogger(__name__)
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"}, resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1, workers=1,
) )
class EncodeWorker: class VllmEncodeWorker:
def __init__(self) -> None: def __init__(self) -> None:
class_name = self.__class__.__name__ class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "") self.engine_args = parse_vllm_args(class_name, "")
...@@ -50,9 +67,50 @@ class EncodeWorker: ...@@ -50,9 +67,50 @@ class EncodeWorker:
self.MODEL_ID, device_map="auto", torch_dtype=torch.float16 self.MODEL_ID, device_map="auto", torch_dtype=torch.float16
).eval() ).eval()
self._image_cache: dict[str, Image.Image] = {}
self._cache_queue: Queue[str] = Queue(maxsize=CACHE_SIZE_MAXIMUM)
@endpoint() @endpoint()
async def encode(self, request: EncodeRequest) -> AsyncIterator[EncodeResponse]: async def encode(self, request: EncodeRequest) -> AsyncIterator[EncodeResponse]:
image = self.open_image(request.image_url) logger.debug(
f"Received encode request: {{ id: {request.request_id}, image_url: '{request.image_url}' }}."
)
request_id = request.request_id
image_url = request.image_url.lower()
# The following steps encode the requested image and provided useful embeddings.
# 1. Open the image from the provided URL.
# 2. Process the image using the image processor.
# 3. Run the image through the vision model's vision tower.
# 4. Run the results of the vision tower through the multi-modal projector.
# 5. Create a descriptor for the embeddings.
# 6. Create a write operation using the serialized request and the descriptor.
# 7. Await for the write operation to complete.
# 8. Yield the encode response.
# Either retrieve the image from the cache or download it and then cache it.
if request.image_url in self._image_cache:
image = self._image_cache[image_url]
logger.debug(
f"Image found in cache for request: {{ id: {request_id}, image_url: '{image_url}' }}."
)
else:
image = self.open_image(image_url)
logger.debug(
f"Downloading/opening image for request: {{ id: {request_id}, image_url: '{image_url}' }}."
)
# Cache the image for future use, and evict the oldest image if the cache is full.
if self._cache_queue.full():
oldest_image_url = self._cache_queue.get()
del self._image_cache[oldest_image_url]
self._image_cache[request.image_url] = image
self._cache_queue.put(request.image_url)
logger.debug(
f"Processing image for request: {{ id: {request_id}, image_url: '{image_url}' }}"
)
image_embeds = self.image_processor(images=image, return_tensors="pt") image_embeds = self.image_processor(images=image, return_tensors="pt")
with torch.no_grad(): with torch.no_grad():
...@@ -60,22 +118,56 @@ class EncodeWorker: ...@@ -60,22 +118,56 @@ class EncodeWorker:
vision_outputs = self.vision_model.vision_tower( vision_outputs = self.vision_model.vision_tower(
image_embeds["pixel_values"].to(self.vision_model.device) image_embeds["pixel_values"].to(self.vision_model.device)
) )
logger.debug("Vision model completed.")
embeddings = vision_outputs.last_hidden_state
embeddings = self.vision_model.multi_modal_projector(embeddings)
logger.debug(
f"Embeddings: {{ shape: {embeddings.shape}, dtype: {embeddings.dtype}, device: {embeddings.device}, ptr: {embeddings.data_ptr()}, elements: {{ count: {embeddings.numel()}, size: {embeddings.element_size()} }} }}."
)
if request.serialized_request is None:
logger.error(
f"Request serialized_request is None for request: {{ id: {request_id}, image_url: '{image_url}' }}."
)
# Create a descriptor for the embeddings, this will register the memory with the connector (and the NIXL runtime).
descriptor = connect.Descriptor(embeddings)
# Create a write operation using the serialized request and the descriptor.
# This will begin the RDMA transfer of the embeddings to the remote worker.
write_op = await self._connector.begin_write(
descriptor,
request.serialized_request,
)
# Await for the write operation to complete.
# This will block until the data has been written to the remote worker or an error occurs.
await write_op.wait_for_completion()
image_features = vision_outputs.last_hidden_state
image_features = self.vision_model.multi_modal_projector(image_features)
yield EncodeResponse( yield EncodeResponse(
image_features=image_features.tolist() request_id=request.request_id,
).model_dump_json() ).model_dump_json()
@async_on_start()
async def on_start(self):
logger.info("Startup started.")
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Startup completed.")
def open_image(self, image: str) -> Image.Image: def open_image(self, image: str) -> Image.Image:
# TODO: Have a seperate field for url and non url - and avoid auto detection # TODO: Have a seperate field for url and non url - and avoid auto detection
try: try:
# Acquire the image and convert it to the format (RGB) the image processor model expects.
if image.startswith("http") or image.startswith("https"): if image.startswith("http") or image.startswith("https"):
response = requests.get(image) response = requests.get(image)
image_data = Image.open(BytesIO(response.content)).convert("RGB") image_data = Image.open(BytesIO(response.content)).convert("RGB")
else: else:
image_data = Image.open(image).convert("RGB") image_data = Image.open(image).convert("RGB")
return image_data
except Exception as e: except Exception as e:
logger.error(f"Error opening image: {e}") logger.error(f"Error opening image: {e}")
raise e raise e
return image_data
...@@ -20,8 +20,9 @@ import os ...@@ -20,8 +20,9 @@ import os
import signal import signal
import sys import sys
import connect
import torch import torch
from components.encode_worker import EncodeWorker from components.encode_worker import VllmEncodeWorker
from pydantic import BaseModel from pydantic import BaseModel
from utils.logging import check_required_workers from utils.logging import check_required_workers
from utils.nixl import NixlMetadataStore from utils.nixl import NixlMetadataStore
...@@ -38,6 +39,11 @@ from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, servic ...@@ -38,6 +39,11 @@ from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, servic
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Constants for the shape and dtype of the embeddings tensor.
EMBEDDINGS_SHAPE = (1, 577, 4096)
EMBEDDINGS_DTYPE = torch.float16
EMBEDDINGS_DEVICE = "cuda"
class RequestType(BaseModel): class RequestType(BaseModel):
text: str text: str
...@@ -50,8 +56,8 @@ class RequestType(BaseModel): ...@@ -50,8 +56,8 @@ class RequestType(BaseModel):
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"}, resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1, workers=1,
) )
class PrefillWorker: class VllmPrefillWorker:
encode_worker = depends(EncodeWorker) encode_worker = depends(VllmEncodeWorker)
def __init__(self): def __init__(self):
class_name = self.__class__.__name__ class_name = self.__class__.__name__
...@@ -95,7 +101,7 @@ class PrefillWorker: ...@@ -95,7 +101,7 @@ class PrefillWorker:
raise RuntimeError("Failed to initialize engine client") raise RuntimeError("Failed to initialize engine client")
runtime = dynamo_context["runtime"] runtime = dynamo_context["runtime"]
enc_comp_ns, enc_comp_name = EncodeWorker.dynamo_address() # type: ignore enc_comp_ns, enc_comp_name = VllmEncodeWorker.dynamo_address() # type: ignore
self.encode_worker_client = ( self.encode_worker_client = (
await runtime.namespace(enc_comp_ns) await runtime.namespace(enc_comp_ns)
.component(enc_comp_name) .component(enc_comp_name)
...@@ -103,6 +109,20 @@ class PrefillWorker: ...@@ -103,6 +109,20 @@ class PrefillWorker:
.client() .client()
) )
self._connector = connect.Connector(runtime=runtime, namespace=enc_comp_ns)
await self._connector.initialize()
# Create a longer-lived buffer for receiving the image embeddings.
embeddings = torch.empty(
EMBEDDINGS_SHAPE,
dtype=EMBEDDINGS_DTYPE,
device=EMBEDDINGS_DEVICE,
)
descriptor = connect.Descriptor(embeddings)
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
descriptor.register_memory(self._connector)
self._embeddings_descriptor = (embeddings, descriptor)
await check_required_workers(self.encode_worker_client, self.min_workers) await check_required_workers(self.encode_worker_client, self.min_workers)
metadata = self.engine_client.nixl_metadata metadata = self.engine_client.nixl_metadata
...@@ -119,19 +139,19 @@ class PrefillWorker: ...@@ -119,19 +139,19 @@ class PrefillWorker:
sys.exit(1) sys.exit(1)
task.add_done_callback(prefill_queue_handler_cb) task.add_done_callback(prefill_queue_handler_cb)
logger.info("PrefillWorker initialized") logger.info("Initialization complete.")
def shutdown_vllm_engine(self, signum, frame): def shutdown_vllm_engine(self, signum, frame):
"""Shutdown the background loop""" """Shutdown the background loop"""
logger.info(f"Received signal {signum}, shutting down") logger.info(f"Shutdown started, signal {signum} received.")
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
try: try:
self.engine_client.close() self.engine_client.close()
logger.info("PrefillWorker shutdown complete")
except Exception as e: except Exception as e:
logger.error(f"Error during shutdown: {e}") logger.error(f"Error during shutdown: {e}")
finally: finally:
loop.stop() loop.stop()
logger.info("Shutdown complete.")
async def prefill_queue_handler(self): async def prefill_queue_handler(self):
logger.info("Prefill queue handler entered") logger.info("Prefill queue handler entered")
...@@ -166,62 +186,88 @@ class PrefillWorker: ...@@ -166,62 +186,88 @@ class PrefillWorker:
if request.multimodal_data_source["image_url"] is None: if request.multimodal_data_source["image_url"] is None:
raise ValueError("No image url provided for prefill request") raise ValueError("No image url provided for prefill request")
encode_generator = await self.encode_worker_client.round_robin( request_id = request.request_id
EncodeRequest( engine_id = request.engine_id
image_url=request.multimodal_data_source["image_url"], image_url = request.multimodal_data_source["image_url"]
).model_dump_json()
logger.info(
f"Received prefill request {{ id: {request_id}, engine_id: {engine_id}, image_url: '{image_url}' }}."
) )
async for encode_response in encode_generator:
encode_output = EncodeResponse.model_validate_json(encode_response.data()) # Extract the pre-allocated, reusable image embeddings tensor and its descriptor.
image_features = torch.tensor( # Doing this avoids unnessesary memory de/registration with NIXL.
encode_output.image_features, device="cpu", dtype=torch.float16 embeddings, descriptor = self._embeddings_descriptor
# Create a new writable operation from the descriptor.
with self._connector.create_writable(descriptor) as writable:
# Extract serialized metadata about the operation from the writable operation,
# and use it to create a new EncodeRequest.
encode_generator = await self.encode_worker_client.round_robin(
EncodeRequest(
request_id=request_id,
image_url=image_url,
serialized_request=writable.to_serialized(),
).model_dump_json()
) )
async for encode_response in encode_generator:
encode_output = EncodeResponse.model_validate_json(
encode_response.data(),
)
logger.debug(
f"Received response: {{ id: {encode_output.request_id} }}."
)
sampling_params = request.sampling_params # Wait for the write operation to complete.
sampling_params.max_tokens = 1 # This will block until the write operation is complete.
sampling_params.min_tokens = 1 # This await should be a no-op since we've already received a response from the encode worker.
await writable.wait_for_completion()
# At this point, the `embeddings` tensor is filled with the image embeddings from the remote encode worker.
remote_prefill_params = RemotePrefillParams( sampling_params = request.sampling_params
is_remote_decode=True, sampling_params.max_tokens = 1
decode_block_ids=request.block_ids, sampling_params.min_tokens = 1
decode_engine_id=request.engine_id,
decode_computed_block_ids=request.computed_block_ids,
)
# TODO check if metadata has changed remote_prefill_params = RemotePrefillParams(
# and reload - currently only loading once is_remote_decode=True,
if request.engine_id not in self._loaded_metadata: decode_block_ids=request.block_ids,
remote_metadata = await self._metadata_store.get(request.engine_id) decode_engine_id=engine_id,
await self.engine_client.add_remote_nixl_metadata(remote_metadata) decode_computed_block_ids=request.computed_block_ids,
logger.info( )
f"Loaded nixl metadata from engine {request.engine_id} into "
f"engine {self.engine_client.nixl_metadata.engine_id}" # TODO check if metadata has changed
# and reload - currently only loading once
if engine_id not in self._loaded_metadata:
remote_metadata = await self._metadata_store.get(request.engine_id)
await self.engine_client.add_remote_nixl_metadata(remote_metadata)
logger.info(
f"Loaded nixl metadata from engine {engine_id} into "
f"engine {self.engine_client.nixl_metadata.engine_id}"
)
self._loaded_metadata.add(engine_id)
# To make sure the decode worker can pre-allocate the memory with the correct size for the prefill worker to transfer the kv cache,
# some placeholder dummy tokens were inserted based on the embedding size in the worker.py.
# The structure of the prompt is "\nUSER: <image> <dummy_tokens>\n<user_prompt>\nASSISTANT:", need to remove the dummy tokens after the image token.
IMAGE_TOKEN_ID = 32000
embedding_size = embeddings.shape[1]
padding_size = embedding_size - 1
image_token_index = request.prompt_token_ids.index(IMAGE_TOKEN_ID)
dummy_token_index = image_token_index + 1
prompt_token_ids = (
request.prompt_token_ids[:dummy_token_index]
+ request.prompt_token_ids[dummy_token_index + padding_size :]
) )
self._loaded_metadata.add(request.engine_id)
# To make sure the decode worker can pre-allocate the memory with the correct size for the prefill worker to transfer the kv cache,
# some placeholder dummy tokens were inserted based on the embedding size in the worker.py.
# The structure of the prompt is "\nUSER: <image> <dummy_tokens>\n<user_prompt>\nASSISTANT:", need to remove the dummy tokens after the image token.
IMAGE_TOKEN_ID = 32000
embedding_size = image_features.shape[1]
padding_size = embedding_size - 1
image_token_index = request.prompt_token_ids.index(IMAGE_TOKEN_ID)
dummy_token_index = image_token_index + 1
prompt_token_ids = (
request.prompt_token_ids[:dummy_token_index]
+ request.prompt_token_ids[dummy_token_index + padding_size :]
)
async for _ in self.engine_client.generate( async for _ in self.engine_client.generate(
request_id=request.request_id, request_id=request_id,
prompt=TokensPrompt( prompt=TokensPrompt(
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
multi_modal_data={"image": image_features}, multi_modal_data={"image": embeddings},
), ),
sampling_params=sampling_params, sampling_params=sampling_params,
remote_prefill_params=remote_prefill_params, remote_prefill_params=remote_prefill_params,
): ):
yield yield
@endpoint() @endpoint()
async def mock(self, req: RequestType): async def mock(self, req: RequestType):
......
...@@ -19,7 +19,7 @@ import uuid ...@@ -19,7 +19,7 @@ import uuid
from enum import Enum from enum import Enum
from typing import AsyncIterator, Tuple, Union from typing import AsyncIterator, Tuple, Union
from components.worker import VllmWorker from components.decode_worker import VllmDecodeWorker
from transformers import AutoTokenizer from transformers import AutoTokenizer
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.logging import check_required_workers from utils.logging import check_required_workers
...@@ -53,7 +53,7 @@ class Processor(ProcessMixIn): ...@@ -53,7 +53,7 @@ class Processor(ProcessMixIn):
vLLM pre and post processing vLLM pre and post processing
""" """
worker = depends(VllmWorker) worker = depends(VllmDecodeWorker)
def __init__(self): def __init__(self):
class_name = self.__class__.__name__ class_name = self.__class__.__name__
...@@ -83,7 +83,7 @@ class Processor(ProcessMixIn): ...@@ -83,7 +83,7 @@ class Processor(ProcessMixIn):
@async_on_start @async_on_start
async def async_init(self): async def async_init(self):
runtime = dynamo_context["runtime"] runtime = dynamo_context["runtime"]
comp_ns, comp_name = VllmWorker.dynamo_address() # type: ignore comp_ns, comp_name = VllmDecodeWorker.dynamo_address() # type: ignore
self.worker_client = ( self.worker_client = (
await runtime.namespace(comp_ns) await runtime.namespace(comp_ns)
.component(comp_name) .component(comp_name)
......
...@@ -21,7 +21,7 @@ Processor: ...@@ -21,7 +21,7 @@ Processor:
router: round-robin router: round-robin
common-configs: [model, block-size, max-model-len] common-configs: [model, block-size, max-model-len]
VllmWorker: VllmDecodeWorker:
enforce-eager: true enforce-eager: true
max-num-batched-tokens: 16384 max-num-batched-tokens: 16384
enable-prefix-caching: true enable-prefix-caching: true
...@@ -33,7 +33,7 @@ VllmWorker: ...@@ -33,7 +33,7 @@ VllmWorker:
gpu: 1 gpu: 1
common-configs: [model, block-size, max-model-len] common-configs: [model, block-size, max-model-len]
EncodeWorker: VllmEncodeWorker:
tensor-parallel-size: 1 tensor-parallel-size: 1
router: random router: random
ServiceArgs: ServiceArgs:
......
...@@ -22,7 +22,7 @@ Processor: ...@@ -22,7 +22,7 @@ Processor:
router: round-robin router: round-robin
common-configs: [model, block-size] common-configs: [model, block-size]
VllmWorker: VllmDecodeWorker:
remote-prefill: true remote-prefill: true
conditional-disagg: true conditional-disagg: true
max-local-prefill-length: 10 max-local-prefill-length: 10
...@@ -33,7 +33,7 @@ VllmWorker: ...@@ -33,7 +33,7 @@ VllmWorker:
gpu: 1 gpu: 1
common-configs: [model, block-size, max-model-len, kv-transfer-config] common-configs: [model, block-size, max-model-len, kv-transfer-config]
PrefillWorker: VllmPrefillWorker:
max-num-batched-tokens: 16384 max-num-batched-tokens: 16384
ServiceArgs: ServiceArgs:
workers: 1 workers: 1
...@@ -41,7 +41,7 @@ PrefillWorker: ...@@ -41,7 +41,7 @@ PrefillWorker:
gpu: 1 gpu: 1
common-configs: [model, block-size, max-model-len, kv-transfer-config] common-configs: [model, block-size, max-model-len, kv-transfer-config]
EncodeWorker: VllmEncodeWorker:
tensor-parallel-size: 1 tensor-parallel-size: 1
router: random router: random
ServiceArgs: ServiceArgs:
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import asyncio
import logging
import socket
import uuid
import zlib
from abc import ABC, abstractmethod
from enum import IntEnum
from functools import cached_property
from typing import Any, List, Optional
import nixl._api as nixl_api
import nixl._bindings as nixl_bindings
import torch
from pydantic import BaseModel, ConfigDict, field_validator
from dynamo.runtime import DistributedRuntime
from dynamo.sdk import dynamo_context
logger = logging.getLogger(__name__)
try:
import cupy as array_module
from cupy_backends.cuda.api.runtime import CUDARuntimeError
logger.info("Utilizing cupy to enable GPU acceleration.")
except ImportError:
try:
import numpy as array_module
logger.warning("Failed to load cupy for GPU acceleration, utilizing numpy to provide CPU based operations.")
except ImportError as e:
raise ImportError("Numpy or cupy must be installed to use this module.") from e
class AbstractOperation(ABC):
"""
Abstract base class for awaitable NIXL based RDMA operations.
"""
def __init__(
self,
connector: Connector,
operation_kind: OperationKind,
local_descriptors: Descriptor | list[Descriptor],
remote_descriptors: Optional[Descriptor | list[Descriptor]],
notification_key: Optional[str],
) -> None:
if not isinstance(connector, Connector):
raise TypeError("Argument `connector` must be `dynamo.connect.Connector`.")
if operation_kind is not OperationKind.READ and operation_kind is not OperationKind.WRITE:
raise ValueError("Argument `operation_kind` must be either `READ` or `WRITE`.")
if not (
isinstance(local_descriptors, (Descriptor, list))
or (isinstance(local_descriptors, list) and all(isinstance(d, Descriptor) for d in local_descriptors))
):
raise TypeError("Argument `local_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.")
if (
remote_descriptors is not None
and not (
isinstance(remote_descriptors, Descriptor)
or (isinstance(remote_descriptors, list) and all(isinstance(d, Descriptor) for d in remote_descriptors))
)
):
raise TypeError("Argument `remote_descriptors` must be dynamo.connect.Descriptor`, `list[dynamo.connect.Descriptor]`, or `None`.")
if isinstance(local_descriptors, list) and len(local_descriptors) == 0:
raise ValueError("Argument `local_descriptors` must not be an empty list.")
if (
remote_descriptors is not None
and isinstance(remote_descriptors, list)
and len(remote_descriptors) == 0
):
raise ValueError("Argument `remote_descriptors` must not be an empty list.")
notification_key = str(uuid.uuid4()) if notification_key is None else notification_key
if not isinstance(notification_key, str):
raise TypeError("Argument `notification_key` must be `str` or `None`.")
if len(notification_key) == 0:
raise ValueError("Argument `notification_key` must not be an empty string.")
self._notification_key: str = "" if notification_key is None else notification_key
self._connector: Connector = connector
self._operation_kind: OperationKind = operation_kind
self._local_descriptors: Descriptor | list[Descriptor] = local_descriptors
self._local_dlist: Optional[list[tuple[int, int, int]]] = None
self._local_memtype: DeviceKind = DeviceKind.UNSPECIFIED
self._remote_descriptors: Optional[Descriptor | list[Descriptor]] = None if remote_descriptors is None else remote_descriptors
self._remote_dlist: Optional[list[tuple[int, int, int]]] = None
self._remote_memtype: DeviceKind = DeviceKind.UNSPECIFIED
# Register local descriptors with NIXL.
# Note: Only local descriptors should be registered with NIXL,
if isinstance(local_descriptors, list):
for d in local_descriptors:
d.register_memory(self._connector)
else:
local_descriptors.register_memory(self._connector)
# Record local descriptors.
memtype, dtlist = self._create_dlist(local_descriptors)
self._local_dlist = dtlist
self._local_memtype = memtype
# Record remote descriptors when provided.
if remote_descriptors is not None:
memtype, dtlist = self._create_dlist(remote_descriptors)
self._remote_dlist = dtlist
self._remote_memtype = memtype
def __del__(self) -> None:
self._release()
def __enter__(self) -> AbstractOperation:
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self._release()
def _release(self) -> None:
"""
Private method to release resources. Only to be called by `self`.
"""
pass
@property
def connector(self) -> Connector:
"""
Gets the local associated with this operation.
"""
return self._connector
@property
def operation_kind(self) -> OperationKind:
"""
Gets the kind of operation.
"""
return self._operation_kind
@abstractmethod
async def wait_for_completion(self) -> None:
"""
Blocks the caller asynchronously until the operation has completed.
"""
raise NotImplementedError("Abstract method not implemented by derived class.")
# Private Methods
def _create_dlist(
self,
descriptors: Descriptor | list[Descriptor],
) -> tuple[DeviceKind, list[tuple[int, int, int]]]:
"""
Helper function to create a list of tuples (ptr, size, device) from descriptors.
"""
dlist: list[tuple[int, int, int]] = []
memtype: DeviceKind = DeviceKind.UNSPECIFIED
if isinstance(descriptors, list):
memtype = descriptors[0].device.kind
for desc in descriptors:
if memtype != desc.device.kind:
raise ValueError("All local descriptors must have the same memory type.")
dlist.append((desc.ptr, desc.size, desc.device.id))
else:
memtype = descriptors.device.kind
dlist.append((descriptors.ptr, descriptors.size, descriptors.device.id))
return (memtype, dlist)
class ActiveOperation(AbstractOperation):
"""
Abstract class for active operations that initiates a NIXL based RDMA transfer based `SerializedRequest`
provided by the remote worker's corresponding `PassiveOperation`.
"""
def __init__(
self,
remote: Remote,
operation_kind: OperationKind,
local_descriptors: Descriptor | list[Descriptor],
remote_descriptors: Descriptor | list[Descriptor],
notification_key: str,
) -> None:
if not isinstance(remote, Remote) or remote._connector is None:
raise TypeError("Argument `remote` must be valid `dynamo.connect.RemoteAgent`.")
if not isinstance(operation_kind, OperationKind):
raise TypeError("Argument `operation_kind` must `dynamo.connect.OperationKind`.")
if operation_kind is not OperationKind.READ and operation_kind is not OperationKind.WRITE:
raise ValueError("Argument `operation_kind` must be either `READ` or `WRITE`.")
if not (
isinstance(local_descriptors, Descriptor)
or (isinstance(local_descriptors, list) and all(isinstance(d, Descriptor) for d in local_descriptors))
):
raise TypeError("Argument `local_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.")
if not (
isinstance(remote_descriptors, Descriptor)
or (isinstance(remote_descriptors, list) and all(isinstance(d, Descriptor) for d in remote_descriptors))
):
raise TypeError("Argument `remote_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.")
# Unpack single descriptors from lists if they are provided as single descriptors.
if isinstance(local_descriptors, list) and len(local_descriptors) == 1:
local_descriptors = local_descriptors[0]
if isinstance(remote_descriptors, list) and len(remote_descriptors) == 1:
remote_descriptors = remote_descriptors[0]
if (isinstance(local_descriptors, list) and isinstance(remote_descriptors, list) and len(local_descriptors) != len(remote_descriptors)):
raise ValueError("When `local_descriptors` and `remote_descriptors` are lists, they must have the same length.")
elif isinstance(local_descriptors, list) != isinstance(remote_descriptors, list):
raise ValueError("Both `local_descriptors` and `remote_descriptors` must be either lists or single descriptors.")
if not isinstance(notification_key, str):
raise TypeError("Argument `notification_key` must be `str`.")
if len(notification_key) == 0:
raise ValueError("Argument `notification_key` must not be an empty string.")
self._remote = remote
self._status = OperationStatus.UNINTIALIZED
super().__init__(remote.connector, operation_kind, local_descriptors, remote_descriptors, notification_key)
# Quick check to ensure remote descriptors are not None to make static analysis happy.
if self._local_dlist is None or self._remote_dlist is None:
raise RuntimeError("NIXL descriptor list(s) not bound to operation.")
self._local_xfer_descs: Optional[nixl_bindings.nixlXferDList] = None
self._remote_xfer_descs: Optional[nixl_bindings.nixlXferDList] = None
self._xfer_hndl: Optional[nixl_api.nixl_xfer_handle] = None
self._local_xfer_descs = self._connector._nixl.get_xfer_descs(
descs=self._local_dlist,
mem_type=str(self._local_memtype),
)
logger.debug(f"Created local NIXL xfer descs: {self._local_xfer_descs}")
self._remote_xfer_descs = self._connector._nixl.get_xfer_descs(
descs=self._remote_dlist,
mem_type=str(self._remote_memtype),
)
logger.debug(f"Created remote NIXL xfer descs: {self._remote_xfer_descs}")
self._xfer_hndl = self._connector._nixl.initialize_xfer(
operation=str(operation_kind),
local_descs=self._local_xfer_descs,
remote_descs=self._remote_xfer_descs,
remote_agent=self._remote.name,
notif_msg=self._notification_key.encode("utf-8"),
)
logger.debug(f"Created NIXL transfer handle: {self._xfer_hndl}")
def __del__(self) -> None:
super().__del__()
self._release()
def __enter__(self) -> ActiveOperation:
super().__enter__()
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
match self.status:
case OperationStatus.IN_PROGRESS | OperationStatus.INITIALIZED:
self._status = OperationStatus.CANCELLED
self._release()
def __repr__(self) -> str:
return str(
f"{self.__class__.__name__}("
f"operation_kind={self._operation_kind}, "
f"local_descriptors={self._local_descriptors}, "
f"remote_descriptors={self._remote_descriptors}, "
f"notification_key='{self._notification_key}', "
f"remote='{self._remote.name}', "
f"status='{self._status}'"
f")"
)
def _release(self) -> None:
"""
Private method to release resources.
"""
error: Optional[Exception] = None
if self._xfer_hndl is not None:
try:
logger.debug(f"NIXL transfer handle {self._xfer_hndl} released.")
self._connector._nixl.release_xfer_handle(self._xfer_hndl)
except Exception as e:
logger.error(f"Failed to release resources: {e}")
error = e
finally:
self._xfer_hndl = None
try:
super()._release()
except Exception as e:
logger.error(f"Failed to release WaitableOperation resources: {e}")
if error is not None:
e.__cause__ = error
error = e
if error is not None:
raise error
def _cancel_(self) -> None:
if self._xfer_hndl is None:
return
if self.status == OperationStatus.ERRORED:
raise RuntimeError("Operation is errored, unable to cancel the operation.")
logger.info(f"Cancellation requested for operation {{ kind={self._operation_kind}, remote='{self._remote.name}', status={self._status} }}.")
# NIXL will cancel the transfer if it is in progress when the handle is released.
self._connector._nixl.release_xfer_handle(self._xfer_hndl)
self._status = OperationStatus.CANCELLED
self._xfer_hndl = None
async def _wait_for_completion_(self) -> None:
# Loop until the operation is no longer in progress (or "initalized"),
# yielding control to the event loop to allow other operations to run.
iteration_count = 0
while True:
if iteration_count & 10 == 0:
logger.debug(f"Waiting for operation {{ kind={self._operation_kind}, remote='{self._remote.name}', duration={iteration_count / 10}s }}.")
match self.status:
# "in progress" or "initialized" means the operation is ongoing.
case OperationStatus.INITIALIZED:
await asyncio.sleep(0.1)
case OperationStatus.IN_PROGRESS:
await asyncio.sleep(0.1)
# Any other state indicates completion or error.
case _:
return
@abstractmethod
def cancel(self) -> None:
"""
Cancels the operation.
No affect if the operation has already completed or errored, or has been cancelled.
"""
raise NotImplementedError("Abstract method not implemented by derived class.")
@property
def remote(self) -> Remote:
"""
Gets the remote agent associated with this operation.
"""
return self._remote
@property
def status(self) -> OperationStatus:
"""
Gets the status of the operation.
"""
# Early return if the operation is already complete, errored, or cancelled.
match self._status:
case OperationStatus.COMPLETE | OperationStatus.ERRORED | OperationStatus.CANCELLED:
return self._status
if self._xfer_hndl is None:
raise RuntimeError("NIXL transfer handle is invalid.")
old_status = self._status
if self._status == OperationStatus.UNINTIALIZED:
state = self._connector._nixl.transfer(self._xfer_hndl, self._notification_key.encode("utf-8"))
logger.debug(f"NIXL reported transfer state: {state}")
if state == "ERR":
self._status = OperationStatus.ERRORED
elif state == "DONE":
self._status = OperationStatus.COMPLETE
else:
self._status = OperationStatus.INITIALIZED
else:
state = self._connector._nixl.check_xfer_state(self._xfer_hndl)
logger.debug(f"NIXL reported transfer state: {state}")
if state == "ERR":
self._status = OperationStatus.ERRORED
elif state == "DONE":
self._status = OperationStatus.COMPLETE
else:
self._status = OperationStatus.IN_PROGRESS
if self._status != old_status:
logger.debug(f"{self.__class__.__name__} {{ remote: '{self._remote.name}' status: '{old_status}' => '{self._status}' }}.")
return self._status
class Connector:
"""
Core class for managing the connection between agents in a distributed environment.
Use this class to create readable and writable operations, or read and write data to remote agents.
"""
def __init__(
self,
namespace: Optional[str] = None,
runtime: Optional[DistributedRuntime] = None,
worker_id: Optional[str] = None,
) -> None:
"""
Creates a new Connector instance.
Parameters
----------
namespace : Optional[str], optional
Dynamo namespace of the component, defaults to "dynamo" when `None`.
runtime : Optional[DistributedRuntime], optional
Reference the dynamo runtime used by the compenent, attempts to use the current runtime when `None`.
worker_id : Optional[str], optional
Unique identifier of the worker, defaults to a new UUID when `None`.
Raises
------
TypeError
When `namespace` is provied and not of type 'str'.
TypeError
When `runtime` iis provied and not of type `dynamo.runtime.DistributedRuntime`.
TypeError
When `worker_id` is provied and not of type `uuid.UUID`.
"""
namespace = "dynamo" if namespace is None else namespace
if not isinstance(namespace, str):
raise TypeError("Argument `namespace` must be `str` or `None`.")
if dynamo_context is not None and "runtime" in dynamo_context:
runtime = dynamo_context["runtime"] if runtime is None else runtime
if not isinstance(runtime, DistributedRuntime) or runtime is None:
raise TypeError("Argument `runtime` must be `dynamo.runtime.DistributedRuntime` or `None`.")
worker_id = worker_id if worker_id is not None else str(uuid.uuid4())
if not isinstance(worker_id, str) or len(worker_id) == 0:
raise TypeError("Argument `worker_id` must be a non-empty `str` or `None`.")
self._worker_id = worker_id
self._is_initialized = False
self._runtime = runtime
self._namespace = namespace
self._nixl = nixl_api.nixl_agent(self._worker_id)
self._hostname = socket.gethostname()
self._agent_metadata: Optional[bytes] = None
logger.debug(f"Created {self.__repr__()}.")
def __repr__(self) -> str:
return str(
f"{self.__class__.__name__}("
f"worker_id='{self._worker_id}', "
f"namespace={self._namespace}, "
f"hostname={self._hostname}, "
f"metadata=<{0 if self._agent_metadata is None else len(self._agent_metadata)} bytes>"
")"
)
def __str__(self) -> str:
return self._worker_id
@cached_property
def is_cuda_available(self) -> bool:
# Note: cuda.is_avalailable initializes cuda
# and can't be called when forking subprocesses
# care should be taken to only call it within
# subprocesses or use 'spawn'
try:
return array_module.cuda is not None and array_module.cuda.is_available()
except CUDARuntimeError:
return False
@property
def metadata(self) -> bytes:
"""
Get the metadata of the agent.
"""
return self._nixl.get_agent_metadata()
@property
def name(self) -> str | None:
"""
Get the name of the agent.
"""
return self._worker_id
@property
def namespace(self) -> str:
"""
Get the namespace of the local.
"""
return self._namespace
@property
def runtime(self) -> DistributedRuntime:
"""
Get the runtime of the local.
"""
if self._runtime is None:
raise RuntimeError("Runtime is not set. This Connector was not initialized with a runtime.")
return self._runtime
async def begin_read(
self,
remote_request: SerializedRequest,
local_descriptors: Descriptor | list[Descriptor],
) -> ReadOperation:
"""
Creates a read operation for fulfilling a remote readable operation.
Parameters
----------
remote_request : SerializedRequest
Serialized request from a remote worker that has created a readable operation.
local_descriptors : Descriptor | list[Descriptor]
Local descriptor(s) to receive data from the remote worker described by `remote_request`.
Returns
-------
ReadOperation
Awaitable read operation that can be used to transfer data from a remote agent.
Raises
------
TypeError
When `remote_request` is not of type `SerializedRequest`.
TypeError
When `local_descriptors` is not of type `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.
"""
if remote_request is None or not isinstance(remote_request, SerializedRequest):
raise TypeError("Argument `remote_request` must be `SerializedRequest`.")
if not (
isinstance(local_descriptors, Descriptor)
or (isinstance(local_descriptors, list) and all(isinstance(d, Descriptor) for d in local_descriptors))
):
raise TypeError("Argument `local_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.")
if remote_request.operation_kind != OperationKind.READ.value:
raise RuntimeError("Cannot create a `dynamo.connect.ReadOperation` to read from a remote `dynamo.connect.WritableOperation`.")
if not self._is_initialized:
raise RuntimeError("Connector not initialized. Call `initialize()` before calling this method.")
op = ReadOperation(self, remote_request, local_descriptors)
return op
async def begin_write(
self,
local_descriptors: Descriptor | list[Descriptor],
remote_request: SerializedRequest,
) -> WriteOperation:
"""
Creates a write operation for transferring data to a remote agent.
Parameters
----------
remote_request : SerializedRequest
Serialized request from a remote worker that has created a readable operation.
local_descriptors : Descriptor | list[Descriptor]
Local descriptors of one or more data objects to be transferred to the remote agent.
"""
if remote_request is None or not isinstance(remote_request, SerializedRequest):
raise TypeError("Argument `remote_request` must be `SerializedRequest`.")
if not (
isinstance(local_descriptors, Descriptor)
or (isinstance(local_descriptors, list) and all(isinstance(d, Descriptor) for d in local_descriptors))
):
raise TypeError("Argument `local_descriptors` must be `Descriptor` or `list[Descriptor]`.")
if remote_request.operation_kind != OperationKind.WRITE:
raise RuntimeError("Cannot create a `WriteOperation` to write to a remote `ReadableOperation`.")
if not isinstance(remote_request.nixl_metadata, str):
raise TypeError("Argument `remote_request.nixl_metadata` must be `str`.")
if not self._is_initialized:
raise RuntimeError("Connector not initialized. Call `initialize()` before calling this method.")
op = WriteOperation(self, local_descriptors, remote_request)
return op
def create_readable(
self,
local_descriptors: Descriptor | list[Descriptor],
) -> ReadableOperation:
"""
Creates a readable operation for transferring data from a remote agent.
Returns
-------
ReadableOperation
A readable operation that can be used to transfer data from a remote agent.
"""
if not self._is_initialized:
raise RuntimeError("Connector not initialized. Call `initialize()` before calling this method.")
op = ReadableOperation(self, local_descriptors)
return op
def create_writable(
self,
local_descriptors: Descriptor | list[Descriptor],
) -> WritableOperation:
"""
Creates a writable operation for transferring data to a remote agent.
Returns
-------
WritableOperation
A writable operation that can be used to transfer data to a remote agent.
"""
if not self._is_initialized:
raise RuntimeError("Connector not initialized. Call `initialize()` before calling this method.")
op = WritableOperation(self, local_descriptors)
return op
async def initialize(self) -> None:
# Only initialize the connector once.
if self._is_initialized:
return
self._is_initialized = True
# This method is a no-op for now, in the future it may be used to initialize the connector.
logger.debug(f"Initialized Connector {{ name: '{self._worker_id}', namespace '{self._namespace}' }} completed.")
class Descriptor:
"""
Memory descriptor that ensures memory is registered w/ NIXL, used for transferring data between workers.
"""
def __init__(
self,
data: torch.Tensor | tuple[array_module.ndarray, Device|str] | bytes | tuple[int, int, Device|str, Any],
) -> None:
"""
Memory descriptor for transferring data between agents.
Parameters
----------
data : torch.Tensor | tuple[ndarray, Device|str] | bytes | tuple[int, int, Device|str, Any]
The data to be transferred.
When `torch.Tensor` is provided, the attributes of the tensor will be used to create the descriptor.
When `tuple[ndarray, Device]` is provided, the tuple must contain:
- `ndarray`: The CuPy or NumPy array to be transferred.
- `Device`: Either a `dynamo.connect.Device` or a string representing the device type (e.g., "cuda" or "cpu").
When `bytes` is provided, the pointer and size derived from the bytes object and memory type will be assumed to be CPU.
When `tuple[int, int, Device|str, Any]` is provided, the tuple must contain the following elements:
- `int`: Pointer to the data in memory.
- `int`: Size of the data in bytes.
- `Device`: Either a `dynamo.connect.Device` or a string representing the device type (e.g., "cuda" or "cpu").
- `Any`: Optional reference to the data (e.g., the original tensor or bytes object).
This is useful for keeping a reference to the data in memory, but it is not required.
Raises
------
ValueError
When `data` is `None`.
TypeError
When `data` is not a valid type (i.e., not `torch.Tensor`, `bytes`, or a valid tuple).
TypeError
When `data` is a tuple but the elements are not of the expected types (i.e., [`ndarray`, `Device|str`] OR [`int`, `int`, `Device|str`, `Any`]).
"""
TYPE_ERROR_MESSAGE = "Argument `data` must be `torch.Tensor`, `tuple[ndarray, Device|str]`, `bytes`, or `tuple[int, int, Device|str, Any]`."
if data is None:
raise ValueError("Argument `data` cannot be `None`.")
if not (isinstance(data, torch.Tensor) or isinstance(data, bytes) or isinstance(data, tuple)):
raise TypeError(TYPE_ERROR_MESSAGE)
self._data_device: Device = Device("cpu")
self._data_ptr: int = 0
self._data_ref: Optional[Any] = None
self._data_size: int = 0
# Member fields for managing NIXL memory registration.
# Note: ONLY local descriptors should be registered with NIXL,
# remote descriptors do not have a valid memory address and registration will fault.
self._connector: Optional[Connector] = None
self._nixl_hndl: Optional[nixl_bindings.nixlRegDList] = None
# Initially `None` cached serialized descriptor reference, populated when `to_serialized()` is called.
self._serialized: Optional[SerializedDescriptor] = None
# Data is `torch.Tensor`.
if isinstance(data, torch.Tensor):
self._data_ptr = data.data_ptr()
self._data_size = data.numel() * data.element_size()
if data.is_cuda:
self._data_device = Device((DeviceKind.CUDA, data.get_device()))
self._data_ref = data
logger.debug(f"Created {self.__repr__()} from `torch.Tensor`.")
# Data is `tuple[ndarray, Device]`.
elif (
isinstance(data, tuple)
and len(data) == 2
and isinstance(data[0], array_module.ndarray)
and (isinstance(data[1], Device) or isinstance(data[1], str))
):
if hasattr(data[0], "__array_interface__"):
self._data_ptr = data[0].__array_interface__["data"][0]
elif hasattr(data[0], "__cuda_array_interface__"):
self._data_ptr = data[0].__cuda_array_interface__["data"][0]
else:
raise TypeError("Argument `data[0]` must be a `ndarray` with a valid array interface.")
self._data_size = data[0].nbytes
self._data_device = data[1] if isinstance(data[1], Device) else Device(data[1])
self._data_ref = data[0]
logger.debug(f"Created {self.__repr__()} from `tuple[ndarray, Device|str]`.")
# Data is `bytes`.
elif isinstance(data, bytes):
self._data_ptr = id(data)
self._data_size = len(data)
self._data_ref = data
logger.debug(f"Created {self.__repr__()} from `bytes`.")
# Data is `tuple[int, int, Device, dtype, tuple, Any]`.
elif isinstance(data, tuple) and len(data) >= 2 and isinstance(data[0], int) and isinstance(data[1], int):
if len(data) >= 3 and not (isinstance(data[2], Device) or isinstance(data[2], str)):
raise TypeError("Argument `data` must be a `tuple[int, int, Device|str, Any]`.")
self._data_ptr = data[0]
self._data_size = data[1]
if len(data) >= 3:
self._data_device = data[2] if isinstance(data[2], Device) else Device(data[2])
self._data_ref = data[3] if len(data) >=4 else None
logger.debug(f"Created {self.__repr__()} from `tuple[int, int, Device|str, Any]`.")
else:
raise TypeError(TYPE_ERROR_MESSAGE)
def __del__(self) -> None:
if self._nixl_hndl is not None and self._connector is not None:
# Unregister the memory with NIXL.
self._connector._nixl.deregister_memory(self._nixl_hndl)
self._nixl_hndl = None
if self._data_ref is not None:
# Release the reference to the data.
del self._data_ref
logger.debug(f"Deleted {self.__repr__()}.")
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self})"
def __str__(self) -> str:
return f"ptr={hex(self._data_ptr)}, size={self._data_size}, device={self._data_device}"
@property
def device(self) -> Device:
"""
Gets the device the of the descriptor.
"""
return self._data_device
@property
def ptr(self) -> int:
"""
Gets the pointer of the descriptor.
"""
return self._data_ptr
@property
def size(self) -> int:
"""
Gets the size of the descriptor.
"""
return self._data_size
@staticmethod
def from_serialized(
serialized: SerializedDescriptor,
) -> Descriptor:
"""
Deserializes a `SerializedDescriptor` into a `Descriptor` object.
Parameters
----------
serialized : SerializedDescriptor
The serialized descriptor to deserialize.
Returns
-------
Descriptor
The deserialized descriptor.
"""
if not isinstance(serialized, SerializedDescriptor):
raise TypeError("Argument `serialized` must be `SerializedDescriptor`.")
return serialized.to_descriptor()
def register_memory(
self,
connector: Connector,
) -> None:
"""
Registers the memory of the descriptor with NIXL.
"""
if not isinstance(connector, Connector):
raise TypeError("Argument `connector` must be `dynamo.connect.Connector`.")
if self._data_ptr == 0:
raise ValueError("Cannot register memory with a null pointer.")
if not (self._nixl_hndl is None and self._connector is None):
return
# Register the memory with NIXL.
self._connector = connector
if isinstance(self._data_ref, torch.Tensor):
self._nixl_hndl = connector._nixl.register_memory(self._data_ref)
else:
mem_type = str(self._data_device.kind)
reg_list = [(self._data_ptr, self._data_size, self._data_device.id, mem_type)]
self._nixl_hndl = connector._nixl.register_memory(reg_list, mem_type)
logger.debug(f"Registered {self.__repr__()} with NIXL.")
def to_serialized(self) -> SerializedDescriptor:
"""
Serializes the descriptor into a `SerializedDescriptor` object.
"""
if self._serialized is None:
self._serialized = SerializedDescriptor(
device=f"{self._data_device}",
ptr=self._data_ptr,
size=self._data_size,
)
return self._serialized
class Device:
"""
Represents a device in the system.
"""
def __init__(
self,
metadata: str | tuple[DeviceKind, int],
) -> None:
if metadata is None:
raise ValueError("Argument `metadata` cannot be `None`.")
if isinstance(metadata, tuple) and len(metadata) == 2 and isinstance(metadata[0], DeviceKind) and isinstance(metadata[1], int):
kind, device_id = metadata
elif isinstance(metadata, str):
metadata = metadata.strip().lower()
if metadata.startswith("cuda") or metadata.startswith("gpu"):
kind = DeviceKind.CUDA
device_id = 0 if metadata.find(":") == -1 else int(metadata.split(":")[1])
elif metadata.startswith("cpu") or metadata.startswith("host"):
kind = DeviceKind.HOST
device_id = 0
else:
raise ValueError("Argument `metadata` must be in the format 'cuda:<device_id>' or 'cpu'.")
else:
raise TypeError("Argument `metadata` must be a `tuple[MemoryKind, int]` or a `str`.")
self._device_id = device_id
self._kind = kind
def __repr__(self) -> str:
return f"{self.__class__.__name__}(kind={self._kind}, id={self._device_id})"
def __str__(self) -> str:
return f"{self._kind}:{self._device_id}" if self._kind is DeviceKind.CUDA else f"{self._kind}"
@property
def id(self) -> int:
"""
Gets the device ID of the device.
"""
return self._device_id
@property
def kind(self) -> DeviceKind:
"""
Gets the memory kind of the device.
"""
return self._kind
class DeviceKind(IntEnum):
"""
Type of memory a descriptor has been allocated to.
"""
UNSPECIFIED = 0
HOST = 1
CUDA = 2
def __str__(self) -> str:
if self == DeviceKind.HOST:
return "cpu"
elif self == DeviceKind.CUDA:
return "cuda"
else:
return "<invalid>"
class OperationKind(IntEnum):
"""
Kind of an operation.
"""
UNSPECIFIED = 0
READ = 1
WRITE = 2
def __str__(self) -> str:
if self == OperationKind.READ:
return "READ"
elif self == OperationKind.WRITE:
return "WRITE"
else:
return "<invalid>"
class OperationStatus(IntEnum):
"""
Status of an operation.
"""
UNINTIALIZED = 0
INITIALIZED = 1
IN_PROGRESS = 2
COMPLETE = 3
CANCELLED = 4
ERRORED = 5
def __str__(self) -> str:
if self == OperationStatus.INITIALIZED:
return "INIT"
elif self == OperationStatus.IN_PROGRESS:
return "PROC"
elif self == OperationStatus.COMPLETE:
return "DONE"
elif self == OperationStatus.ERRORED:
return "ERR"
elif self == OperationStatus.CANCELLED:
return "STOP"
else:
return "<invalid>"
class PassiveOperation(AbstractOperation):
"""
Abstract class for common functionality of passive operations.
"""
def __init__(
self,
connector: Connector,
operation_kind: OperationKind,
local_descriptors: Descriptor | list[Descriptor],
) -> None:
if operation_kind is not OperationKind.READ and operation_kind is not OperationKind.WRITE:
raise ValueError("Argument `operation_kind` must be either `READ` or `WRITE`.")
self._status = OperationStatus.UNINTIALIZED
super().__init__(connector, operation_kind, local_descriptors, None, None)
self._serialized_request: Optional[SerializedRequest] = None
self._status = OperationStatus.INITIALIZED
def __del__(self) -> None:
super().__del__()
def __enter__(self) -> AbstractOperation:
super().__enter__()
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
super().__exit__(exc_type, exc_value, traceback)
def __repr__(self) -> str:
return str(
f"{self.__class__.__name__}("
f"operation_kind={self._operation_kind}, "
f"local_descriptors={self._local_descriptors}, "
f"notification_key='{self._notification_key}', "
f"status='{self._status}'"
f")"
)
async def _wait_for_completion_(self) -> None:
# Loop until the operation is no longer in progress (or "initalized"),
# yielding control to the event loop to allow other operations to run.
while True:
match self.status:
# "in progress" or "initialized" means the operation is ongoing.
case OperationStatus.INITIALIZED:
await asyncio.sleep(0.1)
case OperationStatus.IN_PROGRESS:
await asyncio.sleep(0.1)
# Any other state indicates completion or error.
case _:
return
@property
def status(self) -> OperationStatus:
"""
Gets the status of the operation.
"""
# Early return if the operation is already complete, errored, or cancelled.
match self._status:
case OperationStatus.COMPLETE | OperationStatus.ERRORED | OperationStatus.CANCELLED:
return self._status
old_status = self._status
# Query NIXL for any notifications.
notifications = self._connector._nixl.update_notifs()
if isinstance(notifications, dict):
remote_state = OperationStatus.IN_PROGRESS
logger.debug(f"NIXL reported notifications: {len(notifications)}.")
for key, values in notifications.items():
if not isinstance(values, list):
raise TypeError(f"Expected `dict[str, list[bytes]]` from NIXL notification query; got {type(notifications)}.")
for value in values:
if not isinstance(value, bytes):
continue
notification_key = value.decode("utf-8")
# Once we've found the notification key, we know the operation is complete.
if notification_key == self._notification_key:
remote_state = OperationStatus.COMPLETE
break
if remote_state == OperationStatus.COMPLETE:
self._status = remote_state
logger.debug(f"{self.__class__.__name__} {{ remote: '{self._connector.name}' status: '{old_status}' => '{self._status}' }}.")
return self._status
def to_serialized(self) -> SerializedRequest:
"""
Gets the request descriptor for the operation.
"""
if self._serialized_request is None:
# When we've not yet cached the serialized request, we need to generate one before returning it.
# Handle both cases: multiple and single descriptors.
if isinstance(self._local_descriptors, list):
descriptors = [desc.to_serialized() for desc in self._local_descriptors]
else:
descriptors = [self._local_descriptors.to_serialized()]
original_len = len(self._connector.metadata)
nixl_metadata = self._connector.metadata
nixl_metadata = zlib.compress(nixl_metadata, level=6)
compressed_len = len(nixl_metadata)
logger.debug(f"Compressed NIXL metadata from {original_len} bytes to {compressed_len} bytes.")
if compressed_len > original_len:
logger.warning(f"Compressed NIXL metadata is larger than original ({compressed_len} > {original_len}).")
self._serialized_request = SerializedRequest(
descriptors=descriptors,
nixl_metadata=nixl_metadata.hex(),
notification_key=self._notification_key,
operation_kind=int(self._operation_kind),
)
return self._serialized_request
@abstractmethod
async def wait_for_completion(self) -> None:
"""
Blocks the caller asynchronously until the operation has completed.
"""
raise NotImplementedError("Abstract method not implemented by derived class.")
class ReadOperation(ActiveOperation):
"""
Operation that initiates an RDMA read operation to transfer data from a remote worker's `ReadableOperation`,
as described by `remote_request`, to local buffers.
"""
def __init__(
self,
connector: Connector,
remote_request: SerializedRequest,
local_descriptors: Descriptor | list[Descriptor],
) -> None:
"""
Creates a new instance of `ReadOperation`, registers `local_descriptors` with NIXL,
and begins an RDMA read operation which will transfer data described by `remote_request`
to `local_descriptors`.
Parameters
----------
connector : Connector
Connector instance to use for the operation.
remote_request : SerializedRequest
Serialized request from the remote worker.
local_descriptors : Descriptor | list[Descriptor]
Local descriptor(s) to to receive the data from the remote agent.
"""
if not isinstance(connector, Connector):
raise TypeError("Argument `connector` must be `dynamo.connect.Connector`.")
if not isinstance(remote_request, SerializedRequest):
raise TypeError("Argument `remote_request` must be `dynamo.connect.RequestDescriptor`.")
if remote_request.operation_kind != OperationKind.READ.value:
raise ValueError("Argument `remote_request` must be of kind `READ`.")
remote = Remote(connector, remote_request.nixl_metadata)
remote_descriptors = remote_request.to_descriptors()
if not (
isinstance(local_descriptors, Descriptor)
or (isinstance(local_descriptors, list) and all(isinstance(d, Descriptor) for d in local_descriptors))
):
raise TypeError("Argument `local_descriptors` must be `dynamo.connect.Descriptor`, `list[dynamo.connect.Descriptor]`.")
super().__init__(remote, OperationKind.READ, local_descriptors, remote_descriptors, remote_request.notification_key)
logger.debug(f"Created {self.__repr__()}")
def __del__(self) -> None:
super().__del__()
logger.debug(f"Deleted {self.__repr__()}")
def __enter__(self) -> ReadOperation:
super().__enter__()
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
super().__exit__(exc_type, exc_value, traceback)
def __repr__(self) -> str:
return super().__repr__()
def cancel(self) -> None:
"""
Cancels the operation.
No affect if the operation has already completed or errored, or been cancelled.
"""
super()._cancel_()
def results(self) -> list[Descriptor]:
"""
Gets the results of the operation.
Returns a single descriptor if only one was requested, or a list of descriptors if multiple were requested.
"""
if self._status != OperationStatus.COMPLETE:
raise RuntimeError("Operation has not completed yet, cannot get results.")
return self._local_descriptors if isinstance(self._local_descriptors, list) else [self._local_descriptors]
async def wait_for_completion(self) -> None:
"""
Blocks the caller asynchronously until the operation has completed.
"""
await super()._wait_for_completion_()
class ReadableOperation(PassiveOperation):
"""
Operation that can be awaited until a remote worker has completed a `ReadOperation`.
"""
def __init__(
self,
connector: Connector,
local_descriptors: Descriptor | list[Descriptor],
) -> None:
super().__init__(connector, OperationKind.READ, local_descriptors)
logger.debug(f"Created {self.__repr__()}")
def __del__(self) -> None:
super().__del__()
logger.debug(f"Deleted {self.__repr__()}")
def __enter__(self) -> ReadableOperation:
super().__enter__()
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
super().__exit__(exc_type, exc_value, traceback)
def __repr__(self) -> str:
return super().__repr__()
async def wait_for_completion(self) -> None:
"""
Blocks the caller asynchronously until the operation has completed.
"""
await super()._wait_for_completion_()
class Remote:
"""
Identifies a remote NIXL agent relative to a local NIXL agent.
"""
def __init__(
self,
connector: Connector,
nixl_metadata: bytes | str,
) -> None:
if not isinstance(connector, Connector):
raise TypeError("Argument `local` must be `dynamo.connect.Connector`.")
if not (isinstance(nixl_metadata, bytes) or isinstance(nixl_metadata, str)):
raise TypeError("Argument `nixl_metadata` must be `bytes` or `str`.")
if len(nixl_metadata) == 0:
raise ValueError("Argument `nixl_metadata` cannot be empty.")
self._connector = connector
# When `nixl_metadata` is a string, it is assumed to have come from a remote worker
# via a `SerializedRequest` object and therefore can assumed be a hex-encoded, compressed
# representation of the NIXL metadata.
if isinstance(nixl_metadata, str):
# Decode the hex-encoded string into bytes.
nixl_metadata = bytes.fromhex(nixl_metadata)
# Decompress the NIXL metadata.
nixl_metadata = zlib.decompress(nixl_metadata)
self._name = connector._nixl.add_remote_agent(nixl_metadata)
if isinstance(self._name, bytes):
self._name = self._name.decode("utf-8")
logger.debug(f"Created {self.__repr__()}.")
def __del__(self) -> None:
self._release()
def __enter__(self) -> Remote:
"""
Context manager entry method. Returns the current instance.
"""
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
"""
Context manager exit method. Cleans up the instance.
"""
self._release()
def __repr__(self) -> str:
return f"RemoteAgent(name={self._name}, connector={self._connector.name})"
def __str__(self) -> str:
return self._name
def _release(self) -> None:
"""
Private method for releasing NIXL resources. Not intended for public use.
"""
pass
@property
def connector(self) -> Connector:
"""
Gets the local connector associated with this remote agent.
"""
return self._connector
@property
def name(self) -> str:
"""
Gets the name of the remote agent.
"""
return self._name
class SerializedDescriptor(BaseModel):
"""
Pydantic serialization type for memory descriptors.
"""
model_config = ConfigDict(
extra="forbid",
frozen=True,
arbitrary_types_allowed=True,
)
device: str = "cpu"
ptr: int = 0
size: int = 0
def to_descriptor(self) -> Descriptor:
"""
Deserialize the serialized descriptor into a `Descriptor` object.
"""
return Descriptor(data=(self.ptr, self.size, self.device, None))
@field_validator("device")
def validate_memtype(cls, v: str) -> str:
if not isinstance(v, str):
raise TypeError("Argument `device` must be `str`.")
v = v.strip().lower()
if not (v.startswith("cuda") or v == "cpu"):
raise ValueError("Argument `device` must be one of 'cpu' or 'cuda:<device_id>'.")
return v
@field_validator("ptr")
def validate_ptr(cls, v: int) -> int:
if v == 0:
raise ValueError("Argument `ptr` cannot be zero (aka `null` or `None`).")
return v
@field_validator("size")
def validate_size(cls, v: int) -> int:
if v < 0:
raise ValueError("Argument `size` must be an integer greater than or equal to zero.")
return v
class SerializedRequest(BaseModel):
"""
Pydantic serialization type for describing the passive side of a transfer.
"""
model_config = ConfigDict(
extra="forbid",
frozen=True,
arbitrary_types_allowed=True,
)
descriptors: List[SerializedDescriptor] = []
nixl_metadata: str = ""
notification_key: str = ""
operation_kind: int = 0
def to_descriptors(self) -> Descriptor | list[Descriptor]:
"""
Deserializes the request descriptor into a `dynamo.connect.Descriptor` or list of `dynamo.connect.Descriptor` objects.
"""
if len(self.descriptors) == 0:
raise ValueError("Request descriptor must contain at least one serialized descriptor.")
if len(self.descriptors) == 1:
return self.descriptors[0].to_descriptor()
return [item.to_descriptor() for item in self.descriptors]
@field_validator("operation_kind")
def validate_operation_kind(cls, v: int) -> int:
if v < 1 or v > 3:
raise TypeError("Argument `operation_kind` must be an integer value of `dynamo.connect.OperationKind`.")
return v
class WritableOperation(PassiveOperation):
"""
Operation which can be awaited until written to by a `WriteOperation` from a remote worker.
"""
def __init__(
self,
connector: Connector,
local_descriptors: Descriptor | list[Descriptor],
) -> None:
"""
Creates a new instance of `WritableOperation`, registers the operation and descriptors w/ NIXL,
and enables an RDMA write operation to occur.
Parameters
----------
connector : Connector
Connector instance to use for the operation.
local_descriptors : Descriptor | list[Descriptor]
Descriptors to receive data from a remote worker.
Raises
TypeError
When `local` is not a `dynamo.connect.Connector`.
TypeError
When `local_descriptors` is not a `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.
"""
super().__init__(connector, OperationKind.WRITE, local_descriptors)
logger.debug(f"Created {self.__repr__()}")
def __del__(self) -> None:
super().__del__()
logger.debug(f"Deleted {self.__repr__()}")
def __enter__(self) -> WritableOperation:
super().__enter__()
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
super().__exit__(exc_type, exc_value, traceback)
def __repr__(self) -> str:
return super().__repr__()
async def wait_for_completion(self) -> None:
"""
Blocks the caller asynchronously until the operation has completed.
"""
await super()._wait_for_completion_()
class WriteOperation(ActiveOperation):
"""
Awaitable write operation which initiates an RDMA write operation to a remote worker
which provided a `SerializedRequest` object from a `WritableOperation`.
"""
def __init__(
self,
connector: Connector,
local_descriptors: Descriptor | list[Descriptor],
remote_request: SerializedRequest,
) -> None:
"""
Creates a new instance of `WriteOperation`, registers `local_descriptors` with NIXL,
and begins an RDMA write operation which will transfer from `local_descriptors` to
remote target(s) described by `remote_request`
Parameters
----------
connector : Connector
Connector instance to use for the operation.
local_descriptors : Descriptor | list[Descriptor]
Local descriptor(s) to send from, to the remote agent.
remote_request : SerializedRequest
Serialized request from the remote worker that describes the target(s) to send to.
Raises
TypeError
When `connector` is not a `dynamo.connect.Connector`.
TypeError
When `remote_request` is not a `dynamo.connect.RequestDescriptor`.
ValueError
When `remote_request` is not of kind `WRITE`.
ValueError
When `remote_request.nixl_metadata` is not a non-empty `str`.
TypeError
When `local_descriptors` is not a `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.
"""
if not isinstance(connector, Connector):
raise TypeError("Argument `connector` must be `dynamo.connect.Connector`.")
if not isinstance(remote_request, SerializedRequest):
raise TypeError("Argument `remote_request` must be `dynamo.connect.RequestDescriptor`.")
if remote_request.operation_kind != OperationKind.WRITE.value:
raise ValueError("Argument `remote_request` must be of kind `WRITE`.")
remote = Remote(connector, remote_request.nixl_metadata)
remote_descriptors = remote_request.to_descriptors()
super().__init__(remote, OperationKind.WRITE, local_descriptors, remote_descriptors, remote_request.notification_key)
logger.debug(f"Created {self.__repr__()}")
def __del__(self) -> None:
super().__del__()
logger.debug(f"Deleted {self.__repr__()}")
def __enter__(self) -> WriteOperation:
super().__enter__()
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
super().__exit__(exc_type, exc_value, traceback)
def __repr__(self) -> str:
return super().__repr__()
def cancel(self) -> None:
"""
Cancels the operation.
No affect if the operation has already completed or errored, or has been cancelled.
"""
super()._cancel_()
async def wait_for_completion(self) -> None:
"""
Blocks the caller asynchronously until the operation has completed.
"""
await super()._wait_for_completion_()
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -13,9 +14,9 @@ ...@@ -13,9 +14,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from components.encode_worker import EncodeWorker from components.decode_worker import VllmDecodeWorker
from components.encode_worker import VllmEncodeWorker
from components.frontend import Frontend from components.frontend import Frontend
from components.processor import Processor from components.processor import Processor
from components.worker import VllmWorker
Frontend.link(Processor).link(VllmWorker).link(EncodeWorker) Frontend.link(Processor).link(VllmDecodeWorker).link(VllmEncodeWorker)
...@@ -13,10 +13,12 @@ ...@@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from components.encode_worker import EncodeWorker from components.decode_worker import VllmDecodeWorker
from components.encode_worker import VllmEncodeWorker
from components.frontend import Frontend from components.frontend import Frontend
from components.prefill_worker import PrefillWorker from components.prefill_worker import VllmPrefillWorker
from components.processor import Processor from components.processor import Processor
from components.worker import VllmWorker
Frontend.link(Processor).link(VllmWorker).link(PrefillWorker).link(EncodeWorker) Frontend.link(Processor).link(VllmDecodeWorker).link(VllmPrefillWorker).link(
VllmEncodeWorker
)
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import json import json
from typing import Any, List, Optional from typing import Any, List, Optional
import connect
import msgspec import msgspec
from pydantic import BaseModel, ConfigDict, field_validator from pydantic import BaseModel, ConfigDict, field_validator
from pydantic_core import core_schema from pydantic_core import core_schema
...@@ -111,12 +112,13 @@ class EncodeRequest(BaseModel): ...@@ -111,12 +112,13 @@ class EncodeRequest(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
image_url: str image_url: str
request_id: str
serialized_request: Optional[connect.SerializedRequest] = None
class EncodeResponse(BaseModel): class EncodeResponse(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
request_id: str
image_features: List[List[List[float]]]
class MyRequestOutput(BaseModel): class MyRequestOutput(BaseModel):
...@@ -129,7 +131,6 @@ class MyRequestOutput(BaseModel): ...@@ -129,7 +131,6 @@ class MyRequestOutput(BaseModel):
""" """
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
request_id: str request_id: str
prompt: Optional[str] = None prompt: Optional[str] = None
prompt_token_ids: Optional[List[int]] = None prompt_token_ids: Optional[List[int]] = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment