Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
dcaf73e5
Commit
dcaf73e5
authored
Mar 18, 2025
by
Chenggang Zhao
Browse files
Support zero-copy for low-latency combine
parent
82dcf48f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
80 additions
and
28 deletions
+80
-28
csrc/config.hpp
csrc/config.hpp
+6
-1
csrc/deep_ep.cpp
csrc/deep_ep.cpp
+20
-3
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+5
-1
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+2
-1
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+6
-4
deep_ep/buffer.py
deep_ep/buffer.py
+22
-5
tests/test_low_latency.py
tests/test_low_latency.py
+19
-13
No files found.
csrc/config.hpp
View file @
dcaf73e5
...
...
@@ -102,6 +102,9 @@ struct LowLatencyBuffer {
void
*
combine_rdma_recv_data_buffer
=
nullptr
;
int
*
combine_rdma_recv_flag_buffer
=
nullptr
;
void
*
combine_rdma_send_buffer_data_start
=
nullptr
;
int
num_bytes_per_combine_msg
=
0
;
std
::
pair
<
int
*
,
int
>
clean_meta
()
{
EP_HOST_ASSERT
(
dispatch_rdma_recv_count_buffer
==
combine_rdma_recv_flag_buffer
);
return
{
dispatch_rdma_recv_count_buffer
,
num_clean_int
};
...
...
@@ -163,7 +166,9 @@ struct LowLatencyLayout {
advance
<
int
*>
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
2
+
signaling_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
<
int
*>
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
2
+
signaling_buffer_bytes
*
i
)
advance
<
int
*>
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
2
+
signaling_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
+
sizeof
(
int4
)),
static_cast
<
int
>
(
num_bytes_per_combine_msg
)
};
}
}
...
...
csrc/deep_ep.cpp
View file @
dcaf73e5
...
...
@@ -1100,7 +1100,8 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
Buffer
::
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
async
,
bool
return_recv_hook
,
std
::
optional
<
torch
::
Tensor
>
out
)
{
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
const
std
::
optional
<
torch
::
Tensor
>&
out
)
{
EP_HOST_ASSERT
(
low_latency_mode
);
// Tensor checks
...
...
@@ -1159,7 +1160,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
next_clean_meta
.
first
,
next_clean_meta
.
second
,
num_combined_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
workspace
,
launch_stream
,
phases
);
workspace
,
launch_stream
,
phases
,
zero_copy
);
};
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
...
...
@@ -1182,6 +1184,20 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
return
{
combined_x
,
event
,
recv_hook
};
}
torch
::
Tensor
Buffer
::
get_next_low_latency_combine_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
)
{
LowLatencyLayout
layout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
auto
buffer
=
layout
.
buffers
[
low_latency_buffer_idx
];
auto
dtype
=
torch
::
kBFloat16
;
auto
num_msg_elems
=
static_cast
<
int
>
(
buffer
.
num_bytes_per_combine_msg
/
elementSize
(
torch
::
kBFloat16
));
EP_HOST_ASSERT
(
buffer
.
num_bytes_per_combine_msg
%
elementSize
(
torch
::
kBFloat16
)
==
0
);
return
torch
::
from_blob
(
buffer
.
combine_rdma_send_buffer_data_start
,
{
num_experts
/
num_ranks
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
{
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_msg_elems
,
num_msg_elems
,
1
},
torch
::
TensorOptions
().
dtype
(
dtype
).
device
(
torch
::
kCUDA
));
}
}
// namespace deep_ep
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
...
...
@@ -1218,5 +1234,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.
def
(
"internode_combine"
,
&
deep_ep
::
Buffer
::
internode_combine
)
.
def
(
"clean_low_latency_buffer"
,
&
deep_ep
::
Buffer
::
clean_low_latency_buffer
)
.
def
(
"low_latency_dispatch"
,
&
deep_ep
::
Buffer
::
low_latency_dispatch
)
.
def
(
"low_latency_combine"
,
&
deep_ep
::
Buffer
::
low_latency_combine
);
.
def
(
"low_latency_combine"
,
&
deep_ep
::
Buffer
::
low_latency_combine
)
.
def
(
"get_next_low_latency_combine_buffer"
,
&
deep_ep
::
Buffer
::
get_next_low_latency_combine_buffer
);
}
csrc/deep_ep.hpp
View file @
dcaf73e5
...
...
@@ -143,7 +143,11 @@ public:
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
async
,
bool
return_recv_hook
,
std
::
optional
<
torch
::
Tensor
>
out
=
std
::
nullopt
);
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
const
std
::
optional
<
torch
::
Tensor
>&
out
=
std
::
nullopt
);
torch
::
Tensor
get_next_low_latency_combine_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
);
};
}
// namespace deep_ep
csrc/kernels/api.cuh
View file @
dcaf73e5
...
...
@@ -147,7 +147,8 @@ void combine(void* combined_x,
int
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
void
*
workspace
,
cudaStream_t
stream
,
int
phases
);
void
*
workspace
,
cudaStream_t
stream
,
int
phases
,
bool
zero_copy
);
}
// namespace internode_ll
...
...
csrc/kernels/internode_ll.cu
View file @
dcaf73e5
...
...
@@ -353,7 +353,7 @@ combine(void* combined_x,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
phases
)
{
int
phases
,
bool
zero_copy
)
{
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
num_sms
=
static_cast
<
int
>
(
gridDim
.
x
);
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
...
...
@@ -420,7 +420,8 @@ combine(void* combined_x,
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
}
else
{
const
auto
buf_int4_ptr
=
reinterpret_cast
<
int4
*>
(
buf_ptr
);
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_bf16_int4
,
buf_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
if
(
not
zero_copy
)
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_bf16_int4
,
buf_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
nvshmemi_ibgda_put_nbi_warp
(
dst_ptr
,
buf_ptr
,
hidden
*
sizeof
(
nv_bfloat16
),
dst_rank
,
local_expert_idx
,
lane_id
,
token_idx
-
offset
);
}
}
...
...
@@ -500,7 +501,8 @@ void combine(void* combined_x,
int
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
void
*
workspace
,
cudaStream_t
stream
,
int
phases
)
{
void
*
workspace
,
cudaStream_t
stream
,
int
phases
,
bool
zero_copy
)
{
constexpr
int
kNumWarpsPerGroup
=
10
;
constexpr
int
kNumWarpGroups
=
3
;
constexpr
int
kNumMaxTopk
=
9
;
...
...
@@ -524,7 +526,7 @@ LAUNCH_KERNEL(&cfg, combine_func, \
num_combined_tokens, hidden, num_topk, \
num_max_dispatch_tokens_per_rank, \
num_experts, rank, num_ranks, \
phases); } break
phases
, zero_copy
); } break
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
32
,
stream
);
SWITCH_HIDDEN
(
COMBINE_LAUNCH_CASE
);
...
...
deep_ep/buffer.py
View file @
dcaf73e5
...
...
@@ -488,7 +488,7 @@ class Buffer:
self
.
runtime
.
low_latency_dispatch
(
x
,
topk_idx
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
use_fp8
,
async_finish
,
return_recv_hook
)
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
num_experts
)
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
x
.
size
(
1
),
num_experts
)
tensors_to_record
=
(
x
,
topk_idx
,
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
)
...
...
@@ -497,8 +497,8 @@ class Buffer:
# noinspection PyTypeChecker
def
low_latency_combine
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
handle
:
tuple
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
,
out
:
Optional
[
torch
.
Tensor
]
=
None
)
->
\
handle
:
tuple
,
zero_copy
:
bool
=
False
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
,
out
:
Optional
[
torch
.
Tensor
]
=
None
)
->
\
Tuple
[
torch
.
Tensor
,
EventOverlap
,
Callable
]:
"""
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
...
...
@@ -517,6 +517,8 @@ class Buffer:
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
tokens. The received tokens will be reduced with the weights in this tensor.
handle: the communication handle given by the `dispatch` function.
zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative
with `get_next_low_latency_combine_buffer`.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
...
...
@@ -528,9 +530,24 @@ class Buffer:
event: the event after executing the kernel (valid only if `async_finish` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set).
"""
src_info
,
layout_range
,
num_max_dispatch_tokens_per_rank
,
num_experts
=
handle
src_info
,
layout_range
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
=
handle
combined_x
,
event
,
hook
=
self
.
runtime
.
low_latency_combine
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
async_finish
,
return_recv_hook
,
out
)
zero_copy
,
async_finish
,
return_recv_hook
,
out
)
tensors_to_record
=
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
combined_x
)
return
combined_x
,
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
def
get_next_low_latency_combine_buffer
(
self
,
handle
:
object
):
"""
Get the raw registered RDMA buffer tensor for next low-latency combine, so that the next combine kernel can skip the copying.
Arguments:
handle: the communication handle given by the `dispatch` function.
Returns:
buffer: the raw RDMA low-latency buffer as a BF16 PyTorch tensor with shape
`[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden]`, you should fill this buffer
by yourself.
"""
src_info
,
layout_range
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
=
handle
return
self
.
runtime
.
get_next_low_latency_combine_buffer
(
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
)
tests/test_low_latency.py
View file @
dcaf73e5
...
...
@@ -73,15 +73,19 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
hash_value
^=
hash_tensor
(
packed_recv_x
[
i
,
:
num_valid_tokens
])
# Check combine correctness
out
=
torch
.
empty
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
,
out
=
out
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
if
do_check
:
diff
=
calc_diff
(
x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
assert
torch
.
isnan
(
combined_x
).
sum
().
item
()
==
0
assert
diff
<
1e-5
,
f
'Error: diff=
{
diff
}
'
hash_value
^=
hash_tensor
(
combined_x
)
for
zero_copy
in
(
False
,
True
):
if
zero_copy
:
buffer
.
get_next_low_latency_combine_buffer
(
handle
)[:,
:,
:]
=
simulated_gemm_x
out
=
torch
.
empty
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
,
out
=
out
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
if
do_check
:
diff
=
calc_diff
(
x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
assert
torch
.
isnan
(
combined_x
).
sum
().
item
()
==
0
assert
diff
<
1e-5
,
f
'Error: diff=
{
diff
}
'
hash_value
^=
hash_tensor
(
combined_x
)
def
create_test_cast_with_outliers
(
num_outliers
):
tmp
=
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
...
...
@@ -101,13 +105,15 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
hook
()
# noinspection PyShadowingNames
def
test_func
(
return_recv_hook
):
def
test_func
(
zero_copy
:
bool
,
return_recv_hook
:
bool
):
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
x
,
topk_idx
,
num_tokens
,
num_experts
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
if
zero_copy
:
buffer
.
get_next_low_latency_combine_buffer
(
handle
)[:,
:,
:]
=
simulated_gemm_x
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
return_recv_hook
=
return_recv_hook
)
zero_copy
=
zero_copy
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
# Calculate bandwidth
...
...
@@ -119,14 +125,14 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
num_combine_comm_bytes
+=
num_bf16_bytes
*
num_selections
# Dispatch + combine testing
avg_t
,
min_t
,
max_t
=
bench
(
partial
(
test_func
,
return_recv_hook
=
False
))
avg_t
,
min_t
,
max_t
=
bench
(
partial
(
test_func
,
zero_copy
=
False
,
return_recv_hook
=
False
))
print
(
f
'[rank
{
rank
}
] Dispatch + combine bandwidth:
{
(
num_dispatch_comm_bytes
+
num_combine_comm_bytes
)
/
1e9
/
avg_t
:.
2
f
}
GB/s, '
f
'avg_t=
{
avg_t
*
1e6
:.
2
f
}
us, min_t=
{
min_t
*
1e6
:.
2
f
}
us, max_t=
{
max_t
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
# Separate profiling
for
return_recv_hook
in
(
False
,
True
):
group
.
barrier
()
dispatch_t
,
combine_t
=
bench_kineto
(
partial
(
test_func
,
return_recv_hook
=
return_recv_hook
),
dispatch_t
,
combine_t
=
bench_kineto
(
partial
(
test_func
,
zero_copy
=
True
,
return_recv_hook
=
return_recv_hook
),
kernel_names
=
(
'dispatch'
,
'combine'
),
barrier_comm_profiling
=
True
,
suppress_kineto_output
=
True
)
if
not
return_recv_hook
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment