Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
4f828c59
"examples/python-examples/input.prmtop" did not exist on "534fd40416011ebbbdf49deebecd32bcddaaffae"
Commit
4f828c59
authored
Dec 23, 2025
by
lishen
Browse files
支持combine_wait_recv_cost记录
parent
f4b3020e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
37 additions
and
3 deletions
+37
-3
csrc/deep_ep.cu
csrc/deep_ep.cu
+9
-0
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+1
-0
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+1
-0
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+20
-2
deep_ep/buffer.py
deep_ep/buffer.py
+6
-1
No files found.
csrc/deep_ep.cu
View file @
4f828c59
...
@@ -1397,6 +1397,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1397,6 +1397,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
Buffer
::
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
Buffer
::
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
const
std
::
optional
<
torch
::
Tensor
>&
combine_wait_recv_cost_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
const
std
::
optional
<
torch
::
Tensor
>&
out
)
{
const
std
::
optional
<
torch
::
Tensor
>&
out
)
{
...
@@ -1418,6 +1419,13 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
...
@@ -1418,6 +1419,13 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
EP_HOST_ASSERT
(
layout_range
.
dim
()
==
2
and
layout_range
.
is_contiguous
());
EP_HOST_ASSERT
(
layout_range
.
dim
()
==
2
and
layout_range
.
is_contiguous
());
EP_HOST_ASSERT
(
layout_range
.
scalar_type
()
==
torch
::
kInt64
);
EP_HOST_ASSERT
(
layout_range
.
scalar_type
()
==
torch
::
kInt64
);
EP_HOST_ASSERT
(
layout_range
.
size
(
0
)
==
num_experts
/
num_ranks
and
layout_range
.
size
(
1
)
==
num_ranks
);
EP_HOST_ASSERT
(
layout_range
.
size
(
0
)
==
num_experts
/
num_ranks
and
layout_range
.
size
(
1
)
==
num_ranks
);
if
(
combine_wait_recv_cost_stats
.
has_value
())
{
EP_HOST_ASSERT
(
combine_wait_recv_cost_stats
->
scalar_type
()
==
torch
::
kInt64
);
EP_HOST_ASSERT
(
combine_wait_recv_cost_stats
->
dim
()
==
1
and
combine_wait_recv_cost_stats
->
is_contiguous
());
EP_HOST_ASSERT
(
combine_wait_recv_cost_stats
->
size
(
0
)
==
num_ranks
);
}
auto
hidden
=
static_cast
<
int
>
(
x
.
size
(
2
));
auto
hidden
=
static_cast
<
int
>
(
x
.
size
(
2
));
auto
num_local_experts
=
num_experts
/
num_ranks
,
num_topk
=
static_cast
<
int
>
(
topk_weights
.
size
(
1
));
auto
num_local_experts
=
num_experts
/
num_ranks
,
num_topk
=
static_cast
<
int
>
(
topk_weights
.
size
(
1
));
auto
num_combined_tokens
=
static_cast
<
int
>
(
topk_weights
.
size
(
0
));
auto
num_combined_tokens
=
static_cast
<
int
>
(
topk_weights
.
size
(
0
));
...
@@ -1456,6 +1464,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
...
@@ -1456,6 +1464,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
x
.
data_ptr
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
topk_weights
.
data_ptr
<
float
>
(),
x
.
data_ptr
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
topk_weights
.
data_ptr
<
float
>
(),
src_info
.
data_ptr
<
int
>
(),
layout_range
.
data_ptr
<
int64_t
>
(),
src_info
.
data_ptr
<
int
>
(),
layout_range
.
data_ptr
<
int64_t
>
(),
global_atomic_counter
.
data_ptr
<
int
>
(),
global_atomic_counter
.
data_ptr
<
int
>
(),
combine_wait_recv_cost_stats
.
has_value
()
?
combine_wait_recv_cost_stats
->
data_ptr
<
int64_t
>
()
:
nullptr
,
next_clean_meta
.
first
,
next_clean_meta
.
second
,
next_clean_meta
.
first
,
next_clean_meta
.
second
,
num_combined_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_combined_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
...
...
csrc/deep_ep.hpp
View file @
4f828c59
...
@@ -183,6 +183,7 @@ public:
...
@@ -183,6 +183,7 @@ public:
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
const
std
::
optional
<
torch
::
Tensor
>&
combine_wait_recv_cost_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
const
std
::
optional
<
torch
::
Tensor
>&
out
=
std
::
nullopt
);
const
std
::
optional
<
torch
::
Tensor
>&
out
=
std
::
nullopt
);
...
...
csrc/kernels/api.cuh
View file @
4f828c59
...
@@ -155,6 +155,7 @@ void combine(void* combined_x,
...
@@ -155,6 +155,7 @@ void combine(void* combined_x,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
int
*
global_atomic_counter
,
int
*
global_atomic_counter
,
int64_t
*
combine_wait_recv_cost_stats
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
...
...
csrc/kernels/internode_ll.cu
View file @
4f828c59
...
@@ -549,6 +549,7 @@ combine(void* combined_x,
...
@@ -549,6 +549,7 @@ combine(void* combined_x,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
int
*
global_atomic_counter
,
int
*
global_atomic_counter
,
int64_t
*
combine_wait_recv_cost_stats
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
*
atomic_clean_flag
,
int
*
atomic_clean_flag
,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
...
@@ -724,8 +725,23 @@ LOW_LATENCY_COMBINE_RECV:
...
@@ -724,8 +725,23 @@ LOW_LATENCY_COMBINE_RECV:
// Wait all ranks to arrive and notify PCIe usage
// Wait all ranks to arrive and notify PCIe usage
if
(
responsible_expert_idx
<
num_experts
)
{
if
(
responsible_expert_idx
<
num_experts
)
{
EP_DEVICE_ASSERT
(
num_warps_per_group
>
1
);
EP_DEVICE_ASSERT
(
num_warps_per_group
>
1
);
if
(
sub_warp_id
==
0
and
lane_id
==
0
){
if
(
sub_warp_id
==
0
and
lane_id
==
0
)
{
while
(
ld_acquire_global
(
reinterpret_cast
<
int
*>
(
rdma_recv_flag
+
responsible_expert_idx
))
==
0
);
const
auto
src_rank
=
responsible_expert_idx
/
num_local_experts
;
auto
start_time
=
wall_clock64
();
uint64_t
wait_recv_cost
=
0
;
while
(
ld_acquire_global
(
reinterpret_cast
<
int
*>
(
rdma_recv_flag
+
responsible_expert_idx
))
==
0
// recv not ready
&&
(
wait_recv_cost
=
wall_clock64
()
-
start_time
)
<=
NUM_TIMEOUT_CYCLES
// not timeout
);
// Mask rank if timeout
if
(
wait_recv_cost
>
NUM_TIMEOUT_CYCLES
)
{
printf
(
"Warning: DeepEP timeout for combine receive, rank %d, local_expert_idx %d, src_rank %d
\n
"
,
rank
,
responsible_expert_idx
%
num_local_experts
,
src_rank
);
}
if
(
combine_wait_recv_cost_stats
!=
nullptr
)
{
atomicAdd
(
reinterpret_cast
<
unsigned
long
long
*>
(
combine_wait_recv_cost_stats
+
src_rank
),
wait_recv_cost
);
}
}
}
}
}
grid_barrier
(
global_atomic_counter
,
num_sms
);
grid_barrier
(
global_atomic_counter
,
num_sms
);
...
@@ -776,6 +792,7 @@ void combine(void* combined_x,
...
@@ -776,6 +792,7 @@ void combine(void* combined_x,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
int
*
global_atomic_counter
,
int
*
global_atomic_counter
,
int64_t
*
combine_wait_recv_cost_stats
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
...
@@ -803,6 +820,7 @@ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
...
@@ -803,6 +820,7 @@ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \
x, topk_idx, topk_weights, src_info, layout_range, \
global_atomic_counter, \
global_atomic_counter, \
combine_wait_recv_cost_stats, \
next_clean, num_next_clean_int, \
next_clean, num_next_clean_int, \
atomic_clean_flag, \
atomic_clean_flag, \
num_combined_tokens, hidden, num_topk, \
num_combined_tokens, hidden, num_topk, \
...
...
deep_ep/buffer.py
View file @
4f828c59
...
@@ -901,7 +901,8 @@ class Buffer:
...
@@ -901,7 +901,8 @@ class Buffer:
# noinspection PyTypeChecker
# noinspection PyTypeChecker
def
low_latency_combine
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
def
low_latency_combine
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
handle
:
tuple
,
zero_copy
:
bool
=
False
,
async_finish
:
bool
=
False
,
handle
:
tuple
,
zero_copy
:
bool
=
False
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
,
out
:
Optional
[
torch
.
Tensor
]
=
None
)
->
\
return_recv_hook
:
bool
=
False
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
combine_wait_recv_cost_stats
:
Optional
[
torch
.
Tensor
]
=
None
)
->
\
Tuple
[
torch
.
Tensor
,
EventOverlap
,
Callable
]:
Tuple
[
torch
.
Tensor
,
EventOverlap
,
Callable
]:
"""
"""
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
...
@@ -927,6 +928,9 @@ class Buffer:
...
@@ -927,6 +928,9 @@ class Buffer:
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
This is useful for detecting and pre-cisely localizing slow anomalies.
Returns:
Returns:
combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`.
combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`.
...
@@ -935,6 +939,7 @@ class Buffer:
...
@@ -935,6 +939,7 @@ class Buffer:
"""
"""
src_info
,
layout_range
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
=
handle
src_info
,
layout_range
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
=
handle
combined_x
,
event
,
hook
=
self
.
runtime
.
low_latency_combine
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
combined_x
,
event
,
hook
=
self
.
runtime
.
low_latency_combine
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
combine_wait_recv_cost_stats
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
zero_copy
,
async_finish
,
return_recv_hook
,
out
)
zero_copy
,
async_finish
,
return_recv_hook
,
out
)
tensors_to_record
=
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
combined_x
)
tensors_to_record
=
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
combined_x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment