"vscode:/vscode.git/clone" did not exist on "9bf9709a655bb061d3a4ded14aee5ebda670fb9c"
Unverified Commit 387100c8 authored by Daniel Socek's avatar Daniel Socek Committed by GitHub
Browse files

feat: replaces PersistentConnector monkey-patch with proper nixl_conn… (#6913)


Signed-off-by: default avatarDaniel Socek <daniel.socek@intel.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 2adf8a2d
...@@ -765,46 +765,15 @@ class NixlWriteEmbeddingReceiver(AbstractEmbeddingReceiver): ...@@ -765,46 +765,15 @@ class NixlWriteEmbeddingReceiver(AbstractEmbeddingReceiver):
self.ring_buffer.release_buffer(buffer_id) self.ring_buffer.release_buffer(buffer_id)
class PersistentConnector(nixl_connect.Connector):
"""A persistent NIXL connector that can be shared across multiple send/receive operations."""
def __init__(self):
super().__init__()
self._connection = None
async def _create_connection(self) -> nixl_connect.Connection:
"""
Private method to create a new connection.
"""
if self._connection is None:
self._connection = nixl_connect.Connection(self, 1)
await self._connection.initialize()
return self._connection
# Overwrite the remote release method to prevent deregistering the remote agent on each release,
# with persistent connection, all operations will be initiated on the same agent-pair, if not
# avoiding the deregisteration, the inflight operations will be teminated.
def remote_release_overwrite(self) -> None:
pass
nixl_connect.Remote._release = remote_release_overwrite # type: ignore[method-assign]
class NixlReadEmbeddingSender(AbstractEmbeddingSender): class NixlReadEmbeddingSender(AbstractEmbeddingSender):
""" """NIXL READ based embedding transfer sender.
Initial implementation of NIXL READ based transfer. This implementation uses
a monkey-patched version of 'nixl_connect' wrapper library to persist Uses nixl_connect.Connector which now natively provides a shared singleton
connection (agent registration) and descriptors across multiple send operations Connection (NIXL agent) and reference-counted Remote agent lifecycle.
to avoid the overhead of repeated connection setup and teardown.
NOTE This implementation or the use of 'nixl_connect' needs to be revisited as
the benchmarking result is unexpectedly slow. Keeping it now for completeness,
i.e. provide NIXL WRITE based and READ based transfer classes.
""" """
def __init__(self): def __init__(self):
self.connector = PersistentConnector() self.connector = nixl_connect.Connector()
@_nvtx.annotate("mm:nixl:send_embeddings", color="magenta") @_nvtx.annotate("mm:nixl:send_embeddings", color="magenta")
async def send_embeddings( async def send_embeddings(
...@@ -838,16 +807,10 @@ class NixlReadEmbeddingSender(AbstractEmbeddingSender): ...@@ -838,16 +807,10 @@ class NixlReadEmbeddingSender(AbstractEmbeddingSender):
class NixlReadEmbeddingReceiver(AbstractEmbeddingReceiver): class NixlReadEmbeddingReceiver(AbstractEmbeddingReceiver):
""" """NIXL READ based embedding transfer receiver.
Counter part of 'NixlReadEmbeddingSender', see 'NixlReadEmbeddingSender' for details.
Initial implementation of another usage of NIXL connect library that persists Uses nixl_connect.Connector which now natively provides a shared singleton
connection (agent registration) and descriptors (memory registration) across multiple send operations Connection (NIXL agent) and reference-counted Remote agent lifecycle.
to avoid the overhead of repeated connection setup and teardown.
[gluo FIXME] This implementation requires more memory allocation and somewhat rigid, should move away
from connect library so we can have single descriptor and chunk for transfer on demand, similarly to
KV cache transfer. We may worry less on memory fragmentation as the memory can be released for next
transfer as soon as the embedding has passed to the framework (NEED TO VERIFY: framework will copy) and
can simply loop around the large buffer.
""" """
def __init__( def __init__(
...@@ -857,7 +820,7 @@ class NixlReadEmbeddingReceiver(AbstractEmbeddingReceiver): ...@@ -857,7 +820,7 @@ class NixlReadEmbeddingReceiver(AbstractEmbeddingReceiver):
max_items: int = 1024, max_items: int = 1024,
) -> None: ) -> None:
super().__init__() super().__init__()
self.connector = PersistentConnector() self.connector = nixl_connect.Connector()
self.tensor_id_counter = 0 self.tensor_id_counter = 0
self.aggregated_op_create_time = 0 self.aggregated_op_create_time = 0
self.aggregated_op_wait_time = 0 self.aggregated_op_wait_time = 0
......
...@@ -20,6 +20,7 @@ import base64 ...@@ -20,6 +20,7 @@ import base64
import ctypes import ctypes
import logging import logging
import socket import socket
import threading
import uuid import uuid
import zlib import zlib
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
...@@ -562,6 +563,9 @@ class Connection: ...@@ -562,6 +563,9 @@ class Connection:
self._name = f"{connector.name}-{number}" self._name = f"{connector.name}-{number}"
self._nixl = nixl_api.nixl_agent(self._name) self._nixl = nixl_api.nixl_agent(self._name)
self._remote_refs: dict[str, int] = {} # ref-count remote agents
self._remote_refs_lock = threading.Lock()
logger.debug( logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Created {self.__repr__()}." f"dynamo.nixl_connect.{self.__class__.__name__}: Created {self.__repr__()}."
) )
...@@ -598,6 +602,22 @@ class Connection: ...@@ -598,6 +602,22 @@ class Connection:
""" """
return self._name return self._name
def acquire_remote_ref(self, name: str) -> None:
with self._remote_refs_lock:
self._remote_refs[name] = self._remote_refs.get(name, 0) + 1
def release_remote_ref(self, name: str) -> bool:
"""Returns True when the last reference is released."""
with self._remote_refs_lock:
ref_count = self._remote_refs.get(name)
if ref_count is None:
return False
if ref_count == 1:
self._remote_refs.pop(name, None)
return True
self._remote_refs[name] = ref_count - 1
return False
async def initialize(self) -> None: async def initialize(self) -> None:
# Only initialize the connection once. # Only initialize the connection once.
if self._is_initialized: if self._is_initialized:
...@@ -643,6 +663,9 @@ class Connector: ...@@ -643,6 +663,9 @@ class Connector:
self._worker_id = worker_id self._worker_id = worker_id
self._hostname = socket.gethostname() self._hostname = socket.gethostname()
self._shared_connection: Optional[Connection] = None
self._shared_connection_lock = asyncio.Lock()
logger.debug( logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Created {self.__repr__()}." f"dynamo.nixl_connect.{self.__class__.__name__}: Created {self.__repr__()}."
) )
...@@ -820,13 +843,18 @@ class Connector: ...@@ -820,13 +843,18 @@ class Connector:
) )
async def _create_connection(self) -> Connection: async def _create_connection(self) -> Connection:
""" """Create and return a single shared Connection (NIXL agent)."""
Private method to create a new connection. async with self._shared_connection_lock:
""" if self._shared_connection is not None:
self._connection_count += 1 return self._shared_connection
conn = Connection(self, self._connection_count) self._connection_count += 1
await conn.initialize() conn = Connection(self, self._connection_count)
return conn await conn.initialize()
self._shared_connection = conn
logger.info(
f"dynamo.nixl_connect.Connector: Created shared connection '{conn.name}'."
)
return conn
class Descriptor: class Descriptor:
...@@ -1698,6 +1726,9 @@ class Remote: ...@@ -1698,6 +1726,9 @@ class Remote:
if isinstance(self._name, bytes): if isinstance(self._name, bytes):
self._name = self._name.decode("utf-8") self._name = self._name.decode("utf-8")
connection.acquire_remote_ref(self._name)
self._released = False
logger.debug( logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Created {self.__repr__()}." f"dynamo.nixl_connect.{self.__class__.__name__}: Created {self.__repr__()}."
) )
...@@ -1724,15 +1755,11 @@ class Remote: ...@@ -1724,15 +1755,11 @@ class Remote:
return self._name return self._name
def _release(self) -> None: def _release(self) -> None:
""" if self._released:
Private method for releasing NIXL resources. Not intended for public use. return
""" self._released = True
# We have to deregister the remote agent from NIXL because we cannot know if the remote worker has updated its descriptors or not, and if self._connection.release_remote_ref(self._name):
# NIXL will return an error if we attempt to register a remote agent with the same name but different descriptors (aka conn_info). self._connection._nixl.remove_remote_agent(self._name)
self._connection._nixl.remove_remote_agent(self._name)
logger.debug(
f'dynamo.nixl_connect.{self.__class__.__name__}: Deregistered NIXL remote {{ name: "{self._name}" }}.'
)
@property @property
def connection(self) -> Connection: def connection(self) -> Connection:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment