Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
b3b61ef5
Commit
b3b61ef5
authored
Mar 10, 2025
by
Dmytro Dzhulgakov
Browse files
Allow passing output tensor in low_latency_combine
parent
ed7487c1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
17 additions
and
6 deletions
+17
-6
csrc/deep_ep.cpp
csrc/deep_ep.cpp
+10
-2
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+1
-1
deep_ep/buffer.py
deep_ep/buffer.py
+4
-2
tests/test_low_latency.py
tests/test_low_latency.py
+2
-1
No files found.
csrc/deep_ep.cpp
View file @
b3b61ef5
...
@@ -1100,7 +1100,7 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
...
@@ -1100,7 +1100,7 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
Buffer
::
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
Buffer
::
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
async
,
bool
return_recv_hook
)
{
bool
async
,
bool
return_recv_hook
,
std
::
optional
<
torch
::
Tensor
>
out
)
{
EP_HOST_ASSERT
(
low_latency_mode
);
EP_HOST_ASSERT
(
low_latency_mode
);
// Tensor checks
// Tensor checks
...
@@ -1138,7 +1138,15 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
...
@@ -1138,7 +1138,15 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
stream_wait
(
launch_stream
,
compute_stream
);
stream_wait
(
launch_stream
,
compute_stream
);
// Allocate output tensor
// Allocate output tensor
auto
combined_x
=
torch
::
empty
({
num_combined_tokens
,
hidden
},
x
.
options
());
torch
::
Tensor
combined_x
;
if
(
out
.
has_value
())
{
EP_HOST_ASSERT
(
out
->
dim
()
==
2
and
out
->
is_contiguous
());
EP_HOST_ASSERT
(
out
->
size
(
0
)
==
num_combined_tokens
and
out
->
size
(
1
)
==
hidden
);
EP_HOST_ASSERT
(
out
->
scalar_type
()
==
x
.
scalar_type
());
combined_x
=
out
.
value
();
}
else
{
combined_x
=
torch
::
empty
({
num_combined_tokens
,
hidden
},
x
.
options
());
}
// Kernel launch
// Kernel launch
auto
next_clean_meta
=
next_buffer
.
clean_meta
();
auto
next_clean_meta
=
next_buffer
.
clean_meta
();
...
...
csrc/deep_ep.hpp
View file @
b3b61ef5
...
@@ -143,7 +143,7 @@ public:
...
@@ -143,7 +143,7 @@ public:
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
async
,
bool
return_recv_hook
);
bool
async
,
bool
return_recv_hook
,
std
::
optional
<
torch
::
Tensor
>
out
=
std
::
nullopt
);
};
};
}
// namespace deep_ep
}
// namespace deep_ep
deep_ep/buffer.py
View file @
b3b61ef5
...
@@ -497,7 +497,8 @@ class Buffer:
...
@@ -497,7 +497,8 @@ class Buffer:
# noinspection PyTypeChecker
# noinspection PyTypeChecker
def
low_latency_combine
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
def
low_latency_combine
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
handle
:
tuple
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
)
->
\
handle
:
tuple
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
,
out
:
torch
.
Tensor
|
None
=
None
)
->
\
Tuple
[
torch
.
Tensor
,
EventOverlap
,
Callable
]:
Tuple
[
torch
.
Tensor
,
EventOverlap
,
Callable
]:
"""
"""
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
...
@@ -520,6 +521,7 @@ class Buffer:
...
@@ -520,6 +521,7 @@ class Buffer:
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
Returns:
Returns:
combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`.
combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`.
...
@@ -529,6 +531,6 @@ class Buffer:
...
@@ -529,6 +531,6 @@ class Buffer:
src_info
,
layout_range
,
num_max_dispatch_tokens_per_rank
,
num_experts
=
handle
src_info
,
layout_range
,
num_max_dispatch_tokens_per_rank
,
num_experts
=
handle
combined_x
,
event
,
hook
=
self
.
runtime
.
low_latency_combine
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
combined_x
,
event
,
hook
=
self
.
runtime
.
low_latency_combine
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
async_finish
,
return_recv_hook
)
async_finish
,
return_recv_hook
,
out
)
tensors_to_record
=
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
combined_x
)
tensors_to_record
=
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
combined_x
)
return
combined_x
,
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
return
combined_x
,
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
tests/test_low_latency.py
View file @
b3b61ef5
...
@@ -73,8 +73,9 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
...
@@ -73,8 +73,9 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
hash_value
^=
hash_tensor
(
packed_recv_x
[
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
packed_recv_x
[
i
,
:
num_valid_tokens
])
# Check combine correctness
# Check combine correctness
out
=
torch
.
empty
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
,
out
=
out
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
if
do_check
:
if
do_check
:
diff
=
calc_diff
(
x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
diff
=
calc_diff
(
x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment