Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
f4b3020e
Commit
f4b3020e
authored
Dec 23, 2025
by
lishen
Browse files
支持zero_copy正确性
parent
7e8acdf7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
13 deletions
+16
-13
csrc/config.hpp
csrc/config.hpp
+11
-8
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+4
-4
tests/test_low_latency_new.py
tests/test_low_latency_new.py
+1
-1
No files found.
csrc/config.hpp
View file @
f4b3020e
...
@@ -150,6 +150,8 @@ struct LowLatencyLayout {
...
@@ -150,6 +150,8 @@ struct LowLatencyLayout {
size_t
num_bytes_per_dispatch_msg
=
size_t
num_bytes_per_dispatch_msg
=
sizeof
(
int4
)
+
sizeof
(
int4
)
+
std
::
max
(
hidden
*
sizeof
(
hip_bfloat16
),
hidden
+
num_scales
*
sizeof
(
float
));
std
::
max
(
hidden
*
sizeof
(
hip_bfloat16
),
hidden
+
num_scales
*
sizeof
(
float
));
// 与internode_ll::combine 中的 num_bytes_per_slot 相等
size_t
num_bytes_per_combine_msg
=
hidden
*
sizeof
(
hip_bfloat16
);
size_t
num_bytes_per_combine_msg
=
hidden
*
sizeof
(
hip_bfloat16
);
// Send buffer
// Send buffer
...
@@ -176,7 +178,8 @@ struct LowLatencyLayout {
...
@@ -176,7 +178,8 @@ struct LowLatencyLayout {
size_t
dispatch_recv_count_buffer_bytes
=
num_experts
*
sizeof
(
int64_t
);
size_t
dispatch_recv_count_buffer_bytes
=
num_experts
*
sizeof
(
int64_t
);
size_t
combine_recv_flag_buffer_bytes
=
dispatch_recv_count_buffer_bytes
;
size_t
combine_recv_flag_buffer_bytes
=
dispatch_recv_count_buffer_bytes
;
size_t
signaling_buffer_bytes
=
std
::
max
(
dispatch_recv_count_buffer_bytes
,
combine_recv_flag_buffer_bytes
);
size_t
signaling_buffer_bytes
=
std
::
max
(
dispatch_recv_count_buffer_bytes
,
combine_recv_flag_buffer_bytes
);
total_bytes
+=
signaling_buffer_bytes
*
2
;
size_t
signaling_buffer_bytes_aligned
=
ALIGN
<
size_t
>
(
signaling_buffer_bytes
,
128
);
total_bytes
+=
signaling_buffer_bytes_aligned
*
2
;
// Assign pointers
// Assign pointers
// NOTES: we still leave some space for distinguishing dispatch/combine buffer,
// NOTES: we still leave some space for distinguishing dispatch/combine buffer,
...
@@ -185,15 +188,15 @@ struct LowLatencyLayout {
...
@@ -185,15 +188,15 @@ struct LowLatencyLayout {
buffers
[
i
]
=
{
buffers
[
i
]
=
{
static_cast
<
int
>
(
signaling_buffer_bytes
/
sizeof
(
int64_t
)),
static_cast
<
int
>
(
signaling_buffer_bytes
/
sizeof
(
int64_t
)),
// dispatch:send_buffer + recv_buffer + recv_count
// dispatch:send_buffer + recv_buffer + recv_count
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
<
int64_t
*>
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
2
+
signaling_buffer_bytes
*
i
),
advance
<
int64_t
*>
(
rdma_buffer
,
signaling_buffer_bytes
_aligned
*
i
),
// combine:send_buffer + recv_buffer + recv_count
// combine:send_buffer + recv_buffer + recv_count
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
<
int64_t
*>
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
2
+
signaling_buffer_bytes
*
i
),
advance
<
int64_t
*>
(
rdma_buffer
,
signaling_buffer_bytes
_aligned
*
i
),
// combine_rdma_send_buffer_data_start
// combine_rdma_send_buffer_data_start
advance
(
rdma_buffer
,
s
end
_buffer_bytes
*
i
+
s
izeof
(
int4
)
),
advance
(
rdma_buffer
,
s
ignaling
_buffer_bytes
_aligned
*
2
+
s
end_buffer_bytes
*
i
),
//
//
num_bytes_per_combine_msg
num_bytes_per_combine_msg
};
};
...
...
csrc/kernels/internode_ll.cu
View file @
f4b3020e
...
@@ -572,7 +572,7 @@ combine(void* combined_x,
...
@@ -572,7 +572,7 @@ combine(void* combined_x,
// Message package
// Message package
EP_STATIC_ASSERT
(
kHidden
%
FP8_QUANTIZATION_NUM_PER_CHANNEL
==
0
,
"Invalid hidden"
);
EP_STATIC_ASSERT
(
kHidden
%
FP8_QUANTIZATION_NUM_PER_CHANNEL
==
0
,
"Invalid hidden"
);
constexpr
size_t
num_bytes_per_slot
=
sizeof
(
int4
)
+
kHidden
*
sizeof
(
hip_bfloat16
);
constexpr
size_t
num_bytes_per_slot
=
kHidden
*
sizeof
(
hip_bfloat16
);
EP_STATIC_ASSERT
(
num_bytes_per_slot
%
sizeof
(
int4
)
==
0
,
"Invalid vectorization"
);
EP_STATIC_ASSERT
(
num_bytes_per_slot
%
sizeof
(
int4
)
==
0
,
"Invalid vectorization"
);
// 16 is the max possible number of warps in AMD GPUs
// 16 is the max possible number of warps in AMD GPUs
...
@@ -627,12 +627,12 @@ combine(void* combined_x,
...
@@ -627,12 +627,12 @@ combine(void* combined_x,
for
(
int
token_idx
=
offset
+
sub_warp_id
;
token_idx
<
offset
+
num_tokens_to_send
;
token_idx
+=
num_warps_per_group
)
{
for
(
int
token_idx
=
offset
+
sub_warp_id
;
token_idx
<
offset
+
num_tokens_to_send
;
token_idx
+=
num_warps_per_group
)
{
const
auto
x_int4
=
local_x
+
token_idx
*
hidden_bf16_int4
;
const
auto
x_int4
=
local_x
+
token_idx
*
hidden_bf16_int4
;
const
auto
rdma_send_type_row
=
reinterpret_cast
<
int
*>
(
rdma_send_x_vec
+
token_idx
*
num_bytes_per_slot
);
const
auto
rdma_send_type_row
=
reinterpret_cast
<
int
*>
(
rdma_send_x_vec
+
token_idx
*
num_bytes_per_slot
);
const
auto
rdma_send_x_vec_row
=
reinterpret_cast
<
uint8_t
*>
(
rdma_send_type_row
+
4
);
const
auto
rdma_send_x_vec_row
=
reinterpret_cast
<
uint8_t
*>
(
rdma_send_type_row
);
// Copy directly to local rank, or copy to buffer and issue RDMA
// Copy directly to local rank, or copy to buffer and issue RDMA
const
auto
src_idx
=
__ldg
(
local_src_info
+
token_idx
);
const
auto
src_idx
=
__ldg
(
local_src_info
+
token_idx
);
const
auto
buf_ptr
=
reinterpret_cast
<
int64_t
>
(
rdma_send_x_vec_row
);
const
auto
buf_ptr
=
reinterpret_cast
<
int64_t
>
(
rdma_send_x_vec_row
);
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
(
global_expert_idx
*
num_max_dispatch_tokens_per_rank
+
src_idx
)
*
num_bytes_per_slot
+
sizeof
(
int4
)
;
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
(
global_expert_idx
*
num_max_dispatch_tokens_per_rank
+
src_idx
)
*
num_bytes_per_slot
;
if
(
dst_rank
==
rank
)
{
if
(
dst_rank
==
rank
)
{
const
auto
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
dst_ptr
);
const
auto
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
dst_ptr
);
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
...
@@ -750,7 +750,7 @@ LOW_LATENCY_COMBINE_RECV:
...
@@ -750,7 +750,7 @@ LOW_LATENCY_COMBINE_RECV:
// Read from sources
// Read from sources
auto
rdma_buffer_type
=
reinterpret_cast
<
const
int
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_recv_x
)
+
auto
rdma_buffer_type
=
reinterpret_cast
<
const
int
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_recv_x
)
+
(
reg_topk_idx
[
i
]
*
num_max_dispatch_tokens_per_rank
+
token_idx
)
*
num_bytes_per_slot
);
(
reg_topk_idx
[
i
]
*
num_max_dispatch_tokens_per_rank
+
token_idx
)
*
num_bytes_per_slot
);
auto
rdma_buffer_row
=
reinterpret_cast
<
const
uint8_t
*>
(
rdma_buffer_type
+
4
);
auto
rdma_buffer_row
=
reinterpret_cast
<
const
uint8_t
*>
(
rdma_buffer_type
);
// Reduce
// Reduce
auto
x_vec
=
ld_nc_global
(
reinterpret_cast
<
const
int4
*>
(
rdma_buffer_row
)
+
thread_id
);
auto
x_vec
=
ld_nc_global
(
reinterpret_cast
<
const
int4
*>
(
rdma_buffer_row
)
+
thread_id
);
...
...
tests/test_low_latency_new.py
View file @
f4b3020e
...
@@ -140,7 +140,7 @@ def test_main(num_tokens: int,
...
@@ -140,7 +140,7 @@ def test_main(num_tokens: int,
topk_weights
,
topk_weights
,
handle
,
handle
,
async_finish
=
not
return_recv_hook
,
async_finish
=
not
return_recv_hook
,
#
zero_copy=zero_copy,
zero_copy
=
zero_copy
,
return_recv_hook
=
return_recv_hook
,
return_recv_hook
=
return_recv_hook
,
out
=
out
)
out
=
out
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment