Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
f4b3020e
You need to sign in or sign up before continuing.
Commit
f4b3020e
authored
Dec 23, 2025
by
lishen
Browse files
支持zero_copy正确性
parent
7e8acdf7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
13 deletions
+16
-13
csrc/config.hpp
csrc/config.hpp
+11
-8
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+4
-4
tests/test_low_latency_new.py
tests/test_low_latency_new.py
+1
-1
No files found.
csrc/config.hpp
View file @
f4b3020e
...
@@ -150,6 +150,8 @@ struct LowLatencyLayout {
...
@@ -150,6 +150,8 @@ struct LowLatencyLayout {
size_t
num_bytes_per_dispatch_msg
=
size_t
num_bytes_per_dispatch_msg
=
sizeof
(
int4
)
+
sizeof
(
int4
)
+
std
::
max
(
hidden
*
sizeof
(
hip_bfloat16
),
hidden
+
num_scales
*
sizeof
(
float
));
std
::
max
(
hidden
*
sizeof
(
hip_bfloat16
),
hidden
+
num_scales
*
sizeof
(
float
));
// 与internode_ll::combine 中的 num_bytes_per_slot 相等
size_t
num_bytes_per_combine_msg
=
hidden
*
sizeof
(
hip_bfloat16
);
size_t
num_bytes_per_combine_msg
=
hidden
*
sizeof
(
hip_bfloat16
);
// Send buffer
// Send buffer
...
@@ -176,7 +178,8 @@ struct LowLatencyLayout {
...
@@ -176,7 +178,8 @@ struct LowLatencyLayout {
size_t
dispatch_recv_count_buffer_bytes
=
num_experts
*
sizeof
(
int64_t
);
size_t
dispatch_recv_count_buffer_bytes
=
num_experts
*
sizeof
(
int64_t
);
size_t
combine_recv_flag_buffer_bytes
=
dispatch_recv_count_buffer_bytes
;
size_t
combine_recv_flag_buffer_bytes
=
dispatch_recv_count_buffer_bytes
;
size_t
signaling_buffer_bytes
=
std
::
max
(
dispatch_recv_count_buffer_bytes
,
combine_recv_flag_buffer_bytes
);
size_t
signaling_buffer_bytes
=
std
::
max
(
dispatch_recv_count_buffer_bytes
,
combine_recv_flag_buffer_bytes
);
total_bytes
+=
signaling_buffer_bytes
*
2
;
size_t
signaling_buffer_bytes_aligned
=
ALIGN
<
size_t
>
(
signaling_buffer_bytes
,
128
);
total_bytes
+=
signaling_buffer_bytes_aligned
*
2
;
// Assign pointers
// Assign pointers
// NOTES: we still leave some space for distinguishing dispatch/combine buffer,
// NOTES: we still leave some space for distinguishing dispatch/combine buffer,
...
@@ -185,15 +188,15 @@ struct LowLatencyLayout {
...
@@ -185,15 +188,15 @@ struct LowLatencyLayout {
buffers
[
i
]
=
{
buffers
[
i
]
=
{
static_cast
<
int
>
(
signaling_buffer_bytes
/
sizeof
(
int64_t
)),
static_cast
<
int
>
(
signaling_buffer_bytes
/
sizeof
(
int64_t
)),
// dispatch:send_buffer + recv_buffer + recv_count
// dispatch:send_buffer + recv_buffer + recv_count
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
<
int64_t
*>
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
2
+
signaling_buffer_bytes
*
i
),
advance
<
int64_t
*>
(
rdma_buffer
,
signaling_buffer_bytes
_aligned
*
i
),
// combine:send_buffer + recv_buffer + recv_count
// combine:send_buffer + recv_buffer + recv_count
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
<
int64_t
*>
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
2
+
signaling_buffer_bytes
*
i
),
advance
<
int64_t
*>
(
rdma_buffer
,
signaling_buffer_bytes
_aligned
*
i
),
// combine_rdma_send_buffer_data_start
// combine_rdma_send_buffer_data_start
advance
(
rdma_buffer
,
s
end
_buffer_bytes
*
i
+
s
izeof
(
int4
)
),
advance
(
rdma_buffer
,
s
ignaling
_buffer_bytes
_aligned
*
2
+
s
end_buffer_bytes
*
i
),
//
//
num_bytes_per_combine_msg
num_bytes_per_combine_msg
};
};
...
...
csrc/kernels/internode_ll.cu
View file @
f4b3020e
...
@@ -572,7 +572,7 @@ combine(void* combined_x,
...
@@ -572,7 +572,7 @@ combine(void* combined_x,
// Message package
// Message package
EP_STATIC_ASSERT
(
kHidden
%
FP8_QUANTIZATION_NUM_PER_CHANNEL
==
0
,
"Invalid hidden"
);
EP_STATIC_ASSERT
(
kHidden
%
FP8_QUANTIZATION_NUM_PER_CHANNEL
==
0
,
"Invalid hidden"
);
constexpr
size_t
num_bytes_per_slot
=
sizeof
(
int4
)
+
kHidden
*
sizeof
(
hip_bfloat16
);
constexpr
size_t
num_bytes_per_slot
=
kHidden
*
sizeof
(
hip_bfloat16
);
EP_STATIC_ASSERT
(
num_bytes_per_slot
%
sizeof
(
int4
)
==
0
,
"Invalid vectorization"
);
EP_STATIC_ASSERT
(
num_bytes_per_slot
%
sizeof
(
int4
)
==
0
,
"Invalid vectorization"
);
// 16 is the max possible number of warps in AMD GPUs
// 16 is the max possible number of warps in AMD GPUs
...
@@ -627,12 +627,12 @@ combine(void* combined_x,
...
@@ -627,12 +627,12 @@ combine(void* combined_x,
for
(
int
token_idx
=
offset
+
sub_warp_id
;
token_idx
<
offset
+
num_tokens_to_send
;
token_idx
+=
num_warps_per_group
)
{
for
(
int
token_idx
=
offset
+
sub_warp_id
;
token_idx
<
offset
+
num_tokens_to_send
;
token_idx
+=
num_warps_per_group
)
{
const
auto
x_int4
=
local_x
+
token_idx
*
hidden_bf16_int4
;
const
auto
x_int4
=
local_x
+
token_idx
*
hidden_bf16_int4
;
const
auto
rdma_send_type_row
=
reinterpret_cast
<
int
*>
(
rdma_send_x_vec
+
token_idx
*
num_bytes_per_slot
);
const
auto
rdma_send_type_row
=
reinterpret_cast
<
int
*>
(
rdma_send_x_vec
+
token_idx
*
num_bytes_per_slot
);
const
auto
rdma_send_x_vec_row
=
reinterpret_cast
<
uint8_t
*>
(
rdma_send_type_row
+
4
);
const
auto
rdma_send_x_vec_row
=
reinterpret_cast
<
uint8_t
*>
(
rdma_send_type_row
);
// Copy directly to local rank, or copy to buffer and issue RDMA
// Copy directly to local rank, or copy to buffer and issue RDMA
const
auto
src_idx
=
__ldg
(
local_src_info
+
token_idx
);
const
auto
src_idx
=
__ldg
(
local_src_info
+
token_idx
);
const
auto
buf_ptr
=
reinterpret_cast
<
int64_t
>
(
rdma_send_x_vec_row
);
const
auto
buf_ptr
=
reinterpret_cast
<
int64_t
>
(
rdma_send_x_vec_row
);
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
(
global_expert_idx
*
num_max_dispatch_tokens_per_rank
+
src_idx
)
*
num_bytes_per_slot
+
sizeof
(
int4
)
;
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
(
global_expert_idx
*
num_max_dispatch_tokens_per_rank
+
src_idx
)
*
num_bytes_per_slot
;
if
(
dst_rank
==
rank
)
{
if
(
dst_rank
==
rank
)
{
const
auto
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
dst_ptr
);
const
auto
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
dst_ptr
);
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
...
@@ -750,7 +750,7 @@ LOW_LATENCY_COMBINE_RECV:
...
@@ -750,7 +750,7 @@ LOW_LATENCY_COMBINE_RECV:
// Read from sources
// Read from sources
auto
rdma_buffer_type
=
reinterpret_cast
<
const
int
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_recv_x
)
+
auto
rdma_buffer_type
=
reinterpret_cast
<
const
int
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_recv_x
)
+
(
reg_topk_idx
[
i
]
*
num_max_dispatch_tokens_per_rank
+
token_idx
)
*
num_bytes_per_slot
);
(
reg_topk_idx
[
i
]
*
num_max_dispatch_tokens_per_rank
+
token_idx
)
*
num_bytes_per_slot
);
auto
rdma_buffer_row
=
reinterpret_cast
<
const
uint8_t
*>
(
rdma_buffer_type
+
4
);
auto
rdma_buffer_row
=
reinterpret_cast
<
const
uint8_t
*>
(
rdma_buffer_type
);
// Reduce
// Reduce
auto
x_vec
=
ld_nc_global
(
reinterpret_cast
<
const
int4
*>
(
rdma_buffer_row
)
+
thread_id
);
auto
x_vec
=
ld_nc_global
(
reinterpret_cast
<
const
int4
*>
(
rdma_buffer_row
)
+
thread_id
);
...
...
tests/test_low_latency_new.py
View file @
f4b3020e
...
@@ -140,7 +140,7 @@ def test_main(num_tokens: int,
...
@@ -140,7 +140,7 @@ def test_main(num_tokens: int,
topk_weights
,
topk_weights
,
handle
,
handle
,
async_finish
=
not
return_recv_hook
,
async_finish
=
not
return_recv_hook
,
#
zero_copy=zero_copy,
zero_copy
=
zero_copy
,
return_recv_hook
=
return_recv_hook
,
return_recv_hook
=
return_recv_hook
,
out
=
out
)
out
=
out
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment