Unverified Commit f0d34aab authored by Tailing Yuan's avatar Tailing Yuan Committed by GitHub
Browse files

Init buffer with mpi4py.MPI.Comm (#365)


Signed-off-by: default avatarTailing Yuan <yuantailing@gmail.com>
parent e3908bf5
......@@ -29,12 +29,13 @@ class Buffer:
num_sms: int = 20
def __init__(self, group: dist.ProcessGroup,
def __init__(self, group: Optional[dist.ProcessGroup],
num_nvl_bytes: int = 0, num_rdma_bytes: int = 0,
low_latency_mode: bool = False, num_qps_per_rank: int = 24,
allow_nvlink_for_low_latency_mode: bool = True,
allow_mnnvl: bool = False,
explicitly_destroy: bool = False) -> None:
explicitly_destroy: bool = False,
comm: Optional["mpi4py.MPI.Comm"] = None) -> None:
"""
Initialize the communication buffer.
......@@ -53,13 +54,27 @@ class Buffer:
explicitly_destroy: If this flag is set to True, you need to explicitly call `destroy()` to release resources;
otherwise, the resources will be released by the destructor.
Note: Releasing resources in the destructor may cause Python's exception handling process to hang.
comm: the mpi4py.MPI.Comm communicator to use in case the group parameter is absent.
"""
check_nvlink_connections(group)
# Initialize the CPP runtime
self.rank = group.rank()
self.group_size = group.size()
self.group = group
if group is not None:
self.rank = group.rank()
self.group_size = group.size()
def all_gather_object(obj):
object_list = [None] * self.group_size
dist.all_gather_object(object_list, obj, group)
return object_list
elif comm is not None:
self.rank = comm.Get_rank()
self.group_size = comm.Get_size()
def all_gather_object(obj):
return comm.allgather(obj)
else:
raise ValueError("Either 'group' or 'comm' must be provided.")
self.num_nvl_bytes = num_nvl_bytes
self.num_rdma_bytes = num_rdma_bytes
self.low_latency_mode = low_latency_mode
......@@ -67,14 +82,12 @@ class Buffer:
self.runtime = deep_ep_cpp.Buffer(self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode, explicitly_destroy)
# Synchronize device IDs
device_ids = [None, ] * self.group_size
local_device_id = self.runtime.get_local_device_id()
dist.all_gather_object(device_ids, local_device_id, group)
device_ids = all_gather_object(local_device_id)
# Synchronize IPC handles
ipc_handles = [None, ] * self.group_size
local_ipc_handle = self.runtime.get_local_ipc_handle()
dist.all_gather_object(ipc_handles, local_ipc_handle, group)
ipc_handles = all_gather_object(local_ipc_handle)
# Synchronize NVSHMEM unique IDs
root_unique_id = None
......@@ -100,10 +113,9 @@ class Buffer:
os.environ['NVSHMEM_DISABLE_MNNVL'] = '1'
# Synchronize using the root ID
nvshmem_unique_ids = [None, ] * self.group_size
if (low_latency_mode and self.rank == 0) or (not low_latency_mode and self.runtime.get_rdma_rank() == 0):
root_unique_id = self.runtime.get_local_nvshmem_unique_id()
dist.all_gather_object(nvshmem_unique_ids, root_unique_id, group)
nvshmem_unique_ids = all_gather_object(root_unique_id)
root_unique_id = nvshmem_unique_ids[0 if low_latency_mode else self.runtime.get_root_rdma_rank(True)]
# Make CPP runtime available
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment