Unverified Commit 387100c8 authored by Daniel Socek's avatar Daniel Socek Committed by GitHub
Browse files

feat: replaces PersistentConnector monkey-patch with proper nixl_conn… (#6913)


Signed-off-by: default avatarDaniel Socek <daniel.socek@intel.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 2adf8a2d
......@@ -765,46 +765,15 @@ class NixlWriteEmbeddingReceiver(AbstractEmbeddingReceiver):
self.ring_buffer.release_buffer(buffer_id)
class PersistentConnector(nixl_connect.Connector):
"""A persistent NIXL connector that can be shared across multiple send/receive operations."""
def __init__(self):
super().__init__()
self._connection = None
async def _create_connection(self) -> nixl_connect.Connection:
"""
Private method to create a new connection.
"""
if self._connection is None:
self._connection = nixl_connect.Connection(self, 1)
await self._connection.initialize()
return self._connection
# Overwrite the remote release method to prevent deregistering the remote agent on each release,
# with persistent connection, all operations will be initiated on the same agent-pair, if not
# avoiding the deregisteration, the inflight operations will be teminated.
def remote_release_overwrite(self) -> None:
pass
nixl_connect.Remote._release = remote_release_overwrite # type: ignore[method-assign]
class NixlReadEmbeddingSender(AbstractEmbeddingSender):
"""
Initial implementation of NIXL READ based transfer. This implementation uses
a monkey-patched version of 'nixl_connect' wrapper library to persist
connection (agent registration) and descriptors across multiple send operations
to avoid the overhead of repeated connection setup and teardown.
NOTE This implementation or the use of 'nixl_connect' needs to be revisited as
the benchmarking result is unexpectedly slow. Keeping it now for completeness,
i.e. provide NIXL WRITE based and READ based transfer classes.
"""NIXL READ based embedding transfer sender.
Uses nixl_connect.Connector which now natively provides a shared singleton
Connection (NIXL agent) and reference-counted Remote agent lifecycle.
"""
def __init__(self):
self.connector = PersistentConnector()
self.connector = nixl_connect.Connector()
@_nvtx.annotate("mm:nixl:send_embeddings", color="magenta")
async def send_embeddings(
......@@ -838,16 +807,10 @@ class NixlReadEmbeddingSender(AbstractEmbeddingSender):
class NixlReadEmbeddingReceiver(AbstractEmbeddingReceiver):
"""
Counter part of 'NixlReadEmbeddingSender', see 'NixlReadEmbeddingSender' for details.
Initial implementation of another usage of NIXL connect library that persists
connection (agent registration) and descriptors (memory registration) across multiple send operations
to avoid the overhead of repeated connection setup and teardown.
[gluo FIXME] This implementation requires more memory allocation and somewhat rigid, should move away
from connect library so we can have single descriptor and chunk for transfer on demand, similarly to
KV cache transfer. We may worry less on memory fragmentation as the memory can be released for next
transfer as soon as the embedding has passed to the framework (NEED TO VERIFY: framework will copy) and
can simply loop around the large buffer.
"""NIXL READ based embedding transfer receiver.
Uses nixl_connect.Connector which now natively provides a shared singleton
Connection (NIXL agent) and reference-counted Remote agent lifecycle.
"""
def __init__(
......@@ -857,7 +820,7 @@ class NixlReadEmbeddingReceiver(AbstractEmbeddingReceiver):
max_items: int = 1024,
) -> None:
super().__init__()
self.connector = PersistentConnector()
self.connector = nixl_connect.Connector()
self.tensor_id_counter = 0
self.aggregated_op_create_time = 0
self.aggregated_op_wait_time = 0
......
......@@ -20,6 +20,7 @@ import base64
import ctypes
import logging
import socket
import threading
import uuid
import zlib
from abc import ABC, abstractmethod
......@@ -562,6 +563,9 @@ class Connection:
self._name = f"{connector.name}-{number}"
self._nixl = nixl_api.nixl_agent(self._name)
self._remote_refs: dict[str, int] = {} # ref-count remote agents
self._remote_refs_lock = threading.Lock()
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Created {self.__repr__()}."
)
......@@ -598,6 +602,22 @@ class Connection:
"""
return self._name
def acquire_remote_ref(self, name: str) -> None:
with self._remote_refs_lock:
self._remote_refs[name] = self._remote_refs.get(name, 0) + 1
def release_remote_ref(self, name: str) -> bool:
"""Returns True when the last reference is released."""
with self._remote_refs_lock:
ref_count = self._remote_refs.get(name)
if ref_count is None:
return False
if ref_count == 1:
self._remote_refs.pop(name, None)
return True
self._remote_refs[name] = ref_count - 1
return False
async def initialize(self) -> None:
# Only initialize the connection once.
if self._is_initialized:
......@@ -643,6 +663,9 @@ class Connector:
self._worker_id = worker_id
self._hostname = socket.gethostname()
self._shared_connection: Optional[Connection] = None
self._shared_connection_lock = asyncio.Lock()
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Created {self.__repr__()}."
)
......@@ -820,13 +843,18 @@ class Connector:
)
async def _create_connection(self) -> Connection:
"""
Private method to create a new connection.
"""
self._connection_count += 1
conn = Connection(self, self._connection_count)
await conn.initialize()
return conn
"""Create and return a single shared Connection (NIXL agent)."""
async with self._shared_connection_lock:
if self._shared_connection is not None:
return self._shared_connection
self._connection_count += 1
conn = Connection(self, self._connection_count)
await conn.initialize()
self._shared_connection = conn
logger.info(
f"dynamo.nixl_connect.Connector: Created shared connection '{conn.name}'."
)
return conn
class Descriptor:
......@@ -1698,6 +1726,9 @@ class Remote:
if isinstance(self._name, bytes):
self._name = self._name.decode("utf-8")
connection.acquire_remote_ref(self._name)
self._released = False
logger.debug(
f"dynamo.nixl_connect.{self.__class__.__name__}: Created {self.__repr__()}."
)
......@@ -1724,15 +1755,11 @@ class Remote:
return self._name
def _release(self) -> None:
"""
Private method for releasing NIXL resources. Not intended for public use.
"""
# We have to deregister the remote agent from NIXL because we cannot know if the remote worker has updated its descriptors or not, and
# NIXL will return an error if we attempt to register a remote agent with the same name but different descriptors (aka conn_info).
self._connection._nixl.remove_remote_agent(self._name)
logger.debug(
f'dynamo.nixl_connect.{self.__class__.__name__}: Deregistered NIXL remote {{ name: "{self._name}" }}.'
)
if self._released:
return
self._released = True
if self._connection.release_remote_ref(self._name):
self._connection._nixl.remove_remote_agent(self._name)
@property
def connection(self) -> Connection:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment