Unverified Commit 96fe63fe authored by J Wyman's avatar J Wyman Committed by GitHub
Browse files

feat: nixl_connect: Improve Concurrency Support (#4433)


Signed-off-by: default avatarJ Wyman <jwyman@nvidia.com>
parent 0ce7280d
......@@ -159,7 +159,7 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
# Create descriptor for the multimodal data
descriptor = connect.Descriptor(precomputed_embeddings)
with self._connector.create_readable(descriptor) as readable:
with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata()
logger.debug(f"Request: {request.model_dump_json()}")
......@@ -184,6 +184,5 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Startup completed.")
......@@ -77,7 +77,6 @@ class EmbeddingsProcessor:
async def initialize(self):
"""Initialize the connector for embeddings processing"""
self._connector = connect.Connector()
await self._connector.initialize()
async def process_embeddings(self, request: SglangMultimodalRequest):
"""Process embeddings from serialized request"""
......@@ -103,7 +102,6 @@ class EmbeddingsProcessor:
"Connector is None - this should not happen after initialization"
)
self._connector = connect.Connector()
await self._connector.initialize()
read_op = await self._connector.begin_read(
request.serialized_request, descriptor
......
......@@ -241,7 +241,7 @@ class EncodeHelper:
# Create readable operation with main embeddings tensor (works for both formats)
descriptor = nixl_connect.Descriptor(encodings)
with connector.create_readable(descriptor) as readable_op:
with await connector.create_readable(descriptor) as readable_op:
# Get the metadata for the readable operation
op_metadata = readable_op.metadata()
......
......@@ -332,7 +332,6 @@ async def init(runtime: DistributedRuntime, config: Config):
connector = None
logging.info("Initializing NIXL Connect.")
connector = nixl_connect.Connector()
await connector.initialize()
dump_config(
config.dump_config_to, {"engine_args": engine_args, "dynamo_args": config}
......
......@@ -69,7 +69,6 @@ class EncodeWorkerHandler:
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Encode worker startup completed.")
async def generate(
......@@ -130,7 +129,7 @@ class EncodeWorkerHandler:
request.embeddings_shape = tuple(embeddings.shape)
descriptor = connect.Descriptor(embeddings_cpu)
with self._connector.create_readable(descriptor) as readable:
with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata()
# Clear the image URL as hint that the image is passed as embeddings.
request.multimodal_input.image_url = None
......
......@@ -52,7 +52,6 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
async def async_init(self, runtime: DistributedRuntime):
"""Async initialization - connector needs async setup"""
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Multimodal Decode Worker async initialization completed.")
async def generate(self, request: vLLMMultimodalRequest, context):
......@@ -138,7 +137,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
"""Async initialization for connector that requires async setup"""
# Initialize the connector asynchronously
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Multimodal PD Worker async initialization completed.")
async def generate(self, request: vLLMMultimodalRequest, context):
......
......@@ -47,7 +47,6 @@ The metadata contains required information (identifiers, keys, etc.) which enabl
@async_on_start
async def async_init(self):
self.connector = dynamo.nixl_connect.Connector()
await self.connector.initialize()
```
> [!Tip]
......@@ -109,7 +108,7 @@ Use [`.wait_for_completion()`](write_operation.md#wait_for_completion) to block
### `create_readable`
```python
def create_readable(
async def create_readable(
self,
local_descriptors: Descriptor | list[Descriptor],
) -> ReadableOperation:
......@@ -130,7 +129,7 @@ Use [`.wait_for_completion()`](readable_operation.md#wait_for_completion) to blo
### `create_writable`
```python
def create_writable(
async def create_writable(
self,
local_descriptors: Descriptor | list[Descriptor],
) -> WritableOperation:
......@@ -151,6 +150,15 @@ Use [`.wait_for_completion()`](writable_operation.md#wait_for_completion) to blo
## Properties
### `hostname`
```python
@property
def hostname(self) -> str:
```
Gets the name of the current worker's host.
### `is_cuda_available`
```python
......@@ -169,22 +177,6 @@ def name(self) -> str | None:
Gets the Dynamo component name used by the connector.
### `namespace`
```python
@property
def namespace(self) -> str:
```
Gets the Dynamo namespace used by the connector.
### `runtime`
```python
def runtime(self) -> dynamo.runtime.DistributedRuntime:
```
Gets the Dynamo distributed runtime instance associated with the connector.
## Related Classes
......
......@@ -38,7 +38,7 @@ therefore the operation should be awaited until completed unless cancellation is
) -> None:
descriptor = dynamo.nixl_connect.Descriptor(local_tensor)
with self.connector.begin_read(descriptor, remote_metadata) as read_op:
with await self.connector.begin_read(remote_metadata, descriptor) as read_op:
# Wait for the operation to complete writing data from the remote worker to local_tensor.
await read_op.wait_for_completion()
```
......
......@@ -37,7 +37,7 @@ therefore the operation should be awaited until completed unless cancellation is
) -> None:
descriptor = dynamo.nixl_connect.Descriptor(local_tensor)
with self.connector.create_readable(descriptor) as read_op:
with await self.connector.create_readable(descriptor) as read_op:
op_metadata = read_op.metadata()
# Send the metadata to the remote worker via sideband communication.
......
......@@ -38,7 +38,7 @@ Cancellation is handled asynchronously.
) -> None:
descriptor = dynamo.nixl_connect.Descriptor(local_tensor)
with self.connector.create_writable(descriptor) as write_op:
with await self.connector.create_writable(descriptor) as write_op:
op_metadata = write_op.metadata()
# Send the metadata to the remote worker via sideband communication.
......
......@@ -39,7 +39,7 @@ Cancellation is handled asynchronously.
) -> None:
descriptor = dynamo.nixl_connect.Descriptor(local_tensor)
with self.connector.begin_write(descriptor, remote_metadata) as write_op:
with await self.connector.begin_write(descriptor, remote_metadata) as write_op:
# Wait for the operation to complete writing local_tensor to the remote worker.
await write_op.wait_for_completion()
```
......
......@@ -222,7 +222,7 @@ NIXL is used only for embedding transfer:
```python
Encode Worker:
descriptor = connect.Descriptor(precomputed_embeddings)
with connector.create_readable(descriptor) as readable:
with await connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata()
# Send request with NIXL metadata
await pd_worker_client.round_robin(request)
......
......@@ -168,7 +168,7 @@ class VllmEncodeWorker:
with torch.no_grad():
audio_embeddings = self.get_audio_embeddings(audio_features)
descriptor = connect.Descriptor(audio_embeddings)
with self._connector.create_readable(descriptor) as readable:
with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata()
# Clear the audio URL as hint that the audio is passed as embeddings.
request.multimodal_input.audio_url = None
......
......@@ -125,7 +125,7 @@ class VllmEncodeWorker:
request.embeddings_shape = tuple(embeddings.shape)
descriptor = connect.Descriptor(embeddings)
with self._connector.create_readable(descriptor) as readable:
with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata()
# Clear the image URL as hint that the image is passed as embeddings.
request.multimodal_input.image_url = None
......@@ -158,7 +158,6 @@ class VllmEncodeWorker:
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Startup completed.")
......
......@@ -153,7 +153,7 @@ class VllmEncodeWorker:
request.embeddings_shape = tuple(tensor_for_descriptor.shape)
descriptor = connect.Descriptor(tensor_for_descriptor)
with self._connector.create_readable(descriptor) as readable:
with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata()
# Clear the image URL as hint that the image is passed as embeddings.
request.multimodal_input.video_url = None
......@@ -199,7 +199,6 @@ class VllmEncodeWorker:
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Startup completed.")
......
......@@ -69,15 +69,15 @@ class AbstractOperation(ABC):
def __init__(
self,
connector: Connector,
connection: Connection,
operation_kind: OperationKind,
local_descriptors: Descriptor | list[Descriptor],
remote_descriptors: Optional[Descriptor | list[Descriptor]],
notification_key: Optional[str],
) -> None:
if not isinstance(connector, Connector):
if not isinstance(connection, Connection):
raise TypeError(
"Argument `connector` must be `dynamo.nixl_connect.Connector`."
"Argument `connection` must be `dynamo.nixl_connect.Connection`."
)
if (
operation_kind is not OperationKind.READ
......@@ -126,7 +126,7 @@ class AbstractOperation(ABC):
self._notification_key: str = (
"" if notification_key is None else notification_key
)
self._connector: Connector = connector
self._connection: Connection = connection
self._operation_kind: OperationKind = operation_kind
self._local_desc_list: Descriptor | list[Descriptor] = local_descriptors
self._local_desc_tlist: Optional[list[tuple[int, int, int]]] = None
......@@ -141,9 +141,15 @@ class AbstractOperation(ABC):
# Note: Only local descriptors should be registered with NIXL,
if isinstance(local_descriptors, list):
for d in local_descriptors:
d.register_memory(self._connector)
d.register_with_connector(self._connection)
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Registered descriptor {d} with connector {self._connection}."
)
else:
local_descriptors.register_memory(self._connector)
local_descriptors.register_with_connector(self._connection)
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Registered descriptor {local_descriptors} with connector {self._connection}."
)
# Record local descriptors.
device_kind, desc_tlist = self._create_desc_tlist(local_descriptors)
......@@ -166,14 +172,32 @@ class AbstractOperation(ABC):
self._release()
def _release(self) -> None:
pass
"""
Private method to release resources.
"""
# Deregister local descriptors from NIXL, allowing them to reused by a future operation.
if isinstance(self._local_desc_list, list):
for d in self._local_desc_list:
if d.is_registered:
d.deregister_with_connector(self._connection)
else:
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Descriptor {d} was not registered, skipping deregistration."
)
else:
if self._local_desc_list.is_registered:
self._local_desc_list.deregister_with_connector(self._connection)
else:
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Descriptor {self._local_desc_list} was not registered, skipping deregistration."
)
@property
def connector(self) -> Connector:
def connection(self) -> Connection:
"""
Gets the local associated with this operation.
Gets the local connection associated with this operation.
"""
return self._connector
return self._connection
@property
def operation_kind(self) -> OperationKind:
......@@ -230,7 +254,7 @@ class ActiveOperation(AbstractOperation):
remote_descriptors: Descriptor | list[Descriptor],
notification_key: str,
) -> None:
if not isinstance(remote, Remote) or remote._connector is None:
if not isinstance(remote, Remote) or remote._connection is None:
raise TypeError(
"Argument `remote` must be valid `dynamo.nixl_connect.Remote`."
)
......@@ -303,7 +327,7 @@ class ActiveOperation(AbstractOperation):
self._status = OperationStatus.UNINITIALIZED
super().__init__(
remote.connector,
remote.connection,
operation_kind,
local_descriptors,
remote_descriptors,
......@@ -317,21 +341,21 @@ class ActiveOperation(AbstractOperation):
self._remote_xfer_descs: Optional[nixl_bindings.nixlXferDList] = None
self._xfer_hndl: Optional[nixl_api.nixl_xfer_handle] = None
self._local_xfer_descs = self._connector._nixl.get_xfer_descs(
self._local_xfer_descs = self._connection._nixl.get_xfer_descs(
descs=self._local_desc_tlist,
mem_type=str(self._local_device_kind),
)
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Created local NIXL transfer descriptors: {self._local_xfer_descs}"
)
self._remote_xfer_descs = self._connector._nixl.get_xfer_descs(
self._remote_xfer_descs = self._connection._nixl.get_xfer_descs(
descs=self._remote_desc_tlist,
mem_type=str(self._remote_device_kind),
)
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Created remote NIXL transfer descriptors: {self._remote_xfer_descs}"
)
self._xfer_hndl = self._connector._nixl.initialize_xfer(
self._xfer_hndl = self._connection._nixl.initialize_xfer(
operation=str(operation_kind),
local_descs=self._local_xfer_descs,
remote_descs=self._remote_xfer_descs,
......@@ -380,7 +404,7 @@ class ActiveOperation(AbstractOperation):
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: NIXL transfer handle {self._xfer_hndl} released."
)
self._connector._nixl.release_xfer_handle(self._xfer_hndl)
self._connection._nixl.release_xfer_handle(self._xfer_hndl)
except Exception as e:
logger.error(
f"dynamo.nixl_connect.{self.__class__.__name__}: Failed to release resources: {e}"
......@@ -413,7 +437,7 @@ class ActiveOperation(AbstractOperation):
)
# NIXL will cancel the transfer if it is in progress when the handle is released.
self._connector._nixl.release_xfer_handle(self._xfer_hndl)
self._connection._nixl.release_xfer_handle(self._xfer_hndl)
self._status = OperationStatus.CANCELLED
self._xfer_hndl = None
......@@ -467,7 +491,7 @@ class ActiveOperation(AbstractOperation):
old_status = self._status
if self._status == OperationStatus.UNINITIALIZED:
state = self._connector._nixl.transfer(
state = self._connection._nixl.transfer(
self._xfer_hndl,
self._notification_key.encode("utf-8"),
)
......@@ -481,7 +505,7 @@ class ActiveOperation(AbstractOperation):
else:
self._status = OperationStatus.INITIALIZED
else:
state = self._connector._nixl.check_xfer_state(self._xfer_hndl)
state = self._connection._nixl.check_xfer_state(self._xfer_hndl)
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: NIXL reported transfer state: {state}"
)
......@@ -500,6 +524,90 @@ class ActiveOperation(AbstractOperation):
return self._status
class Connection:
def __init__(self, connector: Connector, number: int):
"""
Creates a new Connection instance.
Parameters
----------
connector : Connector
The connector associated with this connection.
number : int
The connection number.
Used to create a unique name for the connection.
Raises
------
TypeError
When `connector` is provided and not of type `dynamo.nixl_connect.Connector`.
TypeError
When `number` is provided and not of type `int`.
ValueError
When `number` is provided and not greater than 0.
"""
if not isinstance(connector, Connector):
raise TypeError(
"Argument `connector` must be `dynamo.nixl_connect.Connector`."
)
if not isinstance(number, int):
raise TypeError("Argument `number` must be of type `int`.")
if number <= 0:
raise ValueError("Argument `number` must be greater than 0.")
self._connector: Connector = connector
self._is_initialized = False
self._name = f"{connector.name}-{number}"
self._nixl = nixl_api.nixl_agent(self._name)
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Created {self.__repr__()}."
)
def __repr__(self) -> str:
return str(
f"{self.__class__.__name__}("
f"is_initialized={self._is_initialized}, "
f"name='{self._name}'"
")"
)
def __str__(self) -> str:
return self._name
@property
def connector(self) -> Connector:
"""
Get the connector associated with this connection.
"""
return self._connector
@property
def metadata(self) -> bytes:
"""
Get the metadata of the connection.
"""
return self._nixl.get_agent_metadata()
@property
def name(self) -> str | None:
"""
Get the name of the connection.
"""
return self._name
async def initialize(self) -> None:
# Only initialize the connection once.
if self._is_initialized:
return
self._is_initialized = True
# This method is a no-op for now, in the future it may be used to initialize the connection.
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Initialized {{ name: '{self._name}' }} completed."
)
class Connector:
"""
Core class for managing the connection between workers in a distributed environment.
......@@ -529,28 +637,42 @@ class Connector:
if not isinstance(worker_id, str) or len(worker_id) == 0:
raise TypeError("Argument `worker_id` must be a non-empty `str` or `None`.")
self._connection_count: int = 0
self._worker_id = worker_id
self._is_initialized = False
self._nixl = nixl_api.nixl_agent(self._worker_id)
self._hostname = socket.gethostname()
self._agent_metadata: Optional[bytes] = None
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Created {self.__repr__()}."
)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Connector):
return False
return self._worker_id == other._worker_id
def __ne__(self, value: object) -> bool:
if not isinstance(value, Connector):
return True
return self._worker_id != value._worker_id
def __repr__(self) -> str:
return str(
f"{self.__class__.__name__}("
f"worker_id='{self._worker_id}', "
f"hostname={self._hostname}, "
f"metadata=<{0 if self._agent_metadata is None else len(self._agent_metadata)} bytes>"
f"hostname={self._hostname}"
")"
)
def __str__(self) -> str:
return self._worker_id
@property
def hostname(self) -> str:
"""
Get the name of the current worker's host.
"""
return self._hostname
@cached_property
def is_cuda_available(self) -> bool:
# Note: `cuda.is_available` initializes CUDA
......@@ -562,13 +684,6 @@ class Connector:
except CUDARuntimeError:
return False
@property
def metadata(self) -> bytes:
"""
Get the metadata of the worker.
"""
return self._nixl.get_agent_metadata()
@property
def name(self) -> str | None:
"""
......@@ -620,12 +735,8 @@ class Connector:
"Cannot create a `dynamo.nixl_connect.ReadOperation` to read from a remote `dynamo.nixl_connect.WritableOperation`."
)
if not self._is_initialized:
raise RuntimeError(
"Connector not initialized. Call `initialize()` before calling this method."
)
op = ReadOperation(self, remote_metadata, local_descriptors)
conn = await self._create_connection()
op = ReadOperation(conn, remote_metadata, local_descriptors)
return op
async def begin_write(
......@@ -655,22 +766,18 @@ class Connector:
raise TypeError(
"Argument `local_descriptors` must be `Descriptor` or `list[Descriptor]`."
)
if remote_metadata.operation_kind != OperationKind.WRITE:
if remote_metadata.operation_kind != OperationKind.WRITE.value:
raise RuntimeError(
"Cannot create a `WriteOperation` to write to a remote `ReadableOperation`."
)
if not isinstance(remote_metadata.nixl_metadata, str):
raise TypeError("Argument `remote_metadata.nixl_metadata` must be `str`.")
if not self._is_initialized:
raise RuntimeError(
"Connector not initialized. Call `initialize()` before calling this method."
)
op = WriteOperation(self, local_descriptors, remote_metadata)
conn = await self._create_connection()
op = WriteOperation(conn, local_descriptors, remote_metadata)
return op
def create_readable(
async def create_readable(
self,
local_descriptors: Descriptor | list[Descriptor],
) -> ReadableOperation:
......@@ -682,15 +789,11 @@ class Connector:
ReadableOperation
A readable operation that can be used to transfer data from a remote worker.
"""
if not self._is_initialized:
raise RuntimeError(
"Connector not initialized. Call `initialize()` before calling this method."
)
op = ReadableOperation(self, local_descriptors)
conn = await self._create_connection()
op = ReadableOperation(conn, local_descriptors)
return op
def create_writable(
async def create_writable(
self,
local_descriptors: Descriptor | list[Descriptor],
) -> WritableOperation:
......@@ -702,25 +805,27 @@ class Connector:
WritableOperation
A writable operation that can be used to transfer data to a remote worker.
"""
if not self._is_initialized:
raise RuntimeError(
"Connector not initialized. Call `initialize()` before calling this method."
)
op = WritableOperation(self, local_descriptors)
conn = await self._create_connection()
op = WritableOperation(conn, local_descriptors)
return op
async def initialize(self) -> None:
# Only initialize the connector once.
if self._is_initialized:
return
self._is_initialized = True
# This method is a no-op for now, in the future it may be used to initialize the connector.
"""
Deprecated method.
"""
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Initialized {{ name: '{self._worker_id}' }} completed."
f"dynamo.nixl_connect.{self.__class__.__name__}: Initialized {{ name: '{self._worker_id}' }} (This method is deprecated)."
)
async def _create_connection(self) -> Connection:
"""
Private method to create a new connection.
"""
self._connection_count += 1
conn = Connection(self, self._connection_count)
await conn.initialize()
return conn
class Descriptor:
"""
......@@ -784,7 +889,8 @@ class Descriptor:
# Member fields for managing NIXL memory registration.
# Note: ONLY local descriptors should be registered with NIXL,
# remote descriptors do not have a valid memory address and registration will fault.
self._connector: Optional[Connector] = None
self._connection: Optional[Connection] = None
self._nixl_hndl: Optional[nixl_bindings.nixlRegDList] = None
# Initially `None` cached serialized descriptor reference, populated when `get_metadata()` is called.
......@@ -865,10 +971,11 @@ class Descriptor:
raise TypeError(TYPE_ERROR_MESSAGE)
def __del__(self) -> None:
if self._nixl_hndl is not None and self._connector is not None:
# Unregister the memory with NIXL.
self._connector._nixl.deregister_memory(self._nixl_hndl)
if not (self._nixl_hndl is None or self._connection is None):
# Deregister the memory with NIXL.
self._connection._nixl.deregister_memory(self._nixl_hndl)
self._nixl_hndl = None
self._connection = None
if self._data_ref is not None:
# Release the reference to the data.
......@@ -891,6 +998,13 @@ class Descriptor:
"""
return self._data_device
@property
def is_registered(self) -> bool:
"""
Gets whether the descriptor is registered with NIXL.
"""
return self._connection is not None and self._nixl_hndl is not None
@property
def ptr(self) -> int:
"""
......@@ -927,6 +1041,7 @@ class Descriptor:
return serialized.to_descriptor()
@property
def metadata(self) -> SerializedDescriptor:
"""
Serializes the descriptor into a `SerializedDescriptor` object.
......@@ -936,37 +1051,75 @@ class Descriptor:
device=f"{self._data_device}",
ptr=self._data_ptr,
size=self._data_size,
)
) # type: ignore[operator]
return self._serialized
def register_memory(
def deregister_with_connector(self, connection: Connection) -> None:
"""
Deregisters the memory of the descriptor with NIXL.
"""
if not isinstance(connection, Connection):
raise TypeError(
"Argument `connection` must be `dynamo.nixl_connect.Connection`."
)
if connection != self._connection:
raise RuntimeError(
"Descriptor can only be deregistered from the connection it was registered with. "
f"Existing connection: {self._connection.name if self._connection is not None else None}, requested connection: {connection.name}."
)
return
if self._nixl_hndl is None:
logger.warning(
f"dynamo.nixl_connect.{self.__class__.__name__}: Request to deregister Descriptor {self.__repr__()} cannot be completed because the Descriptor is not registered."
)
return
connection._nixl.deregister_memory(self._nixl_hndl)
self._nixl_hndl = None
self._connection = None
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Deregistered {self.__repr__()} with NIXL."
)
def register_with_connector(
self,
connector: Connector,
connection: Connection,
) -> None:
"""
Registers the memory of the descriptor with NIXL.
"""
if not isinstance(connector, Connector):
if not isinstance(connection, Connection):
raise TypeError(
"Argument `connector` must be `dynamo.nixl_connect.Connector`."
"Argument `connection` must be `dynamo.nixl_connect.Connection`."
)
if self._data_ptr == 0:
raise ValueError("Cannot register memory with a null pointer.")
if self._connection is not None:
if self._connection != connection:
raise RuntimeError(
"Descriptor cannot be registered with more than one connection. "
f"Existing connection: {self._connection.name}, new connection: {connection.name}."
)
# Descriptor is already registered with this connection.
return
if not (self._nixl_hndl is None and self._connector is None):
# When the descriptor is already registered with NIXL, just return.
if self._nixl_hndl is not None:
return
# Register the memory with NIXL.
self._connector = connector
self._connection = connection
if isinstance(self._data_ref, torch.Tensor):
self._nixl_hndl = connector._nixl.register_memory(self._data_ref)
self._nixl_hndl = connection._nixl.register_memory(self._data_ref)
else:
mem_type = str(self._data_device.kind)
reg_list = [
(self._data_ptr, self._data_size, self._data_device.id, mem_type)
]
self._nixl_hndl = connector._nixl.register_memory(reg_list, mem_type)
self._nixl_hndl = connection._nixl.register_memory(reg_list, mem_type)
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Registered {self.__repr__()} with NIXL."
......@@ -1173,7 +1326,7 @@ class PassiveOperation(AbstractOperation):
def __init__(
self,
connector: Connector,
connection: Connection,
operation_kind: OperationKind,
local_descriptors: Descriptor | list[Descriptor],
) -> None:
......@@ -1188,7 +1341,7 @@ class PassiveOperation(AbstractOperation):
self._status = OperationStatus.UNINITIALIZED
super().__init__(
connector,
connection,
operation_kind,
local_descriptors,
None,
......@@ -1240,12 +1393,12 @@ class PassiveOperation(AbstractOperation):
# When we've not yet cached the serialized request, we need to generate one before returning it.
# Handle both cases: multiple and single descriptors.
if isinstance(self._local_desc_list, list):
descriptors = [desc.metadata() for desc in self._local_desc_list]
descriptors = [desc.metadata for desc in self._local_desc_list]
else:
descriptors = [self._local_desc_list.metadata()]
descriptors = [self._local_desc_list.metadata]
original_len = len(self._connector.metadata)
nixl_metadata = self._connector.metadata
original_len = len(self._connection.metadata)
nixl_metadata = self._connection.metadata
nixl_metadata = zlib.compress(nixl_metadata, level=6)
compressed_len = len(nixl_metadata)
logger.debug(
......@@ -1283,7 +1436,7 @@ class PassiveOperation(AbstractOperation):
old_status = self._status
# Query NIXL for any notifications.
notifications = self._connector._nixl.update_notifs()
notifications = self._connection._nixl.update_notifs()
if isinstance(notifications, dict):
remote_state = OperationStatus.IN_PROGRESS
......@@ -1309,7 +1462,7 @@ class PassiveOperation(AbstractOperation):
if remote_state == OperationStatus.COMPLETE:
self._status = remote_state
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: {{ remote: '{self._connector.name}' status: '{old_status}' => '{self._status}' }}."
f"dynamo.nixl_connect.{self.__class__.__name__}: {{ remote: '{self._connection.name}' status: '{old_status}' => '{self._status}' }}."
)
return self._status
......@@ -1330,7 +1483,7 @@ class ReadOperation(ActiveOperation):
def __init__(
self,
connector: Connector,
connection: Connection,
remote_metadata: RdmaMetadata,
local_descriptors: Descriptor | list[Descriptor],
) -> None:
......@@ -1341,16 +1494,16 @@ class ReadOperation(ActiveOperation):
Parameters
----------
connector : Connector
Connector instance to use for the operation.
connection : Connection
Connection instance to use for the operation.
remote_metadata : RdmaMetadata
Serialized request from the remote worker.
local_descriptors : Descriptor | list[Descriptor]
Local descriptor(s) to to receive the data from the remote worker.
"""
if not isinstance(connector, Connector):
if not isinstance(connection, Connection):
raise TypeError(
"Argument `connector` must be `dynamo.nixl_connect.Connector`."
"Argument `connection` must be `dynamo.nixl_connect.Connection`."
)
if not isinstance(remote_metadata, RdmaMetadata):
raise TypeError(
......@@ -1359,7 +1512,7 @@ class ReadOperation(ActiveOperation):
if remote_metadata.operation_kind != OperationKind.READ.value:
raise ValueError("Argument `remote_metadata` must be of kind `READ`.")
remote = Remote(connector, remote_metadata.nixl_metadata)
remote = Remote(connection, remote_metadata.nixl_metadata)
remote_descriptors = remote_metadata.to_descriptors()
if not (
......@@ -1435,10 +1588,10 @@ class ReadableOperation(PassiveOperation):
def __init__(
self,
connector: Connector,
connection: Connection,
local_descriptors: Descriptor | list[Descriptor],
) -> None:
super().__init__(connector, OperationKind.READ, local_descriptors)
super().__init__(connection, OperationKind.READ, local_descriptors)
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Created {self.__repr__()}"
)
......@@ -1510,17 +1663,19 @@ class Remote:
def __init__(
self,
connector: Connector,
connection: Connection,
nixl_metadata: bytes | str,
) -> None:
if not isinstance(connector, Connector):
raise TypeError("Argument `local` must be `dynamo.nixl_connect.Connector`.")
if not isinstance(connection, Connection):
raise TypeError(
"Argument `connection` must be `dynamo.nixl_connect.Connection`."
)
if not (isinstance(nixl_metadata, bytes) or isinstance(nixl_metadata, str)):
raise TypeError("Argument `nixl_metadata` must be `bytes` or `str`.")
if len(nixl_metadata) == 0:
raise ValueError("Argument `nixl_metadata` cannot be empty.")
self._connector = connector
self._connection = connection
# When `nixl_metadata` is a string, it is assumed to have come from a remote worker
# via a `RdmaMetadata` object and therefore can assumed be a b64-encoded, compressed
......@@ -1535,7 +1690,7 @@ class Remote:
# Decompress the NIXL metadata.
nixl_metadata = zlib.decompress(nixl_metadata)
self._name = connector._nixl.add_remote_agent(nixl_metadata)
self._name = connection._nixl.add_remote_agent(nixl_metadata)
if isinstance(self._name, bytes):
self._name = self._name.decode("utf-8")
......@@ -1559,7 +1714,7 @@ class Remote:
self._release()
def __repr__(self) -> str:
return f"Remote(name={self._name}, connector={self._connector.name})"
return f"Remote(name={self._name}, connection={self._connection.name})"
def __str__(self) -> str:
return self._name
......@@ -1568,19 +1723,19 @@ class Remote:
"""
Private method for releasing NIXL resources. Not intended for public use.
"""
# We have to unregister the remote agent from NIXL because we cannot know if the remote worker has updated its descriptors or not, and
# We have to deregister the remote agent from NIXL because we cannot know if the remote worker has updated its descriptors or not, and
# NIXL will return an error if we attempt to register a remote agent with the same name but different descriptors (aka conn_info).
self._connector._nixl.remove_remote_agent(self._name)
self._connection._nixl.remove_remote_agent(self._name)
logger.debug(
f'dynamo.nixl_connect.{self.__class__.__name__}: Unregistered NIXL remote {{ name: "{self._name}" }}.'
f'dynamo.nixl_connect.{self.__class__.__name__}: Deregistered NIXL remote {{ name: "{self._name}" }}.'
)
@property
def connector(self) -> Connector:
def connection(self) -> Connection:
"""
Gets the local connector associated with this remote worker.
Gets the local connection associated with this remote worker.
"""
return self._connector
return self._connection
@property
def name(self) -> str:
......@@ -1647,7 +1802,7 @@ class WritableOperation(PassiveOperation):
def __init__(
self,
connector: Connector,
connection: Connection,
local_descriptors: Descriptor | list[Descriptor],
) -> None:
"""
......@@ -1656,18 +1811,18 @@ class WritableOperation(PassiveOperation):
Parameters
----------
connector : Connector
Connector instance to use for the operation.
connection : Connection
Connection instance to use for the operation.
local_descriptors : Descriptor | list[Descriptor]
Descriptors to receive data from a remote worker.
Raises
TypeError
When `local` is not a `dynamo.nixl_connect.Connector`.
When `connection` is not a `dynamo.nixl_connect.Connection`.
TypeError
When `local_descriptors` is not a `dynamo.nixl_connect.Descriptor` or `list[dynamo.nixl_connect.Descriptor]`.
"""
super().__init__(connector, OperationKind.WRITE, local_descriptors)
super().__init__(connection, OperationKind.WRITE, local_descriptors)
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Created {self.__repr__()}"
)
......@@ -1703,7 +1858,7 @@ class WriteOperation(ActiveOperation):
def __init__(
self,
connector: Connector,
connection: Connection,
local_descriptors: Descriptor | list[Descriptor],
remote_metadata: RdmaMetadata,
) -> None:
......@@ -1714,8 +1869,8 @@ class WriteOperation(ActiveOperation):
Parameters
----------
connector : Connector
Connector instance to use for the operation.
connection : Connection
Connection instance to use for the operation.
local_descriptors : Descriptor | list[Descriptor]
Local descriptor(s) to send from, to the remote worker.
remote_metadata : RdmaMetadata
......@@ -1733,9 +1888,9 @@ class WriteOperation(ActiveOperation):
TypeError
When `local_descriptors` is not a `dynamo.nixl_connect.Descriptor` or `list[dynamo.nixl_connect.Descriptor]`.
"""
if not isinstance(connector, Connector):
if not isinstance(connection, Connection):
raise TypeError(
"Argument `connector` must be `dynamo.nixl_connect.Connector`."
"Argument `connection` must be `dynamo.nixl_connect.Connection`."
)
if not isinstance(remote_metadata, RdmaMetadata):
raise TypeError(
......@@ -1744,7 +1899,7 @@ class WriteOperation(ActiveOperation):
if remote_metadata.operation_kind != OperationKind.WRITE.value:
raise ValueError("Argument `remote_metadata` must be of kind `WRITE`.")
remote = Remote(connector, remote_metadata.nixl_metadata)
remote = Remote(connection, remote_metadata.nixl_metadata)
remote_descriptors = remote_metadata.to_descriptors()
super().__init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment