Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
6cc3497d
Commit
6cc3497d
authored
Mar 03, 2025
by
Chenggang Zhao
Browse files
Remove all raw tensors for better P2P overlapping
parent
f6030640
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
14 deletions
+18
-14
csrc/config.hpp
csrc/config.hpp
+1
-5
csrc/deep_ep.cpp
csrc/deep_ep.cpp
+2
-2
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+1
-0
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+14
-7
No files found.
csrc/config.hpp
View file @
6cc3497d
...
@@ -93,7 +93,6 @@ struct LowLatencyBuffer {
...
@@ -93,7 +93,6 @@ struct LowLatencyBuffer {
void
*
dispatch_rdma_send_buffer
=
nullptr
;
void
*
dispatch_rdma_send_buffer
=
nullptr
;
void
*
dispatch_rdma_recv_data_buffer
=
nullptr
;
void
*
dispatch_rdma_recv_data_buffer
=
nullptr
;
int
*
dispatch_rdma_recv_count_buffer
=
nullptr
;
int
*
dispatch_rdma_recv_count_buffer
=
nullptr
;
int
*
dispatch_rdma_atomic_token_counter
=
nullptr
;
void
*
combine_rdma_send_buffer
=
nullptr
;
void
*
combine_rdma_send_buffer
=
nullptr
;
void
*
combine_rdma_recv_data_buffer
=
nullptr
;
void
*
combine_rdma_recv_data_buffer
=
nullptr
;
...
@@ -145,10 +144,8 @@ struct LowLatencyLayout {
...
@@ -145,10 +144,8 @@ struct LowLatencyLayout {
// Symmetric signaling buffers
// Symmetric signaling buffers
size_t
dispatch_recv_count_buffer_bytes
=
num_experts
*
sizeof
(
int
);
size_t
dispatch_recv_count_buffer_bytes
=
num_experts
*
sizeof
(
int
);
size_t
dispatch_recv_atomic_token_counter_bytes
=
num_local_experts
*
sizeof
(
int
);
size_t
combine_recv_flag_buffer_bytes
=
dispatch_recv_count_buffer_bytes
;
size_t
combine_recv_flag_buffer_bytes
=
dispatch_recv_count_buffer_bytes
;
size_t
signaling_buffer_bytes
=
std
::
max
(
dispatch_recv_count_buffer_bytes
+
dispatch_recv_atomic_token_counter_bytes
,
size_t
signaling_buffer_bytes
=
std
::
max
(
dispatch_recv_count_buffer_bytes
,
combine_recv_flag_buffer_bytes
);
combine_recv_flag_buffer_bytes
);
total_bytes
+=
signaling_buffer_bytes
*
2
;
total_bytes
+=
signaling_buffer_bytes
*
2
;
// Assign pointers
// Assign pointers
...
@@ -160,7 +157,6 @@ struct LowLatencyLayout {
...
@@ -160,7 +157,6 @@ struct LowLatencyLayout {
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
<
int
*>
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
2
+
signaling_buffer_bytes
*
i
),
advance
<
int
*>
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
2
+
signaling_buffer_bytes
*
i
),
advance
<
int
*>
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
2
+
signaling_buffer_bytes
*
i
+
dispatch_recv_count_buffer_bytes
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
<
int
*>
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
2
+
signaling_buffer_bytes
*
i
)
advance
<
int
*>
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
2
+
signaling_buffer_bytes
*
i
)
...
...
csrc/deep_ep.cpp
View file @
6cc3497d
...
@@ -1048,8 +1048,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1048,8 +1048,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
auto
packed_recv_x
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
x
.
options
().
dtype
(
torch
::
kFloat8_e4m3fn
));
auto
packed_recv_x
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
x
.
options
().
dtype
(
torch
::
kFloat8_e4m3fn
));
auto
packed_recv_src_info
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
packed_recv_src_info
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
packed_recv_layout_range
=
torch
::
empty
({
num_local_experts
,
num_ranks
},
torch
::
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
));
auto
packed_recv_layout_range
=
torch
::
empty
({
num_local_experts
,
num_ranks
},
torch
::
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
));
auto
packed_recv_count
=
torch
::
from_blob
(
buffer
.
dispatch_rdma_atomic_token_counter
,
auto
packed_recv_count
=
torch
::
empty
({
num_local_experts
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
{
num_local_experts
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
// Allocate column-majored scales
// Allocate column-majored scales
EP_HOST_ASSERT
((
num_ranks
*
num_max_dispatch_tokens_per_rank
)
%
4
==
0
and
"TMA requires the number of tokens to be multiple of 4"
);
EP_HOST_ASSERT
((
num_ranks
*
num_max_dispatch_tokens_per_rank
)
%
4
==
0
and
"TMA requires the number of tokens to be multiple of 4"
);
...
@@ -1061,6 +1060,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1061,6 +1060,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
auto
launcher
=
[
=
](
int
phases
)
{
auto
launcher
=
[
=
](
int
phases
)
{
internode_ll
::
dispatch
(
packed_recv_x
.
data_ptr
(),
packed_recv_x_scales
.
data_ptr
<
float
>
(),
internode_ll
::
dispatch
(
packed_recv_x
.
data_ptr
(),
packed_recv_x_scales
.
data_ptr
<
float
>
(),
packed_recv_src_info
.
data_ptr
<
int
>
(),
packed_recv_layout_range
.
data_ptr
<
int64_t
>
(),
packed_recv_src_info
.
data_ptr
<
int
>
(),
packed_recv_layout_range
.
data_ptr
<
int64_t
>
(),
packed_recv_count
.
data_ptr
<
int
>
(),
buffer
.
dispatch_rdma_recv_data_buffer
,
buffer
.
dispatch_rdma_recv_count_buffer
,
buffer
.
dispatch_rdma_recv_data_buffer
,
buffer
.
dispatch_rdma_recv_count_buffer
,
buffer
.
dispatch_rdma_send_buffer
,
buffer
.
dispatch_rdma_send_buffer
,
x
.
data_ptr
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
x
.
data_ptr
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
...
...
csrc/kernels/api.cuh
View file @
6cc3497d
...
@@ -132,6 +132,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
...
@@ -132,6 +132,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
void
*
rdma_recv_x
,
int
*
rdma_recv_count
,
void
*
rdma_x
,
void
*
rdma_recv_x
,
int
*
rdma_recv_count
,
void
*
rdma_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int
*
next_clean
,
int
num_next_clean_int
,
int
*
next_clean
,
int
num_next_clean_int
,
...
...
csrc/kernels/internode_ll.cu
View file @
6cc3497d
...
@@ -40,9 +40,10 @@ template <int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
...
@@ -40,9 +40,10 @@ template <int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
__global__
__launch_bounds__
(
kNumWarpGroups
*
kNumWarpsPerGroup
*
32
,
1
)
void
__global__
__launch_bounds__
(
kNumWarpGroups
*
kNumWarpsPerGroup
*
32
,
1
)
void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
void
*
rdma_recv_x
,
int
*
rdma_recv_count
,
void
*
rdma_x
,
void
*
rdma_recv_x
,
int
*
rdma_recv_count
,
void
*
rdma_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int
*
atomic_counter_per_expert
,
int
*
atomic_finish_counter_per_expert
,
int
*
atomic_counter_per_local_expert
,
int
*
atomic_counter_per_expert
,
int
*
atomic_finish_counter_per_expert
,
int
*
next_clean
,
int
num_next_clean_int
,
int
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
num_max_dispatch_tokens_per_rank
,
int
num_tokens
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
...
@@ -215,6 +216,10 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -215,6 +216,10 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
// Clean workspace for next use
// Clean workspace for next use
atomic_counter_per_expert
[
responsible_expert_idx
]
=
0
;
atomic_counter_per_expert
[
responsible_expert_idx
]
=
0
;
atomic_finish_counter_per_expert
[
responsible_expert_idx
]
=
0
;
atomic_finish_counter_per_expert
[
responsible_expert_idx
]
=
0
;
// Clean `packed_recv_count`
if
(
dst_rank
==
0
)
packed_recv_count
[
dst_expert_local_idx
]
=
0
;
}
}
__syncwarp
();
__syncwarp
();
...
@@ -223,6 +228,10 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -223,6 +228,10 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
if
((
phases
&
LOW_LATENCY_RECV_PHASE
)
==
0
)
if
((
phases
&
LOW_LATENCY_RECV_PHASE
)
==
0
)
return
;
return
;
// For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible
if
(
phases
&
LOW_LATENCY_SEND_PHASE
)
cg
::
this_grid
().
sync
();
// Receiving and packing
// Receiving and packing
if
(
responsible_expert_idx
<
num_experts
)
{
if
(
responsible_expert_idx
<
num_experts
)
{
const
auto
src_rank
=
responsible_expert_idx
/
num_local_experts
;
const
auto
src_rank
=
responsible_expert_idx
/
num_local_experts
;
...
@@ -252,7 +261,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -252,7 +261,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
while
((
num_recv_tokens
=
ld_acquire_global
(
rdma_recv_count
+
local_expert_idx
*
num_ranks
+
src_rank
))
==
0
);
while
((
num_recv_tokens
=
ld_acquire_global
(
rdma_recv_count
+
local_expert_idx
*
num_ranks
+
src_rank
))
==
0
);
}
}
num_recv_tokens
=
-
num_recv_tokens
-
1
;
num_recv_tokens
=
-
num_recv_tokens
-
1
;
recv_token_begin_idx
=
atomicAdd
(
atomic_counter_per_local_exper
t
+
local_expert_idx
,
num_recv_tokens
);
recv_token_begin_idx
=
atomicAdd
(
packed_recv_coun
t
+
local_expert_idx
,
num_recv_tokens
);
shared_num_recv_tokens
[
warp_group_id
]
=
num_recv_tokens
;
shared_num_recv_tokens
[
warp_group_id
]
=
num_recv_tokens
;
shared_recv_token_begin_idx
[
warp_group_id
]
=
recv_token_begin_idx
;
shared_recv_token_begin_idx
[
warp_group_id
]
=
recv_token_begin_idx
;
recv_range
[
src_rank
]
=
pack2
<
int
,
int64_t
>
(
num_recv_tokens
,
recv_token_begin_idx
);
recv_range
[
src_rank
]
=
pack2
<
int
,
int64_t
>
(
num_recv_tokens
,
recv_token_begin_idx
);
...
@@ -290,6 +299,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -290,6 +299,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
void
*
rdma_recv_x
,
int
*
rdma_recv_count
,
void
*
rdma_x
,
void
*
rdma_recv_x
,
int
*
rdma_recv_count
,
void
*
rdma_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int
*
next_clean
,
int
num_next_clean_int
,
int
*
next_clean
,
int
num_next_clean_int
,
...
@@ -311,17 +321,14 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -311,17 +321,14 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
auto
atomic_finish_counter_per_expert
=
atomic_counter_per_expert
+
num_experts
;
auto
atomic_finish_counter_per_expert
=
atomic_counter_per_expert
+
num_experts
;
EP_HOST_ASSERT
(
num_experts
*
sizeof
(
int
)
*
2
<=
NUM_WORKSPACE_BYTES
);
EP_HOST_ASSERT
(
num_experts
*
sizeof
(
int
)
*
2
<=
NUM_WORKSPACE_BYTES
);
// Use the last part `rdma_recv_count` as `atomic_counter_per_local_expert`
// NOTES: this part will be cleaned in `combine`
auto
atomic_counter_per_local_expert
=
rdma_recv_count
+
num_ranks
*
(
num_experts
/
num_ranks
);
#define DISPATCH_LAUNCH_CASE(hidden) \
#define DISPATCH_LAUNCH_CASE(hidden) \
LAUNCH_KERNEL(&cfg, dispatch<kNumWarpGroups, kNumWarpsPerGroup, hidden>, \
LAUNCH_KERNEL(&cfg, dispatch<kNumWarpGroups, kNumWarpsPerGroup, hidden>, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \
packed_recv_src_info, packed_recv_layout_range, \
packed_recv_count, \
rdma_recv_x, rdma_recv_count, rdma_x, \
rdma_recv_x, rdma_recv_count, rdma_x, \
x, topk_idx, \
x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert,
atomic_counter_per_local_expert,
\
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
next_clean, num_next_clean_int, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, phases); break
num_topk, num_experts, rank, num_ranks, phases); break
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment