Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
b80e55e2
Unverified
Commit
b80e55e2
authored
Jun 25, 2025
by
Shangyan Zhou
Committed by
GitHub
Jun 25, 2025
Browse files
Add `get_comm_stream`. (#256)
* Add `get_comm_stream`. * Fix style.
parent
a15faa9f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
0 deletions
+17
-0
csrc/deep_ep.cpp
csrc/deep_ep.cpp
+5
-0
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+2
-0
deep_ep/buffer.py
deep_ep/buffer.py
+10
-0
No files found.
csrc/deep_ep.cpp
View file @
b80e55e2
...
...
@@ -163,6 +163,10 @@ torch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object& dtype, int
return
torch
::
from_blob
(
base_ptr
,
num_bytes
/
element_bytes
,
torch
::
TensorOptions
().
dtype
(
casted_dtype
).
device
(
at
::
kCUDA
));
}
torch
::
Stream
Buffer
::
get_comm_stream
()
const
{
return
comm_stream
;
}
void
Buffer
::
sync
(
const
std
::
vector
<
int
>
&
device_ids
,
const
std
::
vector
<
std
::
optional
<
pybind11
::
bytearray
>>
&
all_gathered_handles
,
const
std
::
optional
<
pybind11
::
bytearray
>&
root_unique_id_opt
)
{
...
...
@@ -1303,6 +1307,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.
def
(
"get_local_ipc_handle"
,
&
deep_ep
::
Buffer
::
get_local_ipc_handle
)
.
def
(
"get_local_nvshmem_unique_id"
,
&
deep_ep
::
Buffer
::
get_local_nvshmem_unique_id
)
.
def
(
"get_local_buffer_tensor"
,
&
deep_ep
::
Buffer
::
get_local_buffer_tensor
)
.
def
(
"get_comm_stream"
,
&
deep_ep
::
Buffer
::
get_comm_stream
)
.
def
(
"sync"
,
&
deep_ep
::
Buffer
::
sync
)
.
def
(
"get_dispatch_layout"
,
&
deep_ep
::
Buffer
::
get_dispatch_layout
)
.
def
(
"intranode_dispatch"
,
&
deep_ep
::
Buffer
::
intranode_dispatch
)
...
...
csrc/deep_ep.hpp
View file @
b80e55e2
...
...
@@ -94,6 +94,8 @@ public:
torch
::
Tensor
get_local_buffer_tensor
(
const
pybind11
::
object
&
dtype
,
int64_t
offset
,
bool
use_rdma_buffer
)
const
;
torch
::
Stream
get_comm_stream
()
const
;
void
sync
(
const
std
::
vector
<
int
>&
device_ids
,
const
std
::
vector
<
std
::
optional
<
pybind11
::
bytearray
>>&
all_gathered_handles
,
const
std
::
optional
<
pybind11
::
bytearray
>&
root_unique_id_opt
);
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>>
...
...
deep_ep/buffer.py
View file @
b80e55e2
...
...
@@ -147,6 +147,16 @@ class Buffer:
size: the RDMA buffer size recommended.
"""
return
deep_ep_cpp
.
get_low_latency_rdma_size_hint
(
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
)
def
get_comm_stream
(
self
)
->
torch
.
Stream
:
"""
Get the communication stream.
Returns:
stream: the communication stream.
"""
ts
:
torch
.
Stream
=
self
.
runtime
.
get_comm_stream
()
return
torch
.
cuda
.
Stream
(
stream_id
=
ts
.
stream_id
,
device_index
=
ts
.
device_index
,
device_type
=
ts
.
device_type
)
def
get_local_buffer_tensor
(
self
,
dtype
:
torch
.
dtype
,
size
:
Optional
[
torch
.
Size
]
=
None
,
offset
:
int
=
0
,
use_rdma_buffer
:
bool
=
False
)
->
torch
.
Tensor
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment