Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
1b00b9d8
Commit
1b00b9d8
authored
Dec 26, 2025
by
lishen
Browse files
Merge branch 'updates' into 'main'
Updates See merge request dcutoolkit/deeplearing/DeepEP!10
parents
7e8acdf7
4f828c59
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
53 additions
and
16 deletions
+53
-16
csrc/config.hpp
csrc/config.hpp
+11
-8
csrc/deep_ep.cu
csrc/deep_ep.cu
+9
-0
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+1
-0
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+1
-0
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+24
-6
deep_ep/buffer.py
deep_ep/buffer.py
+6
-1
tests/test_low_latency_new.py
tests/test_low_latency_new.py
+1
-1
No files found.
csrc/config.hpp
View file @
1b00b9d8
...
@@ -150,6 +150,8 @@ struct LowLatencyLayout {
...
@@ -150,6 +150,8 @@ struct LowLatencyLayout {
size_t
num_bytes_per_dispatch_msg
=
size_t
num_bytes_per_dispatch_msg
=
sizeof
(
int4
)
+
sizeof
(
int4
)
+
std
::
max
(
hidden
*
sizeof
(
hip_bfloat16
),
hidden
+
num_scales
*
sizeof
(
float
));
std
::
max
(
hidden
*
sizeof
(
hip_bfloat16
),
hidden
+
num_scales
*
sizeof
(
float
));
// 与internode_ll::combine 中的 num_bytes_per_slot 相等
size_t
num_bytes_per_combine_msg
=
hidden
*
sizeof
(
hip_bfloat16
);
size_t
num_bytes_per_combine_msg
=
hidden
*
sizeof
(
hip_bfloat16
);
// Send buffer
// Send buffer
...
@@ -176,7 +178,8 @@ struct LowLatencyLayout {
...
@@ -176,7 +178,8 @@ struct LowLatencyLayout {
size_t
dispatch_recv_count_buffer_bytes
=
num_experts
*
sizeof
(
int64_t
);
size_t
dispatch_recv_count_buffer_bytes
=
num_experts
*
sizeof
(
int64_t
);
size_t
combine_recv_flag_buffer_bytes
=
dispatch_recv_count_buffer_bytes
;
size_t
combine_recv_flag_buffer_bytes
=
dispatch_recv_count_buffer_bytes
;
size_t
signaling_buffer_bytes
=
std
::
max
(
dispatch_recv_count_buffer_bytes
,
combine_recv_flag_buffer_bytes
);
size_t
signaling_buffer_bytes
=
std
::
max
(
dispatch_recv_count_buffer_bytes
,
combine_recv_flag_buffer_bytes
);
total_bytes
+=
signaling_buffer_bytes
*
2
;
size_t
signaling_buffer_bytes_aligned
=
ALIGN
<
size_t
>
(
signaling_buffer_bytes
,
128
);
total_bytes
+=
signaling_buffer_bytes_aligned
*
2
;
// Assign pointers
// Assign pointers
// NOTES: we still leave some space for distinguishing dispatch/combine buffer,
// NOTES: we still leave some space for distinguishing dispatch/combine buffer,
...
@@ -185,15 +188,15 @@ struct LowLatencyLayout {
...
@@ -185,15 +188,15 @@ struct LowLatencyLayout {
buffers
[
i
]
=
{
buffers
[
i
]
=
{
static_cast
<
int
>
(
signaling_buffer_bytes
/
sizeof
(
int64_t
)),
static_cast
<
int
>
(
signaling_buffer_bytes
/
sizeof
(
int64_t
)),
// dispatch:send_buffer + recv_buffer + recv_count
// dispatch:send_buffer + recv_buffer + recv_count
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
<
int64_t
*>
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
2
+
signaling_buffer_bytes
*
i
),
advance
<
int64_t
*>
(
rdma_buffer
,
signaling_buffer_bytes
_aligned
*
i
),
// combine:send_buffer + recv_buffer + recv_count
// combine:send_buffer + recv_buffer + recv_count
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
<
int64_t
*>
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
2
+
signaling_buffer_bytes
*
i
),
advance
<
int64_t
*>
(
rdma_buffer
,
signaling_buffer_bytes
_aligned
*
i
),
// combine_rdma_send_buffer_data_start
// combine_rdma_send_buffer_data_start
advance
(
rdma_buffer
,
s
end
_buffer_bytes
*
i
+
s
izeof
(
int4
)
),
advance
(
rdma_buffer
,
s
ignaling
_buffer_bytes
_aligned
*
2
+
s
end_buffer_bytes
*
i
),
//
//
num_bytes_per_combine_msg
num_bytes_per_combine_msg
};
};
...
...
csrc/deep_ep.cu
View file @
1b00b9d8
...
@@ -1397,6 +1397,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1397,6 +1397,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
Buffer
::
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
Buffer
::
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
const
std
::
optional
<
torch
::
Tensor
>&
combine_wait_recv_cost_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
const
std
::
optional
<
torch
::
Tensor
>&
out
)
{
const
std
::
optional
<
torch
::
Tensor
>&
out
)
{
...
@@ -1418,6 +1419,13 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
...
@@ -1418,6 +1419,13 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
EP_HOST_ASSERT
(
layout_range
.
dim
()
==
2
and
layout_range
.
is_contiguous
());
EP_HOST_ASSERT
(
layout_range
.
dim
()
==
2
and
layout_range
.
is_contiguous
());
EP_HOST_ASSERT
(
layout_range
.
scalar_type
()
==
torch
::
kInt64
);
EP_HOST_ASSERT
(
layout_range
.
scalar_type
()
==
torch
::
kInt64
);
EP_HOST_ASSERT
(
layout_range
.
size
(
0
)
==
num_experts
/
num_ranks
and
layout_range
.
size
(
1
)
==
num_ranks
);
EP_HOST_ASSERT
(
layout_range
.
size
(
0
)
==
num_experts
/
num_ranks
and
layout_range
.
size
(
1
)
==
num_ranks
);
if
(
combine_wait_recv_cost_stats
.
has_value
())
{
EP_HOST_ASSERT
(
combine_wait_recv_cost_stats
->
scalar_type
()
==
torch
::
kInt64
);
EP_HOST_ASSERT
(
combine_wait_recv_cost_stats
->
dim
()
==
1
and
combine_wait_recv_cost_stats
->
is_contiguous
());
EP_HOST_ASSERT
(
combine_wait_recv_cost_stats
->
size
(
0
)
==
num_ranks
);
}
auto
hidden
=
static_cast
<
int
>
(
x
.
size
(
2
));
auto
hidden
=
static_cast
<
int
>
(
x
.
size
(
2
));
auto
num_local_experts
=
num_experts
/
num_ranks
,
num_topk
=
static_cast
<
int
>
(
topk_weights
.
size
(
1
));
auto
num_local_experts
=
num_experts
/
num_ranks
,
num_topk
=
static_cast
<
int
>
(
topk_weights
.
size
(
1
));
auto
num_combined_tokens
=
static_cast
<
int
>
(
topk_weights
.
size
(
0
));
auto
num_combined_tokens
=
static_cast
<
int
>
(
topk_weights
.
size
(
0
));
...
@@ -1456,6 +1464,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
...
@@ -1456,6 +1464,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
x
.
data_ptr
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
topk_weights
.
data_ptr
<
float
>
(),
x
.
data_ptr
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
topk_weights
.
data_ptr
<
float
>
(),
src_info
.
data_ptr
<
int
>
(),
layout_range
.
data_ptr
<
int64_t
>
(),
src_info
.
data_ptr
<
int
>
(),
layout_range
.
data_ptr
<
int64_t
>
(),
global_atomic_counter
.
data_ptr
<
int
>
(),
global_atomic_counter
.
data_ptr
<
int
>
(),
combine_wait_recv_cost_stats
.
has_value
()
?
combine_wait_recv_cost_stats
->
data_ptr
<
int64_t
>
()
:
nullptr
,
next_clean_meta
.
first
,
next_clean_meta
.
second
,
next_clean_meta
.
first
,
next_clean_meta
.
second
,
num_combined_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_combined_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
...
...
csrc/deep_ep.hpp
View file @
1b00b9d8
...
@@ -183,6 +183,7 @@ public:
...
@@ -183,6 +183,7 @@ public:
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
const
std
::
optional
<
torch
::
Tensor
>&
combine_wait_recv_cost_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
const
std
::
optional
<
torch
::
Tensor
>&
out
=
std
::
nullopt
);
const
std
::
optional
<
torch
::
Tensor
>&
out
=
std
::
nullopt
);
...
...
csrc/kernels/api.cuh
View file @
1b00b9d8
...
@@ -155,6 +155,7 @@ void combine(void* combined_x,
...
@@ -155,6 +155,7 @@ void combine(void* combined_x,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
int
*
global_atomic_counter
,
int
*
global_atomic_counter
,
int64_t
*
combine_wait_recv_cost_stats
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
...
...
csrc/kernels/internode_ll.cu
View file @
1b00b9d8
...
@@ -549,6 +549,7 @@ combine(void* combined_x,
...
@@ -549,6 +549,7 @@ combine(void* combined_x,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
int
*
global_atomic_counter
,
int
*
global_atomic_counter
,
int64_t
*
combine_wait_recv_cost_stats
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
*
atomic_clean_flag
,
int
*
atomic_clean_flag
,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
...
@@ -572,7 +573,7 @@ combine(void* combined_x,
...
@@ -572,7 +573,7 @@ combine(void* combined_x,
// Message package
// Message package
EP_STATIC_ASSERT
(
kHidden
%
FP8_QUANTIZATION_NUM_PER_CHANNEL
==
0
,
"Invalid hidden"
);
EP_STATIC_ASSERT
(
kHidden
%
FP8_QUANTIZATION_NUM_PER_CHANNEL
==
0
,
"Invalid hidden"
);
constexpr
size_t
num_bytes_per_slot
=
sizeof
(
int4
)
+
kHidden
*
sizeof
(
hip_bfloat16
);
constexpr
size_t
num_bytes_per_slot
=
kHidden
*
sizeof
(
hip_bfloat16
);
EP_STATIC_ASSERT
(
num_bytes_per_slot
%
sizeof
(
int4
)
==
0
,
"Invalid vectorization"
);
EP_STATIC_ASSERT
(
num_bytes_per_slot
%
sizeof
(
int4
)
==
0
,
"Invalid vectorization"
);
// 16 is the max possible number of warps in AMD GPUs
// 16 is the max possible number of warps in AMD GPUs
...
@@ -627,12 +628,12 @@ combine(void* combined_x,
...
@@ -627,12 +628,12 @@ combine(void* combined_x,
for
(
int
token_idx
=
offset
+
sub_warp_id
;
token_idx
<
offset
+
num_tokens_to_send
;
token_idx
+=
num_warps_per_group
)
{
for
(
int
token_idx
=
offset
+
sub_warp_id
;
token_idx
<
offset
+
num_tokens_to_send
;
token_idx
+=
num_warps_per_group
)
{
const
auto
x_int4
=
local_x
+
token_idx
*
hidden_bf16_int4
;
const
auto
x_int4
=
local_x
+
token_idx
*
hidden_bf16_int4
;
const
auto
rdma_send_type_row
=
reinterpret_cast
<
int
*>
(
rdma_send_x_vec
+
token_idx
*
num_bytes_per_slot
);
const
auto
rdma_send_type_row
=
reinterpret_cast
<
int
*>
(
rdma_send_x_vec
+
token_idx
*
num_bytes_per_slot
);
const
auto
rdma_send_x_vec_row
=
reinterpret_cast
<
uint8_t
*>
(
rdma_send_type_row
+
4
);
const
auto
rdma_send_x_vec_row
=
reinterpret_cast
<
uint8_t
*>
(
rdma_send_type_row
);
// Copy directly to local rank, or copy to buffer and issue RDMA
// Copy directly to local rank, or copy to buffer and issue RDMA
const
auto
src_idx
=
__ldg
(
local_src_info
+
token_idx
);
const
auto
src_idx
=
__ldg
(
local_src_info
+
token_idx
);
const
auto
buf_ptr
=
reinterpret_cast
<
int64_t
>
(
rdma_send_x_vec_row
);
const
auto
buf_ptr
=
reinterpret_cast
<
int64_t
>
(
rdma_send_x_vec_row
);
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
(
global_expert_idx
*
num_max_dispatch_tokens_per_rank
+
src_idx
)
*
num_bytes_per_slot
+
sizeof
(
int4
)
;
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
(
global_expert_idx
*
num_max_dispatch_tokens_per_rank
+
src_idx
)
*
num_bytes_per_slot
;
if
(
dst_rank
==
rank
)
{
if
(
dst_rank
==
rank
)
{
const
auto
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
dst_ptr
);
const
auto
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
dst_ptr
);
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
...
@@ -724,8 +725,23 @@ LOW_LATENCY_COMBINE_RECV:
...
@@ -724,8 +725,23 @@ LOW_LATENCY_COMBINE_RECV:
// Wait all ranks to arrive and notify PCIe usage
// Wait all ranks to arrive and notify PCIe usage
if
(
responsible_expert_idx
<
num_experts
)
{
if
(
responsible_expert_idx
<
num_experts
)
{
EP_DEVICE_ASSERT
(
num_warps_per_group
>
1
);
EP_DEVICE_ASSERT
(
num_warps_per_group
>
1
);
if
(
sub_warp_id
==
0
and
lane_id
==
0
){
if
(
sub_warp_id
==
0
and
lane_id
==
0
)
{
while
(
ld_acquire_global
(
reinterpret_cast
<
int
*>
(
rdma_recv_flag
+
responsible_expert_idx
))
==
0
);
const
auto
src_rank
=
responsible_expert_idx
/
num_local_experts
;
auto
start_time
=
wall_clock64
();
uint64_t
wait_recv_cost
=
0
;
while
(
ld_acquire_global
(
reinterpret_cast
<
int
*>
(
rdma_recv_flag
+
responsible_expert_idx
))
==
0
// recv not ready
&&
(
wait_recv_cost
=
wall_clock64
()
-
start_time
)
<=
NUM_TIMEOUT_CYCLES
// not timeout
);
// Mask rank if timeout
if
(
wait_recv_cost
>
NUM_TIMEOUT_CYCLES
)
{
printf
(
"Warning: DeepEP timeout for combine receive, rank %d, local_expert_idx %d, src_rank %d
\n
"
,
rank
,
responsible_expert_idx
%
num_local_experts
,
src_rank
);
}
if
(
combine_wait_recv_cost_stats
!=
nullptr
)
{
atomicAdd
(
reinterpret_cast
<
unsigned
long
long
*>
(
combine_wait_recv_cost_stats
+
src_rank
),
wait_recv_cost
);
}
}
}
}
}
grid_barrier
(
global_atomic_counter
,
num_sms
);
grid_barrier
(
global_atomic_counter
,
num_sms
);
...
@@ -750,7 +766,7 @@ LOW_LATENCY_COMBINE_RECV:
...
@@ -750,7 +766,7 @@ LOW_LATENCY_COMBINE_RECV:
// Read from sources
// Read from sources
auto
rdma_buffer_type
=
reinterpret_cast
<
const
int
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_recv_x
)
+
auto
rdma_buffer_type
=
reinterpret_cast
<
const
int
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_recv_x
)
+
(
reg_topk_idx
[
i
]
*
num_max_dispatch_tokens_per_rank
+
token_idx
)
*
num_bytes_per_slot
);
(
reg_topk_idx
[
i
]
*
num_max_dispatch_tokens_per_rank
+
token_idx
)
*
num_bytes_per_slot
);
auto
rdma_buffer_row
=
reinterpret_cast
<
const
uint8_t
*>
(
rdma_buffer_type
+
4
);
auto
rdma_buffer_row
=
reinterpret_cast
<
const
uint8_t
*>
(
rdma_buffer_type
);
// Reduce
// Reduce
auto
x_vec
=
ld_nc_global
(
reinterpret_cast
<
const
int4
*>
(
rdma_buffer_row
)
+
thread_id
);
auto
x_vec
=
ld_nc_global
(
reinterpret_cast
<
const
int4
*>
(
rdma_buffer_row
)
+
thread_id
);
...
@@ -776,6 +792,7 @@ void combine(void* combined_x,
...
@@ -776,6 +792,7 @@ void combine(void* combined_x,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
int
*
global_atomic_counter
,
int
*
global_atomic_counter
,
int64_t
*
combine_wait_recv_cost_stats
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
...
@@ -803,6 +820,7 @@ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
...
@@ -803,6 +820,7 @@ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \
x, topk_idx, topk_weights, src_info, layout_range, \
global_atomic_counter, \
global_atomic_counter, \
combine_wait_recv_cost_stats, \
next_clean, num_next_clean_int, \
next_clean, num_next_clean_int, \
atomic_clean_flag, \
atomic_clean_flag, \
num_combined_tokens, hidden, num_topk, \
num_combined_tokens, hidden, num_topk, \
...
...
deep_ep/buffer.py
View file @
1b00b9d8
...
@@ -901,7 +901,8 @@ class Buffer:
...
@@ -901,7 +901,8 @@ class Buffer:
# noinspection PyTypeChecker
# noinspection PyTypeChecker
def
low_latency_combine
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
def
low_latency_combine
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
handle
:
tuple
,
zero_copy
:
bool
=
False
,
async_finish
:
bool
=
False
,
handle
:
tuple
,
zero_copy
:
bool
=
False
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
,
out
:
Optional
[
torch
.
Tensor
]
=
None
)
->
\
return_recv_hook
:
bool
=
False
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
combine_wait_recv_cost_stats
:
Optional
[
torch
.
Tensor
]
=
None
)
->
\
Tuple
[
torch
.
Tensor
,
EventOverlap
,
Callable
]:
Tuple
[
torch
.
Tensor
,
EventOverlap
,
Callable
]:
"""
"""
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
...
@@ -927,6 +928,9 @@ class Buffer:
...
@@ -927,6 +928,9 @@ class Buffer:
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
This is useful for detecting and pre-cisely localizing slow anomalies.
Returns:
Returns:
combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`.
combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`.
...
@@ -935,6 +939,7 @@ class Buffer:
...
@@ -935,6 +939,7 @@ class Buffer:
"""
"""
src_info
,
layout_range
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
=
handle
src_info
,
layout_range
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
=
handle
combined_x
,
event
,
hook
=
self
.
runtime
.
low_latency_combine
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
combined_x
,
event
,
hook
=
self
.
runtime
.
low_latency_combine
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
combine_wait_recv_cost_stats
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
zero_copy
,
async_finish
,
return_recv_hook
,
out
)
zero_copy
,
async_finish
,
return_recv_hook
,
out
)
tensors_to_record
=
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
combined_x
)
tensors_to_record
=
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
combined_x
)
...
...
tests/test_low_latency_new.py
View file @
1b00b9d8
...
@@ -140,7 +140,7 @@ def test_main(num_tokens: int,
...
@@ -140,7 +140,7 @@ def test_main(num_tokens: int,
topk_weights
,
topk_weights
,
handle
,
handle
,
async_finish
=
not
return_recv_hook
,
async_finish
=
not
return_recv_hook
,
#
zero_copy=zero_copy,
zero_copy
=
zero_copy
,
return_recv_hook
=
return_recv_hook
,
return_recv_hook
=
return_recv_hook
,
out
=
out
)
out
=
out
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment