Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
f0d34aab
Unverified
Commit
f0d34aab
authored
Aug 19, 2025
by
Tailing Yuan
Committed by
GitHub
Aug 19, 2025
Browse files
Init buffer with mpi4py.MPI.Comm (#365)
Signed-off-by:
Tailing Yuan
<
yuantailing@gmail.com
>
parent
e3908bf5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
11 deletions
+23
-11
deep_ep/buffer.py
deep_ep/buffer.py
+23
-11
No files found.
deep_ep/buffer.py
View file @
f0d34aab
...
...
@@ -29,12 +29,13 @@ class Buffer:
num_sms
:
int
=
20
def
__init__
(
self
,
group
:
dist
.
ProcessGroup
,
def
__init__
(
self
,
group
:
Optional
[
dist
.
ProcessGroup
]
,
num_nvl_bytes
:
int
=
0
,
num_rdma_bytes
:
int
=
0
,
low_latency_mode
:
bool
=
False
,
num_qps_per_rank
:
int
=
24
,
allow_nvlink_for_low_latency_mode
:
bool
=
True
,
allow_mnnvl
:
bool
=
False
,
explicitly_destroy
:
bool
=
False
)
->
None
:
explicitly_destroy
:
bool
=
False
,
comm
:
Optional
[
"mpi4py.MPI.Comm"
]
=
None
)
->
None
:
"""
Initialize the communication buffer.
...
...
@@ -53,13 +54,27 @@ class Buffer:
explicitly_destroy: If this flag is set to True, you need to explicitly call `destroy()` to release resources;
otherwise, the resources will be released by the destructor.
Note: Releasing resources in the destructor may cause Python's exception handling process to hang.
comm: the mpi4py.MPI.Comm communicator to use in case the group parameter is absent.
"""
check_nvlink_connections
(
group
)
# Initialize the CPP runtime
if
group
is
not
None
:
self
.
rank
=
group
.
rank
()
self
.
group_size
=
group
.
size
()
self
.
group
=
group
def
all_gather_object
(
obj
):
object_list
=
[
None
]
*
self
.
group_size
dist
.
all_gather_object
(
object_list
,
obj
,
group
)
return
object_list
elif
comm
is
not
None
:
self
.
rank
=
comm
.
Get_rank
()
self
.
group_size
=
comm
.
Get_size
()
def
all_gather_object
(
obj
):
return
comm
.
allgather
(
obj
)
else
:
raise
ValueError
(
"Either 'group' or 'comm' must be provided."
)
self
.
num_nvl_bytes
=
num_nvl_bytes
self
.
num_rdma_bytes
=
num_rdma_bytes
self
.
low_latency_mode
=
low_latency_mode
...
...
@@ -67,14 +82,12 @@ class Buffer:
self
.
runtime
=
deep_ep_cpp
.
Buffer
(
self
.
rank
,
self
.
group_size
,
num_nvl_bytes
,
num_rdma_bytes
,
low_latency_mode
,
explicitly_destroy
)
# Synchronize device IDs
device_ids
=
[
None
,
]
*
self
.
group_size
local_device_id
=
self
.
runtime
.
get_local_device_id
()
d
ist
.
all_gather_object
(
device_ids
,
local_device_id
,
group
)
d
evice_ids
=
all_gather_object
(
local_device_id
)
# Synchronize IPC handles
ipc_handles
=
[
None
,
]
*
self
.
group_size
local_ipc_handle
=
self
.
runtime
.
get_local_ipc_handle
()
dist
.
all_gather_object
(
ipc_handles
,
local_ipc_handle
,
group
)
ipc_handles
=
all_gather_object
(
local_ipc_handle
)
# Synchronize NVSHMEM unique IDs
root_unique_id
=
None
...
...
@@ -100,10 +113,9 @@ class Buffer:
os
.
environ
[
'NVSHMEM_DISABLE_MNNVL'
]
=
'1'
# Synchronize using the root ID
nvshmem_unique_ids
=
[
None
,
]
*
self
.
group_size
if
(
low_latency_mode
and
self
.
rank
==
0
)
or
(
not
low_latency_mode
and
self
.
runtime
.
get_rdma_rank
()
==
0
):
root_unique_id
=
self
.
runtime
.
get_local_nvshmem_unique_id
()
dist
.
all_gather_object
(
nvshmem_unique_ids
,
root_unique_id
,
group
)
nvshmem_unique_ids
=
all_gather_object
(
root_unique_id
)
root_unique_id
=
nvshmem_unique_ids
[
0
if
low_latency_mode
else
self
.
runtime
.
get_root_rdma_rank
(
True
)]
# Make CPP runtime available
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment