Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
5a2e37fa
"platforms/vscode:/vscode.git/clone" did not exist on "ba66e90e878e3ee83b4e2969a31afe9233104bf5"
Unverified
Commit
5a2e37fa
authored
Jun 09, 2025
by
Chenggang Zhao
Committed by
GitHub
Jun 09, 2025
Browse files
Support statistics tensor for low-latency kernels (#196)
parent
0d1a855d
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
27 additions
and
3 deletions
+27
-3
csrc/deep_ep.cpp
csrc/deep_ep.cpp
+7
-0
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+1
-0
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+1
-0
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+5
-0
deep_ep/buffer.py
deep_ep/buffer.py
+9
-3
tests/test_low_latency.py
tests/test_low_latency.py
+4
-0
No files found.
csrc/deep_ep.cpp
View file @
5a2e37fa
...
...
@@ -1030,6 +1030,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
Buffer
::
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
std
::
optional
<
torch
::
Tensor
>&
cumulative_local_expert_recv_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
async
,
bool
return_recv_hook
)
{
EP_HOST_ASSERT
(
low_latency_mode
);
...
...
@@ -1042,6 +1043,11 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
EP_HOST_ASSERT
(
x
.
size
(
0
)
==
topk_idx
.
size
(
0
)
and
x
.
size
(
0
)
<=
num_max_dispatch_tokens_per_rank
);
EP_HOST_ASSERT
(
topk_idx
.
scalar_type
()
==
torch
::
kInt64
);
EP_HOST_ASSERT
(
num_experts
%
num_ranks
==
0
);
if
(
cumulative_local_expert_recv_stats
.
has_value
())
{
EP_HOST_ASSERT
(
cumulative_local_expert_recv_stats
->
scalar_type
()
==
torch
::
kInt
);
EP_HOST_ASSERT
(
cumulative_local_expert_recv_stats
->
dim
()
==
1
and
cumulative_local_expert_recv_stats
->
is_contiguous
());
EP_HOST_ASSERT
(
cumulative_local_expert_recv_stats
->
size
(
0
)
==
num_experts
/
num_ranks
);
}
auto
num_tokens
=
static_cast
<
int
>
(
x
.
size
(
0
)),
hidden
=
static_cast
<
int
>
(
x
.
size
(
1
));
auto
num_scales
=
hidden
/
128
,
num_topk
=
static_cast
<
int
>
(
topk_idx
.
size
(
1
));
...
...
@@ -1084,6 +1090,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
internode_ll
::
dispatch
(
packed_recv_x
.
data_ptr
(),
packed_recv_x_scales_ptr
,
packed_recv_src_info
.
data_ptr
<
int
>
(),
packed_recv_layout_range
.
data_ptr
<
int64_t
>
(),
packed_recv_count
.
data_ptr
<
int
>
(),
cumulative_local_expert_recv_stats
.
has_value
()
?
cumulative_local_expert_recv_stats
->
data_ptr
<
int
>
()
:
nullptr
,
buffer
.
dispatch_rdma_recv_data_buffer
,
buffer
.
dispatch_rdma_recv_count_buffer
,
buffer
.
dispatch_rdma_send_buffer
,
x
.
data_ptr
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
...
...
csrc/deep_ep.hpp
View file @
5a2e37fa
...
...
@@ -142,6 +142,7 @@ public:
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
std
::
optional
<
torch
::
Tensor
>&
cumulative_local_expert_recv_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
async
,
bool
return_recv_hook
);
...
...
csrc/kernels/api.cuh
View file @
5a2e37fa
...
...
@@ -133,6 +133,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
cumulative_local_expert_recv_stats
,
void
*
rdma_recv_x
,
int
*
rdma_recv_count
,
void
*
rdma_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int
*
next_clean
,
int
num_next_clean_int
,
...
...
csrc/kernels/internode_ll.cu
View file @
5a2e37fa
...
...
@@ -41,6 +41,7 @@ __global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
cumulative_local_expert_recv_stats
,
void
*
rdma_recv_x
,
int
*
rdma_recv_count
,
void
*
rdma_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int
*
atomic_counter_per_expert
,
int
*
atomic_finish_counter_per_expert
,
...
...
@@ -273,6 +274,8 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
shared_num_recv_tokens
[
warp_group_id
]
=
num_recv_tokens
;
shared_recv_token_begin_idx
[
warp_group_id
]
=
recv_token_begin_idx
;
recv_range
[
src_rank
]
=
pack2
<
int
,
int64_t
>
(
num_recv_tokens
,
recv_token_begin_idx
);
if
(
cumulative_local_expert_recv_stats
!=
nullptr
)
atomicAdd
(
cumulative_local_expert_recv_stats
+
local_expert_idx
,
num_recv_tokens
);
}
asm
volatile
(
"bar.sync %0, %1;"
::
"r"
(
warp_group_id
+
2
),
"r"
(
kNumWarpsPerGroup
*
32
));
num_recv_tokens
=
shared_num_recv_tokens
[
warp_group_id
];
...
...
@@ -310,6 +313,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
cumulative_local_expert_recv_stats
,
void
*
rdma_recv_x
,
int
*
rdma_recv_count
,
void
*
rdma_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int
*
next_clean
,
int
num_next_clean_int
,
...
...
@@ -338,6 +342,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \
packed_recv_count, \
cumulative_local_expert_recv_stats, \
rdma_recv_x, rdma_recv_count, rdma_x, \
x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
...
...
deep_ep/buffer.py
View file @
5a2e37fa
...
...
@@ -473,6 +473,7 @@ class Buffer:
# noinspection PyTypeChecker
def
low_latency_dispatch
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
num_max_dispatch_tokens_per_rank
:
int
,
num_experts
:
int
,
cumulative_local_expert_recv_stats
:
Optional
[
torch
.
Tensor
]
=
None
,
use_fp8
:
bool
=
True
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
)
->
\
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
Tuple
,
EventOverlap
,
Callable
]:
"""
...
...
@@ -481,7 +482,7 @@ class Buffer:
(specifically, IBGDA must be enabled).
Even for ranks in the same node, NVLink are fully disabled for simplicity.
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
low-latency kernels' result tensor at a single moment.
low-latency kernels' result tensor
s
at a single moment.
Arguments:
x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are
...
...
@@ -490,6 +491,9 @@ class Buffer:
are supported. `-1` indices (not selecting any expert) are supported.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_experts: the number of all experts.
cumulative_local_expert_recv_stats: a cumulative expert count tensor for statistics, which should have shape
`[num_local_experts]` and be typed as `torch.int`. This is useful for online service EP load balance
monitoring.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
...
...
@@ -508,19 +512,21 @@ class Buffer:
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
expert receive. As mentioned before, not all tokens are valid in `recv_x`.
expert receive
s
. As mentioned before, not all tokens are valid in `recv_x`.
handle: the communication handle to be used in the `low_latency_combine` function.
event: the event after executing the kernel (valid only if `async_finish` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set).
"""
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
,
event
,
hook
=
\
self
.
runtime
.
low_latency_dispatch
(
x
,
topk_idx
,
cumulative_local_expert_recv_stats
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
use_fp8
,
async_finish
,
return_recv_hook
)
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
x
.
size
(
1
),
num_experts
)
tensors_to_record
=
(
x
,
topk_idx
,
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
)
packed_recv_src_info
,
packed_recv_layout_range
,
cumulative_local_expert_recv_stats
)
return
(
packed_recv_x
,
packed_recv_x_scales
)
if
use_fp8
else
packed_recv_x
,
packed_recv_count
,
handle
,
\
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
...
...
tests/test_low_latency.py
View file @
5a2e37fa
...
...
@@ -36,8 +36,10 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
for
dispatch_use_fp8
in
(
False
,
True
):
num_times
+=
1
for
i
in
range
((
num_times
%
2
)
+
1
):
cumulative_local_expert_recv_stats
=
torch
.
zeros
((
num_local_experts
,
),
dtype
=
torch
.
int
,
device
=
'cuda'
)
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
dispatch_use_fp8
,
cumulative_local_expert_recv_stats
=
cumulative_local_expert_recv_stats
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
...
...
@@ -53,6 +55,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
# Check expert indices
int_mask
=
(
2
**
32
)
-
1
num_valid_tokens
=
recv_count
.
item
()
assert
cumulative_local_expert_recv_stats
[
i
].
item
()
==
num_valid_tokens
,
f
'
{
cumulative_local_expert_recv_stats
[
i
].
item
()
}
!=
{
num_valid_tokens
}
'
assert
num_valid_tokens
==
(
recv_layout_range
&
int_mask
).
sum
().
item
(),
f
'
{
num_valid_tokens
}
!=
{
recv_layout_range
&
int_mask
}
.sum().item()'
assert
num_valid_tokens
==
(
all_topk_idx
==
expert_id
).
sum
().
item
(),
f
'
{
num_valid_tokens
}
!=
{
(
all_topk_idx
==
expert_id
).
sum
().
item
()
}
'
...
...
@@ -108,6 +111,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
def
test_func
(
zero_copy
:
bool
,
return_recv_hook
:
bool
):
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
x
,
topk_idx
,
num_tokens
,
num_experts
,
cumulative_local_expert_recv_stats
=
cumulative_local_expert_recv_stats
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
if
zero_copy
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment