Unverified Commit 96fe63fe authored by J Wyman's avatar J Wyman Committed by GitHub
Browse files

feat: nixl_connect: Improve Concurrency Support (#4433)


Signed-off-by: default avatarJ Wyman <jwyman@nvidia.com>
parent 0ce7280d
...@@ -159,7 +159,7 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler): ...@@ -159,7 +159,7 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
# Create descriptor for the multimodal data # Create descriptor for the multimodal data
descriptor = connect.Descriptor(precomputed_embeddings) descriptor = connect.Descriptor(precomputed_embeddings)
with self._connector.create_readable(descriptor) as readable: with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata() request.serialized_request = readable.metadata()
logger.debug(f"Request: {request.model_dump_json()}") logger.debug(f"Request: {request.model_dump_json()}")
...@@ -184,6 +184,5 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler): ...@@ -184,6 +184,5 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
# Create and initialize a dynamo connector for this worker. # Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently. # We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector() self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Startup completed.") logger.info("Startup completed.")
...@@ -77,7 +77,6 @@ class EmbeddingsProcessor: ...@@ -77,7 +77,6 @@ class EmbeddingsProcessor:
async def initialize(self): async def initialize(self):
"""Initialize the connector for embeddings processing""" """Initialize the connector for embeddings processing"""
self._connector = connect.Connector() self._connector = connect.Connector()
await self._connector.initialize()
async def process_embeddings(self, request: SglangMultimodalRequest): async def process_embeddings(self, request: SglangMultimodalRequest):
"""Process embeddings from serialized request""" """Process embeddings from serialized request"""
...@@ -103,7 +102,6 @@ class EmbeddingsProcessor: ...@@ -103,7 +102,6 @@ class EmbeddingsProcessor:
"Connector is None - this should not happen after initialization" "Connector is None - this should not happen after initialization"
) )
self._connector = connect.Connector() self._connector = connect.Connector()
await self._connector.initialize()
read_op = await self._connector.begin_read( read_op = await self._connector.begin_read(
request.serialized_request, descriptor request.serialized_request, descriptor
......
...@@ -241,7 +241,7 @@ class EncodeHelper: ...@@ -241,7 +241,7 @@ class EncodeHelper:
# Create readable operation with main embeddings tensor (works for both formats) # Create readable operation with main embeddings tensor (works for both formats)
descriptor = nixl_connect.Descriptor(encodings) descriptor = nixl_connect.Descriptor(encodings)
with connector.create_readable(descriptor) as readable_op: with await connector.create_readable(descriptor) as readable_op:
# Get the metadata for the readable operation # Get the metadata for the readable operation
op_metadata = readable_op.metadata() op_metadata = readable_op.metadata()
......
...@@ -332,7 +332,6 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -332,7 +332,6 @@ async def init(runtime: DistributedRuntime, config: Config):
connector = None connector = None
logging.info("Initializing NIXL Connect.") logging.info("Initializing NIXL Connect.")
connector = nixl_connect.Connector() connector = nixl_connect.Connector()
await connector.initialize()
dump_config( dump_config(
config.dump_config_to, {"engine_args": engine_args, "dynamo_args": config} config.dump_config_to, {"engine_args": engine_args, "dynamo_args": config}
......
...@@ -69,7 +69,6 @@ class EncodeWorkerHandler: ...@@ -69,7 +69,6 @@ class EncodeWorkerHandler:
# Create and initialize a dynamo connector for this worker. # Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently. # We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector() self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Encode worker startup completed.") logger.info("Encode worker startup completed.")
async def generate( async def generate(
...@@ -130,7 +129,7 @@ class EncodeWorkerHandler: ...@@ -130,7 +129,7 @@ class EncodeWorkerHandler:
request.embeddings_shape = tuple(embeddings.shape) request.embeddings_shape = tuple(embeddings.shape)
descriptor = connect.Descriptor(embeddings_cpu) descriptor = connect.Descriptor(embeddings_cpu)
with self._connector.create_readable(descriptor) as readable: with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata() request.serialized_request = readable.metadata()
# Clear the image URL as hint that the image is passed as embeddings. # Clear the image URL as hint that the image is passed as embeddings.
request.multimodal_input.image_url = None request.multimodal_input.image_url = None
......
...@@ -52,7 +52,6 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler): ...@@ -52,7 +52,6 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
async def async_init(self, runtime: DistributedRuntime): async def async_init(self, runtime: DistributedRuntime):
"""Async initialization - connector needs async setup""" """Async initialization - connector needs async setup"""
self._connector = connect.Connector() self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Multimodal Decode Worker async initialization completed.") logger.info("Multimodal Decode Worker async initialization completed.")
async def generate(self, request: vLLMMultimodalRequest, context): async def generate(self, request: vLLMMultimodalRequest, context):
...@@ -138,7 +137,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -138,7 +137,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
"""Async initialization for connector that requires async setup""" """Async initialization for connector that requires async setup"""
# Initialize the connector asynchronously # Initialize the connector asynchronously
self._connector = connect.Connector() self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Multimodal PD Worker async initialization completed.") logger.info("Multimodal PD Worker async initialization completed.")
async def generate(self, request: vLLMMultimodalRequest, context): async def generate(self, request: vLLMMultimodalRequest, context):
......
...@@ -47,7 +47,6 @@ The metadata contains required information (identifiers, keys, etc.) which enabl ...@@ -47,7 +47,6 @@ The metadata contains required information (identifiers, keys, etc.) which enabl
@async_on_start @async_on_start
async def async_init(self): async def async_init(self):
self.connector = dynamo.nixl_connect.Connector() self.connector = dynamo.nixl_connect.Connector()
await self.connector.initialize()
``` ```
> [!Tip] > [!Tip]
...@@ -109,7 +108,7 @@ Use [`.wait_for_completion()`](write_operation.md#wait_for_completion) to block ...@@ -109,7 +108,7 @@ Use [`.wait_for_completion()`](write_operation.md#wait_for_completion) to block
### `create_readable` ### `create_readable`
```python ```python
def create_readable( async def create_readable(
self, self,
local_descriptors: Descriptor | list[Descriptor], local_descriptors: Descriptor | list[Descriptor],
) -> ReadableOperation: ) -> ReadableOperation:
...@@ -130,7 +129,7 @@ Use [`.wait_for_completion()`](readable_operation.md#wait_for_completion) to blo ...@@ -130,7 +129,7 @@ Use [`.wait_for_completion()`](readable_operation.md#wait_for_completion) to blo
### `create_writable` ### `create_writable`
```python ```python
def create_writable( async def create_writable(
self, self,
local_descriptors: Descriptor | list[Descriptor], local_descriptors: Descriptor | list[Descriptor],
) -> WritableOperation: ) -> WritableOperation:
...@@ -151,6 +150,15 @@ Use [`.wait_for_completion()`](writable_operation.md#wait_for_completion) to blo ...@@ -151,6 +150,15 @@ Use [`.wait_for_completion()`](writable_operation.md#wait_for_completion) to blo
## Properties ## Properties
### `hostname`
```python
@property
def hostname(self) -> str:
```
Gets the name of the current worker's host.
### `is_cuda_available` ### `is_cuda_available`
```python ```python
...@@ -169,22 +177,6 @@ def name(self) -> str | None: ...@@ -169,22 +177,6 @@ def name(self) -> str | None:
Gets the Dynamo component name used by the connector. Gets the Dynamo component name used by the connector.
### `namespace`
```python
@property
def namespace(self) -> str:
```
Gets the Dynamo namespace used by the connector.
### `runtime`
```python
def runtime(self) -> dynamo.runtime.DistributedRuntime:
```
Gets the Dynamo distributed runtime instance associated with the connector.
## Related Classes ## Related Classes
......
...@@ -38,7 +38,7 @@ therefore the operation should be awaited until completed unless cancellation is ...@@ -38,7 +38,7 @@ therefore the operation should be awaited until completed unless cancellation is
) -> None: ) -> None:
descriptor = dynamo.nixl_connect.Descriptor(local_tensor) descriptor = dynamo.nixl_connect.Descriptor(local_tensor)
with self.connector.begin_read(descriptor, remote_metadata) as read_op: with await self.connector.begin_read(remote_metadata, descriptor) as read_op:
# Wait for the operation to complete writing data from the remote worker to local_tensor. # Wait for the operation to complete writing data from the remote worker to local_tensor.
await read_op.wait_for_completion() await read_op.wait_for_completion()
``` ```
......
...@@ -37,7 +37,7 @@ therefore the operation should be awaited until completed unless cancellation is ...@@ -37,7 +37,7 @@ therefore the operation should be awaited until completed unless cancellation is
) -> None: ) -> None:
descriptor = dynamo.nixl_connect.Descriptor(local_tensor) descriptor = dynamo.nixl_connect.Descriptor(local_tensor)
with self.connector.create_readable(descriptor) as read_op: with await self.connector.create_readable(descriptor) as read_op:
op_metadata = read_op.metadata() op_metadata = read_op.metadata()
# Send the metadata to the remote worker via sideband communication. # Send the metadata to the remote worker via sideband communication.
......
...@@ -38,7 +38,7 @@ Cancellation is handled asynchronously. ...@@ -38,7 +38,7 @@ Cancellation is handled asynchronously.
) -> None: ) -> None:
descriptor = dynamo.nixl_connect.Descriptor(local_tensor) descriptor = dynamo.nixl_connect.Descriptor(local_tensor)
with self.connector.create_writable(descriptor) as write_op: with await self.connector.create_writable(descriptor) as write_op:
op_metadata = write_op.metadata() op_metadata = write_op.metadata()
# Send the metadata to the remote worker via sideband communication. # Send the metadata to the remote worker via sideband communication.
......
...@@ -39,7 +39,7 @@ Cancellation is handled asynchronously. ...@@ -39,7 +39,7 @@ Cancellation is handled asynchronously.
) -> None: ) -> None:
descriptor = dynamo.nixl_connect.Descriptor(local_tensor) descriptor = dynamo.nixl_connect.Descriptor(local_tensor)
with self.connector.begin_write(descriptor, remote_metadata) as write_op: with await self.connector.begin_write(descriptor, remote_metadata) as write_op:
# Wait for the operation to complete writing local_tensor to the remote worker. # Wait for the operation to complete writing local_tensor to the remote worker.
await write_op.wait_for_completion() await write_op.wait_for_completion()
``` ```
......
...@@ -222,7 +222,7 @@ NIXL is used only for embedding transfer: ...@@ -222,7 +222,7 @@ NIXL is used only for embedding transfer:
```python ```python
Encode Worker: Encode Worker:
descriptor = connect.Descriptor(precomputed_embeddings) descriptor = connect.Descriptor(precomputed_embeddings)
with connector.create_readable(descriptor) as readable: with await connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata() request.serialized_request = readable.metadata()
# Send request with NIXL metadata # Send request with NIXL metadata
await pd_worker_client.round_robin(request) await pd_worker_client.round_robin(request)
......
...@@ -168,7 +168,7 @@ class VllmEncodeWorker: ...@@ -168,7 +168,7 @@ class VllmEncodeWorker:
with torch.no_grad(): with torch.no_grad():
audio_embeddings = self.get_audio_embeddings(audio_features) audio_embeddings = self.get_audio_embeddings(audio_features)
descriptor = connect.Descriptor(audio_embeddings) descriptor = connect.Descriptor(audio_embeddings)
with self._connector.create_readable(descriptor) as readable: with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata() request.serialized_request = readable.metadata()
# Clear the audio URL as hint that the audio is passed as embeddings. # Clear the audio URL as hint that the audio is passed as embeddings.
request.multimodal_input.audio_url = None request.multimodal_input.audio_url = None
......
...@@ -125,7 +125,7 @@ class VllmEncodeWorker: ...@@ -125,7 +125,7 @@ class VllmEncodeWorker:
request.embeddings_shape = tuple(embeddings.shape) request.embeddings_shape = tuple(embeddings.shape)
descriptor = connect.Descriptor(embeddings) descriptor = connect.Descriptor(embeddings)
with self._connector.create_readable(descriptor) as readable: with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata() request.serialized_request = readable.metadata()
# Clear the image URL as hint that the image is passed as embeddings. # Clear the image URL as hint that the image is passed as embeddings.
request.multimodal_input.image_url = None request.multimodal_input.image_url = None
...@@ -158,7 +158,6 @@ class VllmEncodeWorker: ...@@ -158,7 +158,6 @@ class VllmEncodeWorker:
# Create and initialize a dynamo connector for this worker. # Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently. # We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector() self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Startup completed.") logger.info("Startup completed.")
......
...@@ -153,7 +153,7 @@ class VllmEncodeWorker: ...@@ -153,7 +153,7 @@ class VllmEncodeWorker:
request.embeddings_shape = tuple(tensor_for_descriptor.shape) request.embeddings_shape = tuple(tensor_for_descriptor.shape)
descriptor = connect.Descriptor(tensor_for_descriptor) descriptor = connect.Descriptor(tensor_for_descriptor)
with self._connector.create_readable(descriptor) as readable: with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata() request.serialized_request = readable.metadata()
# Clear the image URL as hint that the image is passed as embeddings. # Clear the image URL as hint that the image is passed as embeddings.
request.multimodal_input.video_url = None request.multimodal_input.video_url = None
...@@ -199,7 +199,6 @@ class VllmEncodeWorker: ...@@ -199,7 +199,6 @@ class VllmEncodeWorker:
# Create and initialize a dynamo connector for this worker. # Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently. # We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector() self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Startup completed.") logger.info("Startup completed.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment