Unverified Commit 96fe63fe authored by J Wyman's avatar J Wyman Committed by GitHub
Browse files

feat: nixl_connect: Improve Concurrency Support (#4433)


Signed-off-by: default avatarJ Wyman <jwyman@nvidia.com>
parent 0ce7280d
...@@ -159,7 +159,7 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler): ...@@ -159,7 +159,7 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
# Create descriptor for the multimodal data # Create descriptor for the multimodal data
descriptor = connect.Descriptor(precomputed_embeddings) descriptor = connect.Descriptor(precomputed_embeddings)
with self._connector.create_readable(descriptor) as readable: with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata() request.serialized_request = readable.metadata()
logger.debug(f"Request: {request.model_dump_json()}") logger.debug(f"Request: {request.model_dump_json()}")
...@@ -184,6 +184,5 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler): ...@@ -184,6 +184,5 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
# Create and initialize a dynamo connector for this worker. # Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently. # We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector() self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Startup completed.") logger.info("Startup completed.")
...@@ -77,7 +77,6 @@ class EmbeddingsProcessor: ...@@ -77,7 +77,6 @@ class EmbeddingsProcessor:
async def initialize(self): async def initialize(self):
"""Initialize the connector for embeddings processing""" """Initialize the connector for embeddings processing"""
self._connector = connect.Connector() self._connector = connect.Connector()
await self._connector.initialize()
async def process_embeddings(self, request: SglangMultimodalRequest): async def process_embeddings(self, request: SglangMultimodalRequest):
"""Process embeddings from serialized request""" """Process embeddings from serialized request"""
...@@ -103,7 +102,6 @@ class EmbeddingsProcessor: ...@@ -103,7 +102,6 @@ class EmbeddingsProcessor:
"Connector is None - this should not happen after initialization" "Connector is None - this should not happen after initialization"
) )
self._connector = connect.Connector() self._connector = connect.Connector()
await self._connector.initialize()
read_op = await self._connector.begin_read( read_op = await self._connector.begin_read(
request.serialized_request, descriptor request.serialized_request, descriptor
......
...@@ -241,7 +241,7 @@ class EncodeHelper: ...@@ -241,7 +241,7 @@ class EncodeHelper:
# Create readable operation with main embeddings tensor (works for both formats) # Create readable operation with main embeddings tensor (works for both formats)
descriptor = nixl_connect.Descriptor(encodings) descriptor = nixl_connect.Descriptor(encodings)
with connector.create_readable(descriptor) as readable_op: with await connector.create_readable(descriptor) as readable_op:
# Get the metadata for the readable operation # Get the metadata for the readable operation
op_metadata = readable_op.metadata() op_metadata = readable_op.metadata()
......
...@@ -332,7 +332,6 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -332,7 +332,6 @@ async def init(runtime: DistributedRuntime, config: Config):
connector = None connector = None
logging.info("Initializing NIXL Connect.") logging.info("Initializing NIXL Connect.")
connector = nixl_connect.Connector() connector = nixl_connect.Connector()
await connector.initialize()
dump_config( dump_config(
config.dump_config_to, {"engine_args": engine_args, "dynamo_args": config} config.dump_config_to, {"engine_args": engine_args, "dynamo_args": config}
......
...@@ -69,7 +69,6 @@ class EncodeWorkerHandler: ...@@ -69,7 +69,6 @@ class EncodeWorkerHandler:
# Create and initialize a dynamo connector for this worker. # Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently. # We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector() self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Encode worker startup completed.") logger.info("Encode worker startup completed.")
async def generate( async def generate(
...@@ -130,7 +129,7 @@ class EncodeWorkerHandler: ...@@ -130,7 +129,7 @@ class EncodeWorkerHandler:
request.embeddings_shape = tuple(embeddings.shape) request.embeddings_shape = tuple(embeddings.shape)
descriptor = connect.Descriptor(embeddings_cpu) descriptor = connect.Descriptor(embeddings_cpu)
with self._connector.create_readable(descriptor) as readable: with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata() request.serialized_request = readable.metadata()
# Clear the image URL as hint that the image is passed as embeddings. # Clear the image URL as hint that the image is passed as embeddings.
request.multimodal_input.image_url = None request.multimodal_input.image_url = None
......
...@@ -52,7 +52,6 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler): ...@@ -52,7 +52,6 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
async def async_init(self, runtime: DistributedRuntime): async def async_init(self, runtime: DistributedRuntime):
"""Async initialization - connector needs async setup""" """Async initialization - connector needs async setup"""
self._connector = connect.Connector() self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Multimodal Decode Worker async initialization completed.") logger.info("Multimodal Decode Worker async initialization completed.")
async def generate(self, request: vLLMMultimodalRequest, context): async def generate(self, request: vLLMMultimodalRequest, context):
...@@ -138,7 +137,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -138,7 +137,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
"""Async initialization for connector that requires async setup""" """Async initialization for connector that requires async setup"""
# Initialize the connector asynchronously # Initialize the connector asynchronously
self._connector = connect.Connector() self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Multimodal PD Worker async initialization completed.") logger.info("Multimodal PD Worker async initialization completed.")
async def generate(self, request: vLLMMultimodalRequest, context): async def generate(self, request: vLLMMultimodalRequest, context):
......
...@@ -47,7 +47,6 @@ The metadata contains required information (identifiers, keys, etc.) which enabl ...@@ -47,7 +47,6 @@ The metadata contains required information (identifiers, keys, etc.) which enabl
@async_on_start @async_on_start
async def async_init(self): async def async_init(self):
self.connector = dynamo.nixl_connect.Connector() self.connector = dynamo.nixl_connect.Connector()
await self.connector.initialize()
``` ```
> [!Tip] > [!Tip]
...@@ -109,7 +108,7 @@ Use [`.wait_for_completion()`](write_operation.md#wait_for_completion) to block ...@@ -109,7 +108,7 @@ Use [`.wait_for_completion()`](write_operation.md#wait_for_completion) to block
### `create_readable` ### `create_readable`
```python ```python
def create_readable( async def create_readable(
self, self,
local_descriptors: Descriptor | list[Descriptor], local_descriptors: Descriptor | list[Descriptor],
) -> ReadableOperation: ) -> ReadableOperation:
...@@ -130,7 +129,7 @@ Use [`.wait_for_completion()`](readable_operation.md#wait_for_completion) to blo ...@@ -130,7 +129,7 @@ Use [`.wait_for_completion()`](readable_operation.md#wait_for_completion) to blo
### `create_writable` ### `create_writable`
```python ```python
def create_writable( async def create_writable(
self, self,
local_descriptors: Descriptor | list[Descriptor], local_descriptors: Descriptor | list[Descriptor],
) -> WritableOperation: ) -> WritableOperation:
...@@ -151,6 +150,15 @@ Use [`.wait_for_completion()`](writable_operation.md#wait_for_completion) to blo ...@@ -151,6 +150,15 @@ Use [`.wait_for_completion()`](writable_operation.md#wait_for_completion) to blo
## Properties ## Properties
### `hostname`
```python
@property
def hostname(self) -> str:
```
Gets the name of the current worker's host.
### `is_cuda_available` ### `is_cuda_available`
```python ```python
...@@ -169,22 +177,6 @@ def name(self) -> str | None: ...@@ -169,22 +177,6 @@ def name(self) -> str | None:
Gets the Dynamo component name used by the connector. Gets the Dynamo component name used by the connector.
### `namespace`
```python
@property
def namespace(self) -> str:
```
Gets the Dynamo namespace used by the connector.
### `runtime`
```python
def runtime(self) -> dynamo.runtime.DistributedRuntime:
```
Gets the Dynamo distributed runtime instance associated with the connector.
## Related Classes ## Related Classes
......
...@@ -38,7 +38,7 @@ therefore the operation should be awaited until completed unless cancellation is ...@@ -38,7 +38,7 @@ therefore the operation should be awaited until completed unless cancellation is
) -> None: ) -> None:
descriptor = dynamo.nixl_connect.Descriptor(local_tensor) descriptor = dynamo.nixl_connect.Descriptor(local_tensor)
with self.connector.begin_read(descriptor, remote_metadata) as read_op: with await self.connector.begin_read(remote_metadata, descriptor) as read_op:
# Wait for the operation to complete writing data from the remote worker to local_tensor. # Wait for the operation to complete writing data from the remote worker to local_tensor.
await read_op.wait_for_completion() await read_op.wait_for_completion()
``` ```
......
...@@ -37,7 +37,7 @@ therefore the operation should be awaited until completed unless cancellation is ...@@ -37,7 +37,7 @@ therefore the operation should be awaited until completed unless cancellation is
) -> None: ) -> None:
descriptor = dynamo.nixl_connect.Descriptor(local_tensor) descriptor = dynamo.nixl_connect.Descriptor(local_tensor)
with self.connector.create_readable(descriptor) as read_op: with await self.connector.create_readable(descriptor) as read_op:
op_metadata = read_op.metadata() op_metadata = read_op.metadata()
# Send the metadata to the remote worker via sideband communication. # Send the metadata to the remote worker via sideband communication.
......
...@@ -38,7 +38,7 @@ Cancellation is handled asynchronously. ...@@ -38,7 +38,7 @@ Cancellation is handled asynchronously.
) -> None: ) -> None:
descriptor = dynamo.nixl_connect.Descriptor(local_tensor) descriptor = dynamo.nixl_connect.Descriptor(local_tensor)
with self.connector.create_writable(descriptor) as write_op: with await self.connector.create_writable(descriptor) as write_op:
op_metadata = write_op.metadata() op_metadata = write_op.metadata()
# Send the metadata to the remote worker via sideband communication. # Send the metadata to the remote worker via sideband communication.
......
...@@ -39,7 +39,7 @@ Cancellation is handled asynchronously. ...@@ -39,7 +39,7 @@ Cancellation is handled asynchronously.
) -> None: ) -> None:
descriptor = dynamo.nixl_connect.Descriptor(local_tensor) descriptor = dynamo.nixl_connect.Descriptor(local_tensor)
with self.connector.begin_write(descriptor, remote_metadata) as write_op: with await self.connector.begin_write(descriptor, remote_metadata) as write_op:
# Wait for the operation to complete writing local_tensor to the remote worker. # Wait for the operation to complete writing local_tensor to the remote worker.
await write_op.wait_for_completion() await write_op.wait_for_completion()
``` ```
......
...@@ -222,7 +222,7 @@ NIXL is used only for embedding transfer: ...@@ -222,7 +222,7 @@ NIXL is used only for embedding transfer:
```python ```python
Encode Worker: Encode Worker:
descriptor = connect.Descriptor(precomputed_embeddings) descriptor = connect.Descriptor(precomputed_embeddings)
with connector.create_readable(descriptor) as readable: with await connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata() request.serialized_request = readable.metadata()
# Send request with NIXL metadata # Send request with NIXL metadata
await pd_worker_client.round_robin(request) await pd_worker_client.round_robin(request)
......
...@@ -168,7 +168,7 @@ class VllmEncodeWorker: ...@@ -168,7 +168,7 @@ class VllmEncodeWorker:
with torch.no_grad(): with torch.no_grad():
audio_embeddings = self.get_audio_embeddings(audio_features) audio_embeddings = self.get_audio_embeddings(audio_features)
descriptor = connect.Descriptor(audio_embeddings) descriptor = connect.Descriptor(audio_embeddings)
with self._connector.create_readable(descriptor) as readable: with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata() request.serialized_request = readable.metadata()
# Clear the audio URL as hint that the audio is passed as embeddings. # Clear the audio URL as hint that the audio is passed as embeddings.
request.multimodal_input.audio_url = None request.multimodal_input.audio_url = None
......
...@@ -125,7 +125,7 @@ class VllmEncodeWorker: ...@@ -125,7 +125,7 @@ class VllmEncodeWorker:
request.embeddings_shape = tuple(embeddings.shape) request.embeddings_shape = tuple(embeddings.shape)
descriptor = connect.Descriptor(embeddings) descriptor = connect.Descriptor(embeddings)
with self._connector.create_readable(descriptor) as readable: with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata() request.serialized_request = readable.metadata()
# Clear the image URL as hint that the image is passed as embeddings. # Clear the image URL as hint that the image is passed as embeddings.
request.multimodal_input.image_url = None request.multimodal_input.image_url = None
...@@ -158,7 +158,6 @@ class VllmEncodeWorker: ...@@ -158,7 +158,6 @@ class VllmEncodeWorker:
# Create and initialize a dynamo connector for this worker. # Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently. # We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector() self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Startup completed.") logger.info("Startup completed.")
......
...@@ -153,7 +153,7 @@ class VllmEncodeWorker: ...@@ -153,7 +153,7 @@ class VllmEncodeWorker:
request.embeddings_shape = tuple(tensor_for_descriptor.shape) request.embeddings_shape = tuple(tensor_for_descriptor.shape)
descriptor = connect.Descriptor(tensor_for_descriptor) descriptor = connect.Descriptor(tensor_for_descriptor)
with self._connector.create_readable(descriptor) as readable: with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata() request.serialized_request = readable.metadata()
# Clear the image URL as hint that the image is passed as embeddings. # Clear the image URL as hint that the image is passed as embeddings.
request.multimodal_input.video_url = None request.multimodal_input.video_url = None
...@@ -199,7 +199,6 @@ class VllmEncodeWorker: ...@@ -199,7 +199,6 @@ class VllmEncodeWorker:
# Create and initialize a dynamo connector for this worker. # Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently. # We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector() self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Startup completed.") logger.info("Startup completed.")
......
...@@ -69,15 +69,15 @@ class AbstractOperation(ABC): ...@@ -69,15 +69,15 @@ class AbstractOperation(ABC):
def __init__( def __init__(
self, self,
connector: Connector, connection: Connection,
operation_kind: OperationKind, operation_kind: OperationKind,
local_descriptors: Descriptor | list[Descriptor], local_descriptors: Descriptor | list[Descriptor],
remote_descriptors: Optional[Descriptor | list[Descriptor]], remote_descriptors: Optional[Descriptor | list[Descriptor]],
notification_key: Optional[str], notification_key: Optional[str],
) -> None: ) -> None:
if not isinstance(connector, Connector): if not isinstance(connection, Connection):
raise TypeError( raise TypeError(
"Argument `connector` must be `dynamo.nixl_connect.Connector`." "Argument `connection` must be `dynamo.nixl_connect.Connection`."
) )
if ( if (
operation_kind is not OperationKind.READ operation_kind is not OperationKind.READ
...@@ -126,7 +126,7 @@ class AbstractOperation(ABC): ...@@ -126,7 +126,7 @@ class AbstractOperation(ABC):
self._notification_key: str = ( self._notification_key: str = (
"" if notification_key is None else notification_key "" if notification_key is None else notification_key
) )
self._connector: Connector = connector self._connection: Connection = connection
self._operation_kind: OperationKind = operation_kind self._operation_kind: OperationKind = operation_kind
self._local_desc_list: Descriptor | list[Descriptor] = local_descriptors self._local_desc_list: Descriptor | list[Descriptor] = local_descriptors
self._local_desc_tlist: Optional[list[tuple[int, int, int]]] = None self._local_desc_tlist: Optional[list[tuple[int, int, int]]] = None
...@@ -141,9 +141,15 @@ class AbstractOperation(ABC): ...@@ -141,9 +141,15 @@ class AbstractOperation(ABC):
# Note: Only local descriptors should be registered with NIXL, # Note: Only local descriptors should be registered with NIXL,
if isinstance(local_descriptors, list): if isinstance(local_descriptors, list):
for d in local_descriptors: for d in local_descriptors:
d.register_memory(self._connector) d.register_with_connector(self._connection)
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Registered descriptor {d} with connector {self._connection}."
)
else: else:
local_descriptors.register_memory(self._connector) local_descriptors.register_with_connector(self._connection)
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Registered descriptor {local_descriptors} with connector {self._connection}."
)
# Record local descriptors. # Record local descriptors.
device_kind, desc_tlist = self._create_desc_tlist(local_descriptors) device_kind, desc_tlist = self._create_desc_tlist(local_descriptors)
...@@ -166,14 +172,32 @@ class AbstractOperation(ABC): ...@@ -166,14 +172,32 @@ class AbstractOperation(ABC):
self._release() self._release()
def _release(self) -> None: def _release(self) -> None:
pass """
Private method to release resources.
"""
# Deregister local descriptors from NIXL, allowing them to reused by a future operation.
if isinstance(self._local_desc_list, list):
for d in self._local_desc_list:
if d.is_registered:
d.deregister_with_connector(self._connection)
else:
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Descriptor {d} was not registered, skipping deregistration."
)
else:
if self._local_desc_list.is_registered:
self._local_desc_list.deregister_with_connector(self._connection)
else:
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Descriptor {self._local_desc_list} was not registered, skipping deregistration."
)
@property @property
def connector(self) -> Connector: def connection(self) -> Connection:
""" """
Gets the local associated with this operation. Gets the local connection associated with this operation.
""" """
return self._connector return self._connection
@property @property
def operation_kind(self) -> OperationKind: def operation_kind(self) -> OperationKind:
...@@ -230,7 +254,7 @@ class ActiveOperation(AbstractOperation): ...@@ -230,7 +254,7 @@ class ActiveOperation(AbstractOperation):
remote_descriptors: Descriptor | list[Descriptor], remote_descriptors: Descriptor | list[Descriptor],
notification_key: str, notification_key: str,
) -> None: ) -> None:
if not isinstance(remote, Remote) or remote._connector is None: if not isinstance(remote, Remote) or remote._connection is None:
raise TypeError( raise TypeError(
"Argument `remote` must be valid `dynamo.nixl_connect.Remote`." "Argument `remote` must be valid `dynamo.nixl_connect.Remote`."
) )
...@@ -303,7 +327,7 @@ class ActiveOperation(AbstractOperation): ...@@ -303,7 +327,7 @@ class ActiveOperation(AbstractOperation):
self._status = OperationStatus.UNINITIALIZED self._status = OperationStatus.UNINITIALIZED
super().__init__( super().__init__(
remote.connector, remote.connection,
operation_kind, operation_kind,
local_descriptors, local_descriptors,
remote_descriptors, remote_descriptors,
...@@ -317,21 +341,21 @@ class ActiveOperation(AbstractOperation): ...@@ -317,21 +341,21 @@ class ActiveOperation(AbstractOperation):
self._remote_xfer_descs: Optional[nixl_bindings.nixlXferDList] = None self._remote_xfer_descs: Optional[nixl_bindings.nixlXferDList] = None
self._xfer_hndl: Optional[nixl_api.nixl_xfer_handle] = None self._xfer_hndl: Optional[nixl_api.nixl_xfer_handle] = None
self._local_xfer_descs = self._connector._nixl.get_xfer_descs( self._local_xfer_descs = self._connection._nixl.get_xfer_descs(
descs=self._local_desc_tlist, descs=self._local_desc_tlist,
mem_type=str(self._local_device_kind), mem_type=str(self._local_device_kind),
) )
logger.debug( logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Created local NIXL transfer descriptors: {self._local_xfer_descs}" f"dynamo.nixl_connect.{self.__class__.__name__}: Created local NIXL transfer descriptors: {self._local_xfer_descs}"
) )
self._remote_xfer_descs = self._connector._nixl.get_xfer_descs( self._remote_xfer_descs = self._connection._nixl.get_xfer_descs(
descs=self._remote_desc_tlist, descs=self._remote_desc_tlist,
mem_type=str(self._remote_device_kind), mem_type=str(self._remote_device_kind),
) )
logger.debug( logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Created remote NIXL transfer descriptors: {self._remote_xfer_descs}" f"dynamo.nixl_connect.{self.__class__.__name__}: Created remote NIXL transfer descriptors: {self._remote_xfer_descs}"
) )
self._xfer_hndl = self._connector._nixl.initialize_xfer( self._xfer_hndl = self._connection._nixl.initialize_xfer(
operation=str(operation_kind), operation=str(operation_kind),
local_descs=self._local_xfer_descs, local_descs=self._local_xfer_descs,
remote_descs=self._remote_xfer_descs, remote_descs=self._remote_xfer_descs,
...@@ -380,7 +404,7 @@ class ActiveOperation(AbstractOperation): ...@@ -380,7 +404,7 @@ class ActiveOperation(AbstractOperation):
logger.debug( logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: NIXL transfer handle {self._xfer_hndl} released." f"dynamo.nixl_connect.{self.__class__.__name__}: NIXL transfer handle {self._xfer_hndl} released."
) )
self._connector._nixl.release_xfer_handle(self._xfer_hndl) self._connection._nixl.release_xfer_handle(self._xfer_hndl)
except Exception as e: except Exception as e:
logger.error( logger.error(
f"dynamo.nixl_connect.{self.__class__.__name__}: Failed to release resources: {e}" f"dynamo.nixl_connect.{self.__class__.__name__}: Failed to release resources: {e}"
...@@ -413,7 +437,7 @@ class ActiveOperation(AbstractOperation): ...@@ -413,7 +437,7 @@ class ActiveOperation(AbstractOperation):
) )
# NIXL will cancel the transfer if it is in progress when the handle is released. # NIXL will cancel the transfer if it is in progress when the handle is released.
self._connector._nixl.release_xfer_handle(self._xfer_hndl) self._connection._nixl.release_xfer_handle(self._xfer_hndl)
self._status = OperationStatus.CANCELLED self._status = OperationStatus.CANCELLED
self._xfer_hndl = None self._xfer_hndl = None
...@@ -467,7 +491,7 @@ class ActiveOperation(AbstractOperation): ...@@ -467,7 +491,7 @@ class ActiveOperation(AbstractOperation):
old_status = self._status old_status = self._status
if self._status == OperationStatus.UNINITIALIZED: if self._status == OperationStatus.UNINITIALIZED:
state = self._connector._nixl.transfer( state = self._connection._nixl.transfer(
self._xfer_hndl, self._xfer_hndl,
self._notification_key.encode("utf-8"), self._notification_key.encode("utf-8"),
) )
...@@ -481,7 +505,7 @@ class ActiveOperation(AbstractOperation): ...@@ -481,7 +505,7 @@ class ActiveOperation(AbstractOperation):
else: else:
self._status = OperationStatus.INITIALIZED self._status = OperationStatus.INITIALIZED
else: else:
state = self._connector._nixl.check_xfer_state(self._xfer_hndl) state = self._connection._nixl.check_xfer_state(self._xfer_hndl)
logger.debug( logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: NIXL reported transfer state: {state}" f"dynamo.nixl_connect.{self.__class__.__name__}: NIXL reported transfer state: {state}"
) )
...@@ -500,6 +524,90 @@ class ActiveOperation(AbstractOperation): ...@@ -500,6 +524,90 @@ class ActiveOperation(AbstractOperation):
return self._status return self._status
class Connection:
def __init__(self, connector: Connector, number: int):
"""
Creates a new Connection instance.
Parameters
----------
connector : Connector
The connector associated with this connection.
number : int
The connection number.
Used to create a unique name for the connection.
Raises
------
TypeError
When `connector` is provided and not of type `dynamo.nixl_connect.Connector`.
TypeError
When `number` is provided and not of type `int`.
ValueError
When `number` is provided and not greater than 0.
"""
if not isinstance(connector, Connector):
raise TypeError(
"Argument `connector` must be `dynamo.nixl_connect.Connector`."
)
if not isinstance(number, int):
raise TypeError("Argument `number` must be of type `int`.")
if number <= 0:
raise ValueError("Argument `number` must be greater than 0.")
self._connector: Connector = connector
self._is_initialized = False
self._name = f"{connector.name}-{number}"
self._nixl = nixl_api.nixl_agent(self._name)
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Created {self.__repr__()}."
)
def __repr__(self) -> str:
return str(
f"{self.__class__.__name__}("
f"is_initialized={self._is_initialized}, "
f"name='{self._name}'"
")"
)
def __str__(self) -> str:
return self._name
@property
def connector(self) -> Connector:
"""
Get the connector associated with this connection.
"""
return self._connector
@property
def metadata(self) -> bytes:
"""
Get the metadata of the connection.
"""
return self._nixl.get_agent_metadata()
@property
def name(self) -> str | None:
"""
Get the name of the connection.
"""
return self._name
async def initialize(self) -> None:
# Only initialize the connection once.
if self._is_initialized:
return
self._is_initialized = True
# This method is a no-op for now, in the future it may be used to initialize the connection.
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Initialized {{ name: '{self._name}' }} completed."
)
class Connector: class Connector:
""" """
Core class for managing the connection between workers in a distributed environment. Core class for managing the connection between workers in a distributed environment.
...@@ -529,28 +637,42 @@ class Connector: ...@@ -529,28 +637,42 @@ class Connector:
if not isinstance(worker_id, str) or len(worker_id) == 0: if not isinstance(worker_id, str) or len(worker_id) == 0:
raise TypeError("Argument `worker_id` must be a non-empty `str` or `None`.") raise TypeError("Argument `worker_id` must be a non-empty `str` or `None`.")
self._connection_count: int = 0
self._worker_id = worker_id self._worker_id = worker_id
self._is_initialized = False
self._nixl = nixl_api.nixl_agent(self._worker_id)
self._hostname = socket.gethostname() self._hostname = socket.gethostname()
self._agent_metadata: Optional[bytes] = None
logger.debug( logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Created {self.__repr__()}." f"dynamo.nixl_connect.{self.__class__.__name__}: Created {self.__repr__()}."
) )
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Connector):
return False
return self._worker_id == other._worker_id
def __ne__(self, value: object) -> bool:
if not isinstance(value, Connector):
return True
return self._worker_id != value._worker_id
def __repr__(self) -> str: def __repr__(self) -> str:
return str( return str(
f"{self.__class__.__name__}(" f"{self.__class__.__name__}("
f"worker_id='{self._worker_id}', " f"worker_id='{self._worker_id}', "
f"hostname={self._hostname}, " f"hostname={self._hostname}"
f"metadata=<{0 if self._agent_metadata is None else len(self._agent_metadata)} bytes>"
")" ")"
) )
def __str__(self) -> str: def __str__(self) -> str:
return self._worker_id return self._worker_id
@property
def hostname(self) -> str:
"""
Get the name of the current worker's host.
"""
return self._hostname
@cached_property @cached_property
def is_cuda_available(self) -> bool: def is_cuda_available(self) -> bool:
# Note: `cuda.is_available` initializes CUDA # Note: `cuda.is_available` initializes CUDA
...@@ -562,13 +684,6 @@ class Connector: ...@@ -562,13 +684,6 @@ class Connector:
except CUDARuntimeError: except CUDARuntimeError:
return False return False
@property
def metadata(self) -> bytes:
"""
Get the metadata of the worker.
"""
return self._nixl.get_agent_metadata()
@property @property
def name(self) -> str | None: def name(self) -> str | None:
""" """
...@@ -620,12 +735,8 @@ class Connector: ...@@ -620,12 +735,8 @@ class Connector:
"Cannot create a `dynamo.nixl_connect.ReadOperation` to read from a remote `dynamo.nixl_connect.WritableOperation`." "Cannot create a `dynamo.nixl_connect.ReadOperation` to read from a remote `dynamo.nixl_connect.WritableOperation`."
) )
if not self._is_initialized: conn = await self._create_connection()
raise RuntimeError( op = ReadOperation(conn, remote_metadata, local_descriptors)
"Connector not initialized. Call `initialize()` before calling this method."
)
op = ReadOperation(self, remote_metadata, local_descriptors)
return op return op
async def begin_write( async def begin_write(
...@@ -655,22 +766,18 @@ class Connector: ...@@ -655,22 +766,18 @@ class Connector:
raise TypeError( raise TypeError(
"Argument `local_descriptors` must be `Descriptor` or `list[Descriptor]`." "Argument `local_descriptors` must be `Descriptor` or `list[Descriptor]`."
) )
if remote_metadata.operation_kind != OperationKind.WRITE: if remote_metadata.operation_kind != OperationKind.WRITE.value:
raise RuntimeError( raise RuntimeError(
"Cannot create a `WriteOperation` to write to a remote `ReadableOperation`." "Cannot create a `WriteOperation` to write to a remote `ReadableOperation`."
) )
if not isinstance(remote_metadata.nixl_metadata, str): if not isinstance(remote_metadata.nixl_metadata, str):
raise TypeError("Argument `remote_metadata.nixl_metadata` must be `str`.") raise TypeError("Argument `remote_metadata.nixl_metadata` must be `str`.")
if not self._is_initialized: conn = await self._create_connection()
raise RuntimeError( op = WriteOperation(conn, local_descriptors, remote_metadata)
"Connector not initialized. Call `initialize()` before calling this method."
)
op = WriteOperation(self, local_descriptors, remote_metadata)
return op return op
def create_readable( async def create_readable(
self, self,
local_descriptors: Descriptor | list[Descriptor], local_descriptors: Descriptor | list[Descriptor],
) -> ReadableOperation: ) -> ReadableOperation:
...@@ -682,15 +789,11 @@ class Connector: ...@@ -682,15 +789,11 @@ class Connector:
ReadableOperation ReadableOperation
A readable operation that can be used to transfer data from a remote worker. A readable operation that can be used to transfer data from a remote worker.
""" """
if not self._is_initialized: conn = await self._create_connection()
raise RuntimeError( op = ReadableOperation(conn, local_descriptors)
"Connector not initialized. Call `initialize()` before calling this method."
)
op = ReadableOperation(self, local_descriptors)
return op return op
def create_writable( async def create_writable(
self, self,
local_descriptors: Descriptor | list[Descriptor], local_descriptors: Descriptor | list[Descriptor],
) -> WritableOperation: ) -> WritableOperation:
...@@ -702,25 +805,27 @@ class Connector: ...@@ -702,25 +805,27 @@ class Connector:
WritableOperation WritableOperation
A writable operation that can be used to transfer data to a remote worker. A writable operation that can be used to transfer data to a remote worker.
""" """
if not self._is_initialized: conn = await self._create_connection()
raise RuntimeError( op = WritableOperation(conn, local_descriptors)
"Connector not initialized. Call `initialize()` before calling this method."
)
op = WritableOperation(self, local_descriptors)
return op return op
async def initialize(self) -> None: async def initialize(self) -> None:
# Only initialize the connector once. """
if self._is_initialized: Deprecated method.
return """
self._is_initialized = True
# This method is a no-op for now, in the future it may be used to initialize the connector.
logger.debug( logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Initialized {{ name: '{self._worker_id}' }} completed." f"dynamo.nixl_connect.{self.__class__.__name__}: Initialized {{ name: '{self._worker_id}' }} (This method is deprecated)."
) )
async def _create_connection(self) -> Connection:
"""
Private method to create a new connection.
"""
self._connection_count += 1
conn = Connection(self, self._connection_count)
await conn.initialize()
return conn
class Descriptor: class Descriptor:
""" """
...@@ -784,7 +889,8 @@ class Descriptor: ...@@ -784,7 +889,8 @@ class Descriptor:
# Member fields for managing NIXL memory registration. # Member fields for managing NIXL memory registration.
# Note: ONLY local descriptors should be registered with NIXL, # Note: ONLY local descriptors should be registered with NIXL,
# remote descriptors do not have a valid memory address and registration will fault. # remote descriptors do not have a valid memory address and registration will fault.
self._connector: Optional[Connector] = None
self._connection: Optional[Connection] = None
self._nixl_hndl: Optional[nixl_bindings.nixlRegDList] = None self._nixl_hndl: Optional[nixl_bindings.nixlRegDList] = None
# Initially `None` cached serialized descriptor reference, populated when `get_metadata()` is called. # Initially `None` cached serialized descriptor reference, populated when `get_metadata()` is called.
...@@ -865,10 +971,11 @@ class Descriptor: ...@@ -865,10 +971,11 @@ class Descriptor:
raise TypeError(TYPE_ERROR_MESSAGE) raise TypeError(TYPE_ERROR_MESSAGE)
def __del__(self) -> None: def __del__(self) -> None:
if self._nixl_hndl is not None and self._connector is not None: if not (self._nixl_hndl is None or self._connection is None):
# Unregister the memory with NIXL. # Deregister the memory with NIXL.
self._connector._nixl.deregister_memory(self._nixl_hndl) self._connection._nixl.deregister_memory(self._nixl_hndl)
self._nixl_hndl = None self._nixl_hndl = None
self._connection = None
if self._data_ref is not None: if self._data_ref is not None:
# Release the reference to the data. # Release the reference to the data.
...@@ -891,6 +998,13 @@ class Descriptor: ...@@ -891,6 +998,13 @@ class Descriptor:
""" """
return self._data_device return self._data_device
@property
def is_registered(self) -> bool:
"""
Gets whether the descriptor is registered with NIXL.
"""
return self._connection is not None and self._nixl_hndl is not None
@property @property
def ptr(self) -> int: def ptr(self) -> int:
""" """
...@@ -927,6 +1041,7 @@ class Descriptor: ...@@ -927,6 +1041,7 @@ class Descriptor:
return serialized.to_descriptor() return serialized.to_descriptor()
@property
def metadata(self) -> SerializedDescriptor: def metadata(self) -> SerializedDescriptor:
""" """
Serializes the descriptor into a `SerializedDescriptor` object. Serializes the descriptor into a `SerializedDescriptor` object.
...@@ -936,37 +1051,75 @@ class Descriptor: ...@@ -936,37 +1051,75 @@ class Descriptor:
device=f"{self._data_device}", device=f"{self._data_device}",
ptr=self._data_ptr, ptr=self._data_ptr,
size=self._data_size, size=self._data_size,
) ) # type: ignore[operator]
return self._serialized return self._serialized
def register_memory( def deregister_with_connector(self, connection: Connection) -> None:
"""
Deregisters the memory of the descriptor with NIXL.
"""
if not isinstance(connection, Connection):
raise TypeError(
"Argument `connection` must be `dynamo.nixl_connect.Connection`."
)
if connection != self._connection:
raise RuntimeError(
"Descriptor can only be deregistered from the connection it was registered with. "
f"Existing connection: {self._connection.name if self._connection is not None else None}, requested connection: {connection.name}."
)
return
if self._nixl_hndl is None:
logger.warning(
f"dynamo.nixl_connect.{self.__class__.__name__}: Request to deregister Descriptor {self.__repr__()} cannot be completed because the Descriptor is not registered."
)
return
connection._nixl.deregister_memory(self._nixl_hndl)
self._nixl_hndl = None
self._connection = None
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Deregistered {self.__repr__()} with NIXL."
)
def register_with_connector(
self, self,
connector: Connector, connection: Connection,
) -> None: ) -> None:
""" """
Registers the memory of the descriptor with NIXL. Registers the memory of the descriptor with NIXL.
""" """
if not isinstance(connector, Connector): if not isinstance(connection, Connection):
raise TypeError( raise TypeError(
"Argument `connector` must be `dynamo.nixl_connect.Connector`." "Argument `connection` must be `dynamo.nixl_connect.Connection`."
) )
if self._data_ptr == 0: if self._data_ptr == 0:
raise ValueError("Cannot register memory with a null pointer.") raise ValueError("Cannot register memory with a null pointer.")
if self._connection is not None:
if self._connection != connection:
raise RuntimeError(
"Descriptor cannot be registered with more than one connection. "
f"Existing connection: {self._connection.name}, new connection: {connection.name}."
)
# Descriptor is already registered with this connection.
return
if not (self._nixl_hndl is None and self._connector is None): # When the descriptor is already registered with NIXL, just return.
if self._nixl_hndl is not None:
return return
# Register the memory with NIXL. # Register the memory with NIXL.
self._connector = connector self._connection = connection
if isinstance(self._data_ref, torch.Tensor): if isinstance(self._data_ref, torch.Tensor):
self._nixl_hndl = connector._nixl.register_memory(self._data_ref) self._nixl_hndl = connection._nixl.register_memory(self._data_ref)
else: else:
mem_type = str(self._data_device.kind) mem_type = str(self._data_device.kind)
reg_list = [ reg_list = [
(self._data_ptr, self._data_size, self._data_device.id, mem_type) (self._data_ptr, self._data_size, self._data_device.id, mem_type)
] ]
self._nixl_hndl = connector._nixl.register_memory(reg_list, mem_type) self._nixl_hndl = connection._nixl.register_memory(reg_list, mem_type)
logger.debug( logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Registered {self.__repr__()} with NIXL." f"dynamo.nixl_connect.{self.__class__.__name__}: Registered {self.__repr__()} with NIXL."
...@@ -1173,7 +1326,7 @@ class PassiveOperation(AbstractOperation): ...@@ -1173,7 +1326,7 @@ class PassiveOperation(AbstractOperation):
def __init__( def __init__(
self, self,
connector: Connector, connection: Connection,
operation_kind: OperationKind, operation_kind: OperationKind,
local_descriptors: Descriptor | list[Descriptor], local_descriptors: Descriptor | list[Descriptor],
) -> None: ) -> None:
...@@ -1188,7 +1341,7 @@ class PassiveOperation(AbstractOperation): ...@@ -1188,7 +1341,7 @@ class PassiveOperation(AbstractOperation):
self._status = OperationStatus.UNINITIALIZED self._status = OperationStatus.UNINITIALIZED
super().__init__( super().__init__(
connector, connection,
operation_kind, operation_kind,
local_descriptors, local_descriptors,
None, None,
...@@ -1240,12 +1393,12 @@ class PassiveOperation(AbstractOperation): ...@@ -1240,12 +1393,12 @@ class PassiveOperation(AbstractOperation):
# When we've not yet cached the serialized request, we need to generate one before returning it. # When we've not yet cached the serialized request, we need to generate one before returning it.
# Handle both cases: multiple and single descriptors. # Handle both cases: multiple and single descriptors.
if isinstance(self._local_desc_list, list): if isinstance(self._local_desc_list, list):
descriptors = [desc.metadata() for desc in self._local_desc_list] descriptors = [desc.metadata for desc in self._local_desc_list]
else: else:
descriptors = [self._local_desc_list.metadata()] descriptors = [self._local_desc_list.metadata]
original_len = len(self._connector.metadata) original_len = len(self._connection.metadata)
nixl_metadata = self._connector.metadata nixl_metadata = self._connection.metadata
nixl_metadata = zlib.compress(nixl_metadata, level=6) nixl_metadata = zlib.compress(nixl_metadata, level=6)
compressed_len = len(nixl_metadata) compressed_len = len(nixl_metadata)
logger.debug( logger.debug(
...@@ -1283,7 +1436,7 @@ class PassiveOperation(AbstractOperation): ...@@ -1283,7 +1436,7 @@ class PassiveOperation(AbstractOperation):
old_status = self._status old_status = self._status
# Query NIXL for any notifications. # Query NIXL for any notifications.
notifications = self._connector._nixl.update_notifs() notifications = self._connection._nixl.update_notifs()
if isinstance(notifications, dict): if isinstance(notifications, dict):
remote_state = OperationStatus.IN_PROGRESS remote_state = OperationStatus.IN_PROGRESS
...@@ -1309,7 +1462,7 @@ class PassiveOperation(AbstractOperation): ...@@ -1309,7 +1462,7 @@ class PassiveOperation(AbstractOperation):
if remote_state == OperationStatus.COMPLETE: if remote_state == OperationStatus.COMPLETE:
self._status = remote_state self._status = remote_state
logger.debug( logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: {{ remote: '{self._connector.name}' status: '{old_status}' => '{self._status}' }}." f"dynamo.nixl_connect.{self.__class__.__name__}: {{ remote: '{self._connection.name}' status: '{old_status}' => '{self._status}' }}."
) )
return self._status return self._status
...@@ -1330,7 +1483,7 @@ class ReadOperation(ActiveOperation): ...@@ -1330,7 +1483,7 @@ class ReadOperation(ActiveOperation):
def __init__( def __init__(
self, self,
connector: Connector, connection: Connection,
remote_metadata: RdmaMetadata, remote_metadata: RdmaMetadata,
local_descriptors: Descriptor | list[Descriptor], local_descriptors: Descriptor | list[Descriptor],
) -> None: ) -> None:
...@@ -1341,16 +1494,16 @@ class ReadOperation(ActiveOperation): ...@@ -1341,16 +1494,16 @@ class ReadOperation(ActiveOperation):
Parameters Parameters
---------- ----------
connector : Connector connection : Connection
Connector instance to use for the operation. Connection instance to use for the operation.
remote_metadata : RdmaMetadata remote_metadata : RdmaMetadata
Serialized request from the remote worker. Serialized request from the remote worker.
local_descriptors : Descriptor | list[Descriptor] local_descriptors : Descriptor | list[Descriptor]
Local descriptor(s) to to receive the data from the remote worker. Local descriptor(s) to to receive the data from the remote worker.
""" """
if not isinstance(connector, Connector): if not isinstance(connection, Connection):
raise TypeError( raise TypeError(
"Argument `connector` must be `dynamo.nixl_connect.Connector`." "Argument `connection` must be `dynamo.nixl_connect.Connection`."
) )
if not isinstance(remote_metadata, RdmaMetadata): if not isinstance(remote_metadata, RdmaMetadata):
raise TypeError( raise TypeError(
...@@ -1359,7 +1512,7 @@ class ReadOperation(ActiveOperation): ...@@ -1359,7 +1512,7 @@ class ReadOperation(ActiveOperation):
if remote_metadata.operation_kind != OperationKind.READ.value: if remote_metadata.operation_kind != OperationKind.READ.value:
raise ValueError("Argument `remote_metadata` must be of kind `READ`.") raise ValueError("Argument `remote_metadata` must be of kind `READ`.")
remote = Remote(connector, remote_metadata.nixl_metadata) remote = Remote(connection, remote_metadata.nixl_metadata)
remote_descriptors = remote_metadata.to_descriptors() remote_descriptors = remote_metadata.to_descriptors()
if not ( if not (
...@@ -1435,10 +1588,10 @@ class ReadableOperation(PassiveOperation): ...@@ -1435,10 +1588,10 @@ class ReadableOperation(PassiveOperation):
def __init__( def __init__(
self, self,
connector: Connector, connection: Connection,
local_descriptors: Descriptor | list[Descriptor], local_descriptors: Descriptor | list[Descriptor],
) -> None: ) -> None:
super().__init__(connector, OperationKind.READ, local_descriptors) super().__init__(connection, OperationKind.READ, local_descriptors)
logger.debug( logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Created {self.__repr__()}" f"dynamo.nixl_connect.{self.__class__.__name__}: Created {self.__repr__()}"
) )
...@@ -1510,17 +1663,19 @@ class Remote: ...@@ -1510,17 +1663,19 @@ class Remote:
def __init__( def __init__(
self, self,
connector: Connector, connection: Connection,
nixl_metadata: bytes | str, nixl_metadata: bytes | str,
) -> None: ) -> None:
if not isinstance(connector, Connector): if not isinstance(connection, Connection):
raise TypeError("Argument `local` must be `dynamo.nixl_connect.Connector`.") raise TypeError(
"Argument `connection` must be `dynamo.nixl_connect.Connection`."
)
if not (isinstance(nixl_metadata, bytes) or isinstance(nixl_metadata, str)): if not (isinstance(nixl_metadata, bytes) or isinstance(nixl_metadata, str)):
raise TypeError("Argument `nixl_metadata` must be `bytes` or `str`.") raise TypeError("Argument `nixl_metadata` must be `bytes` or `str`.")
if len(nixl_metadata) == 0: if len(nixl_metadata) == 0:
raise ValueError("Argument `nixl_metadata` cannot be empty.") raise ValueError("Argument `nixl_metadata` cannot be empty.")
self._connector = connector self._connection = connection
# When `nixl_metadata` is a string, it is assumed to have come from a remote worker # When `nixl_metadata` is a string, it is assumed to have come from a remote worker
# via a `RdmaMetadata` object and therefore can assumed be a b64-encoded, compressed # via a `RdmaMetadata` object and therefore can assumed be a b64-encoded, compressed
...@@ -1535,7 +1690,7 @@ class Remote: ...@@ -1535,7 +1690,7 @@ class Remote:
# Decompress the NIXL metadata. # Decompress the NIXL metadata.
nixl_metadata = zlib.decompress(nixl_metadata) nixl_metadata = zlib.decompress(nixl_metadata)
self._name = connector._nixl.add_remote_agent(nixl_metadata) self._name = connection._nixl.add_remote_agent(nixl_metadata)
if isinstance(self._name, bytes): if isinstance(self._name, bytes):
self._name = self._name.decode("utf-8") self._name = self._name.decode("utf-8")
...@@ -1559,7 +1714,7 @@ class Remote: ...@@ -1559,7 +1714,7 @@ class Remote:
self._release() self._release()
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Remote(name={self._name}, connector={self._connector.name})" return f"Remote(name={self._name}, connection={self._connection.name})"
def __str__(self) -> str: def __str__(self) -> str:
return self._name return self._name
...@@ -1568,19 +1723,19 @@ class Remote: ...@@ -1568,19 +1723,19 @@ class Remote:
""" """
Private method for releasing NIXL resources. Not intended for public use. Private method for releasing NIXL resources. Not intended for public use.
""" """
# We have to unregister the remote agent from NIXL because we cannot know if the remote worker has updated its descriptors or not, and # We have to deregister the remote agent from NIXL because we cannot know if the remote worker has updated its descriptors or not, and
# NIXL will return an error if we attempt to register a remote agent with the same name but different descriptors (aka conn_info). # NIXL will return an error if we attempt to register a remote agent with the same name but different descriptors (aka conn_info).
self._connector._nixl.remove_remote_agent(self._name) self._connection._nixl.remove_remote_agent(self._name)
logger.debug( logger.debug(
f'dynamo.nixl_connect.{self.__class__.__name__}: Unregistered NIXL remote {{ name: "{self._name}" }}.' f'dynamo.nixl_connect.{self.__class__.__name__}: Deregistered NIXL remote {{ name: "{self._name}" }}.'
) )
@property @property
def connector(self) -> Connector: def connection(self) -> Connection:
""" """
Gets the local connector associated with this remote worker. Gets the local connection associated with this remote worker.
""" """
return self._connector return self._connection
@property @property
def name(self) -> str: def name(self) -> str:
...@@ -1647,7 +1802,7 @@ class WritableOperation(PassiveOperation): ...@@ -1647,7 +1802,7 @@ class WritableOperation(PassiveOperation):
def __init__( def __init__(
self, self,
connector: Connector, connection: Connection,
local_descriptors: Descriptor | list[Descriptor], local_descriptors: Descriptor | list[Descriptor],
) -> None: ) -> None:
""" """
...@@ -1656,18 +1811,18 @@ class WritableOperation(PassiveOperation): ...@@ -1656,18 +1811,18 @@ class WritableOperation(PassiveOperation):
Parameters Parameters
---------- ----------
connector : Connector connection : Connection
Connector instance to use for the operation. Connection instance to use for the operation.
local_descriptors : Descriptor | list[Descriptor] local_descriptors : Descriptor | list[Descriptor]
Descriptors to receive data from a remote worker. Descriptors to receive data from a remote worker.
Raises Raises
TypeError TypeError
When `local` is not a `dynamo.nixl_connect.Connector`. When `connection` is not a `dynamo.nixl_connect.Connection`.
TypeError TypeError
When `local_descriptors` is not a `dynamo.nixl_connect.Descriptor` or `list[dynamo.nixl_connect.Descriptor]`. When `local_descriptors` is not a `dynamo.nixl_connect.Descriptor` or `list[dynamo.nixl_connect.Descriptor]`.
""" """
super().__init__(connector, OperationKind.WRITE, local_descriptors) super().__init__(connection, OperationKind.WRITE, local_descriptors)
logger.debug( logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Created {self.__repr__()}" f"dynamo.nixl_connect.{self.__class__.__name__}: Created {self.__repr__()}"
) )
...@@ -1703,7 +1858,7 @@ class WriteOperation(ActiveOperation): ...@@ -1703,7 +1858,7 @@ class WriteOperation(ActiveOperation):
def __init__( def __init__(
self, self,
connector: Connector, connection: Connection,
local_descriptors: Descriptor | list[Descriptor], local_descriptors: Descriptor | list[Descriptor],
remote_metadata: RdmaMetadata, remote_metadata: RdmaMetadata,
) -> None: ) -> None:
...@@ -1714,8 +1869,8 @@ class WriteOperation(ActiveOperation): ...@@ -1714,8 +1869,8 @@ class WriteOperation(ActiveOperation):
Parameters Parameters
---------- ----------
connector : Connector connection : Connection
Connector instance to use for the operation. Connection instance to use for the operation.
local_descriptors : Descriptor | list[Descriptor] local_descriptors : Descriptor | list[Descriptor]
Local descriptor(s) to send from, to the remote worker. Local descriptor(s) to send from, to the remote worker.
remote_metadata : RdmaMetadata remote_metadata : RdmaMetadata
...@@ -1733,9 +1888,9 @@ class WriteOperation(ActiveOperation): ...@@ -1733,9 +1888,9 @@ class WriteOperation(ActiveOperation):
TypeError TypeError
When `local_descriptors` is not a `dynamo.nixl_connect.Descriptor` or `list[dynamo.nixl_connect.Descriptor]`. When `local_descriptors` is not a `dynamo.nixl_connect.Descriptor` or `list[dynamo.nixl_connect.Descriptor]`.
""" """
if not isinstance(connector, Connector): if not isinstance(connection, Connection):
raise TypeError( raise TypeError(
"Argument `connector` must be `dynamo.nixl_connect.Connector`." "Argument `connection` must be `dynamo.nixl_connect.Connection`."
) )
if not isinstance(remote_metadata, RdmaMetadata): if not isinstance(remote_metadata, RdmaMetadata):
raise TypeError( raise TypeError(
...@@ -1744,7 +1899,7 @@ class WriteOperation(ActiveOperation): ...@@ -1744,7 +1899,7 @@ class WriteOperation(ActiveOperation):
if remote_metadata.operation_kind != OperationKind.WRITE.value: if remote_metadata.operation_kind != OperationKind.WRITE.value:
raise ValueError("Argument `remote_metadata` must be of kind `WRITE`.") raise ValueError("Argument `remote_metadata` must be of kind `WRITE`.")
remote = Remote(connector, remote_metadata.nixl_metadata) remote = Remote(connection, remote_metadata.nixl_metadata)
remote_descriptors = remote_metadata.to_descriptors() remote_descriptors = remote_metadata.to_descriptors()
super().__init__( super().__init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment