Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
0d1a855d
Unverified
Commit
0d1a855d
authored
Jun 09, 2025
by
Chenggang Zhao
Committed by
GitHub
Jun 09, 2025
Browse files
Add low-latency kernel PCIe usage flag (#195)
* Add low-latency kernel usage flag * Update comments
parent
564e3752
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
57 additions
and
13 deletions
+57
-13
csrc/deep_ep.cpp
csrc/deep_ep.cpp
+15
-2
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+6
-0
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+4
-3
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+18
-8
deep_ep/buffer.py
deep_ep/buffer.py
+13
-0
tests/test_low_latency.py
tests/test_low_latency.py
+1
-0
No files found.
csrc/deep_ep.cpp
View file @
0d1a855d
...
...
@@ -76,6 +76,13 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
CUDA_CHECK
(
cudaHostGetDevicePointer
(
&
moe_recv_rdma_counter_mapped
,
const_cast
<
int
*>
(
moe_recv_rdma_counter
),
0
));
*
moe_recv_rdma_counter
=
-
1
;
}
// Low-latency kernels' usage flag
if
(
low_latency_mode
)
{
CUDA_CHECK
(
cudaMallocHost
(
&
low_latency_usage_flag
,
sizeof
(
int
),
cudaHostAllocMapped
));
CUDA_CHECK
(
cudaHostGetDevicePointer
(
&
low_latency_usage_flag_mapped
,
const_cast
<
int
*>
(
low_latency_usage_flag
),
0
));
*
low_latency_usage_flag
=
0
;
}
}
Buffer
::~
Buffer
()
noexcept
(
false
)
{
...
...
@@ -997,6 +1004,11 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
return
{
combined_x
,
combined_topk_weights
,
event
};
}
uint64_t
Buffer
::
get_low_latency_usage_flag
()
const
{
EP_HOST_ASSERT
(
low_latency_usage_flag
!=
nullptr
);
return
reinterpret_cast
<
uint64_t
>
(
low_latency_usage_flag
);
}
void
Buffer
::
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
)
{
EP_HOST_ASSERT
(
low_latency_mode
);
...
...
@@ -1078,7 +1090,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
next_clean_meta
.
first
,
next_clean_meta
.
second
,
num_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
use_fp8
,
workspace
,
launch_stream
,
phases
);
workspace
,
low_latency_usage_flag_mapped
,
launch_stream
,
phases
);
};
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
...
...
@@ -1165,7 +1177,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
next_clean_meta
.
first
,
next_clean_meta
.
second
,
num_combined_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
workspace
,
launch_stream
,
workspace
,
low_latency_usage_flag_mapped
,
launch_stream
,
phases
,
zero_copy
);
};
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
...
...
@@ -1238,6 +1250,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.
def
(
"intranode_combine"
,
&
deep_ep
::
Buffer
::
intranode_combine
)
.
def
(
"internode_dispatch"
,
&
deep_ep
::
Buffer
::
internode_dispatch
)
.
def
(
"internode_combine"
,
&
deep_ep
::
Buffer
::
internode_combine
)
.
def
(
"get_low_latency_usage_flag"
,
&
deep_ep
::
Buffer
::
get_low_latency_usage_flag
)
.
def
(
"clean_low_latency_buffer"
,
&
deep_ep
::
Buffer
::
clean_low_latency_buffer
)
.
def
(
"low_latency_dispatch"
,
&
deep_ep
::
Buffer
::
low_latency_dispatch
)
.
def
(
"low_latency_combine"
,
&
deep_ep
::
Buffer
::
low_latency_combine
)
...
...
csrc/deep_ep.hpp
View file @
0d1a855d
...
...
@@ -71,6 +71,10 @@ private:
volatile
int
*
moe_recv_rdma_counter
=
nullptr
;
int
*
moe_recv_rdma_counter_mapped
=
nullptr
;
// Host-side low-latency kernels' usages
volatile
int
*
low_latency_usage_flag
=
nullptr
;
int
*
low_latency_usage_flag_mapped
=
nullptr
;
private:
void
move_fifo_slots
(
int
num_slots
=
1
);
...
...
@@ -132,6 +136,8 @@ public:
const
torch
::
Tensor
&
combined_rdma_head
,
const
torch
::
Tensor
&
combined_nvl_head
,
const
Config
&
config
,
std
::
optional
<
EventHandle
>&
previous_event
,
bool
async
,
bool
allocate_on_comm_stream
);
uint64_t
get_low_latency_usage_flag
()
const
;
void
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
);
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
...
...
csrc/kernels/api.cuh
View file @
0d1a855d
...
...
@@ -138,7 +138,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
void
*
workspace
,
cudaStream_t
stream
,
int
phases
);
void
*
workspace
,
int
*
usage_flag
,
cudaStream_t
stream
,
int
phases
);
void
combine
(
void
*
combined_x
,
void
*
rdma_recv_x
,
int
*
rdma_recv_flag
,
void
*
rdma_send_x
,
...
...
@@ -147,8 +148,8 @@ void combine(void* combined_x,
int
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
void
*
workspace
,
cudaStream_t
stream
,
int
phases
,
bool
zero_copy
);
void
*
workspace
,
int
*
usage_flag
,
cudaStream_t
stream
,
int
phases
,
bool
zero_copy
);
}
// namespace internode_ll
...
...
csrc/kernels/internode_ll.cu
View file @
0d1a855d
...
...
@@ -47,7 +47,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
phases
)
{
int
*
usage_flag
,
int
phases
)
{
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
const
auto
warp_id
=
thread_id
/
32
,
lane_id
=
get_lane_id
();
...
...
@@ -180,6 +180,10 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
#pragma unroll
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
32
)
atomic_add_release_global
(
atomic_finish_counter_per_expert
+
i
,
FINISHED_SUM_TAG
);
}
else
if
(
sm_id
==
1
)
{
// The second SM is also responsible for notifying PCIe usage
if
(
lane_id
==
0
)
atomicAdd_system
(
usage_flag
,
1
);
}
// This SM should be responsible for some destination experts, read `topk_idx` for them
...
...
@@ -311,7 +315,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
void
*
workspace
,
cudaStream_t
stream
,
int
phases
)
{
void
*
workspace
,
int
*
usage_flag
,
cudaStream_t
stream
,
int
phases
)
{
constexpr
int
kNumMaxTopK
=
9
;
constexpr
int
kNumWarpsPerGroup
=
10
;
constexpr
int
kNumWarpGroups
=
3
;
...
...
@@ -338,7 +343,8 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, phases); } break
num_topk, num_experts, rank, num_ranks, \
usage_flag, phases); } break
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
32
,
stream
);
SWITCH_HIDDEN
(
DISPATCH_LAUNCH_CASE
);
...
...
@@ -356,7 +362,7 @@ combine(void* combined_x,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
phases
,
bool
zero_copy
)
{
int
*
usage_flag
,
int
phases
,
bool
zero_copy
)
{
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
num_sms
=
static_cast
<
int
>
(
gridDim
.
x
);
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
...
...
@@ -451,11 +457,14 @@ combine(void* combined_x,
if
((
phases
&
LOW_LATENCY_RECV_PHASE
)
==
0
)
return
;
// Wait all ranks to arrive
// Wait all ranks to arrive
and notify usages
if
(
responsible_expert_idx
<
num_experts
)
{
EP_STATIC_ASSERT
(
kNumWarpsPerGroup
>
1
,
"Invalid number of warps per group"
);
if
(
sub_warp_id
==
0
and
lane_id
==
0
)
if
(
sub_warp_id
==
0
and
lane_id
==
0
)
{
while
(
ld_acquire_sys_global
(
rdma_recv_flag
+
responsible_expert_idx
)
==
0
);
}
else
if
(
sm_id
==
0
and
sub_warp_id
==
1
and
lane_id
==
0
)
{
atomicAdd_system
(
usage_flag
,
1
);
}
}
cg
::
this_grid
().
sync
();
...
...
@@ -506,8 +515,8 @@ void combine(void* combined_x,
int
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
void
*
workspace
,
cudaStream_t
stream
,
int
phases
,
bool
zero_copy
)
{
void
*
workspace
,
int
*
usage_flag
,
cudaStream_t
stream
,
int
phases
,
bool
zero_copy
)
{
constexpr
int
kNumWarpsPerGroup
=
10
;
constexpr
int
kNumWarpGroups
=
3
;
constexpr
int
kNumMaxTopk
=
9
;
...
...
@@ -531,6 +540,7 @@ LAUNCH_KERNEL(&cfg, combine_func, \
num_combined_tokens, hidden, num_topk, \
num_max_dispatch_tokens_per_rank, \
num_experts, rank, num_ranks, \
usage_flag, \
phases, zero_copy); } break
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
32
,
stream
);
...
...
deep_ep/buffer.py
View file @
0d1a855d
...
...
@@ -443,6 +443,19 @@ class Buffer:
async_finish
,
allocate_on_comm_stream
)
return
combined_x
,
combined_topk_weights
,
EventOverlap
(
event
)
def
get_low_latency_usage_flag
(
self
):
"""
Return a host-side integer flag, which indicates the stages of low-latency kernels.
The initial value is 0, the low-latency dispatch will add 1 before communication, the low-latency combine
will add 1 after communication.
This is useful when there is no two-batch overlap, and you want to overlap H2D/D2H transfer with attention layers.
Returns:
flag: the host-side integer flag pointer. The value is in `int`, but returns a `uint64_t` pointer. Please
`reinterpret_cast` the returned value into `int*`.
"""
return
self
.
runtime
.
get_low_latency_usage_flag
()
def
clean_low_latency_buffer
(
self
,
num_max_dispatch_tokens_per_rank
:
int
,
hidden
:
int
,
num_experts
:
int
)
->
None
:
"""
As low-latency kernels require part of the buffer to be zero-initialized, so it is vital to clean the buffer
...
...
tests/test_low_latency.py
View file @
0d1a855d
...
...
@@ -155,6 +155,7 @@ def test_loop(local_rank: int, num_local_ranks: int):
print
(
f
'Allocating buffer size:
{
num_rdma_bytes
/
1e6
}
MB ...'
,
flush
=
True
)
buffer
=
deep_ep
.
Buffer
(
group
,
num_rdma_bytes
=
num_rdma_bytes
,
low_latency_mode
=
True
,
num_qps_per_rank
=
num_experts
//
num_ranks
)
buffer
.
get_low_latency_usage_flag
()
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
1
)
do_pressure_test
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment