Unverified Commit 96fe63fe authored by J Wyman's avatar J Wyman Committed by GitHub
Browse files

feat: nixl_connect: Improve Concurrency Support (#4433)


Signed-off-by: default avatarJ Wyman <jwyman@nvidia.com>
parent 0ce7280d
......@@ -159,7 +159,7 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
# Create descriptor for the multimodal data
descriptor = connect.Descriptor(precomputed_embeddings)
with self._connector.create_readable(descriptor) as readable:
with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata()
logger.debug(f"Request: {request.model_dump_json()}")
......@@ -184,6 +184,5 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Startup completed.")
......@@ -77,7 +77,6 @@ class EmbeddingsProcessor:
async def initialize(self):
"""Initialize the connector for embeddings processing"""
self._connector = connect.Connector()
await self._connector.initialize()
async def process_embeddings(self, request: SglangMultimodalRequest):
"""Process embeddings from serialized request"""
......@@ -103,7 +102,6 @@ class EmbeddingsProcessor:
"Connector is None - this should not happen after initialization"
)
self._connector = connect.Connector()
await self._connector.initialize()
read_op = await self._connector.begin_read(
request.serialized_request, descriptor
......
......@@ -241,7 +241,7 @@ class EncodeHelper:
# Create readable operation with main embeddings tensor (works for both formats)
descriptor = nixl_connect.Descriptor(encodings)
with connector.create_readable(descriptor) as readable_op:
with await connector.create_readable(descriptor) as readable_op:
# Get the metadata for the readable operation
op_metadata = readable_op.metadata()
......
......@@ -332,7 +332,6 @@ async def init(runtime: DistributedRuntime, config: Config):
connector = None
logging.info("Initializing NIXL Connect.")
connector = nixl_connect.Connector()
await connector.initialize()
dump_config(
config.dump_config_to, {"engine_args": engine_args, "dynamo_args": config}
......
......@@ -69,7 +69,6 @@ class EncodeWorkerHandler:
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Encode worker startup completed.")
async def generate(
......@@ -130,7 +129,7 @@ class EncodeWorkerHandler:
request.embeddings_shape = tuple(embeddings.shape)
descriptor = connect.Descriptor(embeddings_cpu)
with self._connector.create_readable(descriptor) as readable:
with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata()
# Clear the image URL as hint that the image is passed as embeddings.
request.multimodal_input.image_url = None
......
......@@ -52,7 +52,6 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
async def async_init(self, runtime: DistributedRuntime):
"""Async initialization - connector needs async setup"""
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Multimodal Decode Worker async initialization completed.")
async def generate(self, request: vLLMMultimodalRequest, context):
......@@ -138,7 +137,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
"""Async initialization for connector that requires async setup"""
# Initialize the connector asynchronously
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Multimodal PD Worker async initialization completed.")
async def generate(self, request: vLLMMultimodalRequest, context):
......
......@@ -47,7 +47,6 @@ The metadata contains required information (identifiers, keys, etc.) which enabl
@async_on_start
async def async_init(self):
self.connector = dynamo.nixl_connect.Connector()
await self.connector.initialize()
```
> [!Tip]
......@@ -109,7 +108,7 @@ Use [`.wait_for_completion()`](write_operation.md#wait_for_completion) to block
### `create_readable`
```python
def create_readable(
async def create_readable(
self,
local_descriptors: Descriptor | list[Descriptor],
) -> ReadableOperation:
......@@ -130,7 +129,7 @@ Use [`.wait_for_completion()`](readable_operation.md#wait_for_completion) to blo
### `create_writable`
```python
def create_writable(
async def create_writable(
self,
local_descriptors: Descriptor | list[Descriptor],
) -> WritableOperation:
......@@ -151,6 +150,15 @@ Use [`.wait_for_completion()`](writable_operation.md#wait_for_completion) to blo
## Properties
### `hostname`
```python
@property
def hostname(self) -> str:
```
Gets the name of the current worker's host.
### `is_cuda_available`
```python
......@@ -169,22 +177,6 @@ def name(self) -> str | None:
Gets the Dynamo component name used by the connector.
### `namespace`
```python
@property
def namespace(self) -> str:
```
Gets the Dynamo namespace used by the connector.
### `runtime`
```python
def runtime(self) -> dynamo.runtime.DistributedRuntime:
```
Gets the Dynamo distributed runtime instance associated with the connector.
## Related Classes
......
......@@ -38,7 +38,7 @@ therefore the operation should be awaited until completed unless cancellation is
) -> None:
descriptor = dynamo.nixl_connect.Descriptor(local_tensor)
with self.connector.begin_read(descriptor, remote_metadata) as read_op:
with await self.connector.begin_read(remote_metadata, descriptor) as read_op:
# Wait for the operation to complete writing data from the remote worker to local_tensor.
await read_op.wait_for_completion()
```
......
......@@ -37,7 +37,7 @@ therefore the operation should be awaited until completed unless cancellation is
) -> None:
descriptor = dynamo.nixl_connect.Descriptor(local_tensor)
with self.connector.create_readable(descriptor) as read_op:
with await self.connector.create_readable(descriptor) as read_op:
op_metadata = read_op.metadata()
# Send the metadata to the remote worker via sideband communication.
......
......@@ -38,7 +38,7 @@ Cancellation is handled asynchronously.
) -> None:
descriptor = dynamo.nixl_connect.Descriptor(local_tensor)
with self.connector.create_writable(descriptor) as write_op:
with await self.connector.create_writable(descriptor) as write_op:
op_metadata = write_op.metadata()
# Send the metadata to the remote worker via sideband communication.
......
......@@ -39,7 +39,7 @@ Cancellation is handled asynchronously.
) -> None:
descriptor = dynamo.nixl_connect.Descriptor(local_tensor)
with self.connector.begin_write(descriptor, remote_metadata) as write_op:
with await self.connector.begin_write(descriptor, remote_metadata) as write_op:
# Wait for the operation to complete writing local_tensor to the remote worker.
await write_op.wait_for_completion()
```
......
......@@ -222,7 +222,7 @@ NIXL is used only for embedding transfer:
```python
Encode Worker:
descriptor = connect.Descriptor(precomputed_embeddings)
with connector.create_readable(descriptor) as readable:
with await connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata()
# Send request with NIXL metadata
await pd_worker_client.round_robin(request)
......
......@@ -168,7 +168,7 @@ class VllmEncodeWorker:
with torch.no_grad():
audio_embeddings = self.get_audio_embeddings(audio_features)
descriptor = connect.Descriptor(audio_embeddings)
with self._connector.create_readable(descriptor) as readable:
with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata()
# Clear the audio URL as hint that the audio is passed as embeddings.
request.multimodal_input.audio_url = None
......
......@@ -125,7 +125,7 @@ class VllmEncodeWorker:
request.embeddings_shape = tuple(embeddings.shape)
descriptor = connect.Descriptor(embeddings)
with self._connector.create_readable(descriptor) as readable:
with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata()
# Clear the image URL as hint that the image is passed as embeddings.
request.multimodal_input.image_url = None
......@@ -158,7 +158,6 @@ class VllmEncodeWorker:
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Startup completed.")
......
......@@ -153,7 +153,7 @@ class VllmEncodeWorker:
request.embeddings_shape = tuple(tensor_for_descriptor.shape)
descriptor = connect.Descriptor(tensor_for_descriptor)
with self._connector.create_readable(descriptor) as readable:
with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata()
# Clear the image URL as hint that the image is passed as embeddings.
request.multimodal_input.video_url = None
......@@ -199,7 +199,6 @@ class VllmEncodeWorker:
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Startup completed.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment