Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
8a34a9bd
Commit
8a34a9bd
authored
Feb 10, 2026
by
lishen
Browse files
modify internode notify
parent
17d9c844
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
36 deletions
+35
-36
csrc/kernels/internode.cu
csrc/kernels/internode.cu
+35
-36
No files found.
csrc/kernels/internode.cu
View file @
8a34a9bd
...
@@ -116,13 +116,10 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
...
@@ -116,13 +116,10 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
if
(
sm_id
==
0
)
{
if
(
sm_id
==
0
)
{
// Communication with others
// Communication with others
// Global barrier: the first warp do intra-node sync, the second warp do internode sync
// Global barrier: the first warp do intra-node sync, the second warp do internode sync
EP_DEVICE_ASSERT
(
num_warps
>
1
);
EP_DEVICE_ASSERT
(
kNumRDMARanks
<=
num_threads
);
if
(
thread_id
==
kWarpSize
)
if
(
thread_id
==
kWarpSize
)
dushmem_barrier_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
dushmem_barrier_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
barrier_block
<
NUM_MAX_NVL_PEERS
>
(
barrier_signal_ptrs
,
nvl_rank
);
barrier_block
<
NUM_MAX_NVL_PEERS
>
(
barrier_signal_ptrs
,
nvl_rank
);
__syncthreads
();
// Send numbers of tokens per rank/expert to RDMA ranks
// Send numbers of tokens per rank/expert to RDMA ranks
auto
rdma_buffer_ptr_int
=
reinterpret_cast
<
int
*>
(
rdma_buffer_ptr
);
auto
rdma_buffer_ptr_int
=
reinterpret_cast
<
int
*>
(
rdma_buffer_ptr
);
...
@@ -152,14 +149,25 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
...
@@ -152,14 +149,25 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
// Issue send
// Issue send
// TODO: more light fence or barrier or signaling
// TODO: more light fence or barrier or signaling
// TODO: overlap EP barrier and NVL cleaning
// TODO: overlap EP barrier and NVL cleaning
if
(
thread_id
<
kNumRDMARanks
)
{
for
(
int
i
=
warp_id
;
i
<
kNumRDMARanks
;
i
+=
num_warps
)
{
shmem_int_put_nbi
(
if
(
i
!=
rdma_rank
)
{
shmemx_int_put_nbi_warp
(
rdma_recv_num_tokens_mixed
.
recv_buffer
(
rdma_rank
),
rdma_recv_num_tokens_mixed
.
recv_buffer
(
rdma_rank
),
rdma_recv_num_tokens_mixed
.
send_buffer
(
thread_id
),
rdma_recv_num_tokens_mixed
.
send_buffer
(
i
),
NUM_MAX_NVL_PEERS
+
num_rdma_experts
+
1
,
NUM_MAX_NVL_PEERS
+
num_rdma_experts
+
1
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
thread_id
,
nvl_rank
));
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
i
,
nvl_rank
));
}
else
{
UNROLLED_WARP_COPY
(
1
,
lane_id
,
NUM_MAX_NVL_PEERS
+
num_rdma_experts
+
1
,
rdma_recv_num_tokens_mixed
.
recv_buffer
(
rdma_rank
),
rdma_recv_num_tokens_mixed
.
send_buffer
(
i
),
ld_volatile_global
,
st_na_global
);
}
}
}
__syncthreads
();
__syncthreads
();
if
(
thread_id
==
0
)
if
(
thread_id
==
0
)
dushmem_barrier_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
dushmem_barrier_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
...
@@ -215,7 +223,6 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
...
@@ -215,7 +223,6 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
}
}
// Send numbers of tokens per rank/expert to NVL ranks
// Send numbers of tokens per rank/expert to NVL ranks
EP_DEVICE_ASSERT
(
NUM_MAX_NVL_PEERS
<=
num_threads
);
if
(
thread_id
<
NUM_MAX_NVL_PEERS
)
{
if
(
thread_id
<
NUM_MAX_NVL_PEERS
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
kNumRDMARanks
;
++
i
)
for
(
int
i
=
0
;
i
<
kNumRDMARanks
;
++
i
)
...
@@ -225,10 +232,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
...
@@ -225,10 +232,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
nvl_send_num_tokens_per_expert
.
buffer
(
nvl_rank
)[
i
]
=
nvl_send_num_tokens_per_expert
.
buffer
(
nvl_rank
)[
i
]
=
nvl_reduced_num_tokens_per_expert
[
thread_id
*
num_nvl_experts
+
i
];
nvl_reduced_num_tokens_per_expert
[
thread_id
*
num_nvl_experts
+
i
];
}
}
memory_fence
();
__syncthreads
();
barrier_block
<
NUM_MAX_NVL_PEERS
>
(
barrier_signal_ptrs
,
nvl_rank
);
barrier_block
<
NUM_MAX_NVL_PEERS
>
(
barrier_signal_ptrs
,
nvl_rank
);
__syncthreads
();
// Reduce number of tokens per rank/expert
// Reduce number of tokens per rank/expert
EP_DEVICE_ASSERT
(
num_nvl_experts
<=
num_threads
);
EP_DEVICE_ASSERT
(
num_nvl_experts
<=
num_threads
);
...
@@ -255,7 +259,6 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
...
@@ -255,7 +259,6 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
}
}
// Finally barrier
// Finally barrier
__syncthreads
();
if
(
thread_id
==
kWarpSize
)
if
(
thread_id
==
kWarpSize
)
dushmem_barrier_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
dushmem_barrier_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
...
@@ -355,12 +358,13 @@ void notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mappe
...
@@ -355,12 +358,13 @@ void notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mappe
auto
nvl_clean_meta
=
auto
nvl_clean_meta
=
get_nvl_clean_meta
(
hidden_int4
,
num_scales
,
num_topk
,
num_topk
,
num_rdma_ranks
,
get_nvl_clean_meta
(
hidden_int4
,
num_scales
,
num_topk
,
num_topk
,
num_rdma_ranks
,
NUM_MAX_NVL_PEERS
,
num_max_nvl_chunked_recv_tokens
,
num_channels
);
NUM_MAX_NVL_PEERS
,
num_max_nvl_chunked_recv_tokens
,
num_channels
);
EP_HOST_ASSERT
((
rdma_clean_meta
.
first
+
rdma_clean_meta
.
second
)
*
sizeof
(
int
)
<=
EP_HOST_ASSERT
((
rdma_clean_meta
.
first
+
rdma_clean_meta
.
second
)
*
sizeof
(
int
)
<=
num_rdma_bytes
);
num_rdma_bytes
);
EP_HOST_ASSERT
((
nvl_clean_meta
.
first
+
nvl_clean_meta
.
second
)
*
sizeof
(
int
)
<=
num_nvl_bytes
);
EP_HOST_ASSERT
((
nvl_clean_meta
.
first
+
nvl_clean_meta
.
second
)
*
sizeof
(
int
)
<=
num_nvl_bytes
);
EP_HOST_ASSERT
(
num_rdma_bytes
<
std
::
numeric_limits
<
int
>::
max
());
EP_HOST_ASSERT
(
num_rdma_bytes
<
std
::
numeric_limits
<
int
>::
max
());
EP_HOST_ASSERT
(
num_nvl_bytes
<
std
::
numeric_limits
<
int
>::
max
());
EP_HOST_ASSERT
(
num_nvl_bytes
<
std
::
numeric_limits
<
int
>::
max
());
// add assert origin kernel
EP_HOST_ASSERT
(
num_rdma_ranks
<=
kNumThreads
);
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
<=
kNumThreads
,
"Assert NUM_MAX_NVL_PEERS <= kNumThreads"
);
// Launch kernel
// Launch kernel
SETUP_LAUNCH_CONFIG
(
1
+
num_rdma_ranks
,
kNumThreads
,
stream
);
SETUP_LAUNCH_CONFIG
(
1
+
num_rdma_ranks
,
kNumThreads
,
stream
);
...
@@ -1202,37 +1206,31 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
...
@@ -1202,37 +1206,31 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
// Using two SMs, which clean the RDMA/NVL buffer respectively
// Using two SMs, which clean the RDMA/NVL buffer respectively
if
(
sm_id
==
0
)
{
if
(
sm_id
==
0
)
{
// Barrier for RDMA
// Barrier for RDMA
if
(
thread_id
==
0
)
if
(
thread_id
==
kWarpSize
)
dushmem_barrier_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
dushmem_barrier_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
__syncthreads
();
// Barrier for NVL
barrier_block
<
NUM_MAX_NVL_PEERS
>
(
barrier_signal_ptrs
,
nvl_rank
);
// Clean
// Clean
RDMA buffer
auto
rdma_buffer_ptr_int
=
reinterpret_cast
<
int
*>
(
rdma_buffer_ptr
);
auto
rdma_buffer_ptr_int
=
reinterpret_cast
<
int
*>
(
rdma_buffer_ptr
);
for
(
int
i
=
thread_id
;
i
<
rdma_num_int_clean
;
i
+=
num_threads
)
for
(
int
i
=
thread_id
;
i
<
rdma_num_int_clean
;
i
+=
num_threads
)
rdma_buffer_ptr_int
[
rdma_clean_offset
+
i
]
=
0
;
rdma_buffer_ptr_int
[
rdma_clean_offset
+
i
]
=
0
;
shmem_fence
();
__syncthreads
();
// Barrier again
// Clean NVL buffer
if
(
thread_id
==
0
)
dushmem_barrier_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
}
else
if
(
sm_id
==
1
)
{
// Barrier for NVL
barrier_block
<
NUM_MAX_NVL_PEERS
>
(
barrier_signal_ptrs
,
nvl_rank
);
__syncthreads
();
// Clean
auto
nvl_buffer_ptr_int
=
reinterpret_cast
<
int
*>
(
buffer_ptrs
[
nvl_rank
]);
auto
nvl_buffer_ptr_int
=
reinterpret_cast
<
int
*>
(
buffer_ptrs
[
nvl_rank
]);
for
(
int
i
=
thread_id
;
i
<
nvl_num_int_clean
;
i
+=
num_threads
)
for
(
int
i
=
thread_id
;
i
<
nvl_num_int_clean
;
i
+=
num_threads
)
nvl_buffer_ptr_int
[
nvl_clean_offset
+
i
]
=
0
;
nvl_buffer_ptr_int
[
nvl_clean_offset
+
i
]
=
0
;
memory_fence
();
__syncthreads
();
__syncthreads
();
// Barrier again
if
(
thread_id
==
kWarpSize
)
dushmem_barrier_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
// Barrier again
// Barrier again
barrier_block
<
NUM_MAX_NVL_PEERS
>
(
barrier_signal_ptrs
,
nvl_rank
);
barrier_block
<
NUM_MAX_NVL_PEERS
>
(
barrier_signal_ptrs
,
nvl_rank
);
}
else
if
(
sm_id
==
2
)
{
}
else
if
(
sm_id
==
1
)
{
if
(
is_cached_dispatch
)
if
(
is_cached_dispatch
)
return
;
return
;
...
@@ -1265,10 +1263,11 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
...
@@ -1265,10 +1263,11 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
EP_DEVICE_ASSERT
(
rdma_channel_prefix_matrix
!=
nullptr
and
EP_DEVICE_ASSERT
(
rdma_channel_prefix_matrix
!=
nullptr
and
rdma_rank_prefix_sum
!=
nullptr
);
rdma_rank_prefix_sum
!=
nullptr
);
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
<=
kWarpSize
,
"Too many NVL peers"
);
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
<=
kWarpSize
,
"Too many NVL peers"
);
constexpr
int
num_clean_sms
=
2
;
if
(
lane_id
<
NUM_MAX_NVL_PEERS
and
warp_id
<
num_channels
)
{
if
(
lane_id
<
NUM_MAX_NVL_PEERS
and
warp_id
<
num_channels
)
{
for
(
int
dst_rdma_rank
=
sm_id
-
3
;
dst_rdma_rank
<
num_rdma_ranks
;
for
(
int
dst_rdma_rank
=
sm_id
-
num_clean_sms
;
dst_rdma_rank
<
num_rdma_ranks
;
dst_rdma_rank
+=
num_channels
*
2
-
3
)
{
dst_rdma_rank
+=
num_channels
*
2
-
num_clean_sms
)
{
// Iterate in reverse order
// Iterate in reverse order
int
token_start_idx
=
int
token_start_idx
=
warp_id
==
0
warp_id
==
0
...
@@ -1319,7 +1318,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
...
@@ -1319,7 +1318,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
num_nvl_bytes
);
num_nvl_bytes
);
EP_HOST_ASSERT
(
num_rdma_bytes
<
std
::
numeric_limits
<
int
>::
max
());
EP_HOST_ASSERT
(
num_rdma_bytes
<
std
::
numeric_limits
<
int
>::
max
());
EP_HOST_ASSERT
(
num_nvl_bytes
<
std
::
numeric_limits
<
int
>::
max
());
EP_HOST_ASSERT
(
num_nvl_bytes
<
std
::
numeric_limits
<
int
>::
max
());
EP_HOST_ASSERT
(
num_channels
*
2
>
3
);
EP_HOST_ASSERT
(
num_channels
*
2
>
2
);
// Launch kernel
// Launch kernel
auto
cached_notify_func
=
low_latency_mode
?
cached_notify
<
true
>
:
cached_notify
<
false
>
;
auto
cached_notify_func
=
low_latency_mode
?
cached_notify
<
true
>
:
cached_notify
<
false
>
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment