Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
1a35d640
Commit
1a35d640
authored
May 21, 2026
by
root
Browse files
fix dtk26.04 4nodes core dump.
Signed-off-by:
root
<
root@host-10-212-17-3.cluster.local
>
parent
95e46992
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
936 additions
and
952 deletions
+936
-952
csrc/config.hpp
csrc/config.hpp
+2
-2
csrc/deep_ep.cu
csrc/deep_ep.cu
+4
-4
csrc/kernels/internode.cu
csrc/kernels/internode.cu
+930
-946
No files found.
csrc/config.hpp
View file @
1a35d640
...
...
@@ -47,7 +47,7 @@ struct Config {
EP_HOST_ASSERT
(
num_ranks
<=
NUM_MAX_NVL_PEERS
or
num_sms
%
(
2
*
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
)
==
0
);
const
auto
num_rdma_ranks
=
std
::
max
(
num_ranks
/
NUM_MAX_NVL_PEERS
,
1
);
const
auto
num_nvl_ranks
=
std
::
min
(
num_ranks
,
NUM_MAX_NVL_PEERS
);
const
int
num_channels
=
num_
ranks
<=
8
?
num_sms
/
2
:
num_sms
/
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
;
const
int
num_channels
=
num_
sms
/
2
;
// 计算每个nvl通信数据包的数据量
size_t
num_single_nvl_bag_bytes
=
...
...
@@ -83,7 +83,7 @@ struct Config {
EP_HOST_ASSERT
(
num_ranks
%
NUM_MAX_NVL_PEERS
==
0
);
EP_HOST_ASSERT
(
num_sms
%
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
==
0
);
const
int
num_rdma_ranks
=
num_ranks
/
NUM_MAX_NVL_PEERS
;
const
int
num_channels
=
num_sms
/
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
;
const
int
num_channels
=
num_sms
/
2
;
// 计算每个rdma通信数据包的数据量
size_t
num_single_rdma_bag_bytes
=
...
...
csrc/deep_ep.cu
View file @
1a35d640
...
...
@@ -809,8 +809,8 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te
// here.
pybind11
::
gil_scoped_release
release
;
const
int
num_channels
=
config
.
num_sms
/
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
;
EP_HOST_ASSERT
(
config
.
num_sms
%
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
==
0
);
const
int
num_channels
=
config
.
num_sms
/
2
;
//
EP_HOST_ASSERT(config.num_sms % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0);
EP_HOST_ASSERT
(
0
<
get_num_rdma_ranks
()
and
get_num_rdma_ranks
()
<=
NUM_MAX_RDMA_PEERS
);
bool
cached_mode
=
cached_rdma_channel_prefix_matrix
.
has_value
();
...
...
@@ -1130,8 +1130,8 @@ Buffer::internode_combine(
const
torch
::
Tensor
&
combined_nvl_head
,
const
Config
&
config
,
std
::
optional
<
EventHandle
>
&
previous_event
,
bool
async
,
bool
allocate_on_comm_stream
)
{
#ifndef DISABLE_ROCSHMEM
const
int
num_channels
=
config
.
num_sms
/
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
;
EP_HOST_ASSERT
(
config
.
num_sms
%
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
==
0
);
const
int
num_channels
=
config
.
num_sms
/
2
;
//
EP_HOST_ASSERT(config.num_sms % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0);
// Shape and contiguous checks
EP_HOST_ASSERT
(
x
.
dim
()
==
2
and
x
.
is_contiguous
());
...
...
csrc/kernels/internode.cu
View file @
1a35d640
...
...
@@ -7,17 +7,6 @@
#ifndef DISABLE_ROCSHMEM
// 安全检查:确保宏已定义
#ifndef HIP_VERSION_PATCH
#error "HIP_VERSION_PATCH not defined! Check your HIP installation."
#endif
// TODO: fix unroll warnings
// #ifdef __clang__
// #pragma clang diagnostic push
// #pragma clang diagnostic ignored "-Wpass-failed"
// #pragma clang diagnostic ignored "-Wdeprecated-volatile"
// #endif // __clang__
namespace
deep_ep
{
namespace
internode
{
...
...
@@ -25,7 +14,7 @@ namespace internode {
extern
shmem_team_t
cpu_rdma_team
;
struct
SourceMeta
{
int
src_rdma_rank
,
is_token_in_nvl_rank_bits
;
// sizeof(SourceMeta) = 8
int
src_rdma_rank
,
is_token_in_nvl_rank_bits
;
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
==
8
,
"Invalid number of maximum NVL peers"
);
...
...
@@ -60,18 +49,16 @@ __host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_
__host__
__device__
__forceinline__
std
::
pair
<
int
,
int
>
get_rdma_clean_meta
(
int
hidden_int4
,
int
num_scales
,
int
num_topk_idx
,
int
num_topk_weights
,
int
num_rdma_ranks
,
int
num_rdma_recv_buffer_tokens
,
int
num_channels
)
{
int
num_rdma_ranks
,
int
num_rdma_recv_buffer_tokens
,
int
num_sms
)
{
// Return `int32_t` offset and count to clean
return
{(
get_num_bytes_per_rdma_token
(
hidden_int4
,
num_scales
,
num_topk_idx
,
num_topk_weights
)
*
num_rdma_recv_buffer_tokens
*
num_rdma_ranks
*
2
*
num_
channel
s
)
/
sizeof
(
int
),
(
NUM_MAX_NVL_PEERS
*
2
+
4
)
*
num_rdma_ranks
*
2
*
num_
channel
s
};
num_rdma_recv_buffer_tokens
*
num_rdma_ranks
*
2
*
num_
sm
s
)
/
sizeof
(
int
),
(
NUM_MAX_NVL_PEERS
*
2
+
4
)
*
num_rdma_ranks
*
2
*
num_
sm
s
};
}
__host__
__device__
__forceinline__
std
::
pair
<
int
,
int
>
get_nvl_clean_meta
(
int
hidden_int4
,
int
num_scales
,
int
num_topk_idx
,
int
num_topk_weights
,
int
num_rdma_ranks
,
int
num_nvl_ranks
,
int
num_nvl_recv_buffer_tokens
,
int
num_
channel
s
)
{
int
num_
sm
s
)
{
// Return `int32_t` offset and to clean
EP_STATIC_ASSERT
(
sizeof
(
SourceMeta
)
%
sizeof
(
int
)
==
0
,
"Invalid size of `SourceMeta`"
);
...
...
@@ -79,8 +66,8 @@ get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_to
(
num_nvl_recv_buffer_tokens
*
(
hidden_int4
*
sizeof
(
int4
)
+
num_scales
*
sizeof
(
float
)
+
num_topk_idx
*
sizeof
(
int
)
+
num_topk_weights
*
sizeof
(
float
)
+
sizeof
(
SourceMeta
))
*
num_nvl_ranks
*
num_
channel
s
)
/
sizeof
(
int
),
num_nvl_ranks
*
(
2
*
num_rdma_ranks
+
2
)
*
num_
channel
s
,
num_nvl_ranks
*
num_
sm
s
)
/
sizeof
(
int
),
num_nvl_ranks
*
(
2
*
num_rdma_ranks
+
2
)
*
num_
sm
s
,
};
}
...
...
@@ -92,9 +79,10 @@ __forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank,
template
<
bool
kLowLatencyMode
>
__forceinline__
__device__
void
du
shmem_
barrier
_with_same_gpu_idx
(
const
shmem_team_t
&
rdma_team
)
{
shmem_
sync
_with_same_gpu_idx
(
const
shmem_team_t
&
rdma_team
)
{
// NOTE: shmem_device_barrier_all() might be an issue as
// it doesn't follow OpenSHMEM specification on ROCm
// kLowLatencyMode ? shmem_device_sync(rdma_team) : shmem_device_sync_all();
kLowLatencyMode
?
shmem_barrier
(
rdma_team
)
:
shmem_device_barrier_all
();
}
...
...
@@ -123,7 +111,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
// Communication with others
// Global barrier: the first warp do intra-node sync, the second warp do internode sync
if
(
thread_id
==
kWarpSize
)
du
shmem_
barrier
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
shmem_
sync
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
barrier_block
<
NUM_MAX_NVL_PEERS
>
(
barrier_signal_ptrs
,
nvl_rank
);
...
...
@@ -175,7 +163,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
__syncthreads
();
if
(
thread_id
==
0
)
du
shmem_
barrier
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
shmem_
sync
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
__syncthreads
();
...
...
@@ -266,7 +254,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
// Finally barrier
if
(
thread_id
==
kWarpSize
)
du
shmem_
barrier
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
shmem_
sync
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
barrier_block
<
NUM_MAX_NVL_PEERS
>
(
barrier_signal_ptrs
,
nvl_rank
);
}
else
{
...
...
@@ -383,14 +371,12 @@ constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) {
return
num_rdma_ranks
<
8
?
num_rdma_ranks
:
8
;
}
template
<
bool
kLowLatencyMode
,
int
kNumRDMARanks
,
bool
kCachedMode
,
template
<
bool
kLowLatencyMode
,
int
kNumRDMARanks
,
bool
kCachedMode
,
int
kNumDispatchRDMASenderWarps
,
int
kNumTopkRDMARanks
=
get_num_topk_rdma_ranks
(
kNumRDMARanks
)>
__global__
void
__launch_bounds__
(((
1
+
NUM_MAX_NVL_PEERS
)
*
kWarpSize
),
1
)
dispatch
(
int4
*
recv_x
,
float
*
recv_x_scales
,
int64_t
*
recv_topk_idx
,
float
*
recv_topk_weights
,
__global__
void
__launch_bounds__
(((
kNumDispatchRDMASenderWarps
+
1
+
NUM_MAX_NVL_PEERS
)
*
kWarpSize
),
1
)
dispatch
(
int4
*
recv_x
,
float
*
recv_x_scales
,
int64_t
*
recv_topk_idx
,
float
*
recv_topk_weights
,
SourceMeta
*
recv_src_meta
,
const
int4
*
x
,
const
float
*
x_scales
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
int
*
send_rdma_head
,
int
*
send_nvl_head
,
int
*
recv_rdma_channel_prefix_matrix
,
...
...
@@ -403,739 +389,736 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
int
num_max_nvl_chunked_send_tokens
,
int
num_max_nvl_chunked_recv_tokens
,
int
rank
,
int
num_ranks
)
{
enum
class
WarpRole
{
kRDMASender
,
// 从x写入到RDMA发送缓存
kRDMASenderCoordinator
,
// 从RDMA发送缓存写入到远端rdma_rank接收缓存
kRDMAAndNVLForwarder
,
// 从RDMA接收缓存转写到ipc nvl缓存
kForwarderCoordinator
,
// 向远端RDMA确认接收
kNVLReceivers
// 从nvl缓存写入到recv_x
kRDMASender
,
kRDMASenderCoordinator
,
kRDMAAndNVLForwarder
,
kForwarderCoordinator
,
kNVLReceivers
};
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
__shared__
shmem_ctx_t
ctx
;
shmem_wg_ctx_create
(
&
ctx
);
#endif
__shared__
rocshmem
::
rocshmem_ctx_t
ctx
;
rocshmem
::
rocshmem_wg_ctx_create
(
0
,
&
ctx
);
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
num_threads
=
static_cast
<
int
>
(
blockDim
.
x
),
num_warps
=
num_threads
/
kWarpSize
;
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
),
warp_id
=
thread_id
/
kWarpSize
,
lane_id
=
get_lane_id
();
const
auto
num_channels
=
static_cast
<
int
>
(
gridDim
.
x
)
/
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
,
channel_id
=
sm_id
/
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
;
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
),
warp_id
=
thread_id
/
kWarpSize
,
lane_id
=
get_lane_id
();
const
auto
num_channels
=
static_cast
<
int
>
(
gridDim
.
x
)
/
2
,
channel_id
=
sm_id
/
2
;
const
bool
is_forwarder
=
sm_id
%
2
==
0
;
const
auto
rdma_rank
=
rank
/
NUM_MAX_NVL_PEERS
,
nvl_rank
=
rank
%
NUM_MAX_NVL_PEERS
;
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
*
sizeof
(
bool
)
==
sizeof
(
uint64_t
),
"Invalid number of NVL peers"
);
EP_DEVICE_ASSERT
(
num_warps
==
1
+
NUM_MAX_NVL_PEERS
);
const
auto
role_meta
=
[
=
]()
->
std
::
pair
<
WarpRole
,
int
>
{
if
(
sm_id
%
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
==
0
)
{
if
(
warp_id
<
kNumDispatchRDMASenderWarps
)
{
return
{
WarpRole
::
kRDMASender
,
-
1
};
}
else
if
(
warp_id
==
kNumDispatchRDMASenderWarps
)
{
return
{
WarpRole
::
kRDMASenderCoordinator
,
-
1
};
}
}
else
if
(
sm_id
%
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
==
1
)
{
if
(
warp_id
<
NUM_MAX_NVL_PEERS
)
{
if
(
is_forwarder
)
{
if
(
warp_id
<
NUM_MAX_NVL_PEERS
)
{
return
{
WarpRole
::
kRDMAAndNVLForwarder
,
(
warp_id
+
channel_id
)
%
NUM_MAX_NVL_PEERS
};
}
else
{
return
{
WarpRole
::
kForwarderCoordinator
,
warp_id
-
NUM_MAX_NVL_PEERS
};
}
}
else
if
(
warp_id
<
kNumDispatchRDMASenderWarps
)
{
return
{
WarpRole
::
kRDMASender
,
-
1
};
}
else
if
(
warp_id
==
kNumDispatchRDMASenderWarps
)
{
return
{
WarpRole
::
kRDMASenderCoordinator
,
-
1
};
}
else
{
return
{
WarpRole
::
kNVLReceivers
,
(
warp_id
+
channel_id
+
1
)
%
NUM_MAX_NVL_PEERS
};
return
{
WarpRole
::
kNVLReceivers
,
(
warp_id
+
channel_id
-
kNumDispatchRDMASenderWarps
)
%
NUM_MAX_NVL_PEERS
};
}
}();
auto
warp_role
=
role_meta
.
first
;
auto
target_rank
=
role_meta
.
second
;
// Not applicable for RDMA senders
EP_DEVICE_ASSERT
(
num_warps
==
kNumDispatchRDMASenderWarps
+
1
+
NUM_MAX_NVL_PEERS
);
// Data checks
EP_DEVICE_ASSERT
(
num_topk
<=
kWarpSize
);
// RDMA symmetric layout
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
*
sizeof
(
bool
)
==
sizeof
(
uint64_t
),
"Invalid number of NVL peers"
);
auto
hidden_bytes
=
hidden_int4
*
sizeof
(
int4
);
auto
num_bytes_per_rdma_token
=
get_num_bytes_per_rdma_token
(
hidden_int4
,
num_scales
,
num_topk
,
num_topk
);
auto
rdma_channel_data
=
SymBuffer
<
int8_t
>
(
rdma_buffer_ptr
,
num_max_rdma_chunked_recv_tokens
*
num_bytes_per_rdma_token
,
kNumRDMARanks
,
channel_id
,
num_channels
);
auto
rdma_channel_meta
=
SymBuffer
<
int
>
(
rdma_buffer_ptr
,
NUM_MAX_NVL_PEERS
*
2
+
2
,
kNumRDMARanks
,
channel_id
,
num_channels
);
auto
rdma_channel_head
=
SymBuffer
<
uint64_t
,
false
>
(
rdma_buffer_ptr
,
1
,
kNumRDMARanks
,
channel_id
,
num_channels
);
auto
rdma_channel_tail
=
SymBuffer
<
uint64_t
,
false
>
(
rdma_buffer_ptr
,
1
,
kNumRDMARanks
,
channel_id
,
num_channels
);
auto
num_bytes_per_rdma_token
=
get_num_bytes_per_rdma_token
(
hidden_int4
,
num_scales
,
num_topk
,
num_topk
);
auto
rdma_channel_data
=
SymBuffer
<
int8_t
>
(
rdma_buffer_ptr
,
num_max_rdma_chunked_recv_tokens
*
num_bytes_per_rdma_token
,
kNumRDMARanks
,
channel_id
,
num_channels
);
auto
rdma_channel_meta
=
SymBuffer
<
int
>
(
rdma_buffer_ptr
,
NUM_MAX_NVL_PEERS
*
2
+
2
,
kNumRDMARanks
,
channel_id
,
num_channels
);
auto
rdma_channel_head
=
SymBuffer
<
uint64_t
,
false
>
(
rdma_buffer_ptr
,
1
,
kNumRDMARanks
,
channel_id
,
num_channels
);
auto
rdma_channel_tail
=
SymBuffer
<
uint64_t
,
false
>
(
rdma_buffer_ptr
,
1
,
kNumRDMARanks
,
channel_id
,
num_channels
);
// NVL buffer layouts
// NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for Senders, Read for Receivers"
// NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr`
// means "Write for Senders, Read for Receivers"
void
*
rs_wr_buffer_ptr
=
nullptr
,
*
ws_rr_buffer_ptr
=
nullptr
;
int
rs_wr_rank
=
0
,
ws_rr_rank
=
0
;
if
(
warp_role
==
WarpRole
::
kRDMAAndNVLForwarder
)
rs_wr_buffer_ptr
=
buffer_ptrs
[
nvl_rank
],
ws_rr_buffer_ptr
=
buffer_ptrs
[
target_rank
],
rs_wr_rank
=
nvl_rank
,
ws_rr_rank
=
target_rank
;
rs_wr_buffer_ptr
=
buffer_ptrs
[
nvl_rank
],
ws_rr_buffer_ptr
=
buffer_ptrs
[
target_rank
],
rs_wr_rank
=
nvl_rank
,
ws_rr_rank
=
target_rank
;
if
(
warp_role
==
WarpRole
::
kNVLReceivers
)
rs_wr_buffer_ptr
=
buffer_ptrs
[
target_rank
],
ws_rr_buffer_ptr
=
buffer_ptrs
[
nvl_rank
],
rs_wr_rank
=
target_rank
,
ws_rr_rank
=
nvl_rank
;
rs_wr_buffer_ptr
=
buffer_ptrs
[
target_rank
],
ws_rr_buffer_ptr
=
buffer_ptrs
[
nvl_rank
],
rs_wr_rank
=
target_rank
,
ws_rr_rank
=
nvl_rank
;
// Allocate buffers
auto
nvl_channel_x
=
AsymBuffer
<
int4
>
(
ws_rr_buffer_ptr
,
num_max_nvl_chunked_recv_tokens
*
hidden_int4
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
).
advance_also
(
rs_wr_buffer_ptr
);
auto
nvl_channel_src_meta
=
AsymBuffer
<
SourceMeta
>
(
ws_rr_buffer_ptr
,
num_max_nvl_chunked_recv_tokens
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
).
advance_also
(
rs_wr_buffer_ptr
);
auto
nvl_channel_x_scales
=
AsymBuffer
<
float
>
(
ws_rr_buffer_ptr
,
num_max_nvl_chunked_recv_tokens
*
num_scales
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
).
advance_also
(
rs_wr_buffer_ptr
);
auto
nvl_channel_topk_idx
=
AsymBuffer
<
int
>
(
ws_rr_buffer_ptr
,
num_max_nvl_chunked_recv_tokens
*
num_topk
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
).
advance_also
(
rs_wr_buffer_ptr
);
auto
nvl_channel_topk_weights
=
AsymBuffer
<
float
>
(
ws_rr_buffer_ptr
,
num_max_nvl_chunked_recv_tokens
*
num_topk
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
).
advance_also
(
rs_wr_buffer_ptr
);
auto
nvl_channel_prefix_start
=
AsymBuffer
<
int
>
(
ws_rr_buffer_ptr
,
kNumRDMARanks
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
).
advance_also
(
rs_wr_buffer_ptr
);
auto
nvl_channel_prefix_end
=
AsymBuffer
<
int
>
(
ws_rr_buffer_ptr
,
kNumRDMARanks
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
).
advance_also
(
rs_wr_buffer_ptr
);
auto
nvl_channel_head
=
AsymBuffer
<
int
>
(
rs_wr_buffer_ptr
,
1
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
ws_rr_rank
).
advance_also
(
ws_rr_buffer_ptr
);
auto
nvl_channel_tail
=
AsymBuffer
<
int
>
(
ws_rr_buffer_ptr
,
1
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
).
advance_also
(
rs_wr_buffer_ptr
);
auto
nvl_channel_x
=
AsymBuffer
<
int4
>
(
ws_rr_buffer_ptr
,
num_max_nvl_chunked_recv_tokens
*
hidden_int4
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
)
.
advance_also
(
rs_wr_buffer_ptr
);
auto
nvl_channel_src_meta
=
AsymBuffer
<
SourceMeta
>
(
ws_rr_buffer_ptr
,
num_max_nvl_chunked_recv_tokens
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
)
.
advance_also
(
rs_wr_buffer_ptr
);
auto
nvl_channel_x_scales
=
AsymBuffer
<
float
>
(
ws_rr_buffer_ptr
,
num_max_nvl_chunked_recv_tokens
*
num_scales
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
)
.
advance_also
(
rs_wr_buffer_ptr
);
auto
nvl_channel_topk_idx
=
AsymBuffer
<
int
>
(
ws_rr_buffer_ptr
,
num_max_nvl_chunked_recv_tokens
*
num_topk
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
)
.
advance_also
(
rs_wr_buffer_ptr
);
auto
nvl_channel_topk_weights
=
AsymBuffer
<
float
>
(
ws_rr_buffer_ptr
,
num_max_nvl_chunked_recv_tokens
*
num_topk
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
)
.
advance_also
(
rs_wr_buffer_ptr
);
auto
nvl_channel_prefix_start
=
AsymBuffer
<
int
>
(
ws_rr_buffer_ptr
,
kNumRDMARanks
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
)
.
advance_also
(
rs_wr_buffer_ptr
);
auto
nvl_channel_prefix_end
=
AsymBuffer
<
int
>
(
ws_rr_buffer_ptr
,
kNumRDMARanks
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
)
.
advance_also
(
rs_wr_buffer_ptr
);
auto
nvl_channel_head
=
AsymBuffer
<
int
>
(
rs_wr_buffer_ptr
,
1
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
ws_rr_rank
)
.
advance_also
(
ws_rr_buffer_ptr
);
auto
nvl_channel_tail
=
AsymBuffer
<
int
>
(
ws_rr_buffer_ptr
,
1
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
)
.
advance_also
(
rs_wr_buffer_ptr
);
// RDMA sender warp synchronization
__shared__
volatile
int
rdma_send_next_token_idx
;
__shared__
volatile
int
rdma_send_channel_tail
[
kNumRDMARanks
];
__shared__
volatile
int
rdma_send_channel_next_tail
[
kNumRDMARanks
];
__shared__
volatile
int
rdma_sender_counter
[
1
];
__shared__
volatile
int
rdma_forwarder_counter
[
1
];
if
(
threadIdx
.
x
==
0
)
{
rdma_sender_counter
[
0
]
=
0
;
rdma_forwarder_counter
[
0
]
=
0
;
}
__syncthreads
();
// NVL and RDMA coordinate Forward warp synchronization
auto
sync_rdma_sender_smem
=
[
&
]()
{
if
(
lane_id
==
0
)
{
volatile
int
ret
=
__hip_atomic_fetch_add
(
&
rdma_sender_counter
[
0
],
1
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
// volatile int ret = atomicAdd((int*)&rdma_sender_counter[0], 1);
}
syncwarp
();
while
(
rdma_sender_counter
[
0
]
<
(
kNumDispatchRDMASenderWarps
+
1
))
{
}
};
// Forward warp synchronization
__shared__
volatile
int
forward_channel_head
[
NUM_MAX_NVL_PEERS
][
kNumRDMARanks
];
__shared__
volatile
bool
forward_channel_retired
[
NUM_MAX_NVL_PEERS
];
// NOTE: Not sure that __syncthreads() is a suitable replacement
auto
sync_forwarder_smem
=
[
&
]()
{
if
(
lane_id
==
0
)
{
volatile
int
ret
=
__hip_atomic_fetch_add
(
&
rdma_forwarder_counter
[
0
],
1
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
// volatile int ret = atomicAdd((int*)&rdma_forwarder_counter[0], 1);
}
syncwarp
();
while
(
rdma_forwarder_counter
[
0
]
<
(
NUM_MAX_NVL_PEERS
+
1
))
{
}
};
// Place the main logic of your kernel here, using the parameters above.
if
(
warp_role
==
WarpRole
::
kRDMASender
)
{
/*
这段代码的主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作。
它首先获取当前通道的任务范围,然后清理共享内存,接着计算并发送本通道中的令牌数量。
然后,它遍历所有的令牌,读取每个令牌的RDMA秩的存在性,获取顺序锁,计算下一个尾部位置,存储RDMA头部,更新最后一个令牌尾部,释放顺序锁,并广播尾部位置。
最后,它复制相关的数据到对称发送缓冲区。
kRDMASender主要目的是将发送信息x, x_scale,source_meta, topk_idx, topk_weight等信息填充进入rdma发送缓存,
期间要同步warp直接对token的依序操作,以及和kForwarderCoordinator, kRDMASenderCoordinator内存同步。
同时在复制操作时, 使用ld.global.nc.L1::no_allocate.L2::256B, st.global.L1::no_allocate减少L1/L2缓存使用。
*/
// 获取任务范围
if
(
warp_role
==
WarpRole
::
kRDMASender
)
{
// Get tasks
int
token_start_idx
,
token_end_idx
;
get_channel_task_range
(
num_tokens
,
num_channels
,
channel_id
,
token_start_idx
,
token_end_idx
);
get_channel_task_range
(
num_tokens
,
num_channels
,
channel_id
,
token_start_idx
,
token_end_idx
);
// 清理共享内存
EP_STATIC_ASSERT
(
kNumRDMARanks
<=
kWarpSize
,
"无效的RDMA秩数量"
);
if
(
warp_id
==
0
&&
lane_id
==
0
)
{
rdma_send_next_token_idx
=
token_start_idx
;
}
if
(
warp_id
==
0
&&
lane_id
<
kNumRDMARanks
)
{
rdma_send_channel_tail
[
lane_id
]
=
0
;
rdma_send_channel_next_tail
[
lane_id
]
=
0
;
}
// Clean shared memory
EP_STATIC_ASSERT
(
kNumRDMARanks
<=
kWarpSize
,
"Invalid number of RDMA ranks"
);
(
warp_id
==
0
and
lane_id
==
0
)
?
(
rdma_send_next_token_idx
=
token_start_idx
)
:
0
;
(
warp_id
==
0
and
lane_id
<
kNumRDMARanks
)
?
(
rdma_send_channel_tail
[
lane_id
]
=
0
)
:
0
;
(
warp_id
==
0
and
lane_id
<
kNumRDMARanks
)
?
(
rdma_send_channel_next_tail
[
lane_id
]
=
0
)
:
0
;
// 发送本通道中的令牌数量,通过 `-value - 1` 表示
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
*
2
+
2
<=
kWarpSize
,
"无效的NVL对等体数量"
);
// 对于每个目标RDMA秩,以warp为单位进行迭代。计算发送缓冲区的值,并存储在rdma_channel_meta.send_buffer中
// 用于填充rdma_channel_meta.send_buffer本节点发送到远端rank, rdma_rank的起始index和结束index
for
(
int
dst_rdma_rank
=
warp_id
;
dst_rdma_rank
<
kNumRDMARanks
;
dst_rdma_rank
+=
kNumDispatchRDMASenderWarps
)
{
auto
dst_ptr
=
dst_rdma_rank
==
rdma_rank
?
rdma_channel_meta
.
recv_buffer
(
dst_rdma_rank
)
:
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
);
// Send number of tokens in this channel by `-value - 1`
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
*
2
+
2
<=
kWarpSize
,
"Invalid number of NVL peers"
);
for
(
int
dst_rdma_rank
=
warp_id
;
dst_rdma_rank
<
kNumRDMARanks
;
dst_rdma_rank
+=
kNumDispatchRDMASenderWarps
)
{
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
{
dst_ptr
[
lane_id
]
=
-
(
channel_id
==
0
?
0
:
gbl_channel_prefix_matrix
[(
dst_rdma_rank
*
NUM_MAX_NVL_PEERS
+
lane_id
)
*
num_channels
+
channel_id
-
1
])
-
1
;
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
)[
lane_id
]
=
-
(
channel_id
==
0
?
0
:
gbl_channel_prefix_matrix
[(
dst_rdma_rank
*
NUM_MAX_NVL_PEERS
+
lane_id
)
*
num_channels
+
channel_id
-
1
])
-
1
;
}
else
if
(
lane_id
<
NUM_MAX_NVL_PEERS
*
2
)
{
dst_ptr
[
lane_id
]
=
-
gbl_channel_prefix_matrix
[(
dst_rdma_rank
*
NUM_MAX_NVL_PEERS
+
lane_id
-
NUM_MAX_NVL_PEERS
)
*
num_channels
+
channel_id
]
-
1
;
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
)[
lane_id
]
=
-
gbl_channel_prefix_matrix
[(
dst_rdma_rank
*
NUM_MAX_NVL_PEERS
+
lane_id
-
NUM_MAX_NVL_PEERS
)
*
num_channels
+
channel_id
]
-
1
;
}
else
if
(
lane_id
==
NUM_MAX_NVL_PEERS
*
2
)
{
dst_ptr
[
lane_id
]
=
-
(
channel_id
==
0
?
0
:
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel_id
-
1
])
-
1
;
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
)[
lane_id
]
=
-
(
channel_id
==
0
?
0
:
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel_id
-
1
])
-
1
;
}
else
if
(
lane_id
==
NUM_MAX_NVL_PEERS
*
2
+
1
)
{
dst_ptr
[
lane_id
]
=
-
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel_id
]
-
1
;
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
)[
lane_id
]
=
-
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel_id
]
-
1
;
}
syncwarp
();
if
(
dst_rdma_rank
!=
rdma_rank
)
{
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_int_put_nbi_warp
(
ctx
,
#else
shmemx_int_put_nbi_warp
(
#endif
rdma_channel_meta
.
recv_buffer
(
rdma_rank
),
rocshmem
::
rocshmem_ctx_int_put_nbi_wave
(
ctx
,
rdma_channel_meta
.
recv_buffer
(
rdma_rank
),
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
),
NUM_MAX_NVL_PEERS
*
2
+
2
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
}
}
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_quiet
(
ctx
);
#else
shmem_fence
();
#endif
rocshmem
::
rocshmem_ctx_quiet
(
ctx
);
sync_rdma_sender_smem
();
// sync_rdma_sender_smem();
__syncthreads
();
// 遍历令牌并复制到缓冲区
// Iterate over tokens and copy into buffer
int64_t
token_idx
;
int
cached_rdma_channel_head
=
0
,
last_rdma_tail_idx
=
-
1
;
auto
send_buffer
=
lane_id
==
rdma_rank
?
rdma_channel_data
.
recv_buffer
(
lane_id
)
:
rdma_channel_data
.
send_buffer
(
lane_id
);
for
(
token_idx
=
token_start_idx
+
warp_id
;
token_idx
<
token_end_idx
;
token_idx
+=
kNumDispatchRDMASenderWarps
)
{
// 读取RDMA秩的存在性
auto
send_buffer
=
lane_id
==
rdma_rank
?
rdma_channel_data
.
recv_buffer
(
lane_id
)
:
rdma_channel_data
.
send_buffer
(
lane_id
);
for
(
token_idx
=
token_start_idx
+
warp_id
;
token_idx
<
token_end_idx
;
token_idx
+=
kNumDispatchRDMASenderWarps
)
{
// Read RDMA rank existence
uint64_t
is_token_in_rank_uint64
=
0
;
if
(
lane_id
<
kNumRDMARanks
)
{
is_token_in_rank_uint64
=
*
reinterpret_cast
<
const
uint64_t
*>
(
is_token_in_rank
+
token_idx
*
num_ranks
+
lane_id
*
NUM_MAX_NVL_PEERS
);
}
if
(
lane_id
<
kNumRDMARanks
)
is_token_in_rank_uint64
=
*
reinterpret_cast
<
const
uint64_t
*>
(
is_token_in_rank
+
token_idx
*
num_ranks
+
lane_id
*
NUM_MAX_NVL_PEERS
);
// 获得处理数据的自旋锁,获得锁后才会处理一些数据信息
while
(
lane_id
==
0
&&
rdma_send_next_token_idx
!=
token_idx
)
{
// 等待
}
// Acquire sequential lock
while
(
lane_id
==
0
and
rdma_send_next_token_idx
!=
token_idx
)
;
syncwarp
();
//
获取下一个尾部位置
//
Acquire next tail
int
rdma_tail_idx
=
-
1
;
if
(
is_token_in_rank_uint64
!=
0
)
{
if
(
is_token_in_rank_uint64
!=
0
)
{
rdma_tail_idx
=
rdma_send_channel_next_tail
[
lane_id
]
++
;
// 与kForwarderCoordinator相互配合,调节发送数据的频率
while
(
rdma_tail_idx
-
cached_rdma_channel_head
>=
num_max_rdma_chunked_recv_tokens
)
{
cached_rdma_channel_head
=
static_cast
<
int
>
(
ld_volatile_global
(
rdma_channel_head
.
buffer
(
lane_id
)));
}
while
(
rdma_tail_idx
-
cached_rdma_channel_head
>=
num_max_rdma_chunked_recv_tokens
)
cached_rdma_channel_head
=
static_cast
<
int
>
(
ld_volatile_global
(
rdma_channel_head
.
buffer
(
lane_id
)));
}
syncwarp
();
//
存储RDMA头部以供合并
if
(
lane_id
<
kNumRDMARanks
&&
!
kCachedMode
)
{
//
Store RDMA head for combine
if
(
lane_id
<
kNumRDMARanks
and
not
kCachedMode
)
send_rdma_head
[
token_idx
*
kNumRDMARanks
+
lane_id
]
=
rdma_tail_idx
;
}
//
更新最后一个令牌尾部
if
(
last_rdma_tail_idx
>=
0
)
{
st_release_cta
(
const_cast
<
int
*>
(
rdma_send_channel_tail
+
lane_id
),
last_rdma_tail_idx
+
1
);
}
//
Update last token tail
if
(
last_rdma_tail_idx
>=
0
)
st_release_cta
(
const_cast
<
const
int
*>
(
rdma_send_channel_tail
+
lane_id
),
last_rdma_tail_idx
+
1
);
last_rdma_tail_idx
=
rdma_tail_idx
;
// 释放顺序锁
if
(
lane_id
==
0
)
{
rdma_send_next_token_idx
+=
1
;
}
// Release sequential lock
lane_id
==
0
?
(
rdma_send_next_token_idx
+=
1
)
:
0
;
//
广播尾部位置
//
Broadcast tails
SourceMeta
src_meta
;
int
num_topk_ranks
=
0
,
topk_ranks
[
kNumTopkRDMARanks
];
void
*
dst_send_buffers
[
kNumTopkRDMARanks
];
/*
该for循环主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作
*/
#pragma unroll
for
(
int
i
=
0
,
slot_idx
;
i
<
kNumRDMARanks
;
++
i
)
{
// 使用__shfl_sync函数在warp内同步并广播rdma_tail_idx的值
if
((
slot_idx
=
shfl_sync
(
rdma_tail_idx
,
i
))
>=
0
)
{
// warp 所有线程参与,rdma_tail_idx默认为-1, 只有对应rdma rank需要发送时, rdma_tail_idx才会>=0
// 计算slot_idx在接收缓冲区中的位置
void
*
dst_send_buffers
[
kNumTopkRDMARanks
];
#pragma unroll
for
(
int
i
=
0
,
slot_idx
;
i
<
kNumRDMARanks
;
++
i
)
if
((
slot_idx
=
shfl_sync
(
rdma_tail_idx
,
i
))
>=
0
)
{
slot_idx
=
slot_idx
%
num_max_rdma_chunked_recv_tokens
;
// 存储当前RDMA秩到topk_ranks数组中
topk_ranks
[
num_topk_ranks
]
=
i
;
// 广播is_token_in_rank_uint64的值到所有线程,并解释为布尔数组
auto
recv_is_token_in_rank_uint64
=
broadcast
(
is_token_in_rank_uint64
,
i
);
auto
recv_is_token_in_rank_values
=
reinterpret_cast
<
const
bool
*>
(
&
recv_is_token_in_rank_uint64
);
// 如果当前lane_id等于num_topk_ranks,则更新src_meta
if
(
lane_id
==
num_topk_ranks
)
{
auto
recv_is_token_in_rank_values
=
reinterpret_cast
<
const
bool
*>
(
&
recv_is_token_in_rank_uint64
);
if
(
lane_id
==
num_topk_ranks
)
src_meta
=
SourceMeta
(
rdma_rank
,
recv_is_token_in_rank_values
);
}
// 计算目标发送缓冲区的地址,并存储在dst_send_buffers数组中
// 获取到发送地址, num_topk_ranks-1 是需要发送的ranks数
dst_send_buffers
[
num_topk_ranks
++
]
=
reinterpret_cast
<
uint8_t
*>
(
broadcast
(
send_buffer
,
i
))
+
slot_idx
*
num_bytes_per_rdma_token
;
}
dst_send_buffers
[
num_topk_ranks
++
]
=
reinterpret_cast
<
uint8_t
*>
(
broadcast
(
send_buffer
,
i
))
+
slot_idx
*
num_bytes_per_rdma_token
;
}
EP_DEVICE_ASSERT
(
num_topk_ranks
<=
kNumTopkRDMARanks
);
//////////////// 复制数据到发送缓冲区 ////////////////
// 复制源元数据到对称发送缓冲区
if
(
lane_id
<
num_topk_ranks
)
{
st_na_global
(
reinterpret_cast
<
SourceMeta
*>
(
dst_send_buffers
[
lane_id
]),
src_meta
);
}
// Copy `x` into symmetric send buffer
auto
st_broadcast
=
[
=
](
const
int
key
,
const
int4
&
value
)
{
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
st_na_global
(
reinterpret_cast
<
int4
*>
(
dst_send_buffers
[
j
])
+
key
,
value
);
};
UNROLLED_WARP_COPY
(
5
,
lane_id
,
hidden_int4
,
0
,
x
+
token_idx
*
hidden_int4
,
ld_nc_global
,
st_broadcast
);
for
(
int
i
=
0
;
i
<
num_topk_ranks
;
++
i
)
dst_send_buffers
[
i
]
=
reinterpret_cast
<
int4
*>
(
dst_send_buffers
[
i
])
+
hidden_int4
;
for
(
int
i
=
0
;
i
<
num_topk_ranks
;
++
i
)
{
dst_send_buffers
[
i
]
=
reinterpret_cast
<
SourceMeta
*>
(
dst_send_buffers
[
i
])
+
1
;
}
// Copy source metadata into symmetric send buffer
if
(
lane_id
<
num_topk_ranks
)
st_na_global
(
reinterpret_cast
<
SourceMeta
*>
(
dst_send_buffers
[
lane_id
]),
src_meta
);
// 复制 `x` 到对称发送缓冲区
auto
st_broadcast
=
[
=
](
const
int
key
,
const
int4
&
value
)
{
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
{
st_na_global
(
reinterpret_cast
<
int4
*>
(
dst_send_buffers
[
j
])
+
key
,
value
);
}
};
UNROLLED_WARP_COPY
(
5
,
lane_id
,
hidden_int4
,
0
,
x
+
token_idx
*
hidden_int4
,
ld_nc_global
,
st_broadcast
);
for
(
int
i
=
0
;
i
<
num_topk_ranks
;
++
i
)
{
dst_send_buffers
[
i
]
=
reinterpret_cast
<
int4
*>
(
dst_send_buffers
[
i
])
+
hidden_int4
;
}
for
(
int
i
=
0
;
i
<
num_topk_ranks
;
++
i
)
dst_send_buffers
[
i
]
=
reinterpret_cast
<
SourceMeta
*>
(
dst_send_buffers
[
i
])
+
1
;
//
复制
`x_scales`
到对称发送缓冲区
for
(
int
i
=
lane_id
;
i
<
num_scales
;
i
+=
kWarpSize
)
{
//
Copy
`x_scales`
into symmetric send buffer
for
(
int
i
=
lane_id
;
i
<
num_scales
;
i
+=
kWarpSize
)
{
auto
value
=
ld_nc_global
(
x_scales
+
token_idx
*
num_scales
+
i
);
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
{
st_na_global
(
reinterpret_cast
<
float
*>
(
dst_send_buffers
[
j
])
+
i
,
value
);
}
}
for
(
int
i
=
0
;
i
<
num_topk_ranks
;
++
i
)
{
dst_send_buffers
[
i
]
=
reinterpret_cast
<
float
*>
(
dst_send_buffers
[
i
])
+
num_scales
;
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
st_na_global
(
reinterpret_cast
<
float
*>
(
dst_send_buffers
[
j
])
+
i
,
value
);
}
for
(
int
i
=
0
;
i
<
num_topk_ranks
;
++
i
)
dst_send_buffers
[
i
]
=
reinterpret_cast
<
float
*>
(
dst_send_buffers
[
i
])
+
num_scales
;
//
复制
`topk_idx`
和
`topk_weights`
到对称发送缓冲区
for
(
int
i
=
lane_id
;
i
<
num_topk
*
num_topk_ranks
;
i
+=
kWarpSize
)
{
//
Copy
`topk_idx`
and
`topk_weights`
into symmetric send buffer
for
(
int
i
=
lane_id
;
i
<
num_topk
*
num_topk_ranks
;
i
+=
kWarpSize
)
{
auto
rank_idx
=
i
/
num_topk
,
copy_idx
=
i
%
num_topk
;
auto
idx_value
=
static_cast
<
int
>
(
ld_nc_global
(
topk_idx
+
token_idx
*
num_topk
+
copy_idx
));
auto
idx_value
=
static_cast
<
int
>
(
ld_nc_global
(
topk_idx
+
token_idx
*
num_topk
+
copy_idx
));
auto
weight_value
=
ld_nc_global
(
topk_weights
+
token_idx
*
num_topk
+
copy_idx
);
st_na_global
(
reinterpret_cast
<
int
*>
(
dst_send_buffers
[
rank_idx
])
+
copy_idx
,
idx_value
);
st_na_global
(
reinterpret_cast
<
float
*>
(
dst_send_buffers
[
rank_idx
])
+
num_topk
+
copy_idx
,
weight_value
);
st_na_global
(
reinterpret_cast
<
int
*>
(
dst_send_buffers
[
rank_idx
])
+
copy_idx
,
idx_value
);
st_na_global
(
reinterpret_cast
<
float
*>
(
dst_send_buffers
[
rank_idx
])
+
num_topk
+
copy_idx
,
weight_value
);
}
}
// 结尾部分
// 获取顺序锁
while
(
lane_id
==
0
&&
rdma_send_next_token_idx
!=
token_idx
)
{
// 等待
}
// Epilogue
// Acquire sequential lock
while
(
lane_id
==
0
and
rdma_send_next_token_idx
!=
token_idx
)
;
syncwarp
();
//
更新最后一个令牌尾部
if
(
last_rdma_tail_idx
>=
0
)
{
st_release_cta
(
const_cast
<
int
*>
(
rdma_send_channel_tail
+
lane_id
),
last_rdma_tail_idx
+
1
);
}
//
Update last token tail
if
(
last_rdma_tail_idx
>=
0
)
st_release_cta
(
const_cast
<
const
int
*>
(
rdma_send_channel_tail
+
lane_id
),
last_rdma_tail_idx
+
1
);
// 释放顺序锁
if
(
lane_id
==
0
)
{
rdma_send_next_token_idx
+=
1
;
}
}
else
if
(
warp_role
==
WarpRole
::
kRDMASenderCoordinator
)
{
/*
这段代码的主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作。
它首先计算每个RDMA秩需要发送的令牌数,然后在所有RDMA秩之间循环,检查是否有令牌需要发送。
如果有,它将计算本次需要发出的令牌数,并发出相应的RDMA发送请求。
最后,它更新相关的尾部位置,以便下次循环时可以正确地计算需要发送的令牌数。
// Release sequential lock
lane_id
==
0
?
(
rdma_send_next_token_idx
+=
1
)
:
0
;
}
else
if
(
warp_role
==
WarpRole
::
kRDMASenderCoordinator
)
{
// NOTES: in case of splitting the issued put at the end of the buffer
EP_DEVICE_ASSERT
(
num_max_rdma_chunked_recv_tokens
%
num_max_rdma_chunked_send_tokens
==
0
);
kRDMASenderCoordinator使用了同sm内存一致性(ld.acquire.cta.s32),
dushmem内存一致性(dushmem_fence)和原子操作(dushmemx_signal_op),减少硬同步,提升整体效率。
*/
if
(
warp_id
>
kNumDispatchRDMASenderWarps
)
{
return
;
}
// 确保最大接收令牌数可以被最大发送令牌数整除,以避免缓冲区分割问题
EP_DEVICE_ASSERT
(
num_max_rdma_chunked_recv_tokens
%
num_max_rdma_chunked_send_tokens
==
0
);
// Synchronize shared memory
sync_rdma_sender_smem
();
// 同步共享内存,确保所有线程在继续之前都达到了这一点
// sync_rdma_sender_smem();
__syncthreads
();
// 计算当前通道需要发送的令牌数
// Get number of tokens to send for each RDMA rank
int
num_tokens_to_send
=
0
;
if
(
lane_id
<
kNumRDMARanks
)
{
if
(
lane_id
<
kNumRDMARanks
)
{
num_tokens_to_send
=
rdma_channel_prefix_matrix
[
lane_id
*
num_channels
+
channel_id
];
if
(
channel_id
>
0
)
num_tokens_to_send
-=
rdma_channel_prefix_matrix
[
lane_id
*
num_channels
+
channel_id
-
1
];
if
(
channel_id
>
0
)
num_tokens_to_send
-=
rdma_channel_prefix_matrix
[
lane_id
*
num_channels
+
channel_id
-
1
];
}
//
记录上次发出的尾部位置
//
Iterate all RDMA ranks
int
last_issued_tail
=
0
;
// 当有任何RDMA秩需要发送令牌时,继续循环
while
(
__any_sync
(
kFullWarpMask
,
num_tokens_to_send
>
0
))
{
for
(
int
i
=
0
,
synced_num_tokens_to_send
;
i
<
kNumRDMARanks
;
++
i
)
{
// 计算目标RDMA秩
while
(
__any_sync
(
kFullWarpMask
,
num_tokens_to_send
>
0
))
{
for
(
int
i
=
0
,
synced_num_tokens_to_send
;
i
<
kNumRDMARanks
;
++
i
)
{
int
dst_rdma_rank
=
(
i
+
channel_id
)
%
kNumRDMARanks
;
// 获取同步后的需要发送的令牌数
synced_num_tokens_to_send
=
shfl_sync
(
num_tokens_to_send
,
dst_rdma_rank
);
if
(
synced_num_tokens_to_send
==
0
)
continue
;
if
(
synced_num_tokens_to_send
==
0
)
continue
;
// 如果没有令牌需要发送,则跳过
// 读取进度
// Read progress
auto
synced_last_issued_tail
=
shfl_sync
(
last_issued_tail
,
dst_rdma_rank
);
auto
processed_tail
=
ld_acquire_cta
(
const_cast
<
const
int
*>
(
rdma_send_channel_tail
+
dst_rdma_rank
));
auto
processed_tail
=
ld_acquire_cta
(
const_cast
<
const
int
*>
(
rdma_send_channel_tail
+
dst_rdma_rank
));
auto
num_tokens_processed
=
processed_tail
-
synced_last_issued_tail
;
// 如果处理的令牌数不等于需要发送的令牌数,并且处理的令牌数小于最大发送令牌数,则跳过
if
(
num_tokens_processed
!=
synced_num_tokens_to_send
&&
num_tokens_processed
<
num_max_rdma_chunked_send_tokens
)
if
(
num_tokens_processed
!=
synced_num_tokens_to_send
and
num_tokens_processed
<
num_max_rdma_chunked_send_tokens
)
continue
;
//
计算本次需要发出的令牌数
auto
num_tokens_to_issue
=
min
(
num_tokens_processed
,
num_max_rdma_chunked_send_tokens
);
EP_DEVICE_ASSERT
(
num_tokens_to_issue
>=
0
&&
num_tokens_to_issue
<=
synced_num_tokens_to_send
);
// 发出RDMA发送请求
if
(
dst_rdma_rank
!=
rdma_rank
)
{
//
Issue RDMA send
auto
num_tokens_to_issue
=
min
(
num_tokens_processed
,
num_max_rdma_chunked_send_tokens
);
EP_DEVICE_ASSERT
(
num_tokens_to_issue
>=
0
and
num_tokens_to_issue
<=
synced_num_tokens_to_send
);
if
(
dst_rdma_rank
!=
rdma_rank
)
{
auto
dst_slot_idx
=
synced_last_issued_tail
%
num_max_rdma_chunked_recv_tokens
;
EP_DEVICE_ASSERT
(
dst_slot_idx
+
num_tokens_to_issue
<=
num_max_rdma_chunked_recv_tokens
);
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_schar_put_nbi_warp
(
ctx
,
#else
shmemx_int8_put_nbi_warp
(
#endif
EP_DEVICE_ASSERT
(
dst_slot_idx
+
num_tokens_to_issue
<=
num_max_rdma_chunked_recv_tokens
);
rocshmem
::
rocshmem_ctx_schar_put_nbi_wave
(
ctx
,
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
dst_slot_idx
*
num_bytes_per_rdma_token
,
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
dst_slot_idx
*
num_bytes_per_rdma_token
,
num_bytes_per_rdma_token
*
num_tokens_to_issue
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_quiet
(
ctx
);
#else
shmem_fence
();
#endif
rocshmem
::
rocshmem_ctx_quiet
(
ctx
);
}
else
{
//
对于本地RDMA秩,使用较轻的内存屏障
//
Lighter fence for local RDMA rank
memory_fence
();
}
//
更新尾部位置
//
Update tails
syncwarp
();
if
(
lane_id
==
dst_rdma_rank
)
{
if
(
lane_id
==
dst_rdma_rank
)
{
last_issued_tail
+=
num_tokens_to_issue
;
num_tokens_to_send
-=
num_tokens_to_issue
;
// 更新远端rdma 己方已发送的token数,用于做发送信息同步。用于与kRDMAAndNVLForwarder互相通信
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add
(
ctx
,
#else
shmem_signal_op_add
(
#endif
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_tokens_to_issue
,
rocshmem
::
rocshmem_ctx_ulong_atomic_add
(
ctx
,
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_tokens_to_issue
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
}
}
}
// while(__any(num_tokens_to_send > 0))
}
else
if
(
warp_role
==
WarpRole
::
kRDMAAndNVLForwarder
)
{
/*
这段代码的主要功能是在一个CUDA内核中协调从RDMA消费者到NVL生产者的转发操作。
它首先计算目标NVL秩和目标秩,然后等待相关的计数器到达。
接着,它检查目标队列是否为空,或者等待一个缓冲区被释放。
然后,它找到下一个源RDMA秩,并遍历RDMA缓冲区中的每一个令牌,复制相关的数据到NVL缓冲区。
最后,它同步头部和尾部索引,并标记通道为退役状态。
*/
// RDMA消费者和NVL生产者
const
auto
dst_nvl_rank
=
target_rank
;
// 目标NVL秩
const
auto
dst_rank
=
rdma_rank
*
NUM_MAX_NVL_PEERS
+
dst_nvl_rank
;
// 目标秩
const
auto
dst_rank_expert_begin
=
dst_rank
*
(
num_experts
/
num_ranks
);
// 目标秩专家开始
const
auto
dst_rank_expert_end
=
dst_rank_expert_begin
+
(
num_experts
/
num_ranks
);
// 目标秩专家结束
// 等待计数器到达
}
}
else
if
(
warp_role
==
WarpRole
::
kRDMAAndNVLForwarder
)
{
// RDMA consumers and NVL producers
const
auto
dst_nvl_rank
=
target_rank
;
const
auto
dst_rank
=
rdma_rank
*
NUM_MAX_NVL_PEERS
+
dst_nvl_rank
;
const
auto
dst_rank_expert_begin
=
dst_rank
*
(
num_experts
/
num_ranks
);
const
auto
dst_rank_expert_end
=
dst_rank_expert_begin
+
(
num_experts
/
num_ranks
);
// Wait counters to arrive
int
num_tokens_to_recv_from_rdma
=
0
,
src_rdma_channel_prefix
=
0
;
EP_DEVICE_ASSERT
(
kNumRDMARanks
<=
kWarpSize
);
auto
start_time
=
wall_clock64
();
if
(
lane_id
<
kNumRDMARanks
)
{
while
(
true
)
{
// 对应于kRDMASender中的数据写入
auto
meta_0
=
ld_volatile_global
(
rdma_channel_meta
.
recv_buffer
(
lane_id
)
+
dst_nvl_rank
);
// 是nvl节点的起始地址
auto
meta_1
=
ld_volatile_global
(
rdma_channel_meta
.
recv_buffer
(
lane_id
)
+
NUM_MAX_NVL_PEERS
+
dst_nvl_rank
);
// nvl节点的结束地址
auto
meta_2
=
ld_volatile_global
(
rdma_channel_meta
.
recv_buffer
(
lane_id
)
+
NUM_MAX_NVL_PEERS
*
2
);
// 本rdma节点的起始地址
auto
meta_3
=
ld_volatile_global
(
rdma_channel_meta
.
recv_buffer
(
lane_id
)
+
NUM_MAX_NVL_PEERS
*
2
+
1
);
// 本节点的结束地址
if
(
meta_0
<
0
&&
meta_1
<
0
&&
meta_2
<
0
&&
meta_3
<
0
)
{
// 通知NVL秩
if
(
lane_id
<
kNumRDMARanks
)
{
while
(
true
)
{
auto
meta_0
=
ld_volatile_global
(
rdma_channel_meta
.
recv_buffer
(
lane_id
)
+
dst_nvl_rank
);
auto
meta_1
=
ld_volatile_global
(
rdma_channel_meta
.
recv_buffer
(
lane_id
)
+
NUM_MAX_NVL_PEERS
+
dst_nvl_rank
);
auto
meta_2
=
ld_volatile_global
(
rdma_channel_meta
.
recv_buffer
(
lane_id
)
+
NUM_MAX_NVL_PEERS
*
2
);
auto
meta_3
=
ld_volatile_global
(
rdma_channel_meta
.
recv_buffer
(
lane_id
)
+
NUM_MAX_NVL_PEERS
*
2
+
1
);
if
(
meta_0
<
0
and
meta_1
<
0
and
meta_2
<
0
and
meta_3
<
0
)
{
// Notify NVL ranks
int
start_sum
=
-
meta_0
-
1
,
end_sum
=
-
meta_1
-
1
;
EP_DEVICE_ASSERT
(
start_sum
>=
0
&&
end_sum
>=
0
&&
end_sum
>=
start_sum
);
st_relaxed_sys_global
(
nvl_channel_prefix_start
.
buffer
()
+
lane_id
,
-
start_sum
-
1
);
EP_DEVICE_ASSERT
(
start_sum
>=
0
and
end_sum
>=
0
and
end_sum
>=
start_sum
);
st_relaxed_sys_global
(
nvl_channel_prefix_start
.
buffer
()
+
lane_id
,
-
start_sum
-
1
);
st_relaxed_sys_global
(
nvl_channel_prefix_end
.
buffer
()
+
lane_id
,
-
end_sum
-
1
);
//
保存从RDMA通道接收的令牌计数
//
Save RDMA channel received token count
src_rdma_channel_prefix
=
-
meta_2
-
1
;
auto
src_rdma_channel_prefix_1
=
-
meta_3
-
1
;
num_tokens_to_recv_from_rdma
=
src_rdma_channel_prefix_1
-
src_rdma_channel_prefix
;
// 是远端 rdma_rank 会发送给当前节点的token数量
if
(
!
kCachedMode
)
recv_rdma_channel_prefix_matrix
[
lane_id
*
num_channels
+
channel_id
]
=
src_rdma_channel_prefix_1
;
src_rdma_channel_prefix
+=
lane_id
==
0
?
0
:
recv_rdma_rank_prefix_sum
[
lane_id
-
1
];
// 对应的远端 rdma_rank 的起始index, 存在线程0之中
num_tokens_to_recv_from_rdma
=
src_rdma_channel_prefix_1
-
src_rdma_channel_prefix
;
if
(
not
kCachedMode
)
recv_rdma_channel_prefix_matrix
[
lane_id
*
num_channels
+
channel_id
]
=
src_rdma_channel_prefix_1
;
src_rdma_channel_prefix
+=
lane_id
==
0
?
0
:
recv_rdma_rank_prefix_sum
[
lane_id
-
1
];
EP_DEVICE_ASSERT
(
num_tokens_to_recv_from_rdma
>=
0
);
break
;
}
// 超时检查
if
(
wall_clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
)
{
printf
(
"DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, nvl: %d, src RDMA lane: %d, dst NVL: %d, meta: %d, %d, %d, %d
\n
"
,
channel_id
,
rdma_rank
,
nvl_rank
,
lane_id
,
dst_nvl_rank
,
meta_0
,
meta_1
,
meta_2
,
meta_3
);
// Timeout check
long
long
int
elapsed_time
=
wall_clock64
()
>
start_time
?
wall_clock64
()
-
start_time
:
0
;
if
(
elapsed_time
>
NUM_TIMEOUT_CYCLES
)
{
printf
(
"DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, "
"nvl: %d, src RDMA lane: %d, dst NVL: %d, meta: %d, %d, %d, %d
\n
"
,
channel_id
,
rdma_rank
,
nvl_rank
,
lane_id
,
dst_nvl_rank
,
meta_0
,
meta_1
,
meta_2
,
meta_3
);
trap
();
}
}
}
syncwarp
();
// 移动缓存的头部
// Shift cached head
send_nvl_head
+=
src_rdma_channel_prefix
*
NUM_MAX_NVL_PEERS
+
dst_nvl_rank
;
// 等待共享内存被清理
// sync_forwarder_smem();
__syncthreads
();
// Wait shared memory to be cleaned
sync_forwarder_smem
();
// 开始准备处理接受数据,直到所有的数据接受完成。
// 转发从RDMA缓冲区的令牌
// 注意:总是从本地秩开始
// Forward tokens from RDMA buffer
// NOTES: always start from the local rank
int
src_rdma_rank
=
sm_id
%
kNumRDMARanks
;
int
cached_rdma_channel_head
=
0
,
cached_rdma_channel_tail
=
0
;
int
cached_nvl_channel_head
=
0
,
cached_nvl_channel_tail
=
0
,
rdma_nvl_token_idx
=
0
;
while
(
__any_sync
(
kFullWarpMask
,
num_tokens_to_recv_from_rdma
>
0
))
{
//
检查nvl目标队列是否为空,或者等待一个缓冲区被释放
while
(
__any_sync
(
kFullWarpMask
,
num_tokens_to_recv_from_rdma
>
0
))
{
//
Check destination queue emptiness, or wait a buffer to be released
start_time
=
wall_clock64
();
// 用于给kNVLReceivers进行互动,控制数据的传输速度
while
(
lane_id
==
0
)
{
while
(
lane_id
==
0
)
{
int
num_used_slots
=
cached_nvl_channel_tail
-
cached_nvl_channel_head
;
if
(
num_max_nvl_chunked_recv_tokens
-
num_used_slots
>=
num_max_nvl_chunked_send_tokens
)
if
(
num_max_nvl_chunked_recv_tokens
-
num_used_slots
>=
num_max_nvl_chunked_send_tokens
)
break
;
cached_nvl_channel_head
=
ld_volatile_global
(
nvl_channel_head
.
buffer
());
// 超时检查
if
(
wall_clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
)
{
printf
(
"DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, head: %d, tail: %d
\n
"
,
channel_id
,
rdma_rank
,
nvl_rank
,
dst_nvl_rank
,
ld_volatile_global
(
nvl_channel_head
.
buffer
()),
cached_nvl_channel_tail
);
// Timeout check
long
long
int
elapsed_time
=
wall_clock64
()
>
start_time
?
wall_clock64
()
-
start_time
:
0
;
if
(
elapsed_time
>
NUM_TIMEOUT_CYCLES
)
{
printf
(
"DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, "
"nvl: %d, dst NVL: %d, head: %d, tail: %d
\n
"
,
channel_id
,
rdma_rank
,
nvl_rank
,
dst_nvl_rank
,
ld_volatile_global
(
nvl_channel_head
.
buffer
()),
cached_nvl_channel_tail
);
trap
();
}
}
syncwarp
();
//
找到下一个源RDMA秩(轮询)
//
Find next source RDMA rank (round-robin)
start_time
=
wall_clock64
();
while
(
true
)
{
while
(
true
)
{
src_rdma_rank
=
(
src_rdma_rank
+
1
)
%
kNumRDMARanks
;
if
(
shfl_sync
(
num_tokens_to_recv_from_rdma
,
src_rdma_rank
)
>
0
)
{
if
(
lane_id
==
src_rdma_rank
&&
cached_rdma_channel_head
==
cached_rdma_channel_tail
)
cached_rdma_channel_tail
=
static_cast
<
int
>
(
ld_acquire_sys_global
(
rdma_channel_tail
.
buffer
(
src_rdma_rank
)));
if
(
shfl_sync
(
cached_rdma_channel_tail
>
cached_rdma_channel_head
,
src_rdma_rank
))
{
if
(
shfl_sync
(
num_tokens_to_recv_from_rdma
,
src_rdma_rank
)
>
0
)
{
if
(
lane_id
==
src_rdma_rank
and
cached_rdma_channel_head
==
cached_rdma_channel_tail
)
cached_rdma_channel_tail
=
static_cast
<
int
>
(
ld_relaxed_sys_global
(
rdma_channel_tail
.
buffer
(
src_rdma_rank
)));
if
(
shfl_sync
(
cached_rdma_channel_tail
>
cached_rdma_channel_head
,
src_rdma_rank
))
break
;
}
}
// 超时检查
if
(
wall_clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
and
lane_id
<
kNumRDMARanks
)
{
printf
(
"DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, src RDMA lane: %d, head: %d, tail: %d, expected: %d
\n
"
,
channel_id
,
rdma_rank
,
nvl_rank
,
dst_nvl_rank
,
lane_id
,
cached_rdma_channel_head
,
cached_rdma_channel_tail
,
num_tokens_to_recv_from_rdma
);
// Timeout check
long
long
int
elapsed_time
=
wall_clock64
()
>
start_time
?
wall_clock64
()
-
start_time
:
0
;
if
(
elapsed_time
>
NUM_TIMEOUT_CYCLES
and
lane_id
<
kNumRDMARanks
)
{
printf
(
"DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, "
"nvl: %d, dst NVL: %d, src RDMA lane: %d, head: %d, tail: %d, expected: "
"%d
\n
"
,
channel_id
,
rdma_rank
,
nvl_rank
,
dst_nvl_rank
,
lane_id
,
cached_rdma_channel_head
,
cached_rdma_channel_tail
,
num_tokens_to_recv_from_rdma
);
trap
();
}
}
auto
src_rdma_head
=
shfl_sync
(
cached_rdma_channel_head
,
src_rdma_rank
);
auto
src_rdma_tail
=
shfl_sync
(
cached_rdma_channel_tail
,
src_rdma_rank
);
//
遍历RDMA缓冲区中的每一个令牌
for
(
int
i
=
src_rdma_head
,
num_tokens_sent
=
0
;
i
<
src_rdma_tail
;
++
i
)
{
//
Iterate over every token from the RDMA buffer
for
(
int
i
=
src_rdma_head
,
num_tokens_sent
=
0
;
i
<
src_rdma_tail
;
++
i
)
{
auto
rdma_slot_idx
=
i
%
num_max_rdma_chunked_recv_tokens
;
// 首先读取SourceMeta,对应到kRDMASenderCoordinator中 kRDMASender 的数据远程写入
void
*
shifted
=
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
;
auto
src_meta
=
ld_nc_global
(
reinterpret_cast
<
SourceMeta
*>
(
reinterpret_cast
<
int8_t
*>
(
shifted
)));
if
(
lane_id
==
src_rdma_rank
)
{
num_tokens_to_recv_from_rdma
-=
1
;
}
void
*
shifted
=
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
;
auto
src_meta
=
ld_nc_global
(
reinterpret_cast
<
SourceMeta
*>
(
reinterpret_cast
<
int8_t
*>
(
shifted
)
+
hidden_bytes
));
lane_id
==
src_rdma_rank
?
(
num_tokens_to_recv_from_rdma
-=
1
)
:
0
;
bool
is_in_dst_nvl_rank
=
src_meta
.
is_token_in_nvl_rank
(
dst_nvl_rank
);
if
(
lane_id
==
src_rdma_rank
)
{
if
(
lane_id
==
src_rdma_rank
)
{
auto
cached_head
=
is_in_dst_nvl_rank
?
rdma_nvl_token_idx
:
-
1
;
rdma_nvl_token_idx
+=
is_in_dst_nvl_rank
;
if
(
!
kCachedMode
)
if
(
not
kCachedMode
)
send_nvl_head
[
i
*
NUM_MAX_NVL_PEERS
]
=
cached_head
;
}
if
(
!
is_in_dst_nvl_rank
)
if
(
not
is_in_dst_nvl_rank
)
continue
;
//
获取一个空闲槽位
//
Get an empty slot
int
dst_slot_idx
=
(
cached_nvl_channel_tail
++
)
%
num_max_nvl_chunked_recv_tokens
;
//
设置 src和dst 位置
auto
src_gpu_buffer_x
=
reinterpret_cast
<
int4
*>
(
reinterpret_cast
<
int8_t
*>
(
shifted
)
+
sizeof
(
SourceMeta
));
auto
src_gpu_buffer_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
int8_t
*>
(
src_gpu_buffer_x
)
+
hidden_
bytes
);
auto
src_gpu_buffer_topk_idx
=
reinterpret_cast
<
int
*>
(
reinterpret_cast
<
int8_t
*>
(
src_gpu_buffer_scales
)
+
num_scales
*
sizeof
(
float
)
);
auto
src_gpu_buffer_topk_weights
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
int8_t
*>
(
src_gpu_buffer_topk_idx
)
+
num_topk
*
sizeof
(
int
))
;
//
Copy data
UNROLLED_WARP_COPY
(
5
,
lane_id
,
hidden_int4
,
nvl_channel_x
.
buffer
()
+
dst_slot_idx
*
hidden_
int4
,
reinterpret_cast
<
int
4
*>
(
shifted
),
ld_nc_global
,
st_na_global
);
shifted
=
reinterpret_cast
<
int4
*>
(
shifted
)
+
hidden_
int
4
;
auto
dst_gpu_buffer_x
=
nvl_channel_x
.
buffer
()
+
dst_slot_idx
*
hidden_int4
;
auto
dst_gpu_buffer_scales
=
nvl_channel_x_scales
.
buffer
()
+
dst_slot_idx
*
num_scales
;
auto
dst_gpu_buffer_topk_idx
=
nvl_channel_
topk_idx
.
buffer
()
+
dst_slot_idx
*
num_topk
;
auto
dst_gpu_buffer_topk_weights
=
nvl_channel_topk_weights
.
buffer
()
+
dst_slot_idx
*
num_topk
;
// Copy source meta
if
(
lane_id
==
0
)
st_na_global
(
nvl_channel_
src_meta
.
buffer
()
+
dst_slot_idx
,
src_meta
)
;
shifted
=
reinterpret_cast
<
SourceMeta
*>
(
shifted
)
+
1
;
if
(
lane_id
==
0
)
{
st_na_global
(
reinterpret_cast
<
int64_t
*>
(
nvl_channel_src_meta
.
buffer
()
+
dst_slot_idx
),
*
reinterpret_cast
<
int64_t
*>
(
&
src_meta
));
}
// Copy `x_scales`
UNROLLED_WARP_COPY
(
1
,
lane_id
,
num_scales
,
nvl_channel_x_scales
.
buffer
()
+
dst_slot_idx
*
num_scales
,
reinterpret_cast
<
float
*>
(
shifted
),
ld_nc_global
,
st_na_global
);
shifted
=
reinterpret_cast
<
float
*>
(
shifted
)
+
num_scales
;
UNROLLED_WARP_COPY
(
5
,
lane_id
,
hidden_int4
,
dst_gpu_buffer_x
,
src_gpu_buffer_x
,
ld_direct_global
,
st_na_global
);
// Copy `topk_idx` and `topk_weights`
// NOTES: do not use `shifted` after this `if`, because only several lanes are
// shifted
if
(
lane_id
<
num_topk
)
{
// Read
auto
idx_value
=
ld_nc_global
(
reinterpret_cast
<
int
*>
(
shifted
)
+
lane_id
);
shifted
=
reinterpret_cast
<
int
*>
(
shifted
)
+
num_topk
;
auto
weight_value
=
ld_nc_global
(
reinterpret_cast
<
float
*>
(
shifted
)
+
lane_id
);
// Transform and write
idx_value
=
(
idx_value
>=
dst_rank_expert_begin
and
idx_value
<
dst_rank_expert_end
)
?
idx_value
-
dst_rank_expert_begin
:
-
1
;
st_na_global
(
nvl_channel_topk_idx
.
buffer
()
+
dst_slot_idx
*
num_topk
+
lane_id
,
idx_value
);
weight_value
=
idx_value
>=
0
?
weight_value
:
0.0
f
;
st_na_global
(
nvl_channel_topk_weights
.
buffer
()
+
dst_slot_idx
*
num_topk
+
lane_id
,
weight_value
);
}
UNROLLED_WARP_COPY
(
1
,
lane_id
,
num_scales
,
dst_gpu_buffer_scales
,
src_gpu_buffer_scales
,
ld_direct_global
,
st_na_global
);
for
(
int
t
=
lane_id
;
t
<
num_topk
;
t
+=
kWarpSize
)
{
int
idx_val
=
ld_direct_global
(
reinterpret_cast
<
int
*>
(
src_gpu_buffer_topk_idx
)
+
t
);
float
w_val
=
ld_direct_global
(
reinterpret_cast
<
float
*>
(
src_gpu_buffer_topk_weights
)
+
t
);
int
new_idx
=
(
idx_val
>=
dst_rank_expert_begin
&&
idx_val
<
dst_rank_expert_end
)
?
(
idx_val
-
dst_rank_expert_begin
)
:
-
1
;
float
new_w
=
(
new_idx
!=
-
1
)
?
w_val
:
0.0
f
;
dst_gpu_buffer_topk_idx
[
t
]
=
new_idx
;
dst_gpu_buffer_topk_weights
[
t
]
=
new_w
;
}
// 在NVL缓冲区不足的情况下,提前停止
if
((
++
num_tokens_sent
)
==
num_max_nvl_chunked_send_tokens
)
// In case of insufficient NVL buffers, early stopping
if
((
++
num_tokens_sent
)
==
num_max_nvl_chunked_send_tokens
)
src_rdma_tail
=
i
+
1
;
}
// 同步头部索引
if
(
lane_id
==
src_rdma_rank
)
forward_channel_head
[
dst_nvl_rank
][
src_rdma_rank
]
=
(
cached_rdma_channel_head
=
src_rdma_tail
);
// Sync head index
if
(
lane_id
==
src_rdma_rank
)
forward_channel_head
[
dst_nvl_rank
][
src_rdma_rank
]
=
(
cached_rdma_channel_head
=
src_rdma_tail
);
//
移动尾部索引,与kNVLReceivers互相通信使用
//
Move tail index
syncwarp
();
if
(
lane_id
==
0
)
{
st_release_sys_global
(
nvl_channel_tail
.
buffer
(),
cached_nvl_channel_tail
);
}
if
(
lane_id
==
0
)
st_relaxed_sys_global
(
nvl_channel_tail
.
buffer
(),
cached_nvl_channel_tail
);
}
// Retired
syncwarp
();
if
(
lane_id
==
0
)
{
if
(
lane_id
==
0
)
forward_channel_retired
[
dst_nvl_rank
]
=
true
;
}
}
else
if
(
warp_role
==
WarpRole
::
kForwarderCoordinator
)
{
/*
这段代码的主要功能是在一个CUDA内核中协调转发器的逻辑。
它首先检查当前warp是否是额外的转发器协调warp,如果是,则直接退出。
然后,它清理共享内存,并初始化转发通道的头部和退役状态。
接着,它进入一个无限循环,在循环中,它找到最小的头部,如果所有的通道都已退役,则退出循环。
否则,它更新远程头部,并进行纳秒级睡眠,以让其他warp工作。
*/
}
else
if
(
warp_role
==
WarpRole
::
kForwarderCoordinator
)
{
// Extra warps for forwarder coordinator should exit directly
if
(
w
ar
p_id
>
NUM_MAX_NVL_PEERS
)
if
(
t
ar
get_rank
>
0
)
return
;
//
转发warp协调器
EP_STATIC_ASSERT
(
kNumRDMARanks
<=
kWarpSize
,
"
无效的RDMA对等体数量
"
);
// 清理共享内存
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
<=
kWarpSize
,
"无效的NVL对等体数量"
);
#pragma unroll
for
(
int
i
=
lane_id
;
i
<
kNumRDMARanks
*
NUM_MAX_NVL_PEERS
;
i
+=
kWarpSize
)
//
Forward warp coordinator
EP_STATIC_ASSERT
(
kNumRDMARanks
<=
kWarpSize
,
"
Invalid number of RDMA peers
"
);
// Clean shared memory
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
<=
kWarpSize
,
"Invalid number of NVL peers"
);
for
(
int
i
=
lane_id
;
i
<
kNumRDMARanks
*
NUM_MAX_NVL_PEERS
;
i
+=
kWarpSize
)
forward_channel_head
[
i
%
NUM_MAX_NVL_PEERS
][
i
/
NUM_MAX_NVL_PEERS
]
=
0
;
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
forward_channel_retired
[
lane_id
]
=
false
;
// sync_forwarder_smem();
__syncthreads
();
sync_forwarder_smem
();
int
last_head
=
0
,
target_rdma
=
lane_id
<
kNumRDMARanks
?
lane_id
:
0
;
while
(
true
)
{
// 找到最小的头部
while
(
true
)
{
// Find minimum head
int
min_head
=
std
::
numeric_limits
<
int
>::
max
();
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_MAX_NVL_PEERS
;
++
i
)
if
(
!
forward_channel_retired
[
i
])
for
(
int
i
=
0
;
i
<
NUM_MAX_NVL_PEERS
;
++
i
)
if
(
not
forward_channel_retired
[
i
])
min_head
=
min
(
min_head
,
forward_channel_head
[
i
][
target_rdma
]);
if
(
__all_sync
(
kFullWarpMask
,
min_head
==
std
::
numeric_limits
<
int
>::
max
()))
{
if
(
__all_sync
(
kFullWarpMask
,
min_head
==
std
::
numeric_limits
<
int
>::
max
()))
break
;
}
// 更新远程头部
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
&&
min_head
>=
last_head
+
num_max_rdma_chunked_send_tokens
&&
lane_id
<
kNumRDMARanks
){
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add
(
ctx
,
#else
shmem_signal_op_add
(
#endif
rdma_channel_head
.
buffer
(
rdma_rank
),
min_head
-
last_head
,
// Update remote head
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
and
min_head
>=
last_head
+
num_max_rdma_chunked_send_tokens
and
lane_id
<
kNumRDMARanks
)
{
rocshmem
::
rocshmem_ctx_ulong_atomic_add
(
ctx
,
rdma_channel_head
.
buffer
(
rdma_rank
),
min_head
-
last_head
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
lane_id
,
nvl_rank
));
last_head
=
min_head
;
}
//
纳秒级睡眠并让其他warp工作 //
Nanosleep and let other warps work
// Nanosleep and let other warps work
__builtin_amdgcn_s_sleep
(
NUM_WAIT_CYCLES_TIMES_64
);
}
}
else
if
(
warp_role
==
WarpRole
::
kNVLReceivers
)
{
if
(
warp_id
>=
NUM_MAX_NVL_PEERS
)
{
return
;
}
// Place the main logic of your kernel here, using the parameters above.
// NVL消费者
// 从屏障结果中检索秩偏移(每个通道的寄存器存储一个RDMA秩)
}
else
{
// NVL consumers
// Retrieve rank offset from barrier results (each lane's register stores an RDMA rank)
int
src_nvl_rank
=
target_rank
,
total_offset
=
0
;
EP_STATIC_ASSERT
(
kNumRDMARanks
<=
kWarpSize
,
"
无效的RDMA对等体数量
"
);
if
(
lane_id
<
kNumRDMARanks
&&
lane_id
*
NUM_MAX_NVL_PEERS
+
src_nvl_rank
>
0
)
EP_STATIC_ASSERT
(
kNumRDMARanks
<=
kWarpSize
,
"
Invalid number of RDMA peers
"
);
if
(
lane_id
<
kNumRDMARanks
and
lane_id
*
NUM_MAX_NVL_PEERS
+
src_nvl_rank
>
0
)
total_offset
=
recv_gbl_rank_prefix_sum
[
lane_id
*
NUM_MAX_NVL_PEERS
+
src_nvl_rank
-
1
];
//
接收通道偏移
//
Receive channel offsets
int
start_offset
=
0
,
end_offset
=
0
,
num_tokens_to_recv
;
auto
start_time
=
wall_clock64
();
while
(
lane_id
<
kNumRDMARanks
)
{
while
(
lane_id
<
kNumRDMARanks
)
{
start_offset
=
ld_volatile_global
(
nvl_channel_prefix_start
.
buffer
()
+
lane_id
);
end_offset
=
ld_volatile_global
(
nvl_channel_prefix_end
.
buffer
()
+
lane_id
);
if
(
start_offset
<
0
&&
end_offset
<
0
)
{
if
(
start_offset
<
0
and
end_offset
<
0
)
{
start_offset
=
-
start_offset
-
1
,
end_offset
=
-
end_offset
-
1
;
total_offset
+=
start_offset
;
break
;
}
// 超时检查
if
(
wall_clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
)
{
printf
(
"DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, src nvl: %d, start: %d, end: %d
\n
"
,
channel_id
,
rdma_rank
,
nvl_rank
,
lane_id
,
src_nvl_rank
,
start_offset
,
end_offset
);
// Timeout check
long
long
int
elapsed_time
=
wall_clock64
()
>
start_time
?
wall_clock64
()
-
start_time
:
0
;
if
(
elapsed_time
>
NUM_TIMEOUT_CYCLES
)
{
printf
(
"DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src "
"RDMA: %d, src nvl: %d, start: %d, end: %d
\n
"
,
channel_id
,
rdma_rank
,
nvl_rank
,
lane_id
,
src_nvl_rank
,
start_offset
,
end_offset
);
trap
();
}
}
num_tokens_to_recv
=
warp_reduce_sum
(
end_offset
-
start_offset
);
// 保存以供合并使用
if
(
lane_id
<
kNumRDMARanks
&&
!
kCachedMode
)
recv_gbl_channel_prefix_matrix
[(
lane_id
*
NUM_MAX_NVL_PEERS
+
src_nvl_rank
)
*
num_channels
+
channel_id
]
=
total_offset
;
// Save for combine usage
if
(
lane_id
<
kNumRDMARanks
and
not
kCachedMode
)
recv_gbl_channel_prefix_matrix
[(
lane_id
*
NUM_MAX_NVL_PEERS
+
src_nvl_rank
)
*
num_channels
+
channel_id
]
=
total_offset
;
syncwarp
();
int
cached_channel_head_idx
=
0
,
cached_channel_tail_idx
=
0
;
while
(
num_tokens_to_recv
>
0
)
{
//
通过通道0检查通道状态
while
(
num_tokens_to_recv
>
0
)
{
//
Check channel status by lane 0
start_time
=
wall_clock64
();
while
(
lane_id
==
0
)
{
//
准备复制
if
(
cached_channel_head_idx
!=
cached_channel_tail_idx
)
while
(
lane_id
==
0
)
{
//
Ready to copy
if
(
cached_channel_head_idx
!=
cached_channel_tail_idx
)
break
;
cached_channel_tail_idx
=
ld_acquire_sys_global
(
nvl_channel_tail
.
buffer
());
// 超时检查
if
(
wall_clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
)
{
printf
(
"DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d
\n
"
,
channel_id
,
rdma_rank
,
nvl_rank
,
src_nvl_rank
,
cached_channel_head_idx
,
cached_channel_tail_idx
);
cached_channel_tail_idx
=
ld_relaxed_sys_global
(
nvl_channel_tail
.
buffer
());
// Timeout check
long
long
int
elapsed_time
=
wall_clock64
()
>
start_time
?
wall_clock64
()
-
start_time
:
0
;
if
(
elapsed_time
>
NUM_TIMEOUT_CYCLES
)
{
printf
(
"DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, "
"src NVL: %d, head: %d, tail: %d
\n
"
,
channel_id
,
rdma_rank
,
nvl_rank
,
src_nvl_rank
,
cached_channel_head_idx
,
cached_channel_tail_idx
);
trap
();
}
}
//
同步队列尾部
//
Sync queue tail
cached_channel_tail_idx
=
shfl_sync
(
cached_channel_tail_idx
,
0
);
//
复制数据
//
Copy data
int
num_recv_tokens
=
cached_channel_tail_idx
-
cached_channel_head_idx
;
for
(
int
chunk_idx
=
0
;
chunk_idx
<
num_recv_tokens
;
++
chunk_idx
,
--
num_tokens_to_recv
)
{
int
token_idx_in_buffer
=
(
cached_channel_head_idx
++
)
%
num_max_nvl_chunked_recv_tokens
;
for
(
int
chunk_idx
=
0
;
chunk_idx
<
num_recv_tokens
;
++
chunk_idx
,
--
num_tokens_to_recv
)
{
int
token_idx_in_buffer
=
(
cached_channel_head_idx
++
)
%
num_max_nvl_chunked_recv_tokens
;
auto
meta
=
ld_nc_global
(
nvl_channel_src_meta
.
buffer
()
+
token_idx_in_buffer
);
int64_t
recv_token_idx
=
shfl_sync
(
total_offset
,
meta
.
src_rdma_rank
);
(
lane_id
==
meta
.
src_rdma_rank
)
?
(
total_offset
+=
1
)
:
0
;
// 复制数据
UNROLLED_WARP_COPY
(
5
,
lane_id
,
hidden_int4
,
recv_x
+
recv_token_idx
*
hidden_int4
,
// Copy data
UNROLLED_WARP_COPY
(
5
,
lane_id
,
hidden_int4
,
recv_x
+
recv_token_idx
*
hidden_int4
,
nvl_channel_x
.
buffer
()
+
token_idx_in_buffer
*
hidden_int4
,
ld_nc_global
,
st_na_global
);
ld_nc_global
,
st_na_global
);
//
复制源元数据
if
(
lane_id
==
0
&&
!
kCachedMode
)
//
Copy source meta
if
(
lane_id
==
0
and
not
kCachedMode
)
st_na_global
(
recv_src_meta
+
recv_token_idx
,
meta
);
// 复制比例
UNROLLED_WARP_COPY
(
1
,
lane_id
,
num_scales
,
// Copy scales
UNROLLED_WARP_COPY
(
1
,
lane_id
,
num_scales
,
recv_x_scales
+
recv_token_idx
*
num_scales
,
nvl_channel_x_scales
.
buffer
()
+
token_idx_in_buffer
*
num_scales
,
ld_nc_global
,
st_na_global
);
ld_nc_global
,
st_na_global
);
//
复制
`topk_idx`
和
`topk_weights`
if
(
lane_id
<
num_topk
)
{
//
Copy
`topk_idx`
and
`topk_weights`
if
(
lane_id
<
num_topk
)
{
auto
recv_idx
=
recv_token_idx
*
num_topk
+
lane_id
;
auto
buffer_idx
=
token_idx_in_buffer
*
num_topk
+
lane_id
;
st_na_global
(
recv_topk_idx
+
recv_idx
,
static_cast
<
int64_t
>
(
ld_nc_global
(
nvl_channel_topk_idx
.
buffer
()
+
buffer_idx
)));
st_na_global
(
recv_topk_weights
+
recv_idx
,
ld_nc_global
(
nvl_channel_topk_weights
.
buffer
()
+
buffer_idx
));
st_na_global
(
recv_topk_idx
+
recv_idx
,
static_cast
<
int64_t
>
(
ld_nc_global
(
nvl_channel_topk_idx
.
buffer
()
+
buffer_idx
)));
st_na_global
(
recv_topk_weights
+
recv_idx
,
ld_nc_global
(
nvl_channel_topk_weights
.
buffer
()
+
buffer_idx
));
}
}
//
移动队列
//
Move queue
syncwarp
();
if
(
lane_id
==
0
)
{
if
(
lane_id
==
0
)
st_relaxed_sys_global
(
nvl_channel_head
.
buffer
(),
cached_channel_head_idx
);
}
}
// while(num_tokens_to_recv > 0)
}
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_wg_ctx_destroy
(
&
ctx
);
#endif
rocshmem
::
rocshmem_wg_ctx_destroy
(
&
ctx
);
}
void
dispatch
(
void
*
recv_x
,
float
*
recv_x_scales
,
int64_t
*
recv_topk_idx
,
float
*
recv_topk_weights
,
...
...
@@ -1152,8 +1135,6 @@ void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float
int
num_ranks
,
bool
is_cached_dispatch
,
hipStream_t
stream
,
int
num_channels
,
bool
low_latency_mode
)
{
constexpr
int
kNumDispatchRDMASenderWarps
=
7
;
// Make sure never OOB
EP_HOST_ASSERT
(
static_cast
<
int64_t
>
(
num_scales
)
*
scale_hidden_stride
<
std
::
numeric_limits
<
int
>::
max
());
#define DISPATCH_LAUNCH_CASE(num_rdma_ranks) \
{ \
...
...
@@ -1181,8 +1162,8 @@ void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float
EP_HOST_ASSERT
((
topk_idx
==
nullptr
)
==
(
topk_weights
==
nullptr
));
EP_HOST_ASSERT
((
recv_topk_idx
==
nullptr
)
==
(
recv_topk_weights
==
nullptr
));
SETUP_LAUNCH_CONFIG
(
num_channels
*
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
,
(
1
+
NUM_MAX_NVL_PEERS
)
*
kWarpSize
,
stream
);
SETUP_LAUNCH_CONFIG
(
num_channels
*
2
,
(
kNumDispatchRDMASenderWarps
+
1
+
NUM_MAX_NVL_PEERS
)
*
kWarpSize
,
stream
);
SWITCH_RDMA_RANKS
(
DISPATCH_LAUNCH_CASE
);
#undef DISPATCH_LAUNCH_CASE
}
...
...
@@ -1209,7 +1190,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
if
(
sm_id
==
0
)
{
// Barrier for RDMA
if
(
thread_id
==
kWarpSize
)
du
shmem_
barrier
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
shmem_
sync
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
// Barrier for NVL
barrier_block
<
NUM_MAX_NVL_PEERS
>
(
barrier_signal_ptrs
,
nvl_rank
);
...
...
@@ -1228,7 +1209,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
// Barrier again
if
(
thread_id
==
kWarpSize
)
du
shmem_
barrier
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
shmem_
sync
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
// Barrier again
barrier_block
<
NUM_MAX_NVL_PEERS
>
(
barrier_signal_ptrs
,
nvl_rank
);
...
...
@@ -1236,13 +1217,13 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
if
(
is_cached_dispatch
)
return
;
EP_DEVICE_ASSERT
(
num_warps
>=
num_channels
);
EP_DEVICE_ASSERT
(
num_rdma_ranks
<=
kWarpSize
);
// Iterate in reverse order
for
(
int
channel_id
=
warp_id
;
channel_id
<
num_channels
;
channel_id
+=
num_warps
)
{
if
(
lane_id
<
num_rdma_ranks
)
{
if
(
lane_id
<
num_rdma_ranks
and
warp_id
<
num_channels
)
{
int
token_start_idx
,
token_end_idx
;
get_channel_task_range
(
num_combined_tokens
,
num_channels
,
channel
_id
,
token_start_idx
,
get_channel_task_range
(
num_combined_tokens
,
num_channels
,
warp
_id
,
token_start_idx
,
token_end_idx
);
// NOTES: `1 << 25` is a heuristic large number
...
...
@@ -1257,26 +1238,26 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
}
}
}
}
}
else
{
if
(
is_cached_dispatch
)
return
;
EP_DEVICE_ASSERT
(
rdma_channel_prefix_matrix
!=
nullptr
and
rdma_rank_prefix_sum
!=
nullptr
);
EP_DEVICE_ASSERT
(
num_warps
>=
num_channels
);
EP_DEVICE_ASSERT
(
rdma_channel_prefix_matrix
!=
nullptr
and
rdma_rank_prefix_sum
!=
nullptr
);
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
<=
kWarpSize
,
"Too many NVL peers"
);
constexpr
int
num_clean_sms
=
2
;
for
(
int
channel_id
=
warp_id
;
channel_id
<
num_channels
;
channel_id
+=
num_warps
)
{
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
{
if
(
lane_id
<
NUM_MAX_NVL_PEERS
and
warp_id
<
num_channels
)
{
for
(
int
dst_rdma_rank
=
sm_id
-
num_clean_sms
;
dst_rdma_rank
<
num_rdma_ranks
;
dst_rdma_rank
+=
num_channels
*
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
-
num_clean_sms
)
{
dst_rdma_rank
+=
num_channels
*
2
-
num_clean_sms
)
{
// Iterate in reverse order
int
token_start_idx
=
channel
_id
==
0
warp
_id
==
0
?
0
:
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel
_id
-
1
];
:
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
warp
_id
-
1
];
int
token_end_idx
=
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel
_id
];
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
warp
_id
];
int
shift
=
dst_rdma_rank
==
0
?
0
:
rdma_rank_prefix_sum
[
dst_rdma_rank
-
1
];
token_start_idx
+=
shift
,
token_end_idx
+=
shift
;
...
...
@@ -1294,7 +1275,6 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
}
}
}
}
}
void
cached_notify
(
int
hidden_int4
,
int
num_scales
,
int
num_topk_idx
,
int
num_topk_weights
,
...
...
@@ -1305,7 +1285,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
int
num_max_nvl_chunked_recv_tokens
,
int
**
barrier_signal_ptrs
,
int
rank
,
hipStream_t
stream
,
int64_t
num_rdma_bytes
,
int64_t
num_nvl_bytes
,
bool
is_cached_dispatch
,
bool
low_latency_mode
)
{
const
int
num_threads
=
::
min
(
1024
,
::
max
(
128
,
kWarpSize
*
num_channels
)
)
;
const
int
num_threads
=
::
max
(
128
,
kWarpSize
*
num_channels
);
const
auto
num_rdma_ranks
=
num_ranks
/
NUM_MAX_NVL_PEERS
;
// Get clean meta
...
...
@@ -1321,11 +1301,11 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
num_nvl_bytes
);
EP_HOST_ASSERT
(
num_rdma_bytes
<
std
::
numeric_limits
<
int
>::
max
());
EP_HOST_ASSERT
(
num_nvl_bytes
<
std
::
numeric_limits
<
int
>::
max
());
EP_HOST_ASSERT
(
num_channels
*
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
>
2
);
EP_HOST_ASSERT
(
num_channels
*
2
>
2
);
// Launch kernel
auto
cached_notify_func
=
low_latency_mode
?
cached_notify
<
true
>
:
cached_notify
<
false
>
;
SETUP_LAUNCH_CONFIG
(
num_channels
*
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
,
num_threads
,
stream
);
SETUP_LAUNCH_CONFIG
(
num_channels
*
2
,
num_threads
,
stream
);
LAUNCH_KERNEL_NON_COOPERATIVE
(
&
cfg
,
cached_notify_func
,
rdma_clean_meta
.
first
,
rdma_clean_meta
.
second
,
nvl_clean_meta
.
first
,
nvl_clean_meta
.
second
,
combined_rdma_head
,
num_combined_tokens
,
...
...
@@ -1334,45 +1314,49 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
cpu_rdma_team
);
}
template
<
int
kNumRanks
,
typename
dtype_t
,
int
kMaxNumRanks
,
bool
kUseMLS
,
typename
GetAddr
Fn
,
typename
ReceiveTWFn
>
template
<
int
kNumRanks
,
typename
dtype_t
,
int
kMaxNumRanks
,
int
kWidth
,
typename
Receive
Fn
,
typename
ReceiveTWFn
>
__device__
int
combine_token
(
bool
is_token_in_rank
,
int
head_idx
,
int
lane_id
,
int
hidden_int4
,
int
num_topk
,
int4
*
combined_row
,
float
*
combined_topk_weights
,
int
num_max_recv_tokens
,
const
GetAddrFn
&
get_addr_fn
,
const
ReceiveTWFn
&
recv_tw_fn
)
{
int
num_max_recv_tokens
,
const
ReceiveFn
&
recv_fn
,
const
ReceiveTWFn
&
recv_tw_fn
)
{
constexpr
auto
kDtypePerInt4
=
sizeof
(
int4
)
/
sizeof
(
dtype_t
);
// Broadcast current heads
// Lane `i` holds the head of rank `i` and `is_token_in_rank`
EP_STATIC_ASSERT
(
kMaxNumRanks
<=
kW
arpSize
,
"Too many ranks"
);
EP_STATIC_ASSERT
(
kMaxNumRanks
<=
kW
idth
,
"Too many ranks"
);
int
num_topk_ranks
=
0
,
topk_ranks
[
kMaxNumRanks
],
slot_indices
[
kMaxNumRanks
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kNumRanks
;
++
i
)
if
(
shfl_sync
(
is_token_in_rank
,
i
))
{
slot_indices
[
num_topk_ranks
]
=
shfl_sync
(
head_idx
,
i
)
%
num_max_recv_tokens
;
for
(
int
i
=
0
;
i
<
kNumRanks
;
++
i
)
if
(
shfl_sync
(
is_token_in_rank
,
i
,
kWidth
))
{
slot_indices
[
num_topk_ranks
]
=
shfl_sync
(
head_idx
,
i
,
kWidth
)
%
num_max_recv_tokens
;
topk_ranks
[
num_topk_ranks
++
]
=
i
;
}
EP_DEVICE_ASSERT
(
num_topk_ranks
<=
kMaxNumRanks
);
// Reduce data
#pragma unroll
for
(
int
i
=
lane_id
;
i
<
hidden_int4
;
i
+=
kWarpSize
)
{
// Read buffers
float
values
[
kDtypePerInt4
]
=
{
0
};
// 8 × 4B = 32B
for
(
int
i
=
lane_id
;
i
<
hidden_int4
;
i
+=
kWidth
)
{
float
values
[
kDtypePerInt4
]
=
{
0
};
// Temporary buffer
int4
temp
;
#pragma unroll
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
{
int4
recv_value
=
ld_nc_global
(
get_addr_fn
(
topk_ranks
[
j
],
slot_indices
[
j
],
i
));
auto
recv_dtypes
=
reinterpret_cast
<
const
dtype_t
*>
(
&
recv_value
);
temp
=
recv_fn
(
topk_ranks
[
j
],
slot_indices
[
j
],
i
);
const
dtype_t
*
d
=
reinterpret_cast
<
const
dtype_t
*>
(
&
temp
);
#pragma unroll
for
(
int
k
=
0
;
k
<
kDtypePerInt4
;
++
k
)
values
[
k
]
+=
static_cast
<
float
>
(
recv_dtypes
[
k
]);
values
[
k
]
+=
static_cast
<
float
>
(
d
[
k
]);
}
// Cast back to `dtype_t` and write
int4
out_int4
;
auto
out_dtypes
=
reinterpret_cast
<
dtype_t
*>
(
&
out_int4
);
dtype_t
*
out_dtypes
=
reinterpret_cast
<
dtype_t
*>
(
&
out_int4
);
#pragma unroll
for
(
int
j
=
0
;
j
<
kDtypePerInt4
;
++
j
)
for
(
int
j
=
0
;
j
<
kDtypePerInt4
;
++
j
)
out_dtypes
[
j
]
=
static_cast
<
dtype_t
>
(
values
[
j
]);
st_na_global
(
combined_row
+
i
,
out_int4
);
}
...
...
@@ -1389,87 +1373,98 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
return
topk_ranks
[
0
];
}
template
<
bool
kLowLatencyMode
,
int
kNumRDMARanks
,
typename
dtype_t
,
template
<
bool
kLowLatencyMode
,
int
kNumRDMARanks
,
typename
dtype_t
,
int
kNumCombineForwarderWarps
,
int
kNumTopkRDMARanks
=
get_num_topk_rdma_ranks
(
kNumRDMARanks
),
int
kNumWarpsPerForwarder
=
(
kNumCombineForwarderWarps
/
kNumRDMARanks
>
0
)
?
kNumCombineForwarderWarps
/
kNumRDMARanks
:
1
,
int
kNumForwarders
=
kNumRDMARanks
*
kNumWarpsPerForwarder
,
int
kNumRDMAReceivers
=
kNumForwarders
>
__global__
void
__launch_bounds__
((
1
+
NUM_MAX_NVL_PEERS
)
*
kWarpSize
,
1
)
combine
(
int4
*
combined_x
,
float
*
combined_topk_weights
,
const
bool
*
is_combined_token_in_rank
,
const
int4
*
x
,
const
float
*
topk_weights
,
const
int4
*
bias_0
,
const
int4
*
bias_1
,
const
int
*
combined_rdma_head
,
const
int
*
combined_nvl_head
,
const
SourceMeta
*
src_meta
,
const
int
*
rdma_channel_prefix_matrix
,
const
int
*
rdma_rank_prefix_sum
,
const
int
*
gbl_channel_prefix_matrix
,
int
num_tokens
,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
void
*
rdma_buffer_ptr
,
int
num_max_rdma_chunked_send_tokens
,
int
num_max_rdma_chunked_recv_tokens
,
void
**
buffer_ptrs
,
int
num_max_nvl_chunked_send_tokens
,
int
num_max_nvl_chunked_recv_tokens
,
int
rank
,
int
num_ranks
)
{
int
kNumRDMAReceivers
=
kNumRDMARanks
<=
8
?
kNumForwarders
+
NUM_MAX_NVL_PEERS
/
2
:
kNumForwarders
+
NUM_MAX_NVL_PEERS
,
int
kBlockThreads
=
(
kNumRDMARanks
>
8
)
?
((
NUM_MAX_NVL_PEERS
+
kNumForwarders
)
*
kEmulatedWarpSize
+
kWarpSize
)
:
((
NUM_MAX_NVL_PEERS
/
2
+
1
+
kNumForwarders
)
*
kWarpSize
)
>
__global__
void
__launch_bounds__
(
kBlockThreads
,
1
)
combine
(
int4
*
combined_x
,
float
*
combined_topk_weights
,
const
bool
*
is_combined_token_in_rank
,
const
int4
*
x
,
const
float
*
topk_weights
,
const
int4
*
bias_0
,
const
int4
*
bias_1
,
const
int
*
combined_rdma_head
,
const
int
*
combined_nvl_head
,
const
SourceMeta
*
src_meta
,
const
int
*
rdma_channel_prefix_matrix
,
const
int
*
rdma_rank_prefix_sum
,
const
int
*
gbl_channel_prefix_matrix
,
int
num_tokens
,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
void
*
rdma_buffer_ptr
,
int
num_max_rdma_chunked_send_tokens
,
int
num_max_rdma_chunked_recv_tokens
,
void
**
buffer_ptrs
,
int
num_max_nvl_chunked_send_tokens
,
int
num_max_nvl_chunked_recv_tokens
,
int
rank
,
int
num_ranks
)
{
enum
class
WarpRole
{
kNVLSender
,
kNVLAndRDMAForwarder
,
kRDMAReceiver
,
kRDMACoordinator
,
kNVLCoordinator
kCoordinator
};
constexpr
auto
kNVLPeersHyb
=
(
kNumRDMARanks
>
8
)
?
NUM_MAX_NVL_PEERS
:
NUM_MAX_NVL_PEERS
/
2
;
constexpr
auto
kWarpHyb
=
kNumRDMARanks
>
8
?
kEmulatedWarpSize
:
kWarpSize
;
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
num_threads
=
static_cast
<
int
>
(
blockDim
.
x
);
const
int
num_warps
=
kNumRDMARanks
>
8
?
(
num_threads
/
kEmulatedWarpSize
-
1
)
:
(
num_threads
/
kWarpSize
);
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
const
auto
num_channels
=
static_cast
<
int
>
(
gridDim
.
x
)
/
2
,
channel_id
=
sm_id
/
2
;
const
bool
is_rdma_receiver_sm
=
sm_id
%
2
==
1
;
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
__shared__
shmem_ctx_t
ctx
;
shmem_wg_ctx_create
(
&
ctx
);
#endif
EP_STATIC_ASSERT
(
kNumCombineForwarderWarps
<=
kWarpSize
,
"Invalid number of forwarder warps"
);
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
num_threads
=
static_cast
<
int
>
(
blockDim
.
x
),
num_warps
=
num_threads
/
kWarpSize
;
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
),
warp_id
=
thread_id
/
kWarpSize
,
lane_id
=
get_lane_id
();
const
auto
num_channels
=
static_cast
<
int
>
(
gridDim
.
x
)
/
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
,
channel_id
=
sm_id
/
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
;
#endif
EP_DEVICE_ASSERT
(
num_topk
<=
kEmulatedWarpSize
);
EP_DEVICE_ASSERT
(
hidden
%
(
sizeof
(
int4
)
/
sizeof
(
dtype_t
))
==
0
);
const
auto
hidden_int4
=
hidden
/
(
sizeof
(
int4
)
/
sizeof
(
dtype_t
));
// NOTES: we decouple a channel into 2 SMs
const
auto
rdma_rank
=
rank
/
NUM_MAX_NVL_PEERS
,
nvl_rank
=
rank
%
NUM_MAX_NVL_PEERS
;
const
auto
role_meta
=
[
=
]()
->
std
::
pair
<
WarpRole
,
int
>
{
if
(
sm_id
%
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
==
1
)
{
return
{
WarpRole
::
kNVLSender
,
(
warp_id
+
channel_id
)
%
NUM_MAX_NVL_PEERS
};
}
else
if
(
sm_id
%
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
==
0
)
{
if
(
warp_id
<
kNumForwarders
)
{
return
{
WarpRole
::
kNVLAndRDMAForwarder
,
(
warp_id
+
channel_id
)
%
kNumForwarders
};
auto
role_meta
=
[
=
]()
->
std
::
pair
<
WarpRole
,
int
>
{
auto
warp_id
=
thread_id
/
kWarpHyb
;
if
(
not
is_rdma_receiver_sm
)
{
if
(
warp_id
<
kNVLPeersHyb
)
{
auto
shuffled_warp_id
=
warp_id
;
shuffled_warp_id
=
(
shuffled_warp_id
+
channel_id
)
%
kNVLPeersHyb
;
return
{
WarpRole
::
kNVLSender
,
shuffled_warp_id
};
}
else
if
(
warp_id
<
kNVLPeersHyb
+
kNumForwarders
)
{
auto
shuffled_warp_id
=
warp_id
-
kNVLPeersHyb
;
shuffled_warp_id
=
(
shuffled_warp_id
+
channel_id
)
%
kNumForwarders
;
return
{
WarpRole
::
kNVLAndRDMAForwarder
,
shuffled_warp_id
};
}
else
{
return
{
WarpRole
::
k
RDMA
Coordinator
,
0
};
return
{
WarpRole
::
kCoordinator
,
0
};
}
}
else
{
if
(
warp_id
<
kNumForwarders
)
{
if
(
warp_id
<
kNVLPeersHyb
+
kNumForwarders
)
{
return
{
WarpRole
::
kRDMAReceiver
,
warp_id
};
}
else
{
return
{
WarpRole
::
k
NVL
Coordinator
,
0
};
return
{
WarpRole
::
kCoordinator
,
0
};
}
}
}();
auto
warp_role
=
role_meta
.
first
;
auto
t
ar
get_rank
=
role_meta
.
second
;
// Not applicable for RDMA senders
auto
w
ar
p_id
=
role_meta
.
second
;
EP_DEVICE_ASSERT
(
num_warps
==
NUM_MAX_NVL_PEERS
+
1
);
auto
num_max_nvl_chunked_recv_tokens_per_rdma
=
num_max_nvl_chunked_recv_tokens
/
kNumRDMARanks
;
// This approach is designed to sync multiple warps in a loop
constexpr
int
num_sync_large_iteration
=
64
;
constexpr
int
rdma_warp_counters
=
kNumRDMARanks
*
num_sync_large_iteration
;
__shared__
volatile
int
sync_large_warp_counters
[
2
*
rdma_warp_counters
];
for
(
int
i
=
thread_id
;
i
<
2
*
rdma_warp_counters
;
i
+=
num_threads
)
{
__shared__
volatile
int
rdma_receiver_counter
[
1
];
__shared__
volatile
int
rdma_forwarder_counter
[
1
];
__shared__
volatile
uint8_t
sync_large_warp_counters
[
2
*
kNumRDMARanks
*
num_sync_large_iteration
];
if
(
threadIdx
.
x
==
0
){
rdma_receiver_counter
[
0
]
=
0
;
rdma_forwarder_counter
[
0
]
=
0
;
}
for
(
int
i
=
thread_id
;
i
<
2
*
kNumRDMARanks
*
num_sync_large_iteration
;
i
+=
num_threads
)
{
sync_large_warp_counters
[
i
]
=
0
;
}
__syncthreads
();
if
(
warp_role
==
WarpRole
::
kNVLSender
)
{
if
(
warp_id
>=
NUM_MAX_NVL_PEERS
)
{
return
;
}
// NVL producers
const
int
dst_nvl_rank
=
kNumRDMARanks
<=
8
?
(
warp_id
*
2
+
(
thread_id
%
kWarpSize
)
/
kEmulatedWarpSize
)
:
warp_id
;
auto
lane_id
=
get_lane_id
()
%
kEmulatedWarpSize
;
const
auto
dst_nvl_rank
=
target_rank
;
// NVL layouts
// NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources
auto
dst_buffer_ptr
=
buffer_ptrs
[
dst_nvl_rank
],
local_buffer_ptr
=
buffer_ptrs
[
nvl_rank
];
...
...
@@ -1481,7 +1476,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
// Get tasks for each RDMA lane
int
token_start_idx
=
0
,
token_end_idx
=
0
;
if
(
lane_id
<
kNumRDMARanks
)
{
if
(
lane_id
<
kNumRDMARanks
)
{
int
prefix_idx
=
(
lane_id
*
NUM_MAX_NVL_PEERS
+
dst_nvl_rank
)
*
num_channels
+
channel_id
;
token_start_idx
=
gbl_channel_prefix_matrix
[
prefix_idx
];
token_end_idx
=
(
prefix_idx
==
num_channels
*
num_ranks
-
1
)
?
num_tokens
:
gbl_channel_prefix_matrix
[
prefix_idx
+
1
];
...
...
@@ -1490,94 +1485,81 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
// NOTES: here the cached value of each lane is only responsible for a single RDMA buffer
int
cached_channel_head_idx
=
0
,
cached_channel_tail_idx
=
0
;
EP_STATIC_ASSERT
(
kNumRDMARanks
<=
kWarpSize
,
"Invalid number of RDMA peers"
);
EP_STATIC_ASSERT
(
kNumRDMARanks
<=
k
Emulated
WarpSize
,
"Invalid number of RDMA peers"
);
// Iterate over all tokens and send by chunks
while
(
true
)
{
while
(
true
)
{
// Exit if possible
if
(
__all_sync
(
kFullWarpMask
,
token_start_idx
>=
token_end_idx
))
if
(
__all_sync
(
kFullWarpMask
,
token_start_idx
>=
token_end_idx
))
break
;
// Decide next RDMA buffer to send
bool
is_lane_ready
=
false
;
auto
start_time
=
wall_clock64
();
while
(
true
)
{
while
(
true
)
{
int
num_used_slots
=
cached_channel_tail_idx
-
cached_channel_head_idx
;
is_lane_ready
=
lane_id
<
kNumRDMARanks
and
token_start_idx
<
token_end_idx
and
num_max_nvl_chunked_recv_tokens_per_rdma
-
num_used_slots
>=
num_max_nvl_chunked_send_tokens
;
if
(
__any_sync
(
k
FullWarp
Mask
,
is_lane_ready
))
is_lane_ready
=
lane_id
<
kNumRDMARanks
and
token_start_idx
<
token_end_idx
and
num_max_nvl_chunked_recv_tokens_per_rdma
-
num_used_slots
>=
num_max_nvl_chunked_send_tokens
;
if
(
__any_sync
(
kFirstHalfMask
,
is_lane_ready
))
break
;
if
(
__any_sync
(
k
SecondHalf
Mask
,
is_lane_ready
))
break
;
// Retry
if
(
lane_id
<
kNumRDMARanks
and
token_start_idx
<
token_end_idx
)
if
(
lane_id
<
kNumRDMARanks
and
token_start_idx
<
token_end_idx
)
cached_channel_head_idx
=
ld_volatile_global
(
nvl_channel_head
.
buffer
()
+
lane_id
);
// Timeout check
if
(
wall_clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
and
lane_id
<
kNumRDMARanks
)
{
printf
(
"DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, "
"RDMA lane: %d, head: %d, tail: %d, start: %d, end: %d
\n
"
,
channel_id
,
rdma_rank
,
nvl_rank
,
dst_nvl_rank
,
lane_id
,
ld_volatile_global
(
nvl_channel_head
.
buffer
()
+
lane_id
),
cached_channel_tail_idx
,
token_start_idx
,
token_end_idx
);
long
long
int
elapsed_time
=
wall_clock64
()
>
start_time
?
wall_clock64
()
-
start_time
:
0
;
if
(
elapsed_time
>
NUM_TIMEOUT_CYCLES
and
lane_id
<
kNumRDMARanks
)
{
printf
(
"DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, RDMA lane: %d, head: %d, tail: %d, start: %d, end: %d
\n
"
,
channel_id
,
rdma_rank
,
nvl_rank
,
dst_nvl_rank
,
lane_id
,
ld_volatile_global
(
nvl_channel_head
.
buffer
()
+
lane_id
),
cached_channel_tail_idx
,
token_start_idx
,
token_end_idx
);
trap
();
}
__builtin_amdgcn_s_sleep
(
1
);
}
// Sync token start index and count
for
(
int
current_rdma_idx
=
0
;
current_rdma_idx
<
kNumRDMARanks
;
++
current_rdma_idx
)
{
if
(
shfl_sync
((
token_start_idx
>=
token_end_idx
)
or
(
not
is_lane_ready
),
current_rdma_idx
))
for
(
int
current_rdma_idx
=
0
;
current_rdma_idx
<
kNumRDMARanks
;
++
current_rdma_idx
)
{
if
(
shfl_sync
((
token_start_idx
>=
token_end_idx
)
or
(
not
is_lane_ready
),
current_rdma_idx
,
kEmulatedWarpSize
))
continue
;
// Sync token start index
auto
token_idx
=
static_cast
<
int64_t
>
(
shfl_sync
(
token_start_idx
,
current_rdma_idx
));
int
num_tokens_in_chunk
=
shfl_sync
(
min
(
num_max_nvl_chunked_send_tokens
,
token_end_idx
-
token_start_idx
),
current_rdma_idx
);
auto
token_idx
=
static_cast
<
int64_t
>
(
shfl_sync
(
token_start_idx
,
current_rdma_idx
,
kEmulatedWarpSize
));
int
num_tokens_in_chunk
=
shfl_sync
(
min
(
num_max_nvl_chunked_send_tokens
,
token_end_idx
-
token_start_idx
),
current_rdma_idx
,
kEmulatedWarpSize
);
// Send by chunk
for
(
int
chunk_idx
=
0
;
chunk_idx
<
num_tokens_in_chunk
;
++
chunk_idx
,
++
token_idx
)
{
for
(
int
chunk_idx
=
0
;
chunk_idx
<
num_tokens_in_chunk
;
++
chunk_idx
,
++
token_idx
)
{
// Get an empty slot
int
dst_slot_idx
=
0
;
if
(
lane_id
==
current_rdma_idx
)
{
dst_slot_idx
=
(
cached_channel_tail_idx
++
)
%
num_max_nvl_chunked_recv_tokens_per_rdma
;
if
(
lane_id
==
current_rdma_idx
)
{
dst_slot_idx
=
(
cached_channel_tail_idx
++
)
%
num_max_nvl_chunked_recv_tokens_per_rdma
;
dst_slot_idx
=
current_rdma_idx
*
num_max_nvl_chunked_recv_tokens_per_rdma
+
dst_slot_idx
;
}
dst_slot_idx
=
shfl_sync
(
dst_slot_idx
,
current_rdma_idx
);
dst_slot_idx
=
shfl_sync
(
dst_slot_idx
,
current_rdma_idx
,
kEmulatedWarpSize
);
// Copy data
auto
shifted_x_buffers
=
nvl_channel_x
.
buffer
()
+
dst_slot_idx
*
hidden_int4
;
auto
shifted_x
=
x
+
token_idx
*
hidden_int4
;
UNROLLED_WARP_COPY
(
5
,
lane_id
,
hidden_int4
,
shifted_x_buffers
,
shifted_x
,
ld_nc_global
,
st_na_global
);
UNROLLED_WARP_COPY
_EMULATED
(
5
,
lane_id
,
hidden_int4
,
shifted_x_buffers
,
shifted_x
,
ld_nc_global
,
st_na_global
);
// Copy source meta
if
(
lane_id
==
0
)
if
(
lane_id
==
num_topk
)
st_na_global
(
nvl_channel_src_meta
.
buffer
()
+
dst_slot_idx
,
ld_nc_global
(
src_meta
+
token_idx
));
// Copy `topk_weights`
if
(
lane_id
<
num_topk
)
st_na_global
(
nvl_channel_topk_weights
.
buffer
()
+
dst_slot_idx
*
num_topk
+
lane_id
,
ld_nc_global
(
topk_weights
+
token_idx
*
num_topk
+
lane_id
));
if
(
lane_id
<
num_topk
)
st_na_global
(
nvl_channel_topk_weights
.
buffer
()
+
dst_slot_idx
*
num_topk
+
lane_id
,
ld_nc_global
(
topk_weights
+
token_idx
*
num_topk
+
lane_id
));
}
lane_id
==
current_rdma_idx
?
(
token_start_idx
=
static_cast
<
int
>
(
token_idx
))
:
0
;
}
// Move queue tail
syncwarp
();
if
(
lane_id
<
kNumRDMARanks
and
is_lane_ready
)
{
st_release_sys_global
(
nvl_channel_tail
.
buffer
()
+
lane_id
,
cached_channel_tail_idx
);
}
if
(
lane_id
<
kNumRDMARanks
and
is_lane_ready
)
st_relaxed_sys_global
(
nvl_channel_tail
.
buffer
()
+
lane_id
,
cached_channel_tail_idx
);
}
}
else
{
if
(
warp_id
>
kNumForwarders
)
{
return
;
}
auto
lane_id
=
get_lane_id
()
%
kWarpHyb
;
// Combiners and coordinators
// RDMA symmetric layout
auto
hidden_bytes
=
hidden_int4
*
sizeof
(
int4
);
...
...
@@ -1604,38 +1586,48 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
__shared__
volatile
int
rdma_receiver_rdma_head
[
kNumRDMAReceivers
][
kNumRDMARanks
];
__shared__
volatile
bool
rdma_receiver_retired
[
kNumRDMAReceivers
];
auto
sync_forwarder_smem
=
[
&
]()
{
if
(
lane_id
==
0
)
{
volatile
int
ret
=
__hip_atomic_fetch_add
(
&
rdma_forwarder_counter
[
0
],
1
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
}
syncwarp
();
while
(
rdma_forwarder_counter
[
0
]
<
(
kNumForwarders
+
1
)){}
};
auto
sync_rdma_receiver_smem
=
[
&
]()
{
if
(
lane_id
==
0
)
{
volatile
int
ret
=
__hip_atomic_fetch_add
(
&
rdma_receiver_counter
[
0
],
1
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
}
syncwarp
();
while
(
rdma_receiver_counter
[
0
]
<
(
kNumRDMAReceivers
+
1
)){}
};
if
(
warp_role
==
WarpRole
::
kNVLAndRDMAForwarder
)
{
// Receive from NVL ranks and forward to RDMA ranks
// NOTES: this part is using "large warps" for each RDMA ranks
const
auto
dst_rdma_rank
=
t
ar
get_rank
/
kNumWarpsPerForwarder
;
const
auto
sub_warp_id
=
t
ar
get_rank
%
kNumWarpsPerForwarder
;
const
auto
dst_rdma_rank
=
w
ar
p_id
/
kNumWarpsPerForwarder
;
const
auto
sub_warp_id
=
w
ar
p_id
%
kNumWarpsPerForwarder
;
auto
send_buffer
=
dst_rdma_rank
==
rdma_rank
?
rdma_channel_data
.
recv_buffer
(
dst_rdma_rank
)
:
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
);
// auto sync_large_warp = [=]() {
// if(kNumWarpsPerForwarder == 1) {
// syncwarp();
// } else {
// // asm volatile("bar.sync %0, %1;" ::"r"(dst_rdma_rank + 2), "r"(kNumWarpsPerForwarder * kWarpSize));
// // __syncthreads();
// syncwarp();
// }
// };
auto
sync_large_warp
=
[
=
](
const
int
iter
,
const
int
mode
)
{
if
(
kNumWarpsPerForwarder
==
1
)
{
syncwarp
();
}
else
{
// LDS index to store for sync
int
lds_dst_rdma_rank
=
dst_rdma_rank
+
(
iter
%
num_sync_large_iteration
)
*
kNumRDMARanks
+
mode
*
rdma_warp_counters
;
int
lds_dst_rdma_rank
=
dst_rdma_rank
+
(
iter
%
num_sync_large_iteration
)
*
kNumRDMARanks
+
mode
*
kNumRDMARanks
*
num_sync_large_iteration
;
//reset index in the LDS to avoid race condition due to warp scheduling
int
reset_idx
=
dst_rdma_rank
+
((
iter
+
num_sync_large_iteration
/
2
)
%
num_sync_large_iteration
)
*
kNumRDMARanks
+
mode
*
rdma_warp_counters
;
auto
start_time
=
wall_clock64
();
int
reset_idx
=
dst_rdma_rank
+
((
iter
+
num_sync_large_iteration
/
2
)
%
num_sync_large_iteration
)
*
kNumRDMARanks
+
mode
*
kNumRDMARanks
*
num_sync_large_iteration
;
// if (lane_id==0)
// printf("rank %d dst_rdma_rank %d iter %d warp_id %d val %d\n", rank, dst_rdma_rank, iter, warp_id, sync_large_warp_counters[lds_dst_rdma_rank]);
auto
start_time
=
clock64
();
if
(
lane_id
==
0
){
volatile
int
ret
=
atomicAdd
((
int
*
)
&
sync_large_warp_counters
[
lds_dst_rdma_rank
],
1
);
volatile
int
ret
=
__hip_atomic_fetch_add
(
&
sync_large_warp_counters
[
lds_dst_rdma_rank
],
1
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
}
syncwarp
();
//The while(...) loop polls the counter until all warps have arrived
if
(
lane_id
==
0
){
while
(
sync_large_warp_counters
[
lds_dst_rdma_rank
]
<
(
kNumWarpsPerForwarder
)){
if
(
wall_
clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
)
{
if
(
clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
)
{
printf
(
"DeepEP combine sync timeout. current num_sync_large_iteration %d. double it.
\n
"
,
num_sync_large_iteration
);
trap
();
}
...
...
@@ -1648,9 +1640,11 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
syncwarp
();
}
};
EP_STATIC_ASSERT
(
kNumWarpsPerForwarder
==
1
or
kNumRDMARanks
+
2
<=
kNumCombineForwarderWarps
,
"Barriers are not enough"
);
EP_STATIC_ASSERT
(
kNumWarpsPerForwarder
==
1
or
kNumRDMARanks
+
2
<=
16
,
"Barriers are not enough"
);
// Advance to the corresponding NVL buffer, 基于原本指针进行的地址偏移
// In case of running less than 8 nodes
constexpr
bool
kUseWave
=
(
kNumRDMARanks
<=
8
);
// Advance to the corresponding NVL buffer
nvl_channel_x
.
advance
(
dst_rdma_rank
*
num_max_nvl_chunked_recv_tokens_per_rdma
*
hidden_int4
);
nvl_channel_src_meta
.
advance
(
dst_rdma_rank
*
num_max_nvl_chunked_recv_tokens_per_rdma
);
nvl_channel_topk_weights
.
advance
(
dst_rdma_rank
*
num_max_nvl_chunked_recv_tokens_per_rdma
*
num_topk
);
...
...
@@ -1659,10 +1653,9 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
// Clean shared memory and sync
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
<=
kWarpSize
,
"Invalid number of NVL peers"
);
lane_id
<
NUM_MAX_NVL_PEERS
?
(
forwarder_nvl_head
[
target_rank
][
lane_id
]
=
0
)
:
0
;
lane_id
==
0
?
(
forwarder_retired
[
target_rank
]
=
false
)
:
false
;
// sync_forwarder_smem();
__syncthreads
();
lane_id
<
NUM_MAX_NVL_PEERS
?
(
forwarder_nvl_head
[
warp_id
][
lane_id
]
=
0
)
:
0
;
lane_id
==
0
?
(
forwarder_retired
[
warp_id
]
=
false
)
:
false
;
sync_forwarder_smem
();
// Get count and cached head
int
cached_nvl_channel_tail_idx
=
0
;
...
...
@@ -1673,87 +1666,104 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
combined_nvl_head
+=
num_tokens_prefix
*
NUM_MAX_NVL_PEERS
;
// Iterate over all tokens and combine by chunks
for
(
int
token_start_idx
=
0
;
token_start_idx
<
num_tokens_to_combine
;
token_start_idx
+=
num_max_rdma_chunked_send_tokens
)
{
for
(
int
token_start_idx
=
0
;
token_start_idx
<
num_tokens_to_combine
;
token_start_idx
+=
num_max_rdma_chunked_send_tokens
)
{
// Check destination queue emptiness, or wait a buffer to be released
auto
token_end_idx
=
min
(
token_start_idx
+
num_max_rdma_chunked_send_tokens
,
num_tokens_to_combine
);
auto
num_chunked_tokens
=
token_end_idx
-
token_start_idx
;
auto
start_time
=
wall_clock64
();
while
(
sub_warp_id
==
0
and
lane_id
==
0
)
{
while
(
sub_warp_id
==
0
and
lane_id
==
0
)
{
// Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens`
// Here, `token_start_idx` is the actual tail
int
num_used_slots
=
token_start_idx
-
ld_volatile_global
(
rdma_channel_head
.
buffer
(
dst_rdma_rank
));
if
(
num_max_rdma_chunked_recv_tokens
-
num_used_slots
>=
num_chunked_tokens
)
if
(
num_max_rdma_chunked_recv_tokens
-
num_used_slots
>=
num_chunked_tokens
)
break
;
// Timeout check
if
(
wall_clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
)
{
long
long
int
elapsed_time
=
wall_clock64
()
>
start_time
?
wall_clock64
()
-
start_time
:
0
;
if
(
elapsed_time
>
NUM_TIMEOUT_CYCLES
)
{
printf
(
"DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA: %d, head: %ld, tail: %d, chunked: %d
\n
"
,
channel_id
,
rdma_rank
,
nvl_rank
,
dst_rdma_rank
,
ld_volatile_global
(
rdma_channel_head
.
buffer
(
dst_rdma_rank
)),
token_start_idx
,
num_chunked_tokens
);
trap
();
}
}
// sync_large_warp();
sync_large_warp
(
token_start_idx
,
0
);
// Combine and write to the RDMA buffer
for
(
int
token_idx
=
token_start_idx
+
sub_warp_id
;
token_idx
<
token_end_idx
;
token_idx
+=
kNumWarpsPerForwarder
)
{
for
(
int
token_idx
=
token_start_idx
+
sub_warp_id
;
token_idx
<
token_end_idx
;
token_idx
+=
kNumWarpsPerForwarder
)
{
// Read expected head
EP_STATIC_ASSERT
(
kNumRDMARanks
<=
kWarpSize
,
"Invalid number of RDMA peers"
);
int
expected_head
=
-
1
;
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
expected_head
=
ld_nc_global
(
combined_nvl_head
+
token_idx
*
NUM_MAX_NVL_PEERS
+
lane_id
);
// Wait lanes to be ready
start_time
=
wall_clock64
();
while
(
cached_nvl_channel_tail_idx
<=
expected_head
)
{
cached_nvl_channel_tail_idx
=
ld_acquire_sys_global
(
nvl_channel_tail
.
buffer
(
lane_id
));
while
(
cached_nvl_channel_tail_idx
<=
expected_head
)
{
cached_nvl_channel_tail_idx
=
ld_relaxed_sys_global
(
nvl_channel_tail
.
buffer
(
lane_id
));
// Timeout check
if
(
wall_clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
and
lane_id
<
NUM_MAX_NVL_PEERS
)
{
long
long
int
elapsed_time
=
wall_clock64
()
>
start_time
?
wall_clock64
()
-
start_time
:
0
;
if
(
elapsed_time
>
NUM_TIMEOUT_CYCLES
and
lane_id
<
NUM_MAX_NVL_PEERS
)
{
printf
(
"DeepEP combine forwarder (NVL check) timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, dst RDMA: %d, tail: %d, waiting: %d, total: %d, sub: %d, large: %d, expected: %d
\n
"
,
channel_id
,
rdma_rank
,
nvl_rank
,
lane_id
,
dst_rdma_rank
,
cached_nvl_channel_tail_idx
,
token_idx
,
num_tokens_to_combine
,
sub_warp_id
,
kNumWarpsPerForwarder
,
expected_head
);
trap
();
}
__builtin_amdgcn_s_sleep
(
1
);
}
// Combine current token
auto
rdma_slot_idx
=
token_idx
%
num_max_rdma_chunked_recv_tokens
;
void
*
shifted
=
send_buffer
+
rdma_slot_idx
*
num_bytes_per_rdma_token
;
auto
get_addr
_fn
=
[
&
](
int
src_nvl_rank
,
int
slot_idx
,
int
hidden_int4_idx
)
->
int4
*
{
return
reinterpret_cast
<
int4
*>
(
nvl_channel_x
.
buffer
(
src_nvl_rank
)
+
slot_idx
*
hidden_int4
)
+
hidden_int4_idx
;
};
auto
recv
_fn
=
[
&
](
int
src_nvl_rank
,
int
slot_idx
,
int
hidden_int4_idx
)
->
int4
{
return
ld_nc_global
(
nvl_channel_x
.
buffer
(
src_nvl_rank
)
+
slot_idx
*
hidden_int4
+
hidden_int4_idx
)
;
};
auto
recv_tw_fn
=
[
&
](
int
src_nvl_rank
,
int
slot_idx
,
int
topk_idx
)
->
float
{
return
ld_nc_global
(
nvl_channel_topk_weights
.
buffer
(
src_nvl_rank
)
+
slot_idx
*
num_topk
+
topk_idx
);
};
combine_token
<
NUM_MAX_NVL_PEERS
,
dtype_t
,
NUM_MAX_NVL_PEERS
,
true
>
(
expected_head
>=
0
,
combine_token
<
NUM_MAX_NVL_PEERS
,
dtype_t
,
NUM_MAX_NVL_PEERS
,
kWarpHyb
>
(
expected_head
>=
0
,
expected_head
,
lane_id
,
hidden_int4
,
num_topk
,
reinterpret_cast
<
int4
*>
(
shifted
),
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
int8_t
*>
(
shifted
)
+
hidden_bytes
+
sizeof
(
SourceMeta
)),
num_max_nvl_chunked_recv_tokens_per_rdma
,
get_addr_fn
,
recv_tw_fn
);
num_max_nvl_chunked_recv_tokens_per_rdma
,
recv_fn
,
recv_tw_fn
);
// Update head
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
{
expected_head
<
0
?
(
forwarder_nvl_head
[
target_rank
][
lane_id
]
=
-
expected_head
-
1
)
:
(
forwarder_nvl_head
[
target_rank
][
lane_id
]
=
expected_head
+
1
);
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
expected_head
<
0
?
(
forwarder_nvl_head
[
warp_id
][
lane_id
]
=
-
expected_head
-
1
)
:
(
forwarder_nvl_head
[
warp_id
][
lane_id
]
=
expected_head
+
1
);
}
}
// sync_large_warp();
sync_large_warp
(
token_start_idx
,
1
);
// Issue RDMA send
if
(
sub_warp_id
==
kNumWarpsPerForwarder
-
1
)
{
if
(
dst_rdma_rank
!=
rdma_rank
)
{
// TODO: Switch back to put_nbi_wave function
if
(
sub_warp_id
==
kNumWarpsPerForwarder
-
1
)
{
if
(
dst_rdma_rank
!=
rdma_rank
)
{
auto
rdma_slot_idx
=
token_start_idx
%
num_max_rdma_chunked_recv_tokens
;
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_schar_put_nbi_warp
(
ctx
,
#else
#ifdef FORCE_DUSHMEM_API
shmemx_int8_put_nbi_warp
(
#endif
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
,
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
,
num_chunked_tokens
*
num_bytes_per_rdma_token
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
#else
#if !defined(ROCM_DISABLE_CTX)
if
constexpr
(
kUseWave
){
shmem_ctx_schar_put_nbi_warp
(
ctx
,
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
,
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
,
num_chunked_tokens
*
num_bytes_per_rdma_token
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
}
#else
if
constexpr
(
kUseWave
){
shmemx_int8_put_nbi_warp
(
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
,
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
,
num_chunked_tokens
*
num_bytes_per_rdma_token
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
}
else
{
if
(
lane_id
==
0
)
shmemx_int8_put_nbi
(
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
,
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
,
num_chunked_tokens
*
num_bytes_per_rdma_token
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
}
#endif
#endif
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_quiet
(
ctx
);
#else
...
...
@@ -1765,7 +1775,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
// Write new RDMA tail
syncwarp
();
if
(
lane_id
==
0
)
{
if
(
lane_id
==
0
)
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add
(
ctx
,
#else
...
...
@@ -1775,157 +1785,90 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
}
}
}
// Retired
syncwarp
();
if
(
lane_id
==
0
)
{
forwarder_retired
[
target_rank
]
=
true
;
}
}
else
if
(
warp_role
==
WarpRole
::
kRDMACoordinator
)
{
// Coordinator
// Sync shared memory status
// sync_forwarder_smem();
__syncthreads
();
constexpr
int
num_warps_per_rdma_rank
=
kNumForwarders
/
kNumRDMARanks
;
int
last_nvl_head
[
kNumRDMARanks
]
=
{
0
};
int
dst_nvl_rank
=
lane_id
<
NUM_MAX_NVL_PEERS
?
lane_id
:
0
;
while
(
true
)
{
// Retired
if
(
__all_sync
(
kFullWarpMask
,
lane_id
>=
kNumForwarders
or
forwarder_retired
[
lane_id
]))
break
;
{
// Find minimum head for NVL ranks
#pragma unroll
for
(
int
i
=
0
;
i
<
kNumRDMARanks
;
++
i
)
{
int
min_head
=
std
::
numeric_limits
<
int
>::
max
();
#pragma unroll
for
(
int
j
=
0
;
j
<
num_warps_per_rdma_rank
;
++
j
)
if
(
not
forwarder_retired
[
i
*
num_warps_per_rdma_rank
+
j
])
min_head
=
min
(
min_head
,
forwarder_nvl_head
[
i
*
num_warps_per_rdma_rank
+
j
][
dst_nvl_rank
]);
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
and
min_head
>
last_nvl_head
[
i
]
and
lane_id
<
NUM_MAX_NVL_PEERS
)
{
st_relaxed_sys_global
(
nvl_channel_head
.
buffer_by
(
dst_nvl_rank
)
+
i
,
last_nvl_head
[
i
]
=
min_head
);
}
}
}
// Nanosleep and let other warps work
__builtin_amdgcn_s_sleep
(
NUM_WAIT_CYCLES_TIMES_64
);
}
}
else
if
(
warp_role
==
WarpRole
::
kRDMAReceiver
)
{
if
(
lane_id
==
0
)
forwarder_retired
[
warp_id
]
=
true
;
}
else
if
(
warp_role
==
WarpRole
::
kRDMAReceiver
)
{
// Receive from RDMA ranks and write to the output tensor
// Clean shared memory and sync
EP_DEVICE_ASSERT
(
kNumRDMARanks
<=
kWarpSize
);
lane_id
<
kNumRDMARanks
?
(
rdma_receiver_rdma_head
[
target_rank
][
lane_id
]
=
0
)
:
0
;
lane_id
==
0
?
(
rdma_receiver_retired
[
target_rank
]
=
false
)
:
0
;
// sync_rdma_receiver_smem();
__syncthreads
();
lane_id
<
kNumRDMARanks
?
(
rdma_receiver_rdma_head
[
warp_id
][
lane_id
]
=
0
)
:
0
;
lane_id
==
0
?
(
rdma_receiver_retired
[
warp_id
]
=
false
)
:
0
;
sync_rdma_receiver_smem
();
// The same tokens as the dispatch process
int
token_start_idx
,
token_end_idx
;
get_channel_task_range
(
num_combined_tokens
,
num_channels
,
channel_id
,
token_start_idx
,
token_end_idx
);
// ==================== Token 级展开 x4 ====================
constexpr
int
kTokenUnroll
=
4
;
// Iterate over all tokens and combine
int
cached_channel_tail_idx
=
0
;
for
(
int64_t
base
=
token_start_idx
+
target_rank
;
base
<
token_end_idx
;
base
+=
(
int64_t
)
kNumRDMAReceivers
*
kTokenUnroll
)
{
// ---- Phase 1: 批量预取所有 token 的 expected_head ----
int
cached_expected_head
[
kTokenUnroll
];
int
max_expected_head
=
-
1
;
#pragma unroll
for
(
int
u
=
0
;
u
<
kTokenUnroll
;
++
u
)
{
int64_t
tidx
=
base
+
(
int64_t
)
u
*
kNumRDMAReceivers
;
cached_expected_head
[
u
]
=
-
1
;
if
(
tidx
<
token_end_idx
&&
lane_id
<
kNumRDMARanks
)
{
int
expected_head
=
ld_nc_global
(
combined_rdma_head
+
tidx
*
kNumRDMARanks
+
lane_id
);
cached_expected_head
[
u
]
=
expected_head
;
if
(
expected_head
>
max_expected_head
)
max_expected_head
=
expected_head
;
}
for
(
int64_t
token_idx
=
token_start_idx
+
warp_id
;
token_idx
<
token_end_idx
;
token_idx
+=
kNumRDMAReceivers
)
{
// Read expected head
EP_STATIC_ASSERT
(
kNumRDMARanks
<=
kWarpSize
,
"Invalid number of RDMA peers"
);
int
expected_head
=
-
1
;
if
(
lane_id
<
kNumRDMARanks
)
{
expected_head
=
ld_nc_global
(
combined_rdma_head
+
token_idx
*
kNumRDMARanks
+
lane_id
);
(
expected_head
<
0
)
?
(
rdma_receiver_rdma_head
[
warp_id
][
lane_id
]
=
-
expected_head
-
1
)
:
(
rdma_receiver_rdma_head
[
warp_id
][
lane_id
]
=
expected_head
);
}
// ---- Phase 2: 一次等待,覆盖所有 token ----
if
(
max_expected_head
>=
0
)
{
// Wait lanes to be ready
auto
start_time
=
wall_clock64
();
while
(
cached_channel_tail_idx
<=
max_expected_head
)
{
cached_channel_tail_idx
=
static_cast
<
int
>
(
ld_acquire_sys_global
(
rdma_channel_tail
.
buffer
(
lane_id
)));
if
(
wall_clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
)
{
printf
(
"DeepEP combine RDMA receiver timeout (unroll x%d), "
"ch: %d, rdma: %d, nvl: %d, lane: %d, "
"tail: %d, wait: %d
\n
"
,
kTokenUnroll
,
channel_id
,
rdma_rank
,
nvl_rank
,
lane_id
,
cached_channel_tail_idx
,
max_expected_head
);
while
(
cached_channel_tail_idx
<=
expected_head
)
{
cached_channel_tail_idx
=
static_cast
<
int
>
(
ld_acquire_sys_global
(
rdma_channel_tail
.
buffer
(
lane_id
)));
// Timeout check
long
long
int
elapsed_time
=
wall_clock64
()
>
start_time
?
wall_clock64
()
-
start_time
:
0
;
if
(
elapsed_time
>
NUM_TIMEOUT_CYCLES
)
{
printf
(
"DeepEP combine RDMA receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, tail: %d, waiting: %ld, expect: %d
\n
"
,
channel_id
,
rdma_rank
,
nvl_rank
,
lane_id
,
cached_channel_tail_idx
,
token_idx
,
expected_head
);
trap
();
}
}
__builtin_amdgcn_s_sleep
(
1
);
}
syncwarp
();
// ---- Phase 3: 批量处理所有就绪 token ----
#pragma unroll
for
(
int
u
=
0
;
u
<
kTokenUnroll
;
++
u
)
{
int64_t
tidx
=
base
+
(
int64_t
)
u
*
kNumRDMAReceivers
;
if
(
tidx
<
token_end_idx
)
{
int
expected_head
=
cached_expected_head
[
u
];
// Combine current token
auto
get_addr
_fn
=
[
&
](
int
src_rdma_rank
,
int
slot_idx
,
int
hidden_int4_idx
)
->
int4
*
{
return
reinterpret_cast
<
int4
*>
(
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
slot_idx
*
num_bytes_per_rdma_token
)
+
hidden_int4_idx
;
};
auto
recv
_fn
=
[
&
](
int
src_rdma_rank
,
int
slot_idx
,
int
hidden_int4_idx
)
->
int4
{
return
ld_nc_global
(
reinterpret_cast
<
const
int4
*>
(
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
slot_idx
*
num_bytes_per_rdma_token
)
+
hidden_int4_idx
)
;};
auto
recv_tw_fn
=
[
&
](
int
src_rdma_rank
,
int
slot_idx
,
int
topk_idx
)
->
float
{
return
ld_nc_global
(
reinterpret_cast
<
const
float
*>
(
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
slot_idx
*
num_bytes_per_rdma_token
+
hidden_bytes
+
sizeof
(
SourceMeta
))
+
topk_idx
);};
combine_token
<
kNumRDMARanks
,
dtype_t
,
kNumTopkRDMARanks
,
false
>
(
expected_head
>=
0
,
expected_head
,
lane_id
,
combine_token
<
kNumRDMARanks
,
dtype_t
,
kNumTopkRDMARanks
,
kWarpHyb
>
(
expected_head
>=
0
,
expected_head
,
lane_id
,
hidden_int4
,
num_topk
,
combined_x
+
tidx
*
hidden_int4
,
combined_topk_weights
+
tidx
*
num_topk
,
num_max_rdma_chunked_recv_tokens
,
get_addr_fn
,
recv_tw_fn
);
if
(
lane_id
<
kNumRDMARanks
)
{
rdma_receiver_rdma_head
[
target_rank
][
lane_id
]
=
expected_head
<
0
?
-
expected_head
-
1
:
expected_head
;
}
}
}
combined_x
+
token_idx
*
hidden_int4
,
combined_topk_weights
+
token_idx
*
num_topk
,
num_max_rdma_chunked_recv_tokens
,
recv_fn
,
recv_tw_fn
);
}
// Retired
syncwarp
();
if
(
lane_id
==
0
)
{
rdma_receiver_retired
[
t
ar
get_rank
]
=
true
;
}
}
else
if
(
warp_role
==
WarpRole
::
kNVLCoordinator
)
{
if
(
lane_id
==
0
)
rdma_receiver_retired
[
w
ar
p_id
]
=
true
;
}
else
{
auto
lane_id
=
get_lane_id
();
// Coordinator
// Sync shared memory status
// sync_rdma_receiver_smem();
__syncthreads
();
is_rdma_receiver_sm
?
sync_rdma_receiver_smem
()
:
sync_forwarder_smem
();
const
auto
num_warps_per_rdma_rank
=
kNumForwarders
/
kNumRDMARanks
;
int
last_rdma_head
=
0
;
int
last_nvl_head
[
kNumRDMARanks
]
=
{
0
};
int
dst_rdma_rank
=
lane_id
<
kNumRDMARanks
?
lane_id
:
0
;
int
dst_nvl_rank
=
lane_id
<
NUM_MAX_NVL_PEERS
?
lane_id
:
0
;
while
(
true
)
{
EP_STATIC_ASSERT
(
kNumCombineForwarderWarps
<=
kWarpSize
,
"Invalid number of forwarder warps"
);
while
(
true
)
{
// Retired
if
(
__all_sync
(
kFullWarpMask
,
lane_id
>=
kNumRDMAReceivers
or
rdma_receiver_retired
[
lane_id
]))
if
(
is_rdma_receiver_sm
and
__all_sync
(
kFullWarpMask
,
lane_id
>=
kNumRDMAReceivers
or
rdma_receiver_retired
[
lane_id
]))
break
;
if
(
not
is_rdma_receiver_sm
and
__all_sync
(
kFullWarpMask
,
lane_id
>=
kNumForwarders
or
forwarder_retired
[
lane_id
]))
break
;
// Find minimum head for RDMA ranks
{
if
(
is_rdma_receiver_sm
)
{
int
min_head
=
std
::
numeric_limits
<
int
>::
max
();
#pragma unroll
for
(
int
i
=
0
;
i
<
kNumRDMAReceivers
;
++
i
)
if
(
not
rdma_receiver_retired
[
i
])
for
(
int
i
=
0
;
i
<
kNumRDMAReceivers
;
++
i
)
if
(
not
rdma_receiver_retired
[
i
])
min_head
=
min
(
min_head
,
rdma_receiver_rdma_head
[
i
][
dst_rdma_rank
]);
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
and
min_head
>=
last_rdma_head
+
num_max_rdma_chunked_send_tokens
and
lane_id
<
kNumRDMARanks
)
{
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add
(
ctx
,
...
...
@@ -1937,6 +1880,17 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
last_rdma_head
=
min_head
;
}
}
else
{
// Find minimum head for NVL ranks
#pragma unroll
for
(
int
i
=
0
;
i
<
kNumRDMARanks
;
++
i
)
{
int
min_head
=
std
::
numeric_limits
<
int
>::
max
();
#pragma unroll
for
(
int
j
=
0
;
j
<
num_warps_per_rdma_rank
;
++
j
)
if
(
not
forwarder_retired
[
i
*
num_warps_per_rdma_rank
+
j
])
min_head
=
min
(
min_head
,
forwarder_nvl_head
[
i
*
num_warps_per_rdma_rank
+
j
][
dst_nvl_rank
]);
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
and
min_head
>
last_nvl_head
[
i
]
and
lane_id
<
NUM_MAX_NVL_PEERS
)
st_relaxed_sys_global
(
nvl_channel_head
.
buffer_by
(
dst_nvl_rank
)
+
i
,
last_nvl_head
[
i
]
=
min_head
);
}
}
// Nanosleep and let other warps work
...
...
@@ -1949,46 +1903,80 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
#endif
}
void
combine
(
hipDataType
type
,
void
*
combined_x
,
float
*
combined_topk_weights
,
const
bool
*
is_combined_token_in_rank
,
const
void
*
x
,
const
float
*
topk_weights
,
const
void
*
bias_0
,
const
void
*
bias_1
,
const
int
*
combined_rdma_head
,
const
int
*
combined_nvl_head
,
const
void
*
src_meta
,
const
int
*
rdma_channel_prefix_matrix
,
const
int
*
rdma_rank_prefix_sum
,
const
int
*
gbl_channel_prefix_matrix
,
int
num_tokens
,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
void
*
rdma_buffer_ptr
,
int
num_max_rdma_chunked_send_tokens
,
int
num_max_rdma_chunked_recv_tokens
,
void
**
buffer_ptrs
,
int
num_max_nvl_chunked_send_tokens
,
int
num_max_nvl_chunked_recv_tokens
,
int
rank
,
int
num_ranks
,
hipStream_t
stream
,
int
num_channels
,
bool
low_latency_mode
)
{
constexpr
int
kNumCombineForwarderWarps
=
8
;
#define COMBINE_LAUNCH_CASE(num_rdma_ranks) \
{ \
auto combine_func = \
low_latency_mode \
? combine<true, num_rdma_ranks, hip_bfloat16, kNumCombineForwarderWarps> \
: combine<false, num_rdma_ranks, hip_bfloat16, kNumCombineForwarderWarps>; \
LAUNCH_KERNEL_NON_COOPERATIVE( \
&cfg, combine_func, reinterpret_cast<int4 *>(combined_x), combined_topk_weights, \
is_combined_token_in_rank, reinterpret_cast<const int4 *>(x), topk_weights, \
reinterpret_cast<const int4 *>(bias_0), reinterpret_cast<const int4 *>(bias_1), \
combined_rdma_head, combined_nvl_head, reinterpret_cast<const SourceMeta *>(src_meta), \
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \
num_tokens, num_combined_tokens, hidden, num_topk, rdma_buffer_ptr, \
num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, buffer_ptrs, \
num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, rank, num_ranks); \
} \
break
int
num_rdma_ranks
=
num_ranks
/
NUM_MAX_NVL_PEERS
;
auto
num_warps_per_forwarder
=
std
::
max
(
kNumCombineForwarderWarps
/
num_rdma_ranks
,
1
);
int
num_forwarder_warps
=
num_rdma_ranks
*
num_warps_per_forwarder
;
EP_HOST_ASSERT
(
num_forwarder_warps
>=
NUM_MAX_NVL_PEERS
);
EP_HOST_ASSERT
(
num_forwarder_warps
>
0
and
num_forwarder_warps
%
num_rdma_ranks
==
0
);
void
combine
(
hipDataType
type
,
void
*
combined_x
,
float
*
combined_topk_weights
,
const
bool
*
is_combined_token_in_rank
,
const
void
*
x
,
const
float
*
topk_weights
,
const
void
*
bias_0
,
const
void
*
bias_1
,
const
int
*
combined_rdma_head
,
const
int
*
combined_nvl_head
,
const
void
*
src_meta
,
const
int
*
rdma_channel_prefix_matrix
,
const
int
*
rdma_rank_prefix_sum
,
const
int
*
gbl_channel_prefix_matrix
,
int
num_tokens
,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
void
*
rdma_buffer_ptr
,
int
num_max_rdma_chunked_send_tokens
,
int
num_max_rdma_chunked_recv_tokens
,
void
**
buffer_ptrs
,
int
num_max_nvl_chunked_send_tokens
,
int
num_max_nvl_chunked_recv_tokens
,
int
rank
,
int
num_ranks
,
hipStream_t
stream
,
int
num_channels
,
bool
low_latency_mode
)
{
const
int
num_rdma_ranks
=
num_ranks
/
NUM_MAX_NVL_PEERS
;
EP_HOST_ASSERT
(
num_rdma_ranks
>
0
);
EP_HOST_ASSERT
(
num_max_nvl_chunked_recv_tokens
%
num_rdma_ranks
==
0
);
EP_HOST_ASSERT
(
num_max_nvl_chunked_recv_tokens
/
num_rdma_ranks
>
std
::
max
(
num_max_rdma_chunked_send_tokens
,
num_max_nvl_chunked_send_tokens
));
EP_HOST_ASSERT
(
type
==
HIP_R_16BF
);
SETUP_LAUNCH_CONFIG
(
num_channels
*
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
,
(
NUM_MAX_NVL_PEERS
+
1
)
*
kWarpSize
,
stream
);
// One case per compile-time NR specialization.
#define COMBINE_LAUNCH_CASE(NR) { \
/* Per-case compile-time constants */
\
constexpr int kNumCombineForwarderWarps = (NR < 9) ? 10 : 16; \
constexpr int kWarpsPerForwarder = (kNumCombineForwarderWarps/NR) > 0 \
? (kNumCombineForwarderWarps/NR) : 1; \
constexpr int kNumForwarders = NR * kWarpsPerForwarder; \
constexpr int kBlockThreads = (NR > 8) \
? ((NUM_MAX_NVL_PEERS + kNumForwarders) * kEmulatedWarpSize + kWarpSize) \
: ((NUM_MAX_NVL_PEERS/2 + 1 + kNumForwarders) * kWarpSize); \
\
SETUP_LAUNCH_CONFIG(num_channels * 2, kBlockThreads, stream); \
\
using scalar_t = hip_bfloat16; \
auto fn = low_latency_mode \
? combine<true, NR, scalar_t, kNumCombineForwarderWarps> \
: combine<false, NR, scalar_t, kNumCombineForwarderWarps>; \
\
/* Launch (backend-specific) */
\
\
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, fn, \
reinterpret_cast<int4*>(combined_x), combined_topk_weights, is_combined_token_in_rank, \
reinterpret_cast<const int4*>(x), topk_weights, \
reinterpret_cast<const int4 *>(bias_0), reinterpret_cast<const int4 *>(bias_1), \
combined_rdma_head, combined_nvl_head, \
reinterpret_cast<const SourceMeta*>(src_meta), rdma_channel_prefix_matrix, \
rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \
num_tokens, num_combined_tokens, hidden, num_topk, \
rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, \
num_max_rdma_chunked_recv_tokens, \
buffer_ptrs, num_max_nvl_chunked_send_tokens, \
num_max_nvl_chunked_recv_tokens, \
rank, num_ranks); \
} break
// Dispatch on the runtime num_rdma_ranks, but each case is compile-time specialized.
SWITCH_RDMA_RANKS
(
COMBINE_LAUNCH_CASE
);
#undef COMBINE_LAUNCH_CASE
}
...
...
@@ -1997,8 +1985,4 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
}
// namespace deep_ep
// #ifdef __clang__
// #pragma clang diagnostic pop
// #endif // __clang__
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment