Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
1a35d640
Commit
1a35d640
authored
May 21, 2026
by
root
Browse files
fix dtk26.04 4nodes core dump.
Signed-off-by:
root
<
root@host-10-212-17-3.cluster.local
>
parent
95e46992
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
936 additions
and
952 deletions
+936
-952
csrc/config.hpp
csrc/config.hpp
+2
-2
csrc/deep_ep.cu
csrc/deep_ep.cu
+4
-4
csrc/kernels/internode.cu
csrc/kernels/internode.cu
+930
-946
No files found.
csrc/config.hpp
View file @
1a35d640
...
@@ -47,7 +47,7 @@ struct Config {
...
@@ -47,7 +47,7 @@ struct Config {
EP_HOST_ASSERT
(
num_ranks
<=
NUM_MAX_NVL_PEERS
or
num_sms
%
(
2
*
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
)
==
0
);
EP_HOST_ASSERT
(
num_ranks
<=
NUM_MAX_NVL_PEERS
or
num_sms
%
(
2
*
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
)
==
0
);
const
auto
num_rdma_ranks
=
std
::
max
(
num_ranks
/
NUM_MAX_NVL_PEERS
,
1
);
const
auto
num_rdma_ranks
=
std
::
max
(
num_ranks
/
NUM_MAX_NVL_PEERS
,
1
);
const
auto
num_nvl_ranks
=
std
::
min
(
num_ranks
,
NUM_MAX_NVL_PEERS
);
const
auto
num_nvl_ranks
=
std
::
min
(
num_ranks
,
NUM_MAX_NVL_PEERS
);
const
int
num_channels
=
num_
ranks
<=
8
?
num_sms
/
2
:
num_sms
/
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
;
const
int
num_channels
=
num_
sms
/
2
;
// 计算每个nvl通信数据包的数据量
// 计算每个nvl通信数据包的数据量
size_t
num_single_nvl_bag_bytes
=
size_t
num_single_nvl_bag_bytes
=
...
@@ -83,7 +83,7 @@ struct Config {
...
@@ -83,7 +83,7 @@ struct Config {
EP_HOST_ASSERT
(
num_ranks
%
NUM_MAX_NVL_PEERS
==
0
);
EP_HOST_ASSERT
(
num_ranks
%
NUM_MAX_NVL_PEERS
==
0
);
EP_HOST_ASSERT
(
num_sms
%
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
==
0
);
EP_HOST_ASSERT
(
num_sms
%
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
==
0
);
const
int
num_rdma_ranks
=
num_ranks
/
NUM_MAX_NVL_PEERS
;
const
int
num_rdma_ranks
=
num_ranks
/
NUM_MAX_NVL_PEERS
;
const
int
num_channels
=
num_sms
/
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
;
const
int
num_channels
=
num_sms
/
2
;
// 计算每个rdma通信数据包的数据量
// 计算每个rdma通信数据包的数据量
size_t
num_single_rdma_bag_bytes
=
size_t
num_single_rdma_bag_bytes
=
...
...
csrc/deep_ep.cu
View file @
1a35d640
...
@@ -809,8 +809,8 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te
...
@@ -809,8 +809,8 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te
// here.
// here.
pybind11
::
gil_scoped_release
release
;
pybind11
::
gil_scoped_release
release
;
const
int
num_channels
=
config
.
num_sms
/
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
;
const
int
num_channels
=
config
.
num_sms
/
2
;
EP_HOST_ASSERT
(
config
.
num_sms
%
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
==
0
);
//
EP_HOST_ASSERT(config.num_sms % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0);
EP_HOST_ASSERT
(
0
<
get_num_rdma_ranks
()
and
get_num_rdma_ranks
()
<=
NUM_MAX_RDMA_PEERS
);
EP_HOST_ASSERT
(
0
<
get_num_rdma_ranks
()
and
get_num_rdma_ranks
()
<=
NUM_MAX_RDMA_PEERS
);
bool
cached_mode
=
cached_rdma_channel_prefix_matrix
.
has_value
();
bool
cached_mode
=
cached_rdma_channel_prefix_matrix
.
has_value
();
...
@@ -1130,8 +1130,8 @@ Buffer::internode_combine(
...
@@ -1130,8 +1130,8 @@ Buffer::internode_combine(
const
torch
::
Tensor
&
combined_nvl_head
,
const
Config
&
config
,
const
torch
::
Tensor
&
combined_nvl_head
,
const
Config
&
config
,
std
::
optional
<
EventHandle
>
&
previous_event
,
bool
async
,
bool
allocate_on_comm_stream
)
{
std
::
optional
<
EventHandle
>
&
previous_event
,
bool
async
,
bool
allocate_on_comm_stream
)
{
#ifndef DISABLE_ROCSHMEM
#ifndef DISABLE_ROCSHMEM
const
int
num_channels
=
config
.
num_sms
/
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
;
const
int
num_channels
=
config
.
num_sms
/
2
;
EP_HOST_ASSERT
(
config
.
num_sms
%
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
==
0
);
//
EP_HOST_ASSERT(config.num_sms % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0);
// Shape and contiguous checks
// Shape and contiguous checks
EP_HOST_ASSERT
(
x
.
dim
()
==
2
and
x
.
is_contiguous
());
EP_HOST_ASSERT
(
x
.
dim
()
==
2
and
x
.
is_contiguous
());
...
...
csrc/kernels/internode.cu
View file @
1a35d640
...
@@ -7,17 +7,6 @@
...
@@ -7,17 +7,6 @@
#ifndef DISABLE_ROCSHMEM
#ifndef DISABLE_ROCSHMEM
// 安全检查:确保宏已定义
#ifndef HIP_VERSION_PATCH
#error "HIP_VERSION_PATCH not defined! Check your HIP installation."
#endif
// TODO: fix unroll warnings
// #ifdef __clang__
// #pragma clang diagnostic push
// #pragma clang diagnostic ignored "-Wpass-failed"
// #pragma clang diagnostic ignored "-Wdeprecated-volatile"
// #endif // __clang__
namespace
deep_ep
{
namespace
deep_ep
{
namespace
internode
{
namespace
internode
{
...
@@ -25,7 +14,7 @@ namespace internode {
...
@@ -25,7 +14,7 @@ namespace internode {
extern
shmem_team_t
cpu_rdma_team
;
extern
shmem_team_t
cpu_rdma_team
;
struct
SourceMeta
{
struct
SourceMeta
{
int
src_rdma_rank
,
is_token_in_nvl_rank_bits
;
// sizeof(SourceMeta) = 8
int
src_rdma_rank
,
is_token_in_nvl_rank_bits
;
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
==
8
,
"Invalid number of maximum NVL peers"
);
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
==
8
,
"Invalid number of maximum NVL peers"
);
...
@@ -60,18 +49,16 @@ __host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_
...
@@ -60,18 +49,16 @@ __host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_
__host__
__device__
__forceinline__
std
::
pair
<
int
,
int
>
__host__
__device__
__forceinline__
std
::
pair
<
int
,
int
>
get_rdma_clean_meta
(
int
hidden_int4
,
int
num_scales
,
int
num_topk_idx
,
int
num_topk_weights
,
get_rdma_clean_meta
(
int
hidden_int4
,
int
num_scales
,
int
num_topk_idx
,
int
num_topk_weights
,
int
num_rdma_ranks
,
int
num_rdma_recv_buffer_tokens
,
int
num_channels
)
{
int
num_rdma_ranks
,
int
num_rdma_recv_buffer_tokens
,
int
num_sms
)
{
// Return `int32_t` offset and count to clean
// Return `int32_t` offset and count to clean
return
{(
get_num_bytes_per_rdma_token
(
hidden_int4
,
num_scales
,
num_topk_idx
,
num_topk_weights
)
*
return
{(
get_num_bytes_per_rdma_token
(
hidden_int4
,
num_scales
,
num_topk_idx
,
num_topk_weights
)
*
num_rdma_recv_buffer_tokens
*
num_rdma_ranks
*
2
*
num_
channel
s
)
/
sizeof
(
int
),
num_rdma_recv_buffer_tokens
*
num_rdma_ranks
*
2
*
num_
sm
s
)
/
sizeof
(
int
),
(
NUM_MAX_NVL_PEERS
*
2
+
4
)
*
num_rdma_ranks
*
2
*
num_
channel
s
};
(
NUM_MAX_NVL_PEERS
*
2
+
4
)
*
num_rdma_ranks
*
2
*
num_
sm
s
};
}
}
__host__
__device__
__forceinline__
std
::
pair
<
int
,
int
>
__host__
__device__
__forceinline__
std
::
pair
<
int
,
int
>
get_nvl_clean_meta
(
int
hidden_int4
,
int
num_scales
,
int
num_topk_idx
,
int
num_topk_weights
,
get_nvl_clean_meta
(
int
hidden_int4
,
int
num_scales
,
int
num_topk_idx
,
int
num_topk_weights
,
int
num_rdma_ranks
,
int
num_nvl_ranks
,
int
num_nvl_recv_buffer_tokens
,
int
num_rdma_ranks
,
int
num_nvl_ranks
,
int
num_nvl_recv_buffer_tokens
,
int
num_
channel
s
)
{
int
num_
sm
s
)
{
// Return `int32_t` offset and to clean
// Return `int32_t` offset and to clean
EP_STATIC_ASSERT
(
sizeof
(
SourceMeta
)
%
sizeof
(
int
)
==
0
,
EP_STATIC_ASSERT
(
sizeof
(
SourceMeta
)
%
sizeof
(
int
)
==
0
,
"Invalid size of `SourceMeta`"
);
"Invalid size of `SourceMeta`"
);
...
@@ -79,8 +66,8 @@ get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_to
...
@@ -79,8 +66,8 @@ get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_to
(
num_nvl_recv_buffer_tokens
*
(
num_nvl_recv_buffer_tokens
*
(
hidden_int4
*
sizeof
(
int4
)
+
num_scales
*
sizeof
(
float
)
+
num_topk_idx
*
sizeof
(
int
)
+
(
hidden_int4
*
sizeof
(
int4
)
+
num_scales
*
sizeof
(
float
)
+
num_topk_idx
*
sizeof
(
int
)
+
num_topk_weights
*
sizeof
(
float
)
+
sizeof
(
SourceMeta
))
*
num_topk_weights
*
sizeof
(
float
)
+
sizeof
(
SourceMeta
))
*
num_nvl_ranks
*
num_
channel
s
)
/
sizeof
(
int
),
num_nvl_ranks
*
num_
sm
s
)
/
sizeof
(
int
),
num_nvl_ranks
*
(
2
*
num_rdma_ranks
+
2
)
*
num_
channel
s
,
num_nvl_ranks
*
(
2
*
num_rdma_ranks
+
2
)
*
num_
sm
s
,
};
};
}
}
...
@@ -92,9 +79,10 @@ __forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank,
...
@@ -92,9 +79,10 @@ __forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank,
template
<
bool
kLowLatencyMode
>
template
<
bool
kLowLatencyMode
>
__forceinline__
__device__
void
__forceinline__
__device__
void
du
shmem_
barrier
_with_same_gpu_idx
(
const
shmem_team_t
&
rdma_team
)
{
shmem_
sync
_with_same_gpu_idx
(
const
shmem_team_t
&
rdma_team
)
{
// NOTE: shmem_device_barrier_all() might be an issue as
// NOTE: shmem_device_barrier_all() might be an issue as
// it doesn't follow OpenSHMEM specification on ROCm
// it doesn't follow OpenSHMEM specification on ROCm
// kLowLatencyMode ? shmem_device_sync(rdma_team) : shmem_device_sync_all();
kLowLatencyMode
?
shmem_barrier
(
rdma_team
)
:
shmem_device_barrier_all
();
kLowLatencyMode
?
shmem_barrier
(
rdma_team
)
:
shmem_device_barrier_all
();
}
}
...
@@ -123,7 +111,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
...
@@ -123,7 +111,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
// Communication with others
// Communication with others
// Global barrier: the first warp do intra-node sync, the second warp do internode sync
// Global barrier: the first warp do intra-node sync, the second warp do internode sync
if
(
thread_id
==
kWarpSize
)
if
(
thread_id
==
kWarpSize
)
du
shmem_
barrier
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
shmem_
sync
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
barrier_block
<
NUM_MAX_NVL_PEERS
>
(
barrier_signal_ptrs
,
nvl_rank
);
barrier_block
<
NUM_MAX_NVL_PEERS
>
(
barrier_signal_ptrs
,
nvl_rank
);
...
@@ -175,7 +163,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
...
@@ -175,7 +163,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
__syncthreads
();
__syncthreads
();
if
(
thread_id
==
0
)
if
(
thread_id
==
0
)
du
shmem_
barrier
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
shmem_
sync
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
__syncthreads
();
__syncthreads
();
...
@@ -266,7 +254,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
...
@@ -266,7 +254,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
// Finally barrier
// Finally barrier
if
(
thread_id
==
kWarpSize
)
if
(
thread_id
==
kWarpSize
)
du
shmem_
barrier
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
shmem_
sync
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
barrier_block
<
NUM_MAX_NVL_PEERS
>
(
barrier_signal_ptrs
,
nvl_rank
);
barrier_block
<
NUM_MAX_NVL_PEERS
>
(
barrier_signal_ptrs
,
nvl_rank
);
}
else
{
}
else
{
...
@@ -383,759 +371,754 @@ constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) {
...
@@ -383,759 +371,754 @@ constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) {
return
num_rdma_ranks
<
8
?
num_rdma_ranks
:
8
;
return
num_rdma_ranks
<
8
?
num_rdma_ranks
:
8
;
}
}
template
<
bool
kLowLatencyMode
,
int
kNumRDMARanks
,
bool
kCachedMode
,
template
<
bool
kLowLatencyMode
,
int
kNumRDMARanks
,
bool
kCachedMode
,
int
kNumDispatchRDMASenderWarps
,
int
kNumDispatchRDMASenderWarps
,
int
kNumTopkRDMARanks
=
get_num_topk_rdma_ranks
(
kNumRDMARanks
)>
int
kNumTopkRDMARanks
=
get_num_topk_rdma_ranks
(
kNumRDMARanks
)>
__global__
void
__launch_bounds__
(((
1
+
NUM_MAX_NVL_PEERS
)
*
kWarpSize
),
1
)
__global__
void
dispatch
(
int4
*
recv_x
,
float
*
recv_x_scales
,
int64_t
*
recv_topk_idx
,
float
*
recv_topk_weights
,
__launch_bounds__
(((
kNumDispatchRDMASenderWarps
+
1
+
NUM_MAX_NVL_PEERS
)
*
kWarpSize
),
1
)
SourceMeta
*
recv_src_meta
,
const
int4
*
x
,
const
float
*
x_scales
,
dispatch
(
int4
*
recv_x
,
float
*
recv_x_scales
,
int64_t
*
recv_topk_idx
,
float
*
recv_topk_weights
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
int
*
send_rdma_head
,
SourceMeta
*
recv_src_meta
,
const
int4
*
x
,
const
float
*
x_scales
,
int
*
send_nvl_head
,
int
*
recv_rdma_channel_prefix_matrix
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
int
*
send_rdma_head
,
int
*
recv_gbl_channel_prefix_matrix
,
const
int
*
rdma_channel_prefix_matrix
,
int
*
send_nvl_head
,
int
*
recv_rdma_channel_prefix_matrix
,
const
int
*
recv_rdma_rank_prefix_sum
,
const
int
*
gbl_channel_prefix_matrix
,
int
*
recv_gbl_channel_prefix_matrix
,
const
int
*
rdma_channel_prefix_matrix
,
const
int
*
recv_gbl_rank_prefix_sum
,
const
bool
*
is_token_in_rank
,
int
num_tokens
,
const
int
*
recv_rdma_rank_prefix_sum
,
const
int
*
gbl_channel_prefix_matrix
,
int
hidden_int4
,
int
num_scales
,
int
num_topk
,
int
num_experts
,
int
scale_token_stride
,
const
int
*
recv_gbl_rank_prefix_sum
,
const
bool
*
is_token_in_rank
,
int
num_tokens
,
int
scale_hidden_stride
,
void
*
rdma_buffer_ptr
,
int
num_max_rdma_chunked_send_tokens
,
int
hidden_int4
,
int
num_scales
,
int
num_topk
,
int
num_experts
,
int
scale_token_stride
,
int
num_max_rdma_chunked_recv_tokens
,
void
**
buffer_ptrs
,
int
scale_hidden_stride
,
void
*
rdma_buffer_ptr
,
int
num_max_rdma_chunked_send_tokens
,
int
num_max_nvl_chunked_send_tokens
,
int
num_max_nvl_chunked_recv_tokens
,
int
rank
,
int
num_max_rdma_chunked_recv_tokens
,
void
**
buffer_ptrs
,
int
num_ranks
)
{
int
num_max_nvl_chunked_send_tokens
,
int
num_max_nvl_chunked_recv_tokens
,
int
rank
,
int
num_ranks
)
{
enum
class
WarpRole
{
enum
class
WarpRole
{
kRDMASender
,
// 从x写入到RDMA发送缓存
kRDMASender
,
kRDMASenderCoordinator
,
// 从RDMA发送缓存写入到远端rdma_rank接收缓存
kRDMASenderCoordinator
,
kRDMAAndNVLForwarder
,
// 从RDMA接收缓存转写到ipc nvl缓存
kRDMAAndNVLForwarder
,
kForwarderCoordinator
,
// 向远端RDMA确认接收
kForwarderCoordinator
,
kNVLReceivers
// 从nvl缓存写入到recv_x
kNVLReceivers
};
};
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
__shared__
shmem_ctx_t
ctx
;
__shared__
rocshmem
::
rocshmem_ctx_t
ctx
;
shmem_wg_ctx_create
(
&
ctx
);
rocshmem
::
rocshmem_wg_ctx_create
(
0
,
&
ctx
);
#endif
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
num_threads
=
static_cast
<
int
>
(
blockDim
.
x
),
num_warps
=
num_threads
/
kWarpSize
;
const
auto
num_threads
=
static_cast
<
int
>
(
blockDim
.
x
),
num_warps
=
num_threads
/
kWarpSize
;
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
),
warp_id
=
thread_id
/
kWarpSize
,
lane_id
=
get_lane_id
();
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
),
warp_id
=
thread_id
/
kWarpSize
,
const
auto
num_channels
=
static_cast
<
int
>
(
gridDim
.
x
)
/
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
,
lane_id
=
get_lane_id
();
channel_id
=
sm_id
/
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
;
const
auto
num_channels
=
static_cast
<
int
>
(
gridDim
.
x
)
/
2
,
channel_id
=
sm_id
/
2
;
const
bool
is_forwarder
=
sm_id
%
2
==
0
;
const
auto
rdma_rank
=
rank
/
NUM_MAX_NVL_PEERS
,
nvl_rank
=
rank
%
NUM_MAX_NVL_PEERS
;
const
auto
rdma_rank
=
rank
/
NUM_MAX_NVL_PEERS
,
nvl_rank
=
rank
%
NUM_MAX_NVL_PEERS
;
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
*
sizeof
(
bool
)
==
sizeof
(
uint64_t
),
"Invalid number of NVL peers"
);
EP_DEVICE_ASSERT
(
num_warps
==
1
+
NUM_MAX_NVL_PEERS
);
const
auto
role_meta
=
[
=
]()
->
std
::
pair
<
WarpRole
,
int
>
{
const
auto
role_meta
=
[
=
]()
->
std
::
pair
<
WarpRole
,
int
>
{
if
(
sm_id
%
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
==
0
)
{
if
(
is_forwarder
)
{
if
(
warp_id
<
kNumDispatchRDMASenderWarps
)
{
if
(
warp_id
<
NUM_MAX_NVL_PEERS
)
{
return
{
WarpRole
::
kRDMASender
,
-
1
};
}
else
if
(
warp_id
==
kNumDispatchRDMASenderWarps
)
{
return
{
WarpRole
::
kRDMASenderCoordinator
,
-
1
};
}
}
else
if
(
sm_id
%
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
==
1
)
{
if
(
warp_id
<
NUM_MAX_NVL_PEERS
)
{
return
{
WarpRole
::
kRDMAAndNVLForwarder
,
(
warp_id
+
channel_id
)
%
NUM_MAX_NVL_PEERS
};
return
{
WarpRole
::
kRDMAAndNVLForwarder
,
(
warp_id
+
channel_id
)
%
NUM_MAX_NVL_PEERS
};
}
else
{
}
else
{
return
{
WarpRole
::
kForwarderCoordinator
,
warp_id
-
NUM_MAX_NVL_PEERS
};
return
{
WarpRole
::
kForwarderCoordinator
,
warp_id
-
NUM_MAX_NVL_PEERS
};
}
}
}
else
if
(
warp_id
<
kNumDispatchRDMASenderWarps
)
{
return
{
WarpRole
::
kRDMASender
,
-
1
};
}
else
if
(
warp_id
==
kNumDispatchRDMASenderWarps
)
{
return
{
WarpRole
::
kRDMASenderCoordinator
,
-
1
};
}
else
{
}
else
{
return
{
WarpRole
::
kNVLReceivers
,
(
warp_id
+
channel_id
+
1
)
%
NUM_MAX_NVL_PEERS
};
return
{
WarpRole
::
kNVLReceivers
,
(
warp_id
+
channel_id
-
kNumDispatchRDMASenderWarps
)
%
NUM_MAX_NVL_PEERS
};
}
}
}();
}();
auto
warp_role
=
role_meta
.
first
;
auto
warp_role
=
role_meta
.
first
;
auto
target_rank
=
role_meta
.
second
;
// Not applicable for RDMA senders
auto
target_rank
=
role_meta
.
second
;
// Not applicable for RDMA senders
EP_DEVICE_ASSERT
(
num_warps
==
kNumDispatchRDMASenderWarps
+
1
+
NUM_MAX_NVL_PEERS
);
// Data checks
EP_DEVICE_ASSERT
(
num_topk
<=
kWarpSize
);
// RDMA symmetric layout
// RDMA symmetric layout
auto
hidden_bytes
=
hidden_int4
*
sizeof
(
int4
);
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
*
sizeof
(
bool
)
==
sizeof
(
uint64_t
),
auto
num_bytes_per_rdma_token
=
get_num_bytes_per_rdma_token
(
hidden_int4
,
num_scales
,
num_topk
,
num_topk
);
"Invalid number of NVL peers"
);
auto
rdma_channel_data
=
SymBuffer
<
int8_t
>
(
rdma_buffer_ptr
,
num_max_rdma_chunked_recv_tokens
*
num_bytes_per_rdma_token
,
kNumRDMARanks
,
channel_id
,
num_channels
);
auto
hidden_bytes
=
hidden_int4
*
sizeof
(
int4
);
auto
rdma_channel_meta
=
SymBuffer
<
int
>
(
rdma_buffer_ptr
,
NUM_MAX_NVL_PEERS
*
2
+
2
,
kNumRDMARanks
,
channel_id
,
num_channels
);
auto
num_bytes_per_rdma_token
=
auto
rdma_channel_head
=
SymBuffer
<
uint64_t
,
false
>
(
rdma_buffer_ptr
,
1
,
kNumRDMARanks
,
channel_id
,
num_channels
);
get_num_bytes_per_rdma_token
(
hidden_int4
,
num_scales
,
num_topk
,
num_topk
);
auto
rdma_channel_tail
=
SymBuffer
<
uint64_t
,
false
>
(
rdma_buffer_ptr
,
1
,
kNumRDMARanks
,
channel_id
,
num_channels
);
auto
rdma_channel_data
=
SymBuffer
<
int8_t
>
(
rdma_buffer_ptr
,
num_max_rdma_chunked_recv_tokens
*
num_bytes_per_rdma_token
,
kNumRDMARanks
,
channel_id
,
num_channels
);
auto
rdma_channel_meta
=
SymBuffer
<
int
>
(
rdma_buffer_ptr
,
NUM_MAX_NVL_PEERS
*
2
+
2
,
kNumRDMARanks
,
channel_id
,
num_channels
);
auto
rdma_channel_head
=
SymBuffer
<
uint64_t
,
false
>
(
rdma_buffer_ptr
,
1
,
kNumRDMARanks
,
channel_id
,
num_channels
);
auto
rdma_channel_tail
=
SymBuffer
<
uint64_t
,
false
>
(
rdma_buffer_ptr
,
1
,
kNumRDMARanks
,
channel_id
,
num_channels
);
// NVL buffer layouts
// NVL buffer layouts
// NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for Senders, Read for Receivers"
// NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr`
// means "Write for Senders, Read for Receivers"
void
*
rs_wr_buffer_ptr
=
nullptr
,
*
ws_rr_buffer_ptr
=
nullptr
;
void
*
rs_wr_buffer_ptr
=
nullptr
,
*
ws_rr_buffer_ptr
=
nullptr
;
int
rs_wr_rank
=
0
,
ws_rr_rank
=
0
;
int
rs_wr_rank
=
0
,
ws_rr_rank
=
0
;
if
(
warp_role
==
WarpRole
::
kRDMAAndNVLForwarder
)
if
(
warp_role
==
WarpRole
::
kRDMAAndNVLForwarder
)
rs_wr_buffer_ptr
=
buffer_ptrs
[
nvl_rank
],
ws_rr_buffer_ptr
=
buffer_ptrs
[
target_rank
],
rs_wr_rank
=
nvl_rank
,
ws_rr_rank
=
target_rank
;
rs_wr_buffer_ptr
=
buffer_ptrs
[
nvl_rank
],
ws_rr_buffer_ptr
=
buffer_ptrs
[
target_rank
],
rs_wr_rank
=
nvl_rank
,
ws_rr_rank
=
target_rank
;
if
(
warp_role
==
WarpRole
::
kNVLReceivers
)
if
(
warp_role
==
WarpRole
::
kNVLReceivers
)
rs_wr_buffer_ptr
=
buffer_ptrs
[
target_rank
],
ws_rr_buffer_ptr
=
buffer_ptrs
[
nvl_rank
],
rs_wr_rank
=
target_rank
,
ws_rr_rank
=
nvl_rank
;
rs_wr_buffer_ptr
=
buffer_ptrs
[
target_rank
],
ws_rr_buffer_ptr
=
buffer_ptrs
[
nvl_rank
],
rs_wr_rank
=
target_rank
,
ws_rr_rank
=
nvl_rank
;
// Allocate buffers
// Allocate buffers
auto
nvl_channel_x
=
AsymBuffer
<
int4
>
(
ws_rr_buffer_ptr
,
num_max_nvl_chunked_recv_tokens
*
hidden_int4
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
).
advance_also
(
rs_wr_buffer_ptr
);
auto
nvl_channel_x
=
auto
nvl_channel_src_meta
=
AsymBuffer
<
SourceMeta
>
(
ws_rr_buffer_ptr
,
num_max_nvl_chunked_recv_tokens
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
).
advance_also
(
rs_wr_buffer_ptr
);
AsymBuffer
<
int4
>
(
ws_rr_buffer_ptr
,
num_max_nvl_chunked_recv_tokens
*
hidden_int4
,
auto
nvl_channel_x_scales
=
AsymBuffer
<
float
>
(
ws_rr_buffer_ptr
,
num_max_nvl_chunked_recv_tokens
*
num_scales
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
).
advance_also
(
rs_wr_buffer_ptr
);
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
)
auto
nvl_channel_topk_idx
=
AsymBuffer
<
int
>
(
ws_rr_buffer_ptr
,
num_max_nvl_chunked_recv_tokens
*
num_topk
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
).
advance_also
(
rs_wr_buffer_ptr
);
.
advance_also
(
rs_wr_buffer_ptr
);
auto
nvl_channel_topk_weights
=
AsymBuffer
<
float
>
(
ws_rr_buffer_ptr
,
num_max_nvl_chunked_recv_tokens
*
num_topk
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
).
advance_also
(
rs_wr_buffer_ptr
);
auto
nvl_channel_src_meta
=
auto
nvl_channel_prefix_start
=
AsymBuffer
<
int
>
(
ws_rr_buffer_ptr
,
kNumRDMARanks
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
).
advance_also
(
rs_wr_buffer_ptr
);
AsymBuffer
<
SourceMeta
>
(
ws_rr_buffer_ptr
,
num_max_nvl_chunked_recv_tokens
,
NUM_MAX_NVL_PEERS
,
auto
nvl_channel_prefix_end
=
AsymBuffer
<
int
>
(
ws_rr_buffer_ptr
,
kNumRDMARanks
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
).
advance_also
(
rs_wr_buffer_ptr
);
channel_id
,
num_channels
,
rs_wr_rank
)
auto
nvl_channel_head
=
AsymBuffer
<
int
>
(
rs_wr_buffer_ptr
,
1
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
ws_rr_rank
).
advance_also
(
ws_rr_buffer_ptr
);
.
advance_also
(
rs_wr_buffer_ptr
);
auto
nvl_channel_tail
=
AsymBuffer
<
int
>
(
ws_rr_buffer_ptr
,
1
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
).
advance_also
(
rs_wr_buffer_ptr
);
auto
nvl_channel_x_scales
=
AsymBuffer
<
float
>
(
ws_rr_buffer_ptr
,
num_max_nvl_chunked_recv_tokens
*
num_scales
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
)
.
advance_also
(
rs_wr_buffer_ptr
);
auto
nvl_channel_topk_idx
=
AsymBuffer
<
int
>
(
ws_rr_buffer_ptr
,
num_max_nvl_chunked_recv_tokens
*
num_topk
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
)
.
advance_also
(
rs_wr_buffer_ptr
);
auto
nvl_channel_topk_weights
=
AsymBuffer
<
float
>
(
ws_rr_buffer_ptr
,
num_max_nvl_chunked_recv_tokens
*
num_topk
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
)
.
advance_also
(
rs_wr_buffer_ptr
);
auto
nvl_channel_prefix_start
=
AsymBuffer
<
int
>
(
ws_rr_buffer_ptr
,
kNumRDMARanks
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
)
.
advance_also
(
rs_wr_buffer_ptr
);
auto
nvl_channel_prefix_end
=
AsymBuffer
<
int
>
(
ws_rr_buffer_ptr
,
kNumRDMARanks
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
)
.
advance_also
(
rs_wr_buffer_ptr
);
auto
nvl_channel_head
=
AsymBuffer
<
int
>
(
rs_wr_buffer_ptr
,
1
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
ws_rr_rank
)
.
advance_also
(
ws_rr_buffer_ptr
);
auto
nvl_channel_tail
=
AsymBuffer
<
int
>
(
ws_rr_buffer_ptr
,
1
,
NUM_MAX_NVL_PEERS
,
channel_id
,
num_channels
,
rs_wr_rank
)
.
advance_also
(
rs_wr_buffer_ptr
);
// RDMA sender warp synchronization
// RDMA sender warp synchronization
__shared__
volatile
int
rdma_send_next_token_idx
;
__shared__
volatile
int
rdma_send_next_token_idx
;
__shared__
volatile
int
rdma_send_channel_tail
[
kNumRDMARanks
];
__shared__
volatile
int
rdma_send_channel_tail
[
kNumRDMARanks
];
__shared__
volatile
int
rdma_send_channel_next_tail
[
kNumRDMARanks
];
__shared__
volatile
int
rdma_send_channel_next_tail
[
kNumRDMARanks
];
__shared__
volatile
int
rdma_sender_counter
[
1
];
__shared__
volatile
int
rdma_forwarder_counter
[
1
];
if
(
threadIdx
.
x
==
0
)
{
rdma_sender_counter
[
0
]
=
0
;
rdma_forwarder_counter
[
0
]
=
0
;
}
__syncthreads
();
// NVL and RDMA coordinate Forward warp synchronization
auto
sync_rdma_sender_smem
=
[
&
]()
{
__shared__
volatile
int
forward_channel_head
[
NUM_MAX_NVL_PEERS
][
kNumRDMARanks
];
if
(
lane_id
==
0
)
{
__shared__
volatile
bool
forward_channel_retired
[
NUM_MAX_NVL_PEERS
];
volatile
int
ret
=
__hip_atomic_fetch_add
(
&
rdma_sender_counter
[
0
],
1
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
// Place the main logic of your kernel here, using the parameters above.
// volatile int ret = atomicAdd((int*)&rdma_sender_counter[0], 1);
if
(
warp_role
==
WarpRole
::
kRDMASender
)
{
}
/*
syncwarp
();
这段代码的主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作。
while
(
rdma_sender_counter
[
0
]
<
(
kNumDispatchRDMASenderWarps
+
1
))
{
它首先获取当前通道的任务范围,然后清理共享内存,接着计算并发送本通道中的令牌数量。
}
然后,它遍历所有的令牌,读取每个令牌的RDMA秩的存在性,获取顺序锁,计算下一个尾部位置,存储RDMA头部,更新最后一个令牌尾部,释放顺序锁,并广播尾部位置。
};
最后,它复制相关的数据到对称发送缓冲区。
kRDMASender主要目的是将发送信息x, x_scale,source_meta, topk_idx, topk_weight等信息填充进入rdma发送缓存,
期间要同步warp直接对token的依序操作,以及和kForwarderCoordinator, kRDMASenderCoordinator内存同步。
同时在复制操作时, 使用ld.global.nc.L1::no_allocate.L2::256B, st.global.L1::no_allocate减少L1/L2缓存使用。
*/
// 获取任务范围
int
token_start_idx
,
token_end_idx
;
get_channel_task_range
(
num_tokens
,
num_channels
,
channel_id
,
token_start_idx
,
token_end_idx
);
// 清理共享内存
// Forward warp synchronization
EP_STATIC_ASSERT
(
kNumRDMARanks
<=
kWarpSize
,
"无效的RDMA秩数量"
);
__shared__
volatile
int
forward_channel_head
[
NUM_MAX_NVL_PEERS
][
kNumRDMARanks
];
if
(
warp_id
==
0
&&
lane_id
==
0
)
{
__shared__
volatile
bool
forward_channel_retired
[
NUM_MAX_NVL_PEERS
];
rdma_send_next_token_idx
=
token_start_idx
;
// NOTE: Not sure that __syncthreads() is a suitable replacement
auto
sync_forwarder_smem
=
[
&
]()
{
if
(
lane_id
==
0
)
{
volatile
int
ret
=
__hip_atomic_fetch_add
(
&
rdma_forwarder_counter
[
0
],
1
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
// volatile int ret = atomicAdd((int*)&rdma_forwarder_counter[0], 1);
}
}
if
(
warp_id
==
0
&&
lane_id
<
kNumRDMARanks
)
{
syncwarp
();
rdma_send_channel_tail
[
lane_id
]
=
0
;
while
(
rdma_forwarder_counter
[
0
]
<
(
NUM_MAX_NVL_PEERS
+
1
))
{
rdma_send_channel_next_tail
[
lane_id
]
=
0
;
}
}
};
// 发送本通道中的令牌数量,通过 `-value - 1` 表示
if
(
warp_role
==
WarpRole
::
kRDMASender
)
{
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
*
2
+
2
<=
kWarpSize
,
"无效的NVL对等体数量"
);
// Get tasks
// 对于每个目标RDMA秩,以warp为单位进行迭代。计算发送缓冲区的值,并存储在rdma_channel_meta.send_buffer中
int
token_start_idx
,
token_end_idx
;
// 用于填充rdma_channel_meta.send_buffer本节点发送到远端rank, rdma_rank的起始index和结束index
get_channel_task_range
(
num_tokens
,
num_channels
,
channel_id
,
token_start_idx
,
for
(
int
dst_rdma_rank
=
warp_id
;
dst_rdma_rank
<
kNumRDMARanks
;
dst_rdma_rank
+=
kNumDispatchRDMASenderWarps
)
{
token_end_idx
);
auto
dst_ptr
=
dst_rdma_rank
==
rdma_rank
?
rdma_channel_meta
.
recv_buffer
(
dst_rdma_rank
)
:
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
);
// Clean shared memory
EP_STATIC_ASSERT
(
kNumRDMARanks
<=
kWarpSize
,
"Invalid number of RDMA ranks"
);
(
warp_id
==
0
and
lane_id
==
0
)
?
(
rdma_send_next_token_idx
=
token_start_idx
)
:
0
;
(
warp_id
==
0
and
lane_id
<
kNumRDMARanks
)
?
(
rdma_send_channel_tail
[
lane_id
]
=
0
)
:
0
;
(
warp_id
==
0
and
lane_id
<
kNumRDMARanks
)
?
(
rdma_send_channel_next_tail
[
lane_id
]
=
0
)
:
0
;
// Send number of tokens in this channel by `-value - 1`
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
*
2
+
2
<=
kWarpSize
,
"Invalid number of NVL peers"
);
for
(
int
dst_rdma_rank
=
warp_id
;
dst_rdma_rank
<
kNumRDMARanks
;
dst_rdma_rank
+=
kNumDispatchRDMASenderWarps
)
{
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
{
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
{
dst_ptr
[
lane_id
]
=
-
(
channel_id
==
0
?
0
:
gbl_channel_prefix_matrix
[(
dst_rdma_rank
*
NUM_MAX_NVL_PEERS
+
lane_id
)
*
num_channels
+
channel_id
-
1
])
-
1
;
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
)[
lane_id
]
=
-
(
channel_id
==
0
?
0
:
gbl_channel_prefix_matrix
[(
dst_rdma_rank
*
NUM_MAX_NVL_PEERS
+
lane_id
)
*
num_channels
+
channel_id
-
1
])
-
1
;
}
else
if
(
lane_id
<
NUM_MAX_NVL_PEERS
*
2
)
{
}
else
if
(
lane_id
<
NUM_MAX_NVL_PEERS
*
2
)
{
dst_ptr
[
lane_id
]
=
-
gbl_channel_prefix_matrix
[(
dst_rdma_rank
*
NUM_MAX_NVL_PEERS
+
lane_id
-
NUM_MAX_NVL_PEERS
)
*
num_channels
+
channel_id
]
-
1
;
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
)[
lane_id
]
=
-
gbl_channel_prefix_matrix
[(
dst_rdma_rank
*
NUM_MAX_NVL_PEERS
+
lane_id
-
NUM_MAX_NVL_PEERS
)
*
num_channels
+
channel_id
]
-
1
;
}
else
if
(
lane_id
==
NUM_MAX_NVL_PEERS
*
2
)
{
}
else
if
(
lane_id
==
NUM_MAX_NVL_PEERS
*
2
)
{
dst_ptr
[
lane_id
]
=
-
(
channel_id
==
0
?
0
:
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel_id
-
1
])
-
1
;
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
)[
lane_id
]
=
-
(
channel_id
==
0
?
0
:
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel_id
-
1
])
-
1
;
}
else
if
(
lane_id
==
NUM_MAX_NVL_PEERS
*
2
+
1
)
{
}
else
if
(
lane_id
==
NUM_MAX_NVL_PEERS
*
2
+
1
)
{
dst_ptr
[
lane_id
]
=
-
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel_id
]
-
1
;
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
)[
lane_id
]
=
-
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel_id
]
-
1
;
}
}
syncwarp
();
rocshmem
::
rocshmem_ctx_int_put_nbi_wave
(
if
(
dst_rdma_rank
!=
rdma_rank
)
{
ctx
,
rdma_channel_meta
.
recv_buffer
(
rdma_rank
),
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_int_put_nbi_warp
(
ctx
,
#else
shmemx_int_put_nbi_warp
(
#endif
rdma_channel_meta
.
recv_buffer
(
rdma_rank
),
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
),
NUM_MAX_NVL_PEERS
*
2
+
2
,
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
),
NUM_MAX_NVL_PEERS
*
2
+
2
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
}
}
}
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
rocshmem
::
rocshmem_ctx_quiet
(
ctx
);
shmem_ctx_quiet
(
ctx
);
sync_rdma_sender_smem
();
#else
shmem_fence
();
#endif
// sync_rdma_sender_smem();
// Iterate over tokens and copy into buffer
__syncthreads
();
// 遍历令牌并复制到缓冲区
int64_t
token_idx
;
int64_t
token_idx
;
int
cached_rdma_channel_head
=
0
,
last_rdma_tail_idx
=
-
1
;
int
cached_rdma_channel_head
=
0
,
last_rdma_tail_idx
=
-
1
;
auto
send_buffer
=
lane_id
==
rdma_rank
?
rdma_channel_data
.
recv_buffer
(
lane_id
)
:
rdma_channel_data
.
send_buffer
(
lane_id
);
auto
send_buffer
=
lane_id
==
rdma_rank
?
rdma_channel_data
.
recv_buffer
(
lane_id
)
for
(
token_idx
=
token_start_idx
+
warp_id
;
token_idx
<
token_end_idx
;
token_idx
+=
kNumDispatchRDMASenderWarps
)
{
:
rdma_channel_data
.
send_buffer
(
lane_id
);
// 读取RDMA秩的存在性
for
(
token_idx
=
token_start_idx
+
warp_id
;
token_idx
<
token_end_idx
;
token_idx
+=
kNumDispatchRDMASenderWarps
)
{
// Read RDMA rank existence
uint64_t
is_token_in_rank_uint64
=
0
;
uint64_t
is_token_in_rank_uint64
=
0
;
if
(
lane_id
<
kNumRDMARanks
)
{
if
(
lane_id
<
kNumRDMARanks
)
is_token_in_rank_uint64
=
*
reinterpret_cast
<
const
uint64_t
*>
(
is_token_in_rank
+
token_idx
*
num_ranks
+
lane_id
*
NUM_MAX_NVL_PEERS
);
is_token_in_rank_uint64
=
*
reinterpret_cast
<
const
uint64_t
*>
(
}
is_token_in_rank
+
token_idx
*
num_ranks
+
lane_id
*
NUM_MAX_NVL_PEERS
);
// 获得处理数据的自旋锁,获得锁后才会处理一些数据信息
// Acquire sequential lock
while
(
lane_id
==
0
&&
rdma_send_next_token_idx
!=
token_idx
)
{
while
(
lane_id
==
0
and
rdma_send_next_token_idx
!=
token_idx
)
// 等待
;
}
syncwarp
();
syncwarp
();
//
获取下一个尾部位置
//
Acquire next tail
int
rdma_tail_idx
=
-
1
;
int
rdma_tail_idx
=
-
1
;
if
(
is_token_in_rank_uint64
!=
0
)
{
if
(
is_token_in_rank_uint64
!=
0
)
{
rdma_tail_idx
=
rdma_send_channel_next_tail
[
lane_id
]
++
;
rdma_tail_idx
=
rdma_send_channel_next_tail
[
lane_id
]
++
;
while
(
rdma_tail_idx
-
cached_rdma_channel_head
>=
num_max_rdma_chunked_recv_tokens
)
// 与kForwarderCoordinator相互配合,调节发送数据的频率
cached_rdma_channel_head
=
while
(
rdma_tail_idx
-
cached_rdma_channel_head
>=
num_max_rdma_chunked_recv_tokens
)
{
static_cast
<
int
>
(
ld_volatile_global
(
rdma_channel_head
.
buffer
(
lane_id
)));
cached_rdma_channel_head
=
static_cast
<
int
>
(
ld_volatile_global
(
rdma_channel_head
.
buffer
(
lane_id
)));
}
}
}
syncwarp
();
syncwarp
();
//
存储RDMA头部以供合并
//
Store RDMA head for combine
if
(
lane_id
<
kNumRDMARanks
&&
!
kCachedMode
)
{
if
(
lane_id
<
kNumRDMARanks
and
not
kCachedMode
)
send_rdma_head
[
token_idx
*
kNumRDMARanks
+
lane_id
]
=
rdma_tail_idx
;
send_rdma_head
[
token_idx
*
kNumRDMARanks
+
lane_id
]
=
rdma_tail_idx
;
}
//
更新最后一个令牌尾部
//
Update last token tail
if
(
last_rdma_tail_idx
>=
0
)
{
if
(
last_rdma_tail_idx
>=
0
)
st_release_cta
(
const_cast
<
int
*>
(
rdma_send_channel_tail
+
lane_id
),
last_rdma_tail_idx
+
1
);
st_release_cta
(
const_cast
<
const
int
*>
(
rdma_send_channel_tail
+
lane_id
),
}
last_rdma_tail_idx
+
1
);
last_rdma_tail_idx
=
rdma_tail_idx
;
last_rdma_tail_idx
=
rdma_tail_idx
;
// 释放顺序锁
// Release sequential lock
if
(
lane_id
==
0
)
{
lane_id
==
0
?
(
rdma_send_next_token_idx
+=
1
)
:
0
;
rdma_send_next_token_idx
+=
1
;
}
//
广播尾部位置
//
Broadcast tails
SourceMeta
src_meta
;
SourceMeta
src_meta
;
int
num_topk_ranks
=
0
,
topk_ranks
[
kNumTopkRDMARanks
];
int
num_topk_ranks
=
0
,
topk_ranks
[
kNumTopkRDMARanks
];
void
*
dst_send_buffers
[
kNumTopkRDMARanks
];
void
*
dst_send_buffers
[
kNumTopkRDMARanks
];
/*
#pragma unroll
该for循环主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作
for
(
int
i
=
0
,
slot_idx
;
i
<
kNumRDMARanks
;
++
i
)
*/
if
((
slot_idx
=
shfl_sync
(
rdma_tail_idx
,
i
))
>=
0
)
{
#pragma unroll
slot_idx
=
slot_idx
%
num_max_rdma_chunked_recv_tokens
;
for
(
int
i
=
0
,
slot_idx
;
i
<
kNumRDMARanks
;
++
i
)
{
topk_ranks
[
num_topk_ranks
]
=
i
;
// 使用__shfl_sync函数在warp内同步并广播rdma_tail_idx的值
if
((
slot_idx
=
shfl_sync
(
rdma_tail_idx
,
i
))
>=
0
)
{
// warp 所有线程参与,rdma_tail_idx默认为-1, 只有对应rdma rank需要发送时, rdma_tail_idx才会>=0
// 计算slot_idx在接收缓冲区中的位置
slot_idx
=
slot_idx
%
num_max_rdma_chunked_recv_tokens
;
// 存储当前RDMA秩到topk_ranks数组中
topk_ranks
[
num_topk_ranks
]
=
i
;
// 广播is_token_in_rank_uint64的值到所有线程,并解释为布尔数组
auto
recv_is_token_in_rank_uint64
=
broadcast
(
is_token_in_rank_uint64
,
i
);
auto
recv_is_token_in_rank_uint64
=
broadcast
(
is_token_in_rank_uint64
,
i
);
auto
recv_is_token_in_rank_values
=
reinterpret_cast
<
const
bool
*>
(
&
recv_is_token_in_rank_uint64
);
auto
recv_is_token_in_rank_values
=
reinterpret_cast
<
const
bool
*>
(
&
recv_is_token_in_rank_uint64
);
// 如果当前lane_id等于num_topk_ranks,则更新src_meta
if
(
lane_id
==
num_topk_ranks
)
if
(
lane_id
==
num_topk_ranks
)
{
src_meta
=
SourceMeta
(
rdma_rank
,
recv_is_token_in_rank_values
);
src_meta
=
SourceMeta
(
rdma_rank
,
recv_is_token_in_rank_values
);
}
dst_send_buffers
[
num_topk_ranks
++
]
=
reinterpret_cast
<
uint8_t
*>
(
broadcast
(
send_buffer
,
i
))
+
// 计算目标发送缓冲区的地址,并存储在dst_send_buffers数组中
slot_idx
*
num_bytes_per_rdma_token
;
// 获取到发送地址, num_topk_ranks-1 是需要发送的ranks数
dst_send_buffers
[
num_topk_ranks
++
]
=
reinterpret_cast
<
uint8_t
*>
(
broadcast
(
send_buffer
,
i
))
+
slot_idx
*
num_bytes_per_rdma_token
;
}
}
}
EP_DEVICE_ASSERT
(
num_topk_ranks
<=
kNumTopkRDMARanks
);
EP_DEVICE_ASSERT
(
num_topk_ranks
<=
kNumTopkRDMARanks
);
//////////////// 复制数据到发送缓冲区 ////////////////
// Copy `x` into symmetric send buffer
// 复制源元数据到对称发送缓冲区
auto
st_broadcast
=
[
=
](
const
int
key
,
const
int4
&
value
)
{
if
(
lane_id
<
num_topk_ranks
)
{
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
st_na_global
(
reinterpret_cast
<
SourceMeta
*>
(
dst_send_buffers
[
lane_id
]),
src_meta
);
st_na_global
(
reinterpret_cast
<
int4
*>
(
dst_send_buffers
[
j
])
+
key
,
value
);
}
};
UNROLLED_WARP_COPY
(
5
,
lane_id
,
hidden_int4
,
0
,
x
+
token_idx
*
hidden_int4
,
ld_nc_global
,
st_broadcast
);
for
(
int
i
=
0
;
i
<
num_topk_ranks
;
++
i
)
dst_send_buffers
[
i
]
=
reinterpret_cast
<
int4
*>
(
dst_send_buffers
[
i
])
+
hidden_int4
;
for
(
int
i
=
0
;
i
<
num_topk_ranks
;
++
i
)
{
// Copy source metadata into symmetric send buffer
dst_send_buffers
[
i
]
=
reinterpret_cast
<
SourceMeta
*>
(
dst_send_buffers
[
i
])
+
1
;
if
(
lane_id
<
num_topk_ranks
)
}
st_na_global
(
reinterpret_cast
<
SourceMeta
*>
(
dst_send_buffers
[
lane_id
]),
src_meta
);
// 复制 `x` 到对称发送缓冲区
for
(
int
i
=
0
;
i
<
num_topk_ranks
;
++
i
)
auto
st_broadcast
=
[
=
](
const
int
key
,
const
int4
&
value
)
{
dst_send_buffers
[
i
]
=
reinterpret_cast
<
SourceMeta
*>
(
dst_send_buffers
[
i
])
+
1
;
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
{
st_na_global
(
reinterpret_cast
<
int4
*>
(
dst_send_buffers
[
j
])
+
key
,
value
);
}
};
UNROLLED_WARP_COPY
(
5
,
lane_id
,
hidden_int4
,
0
,
x
+
token_idx
*
hidden_int4
,
ld_nc_global
,
st_broadcast
);
for
(
int
i
=
0
;
i
<
num_topk_ranks
;
++
i
)
{
dst_send_buffers
[
i
]
=
reinterpret_cast
<
int4
*>
(
dst_send_buffers
[
i
])
+
hidden_int4
;
}
//
复制
`x_scales`
到对称发送缓冲区
//
Copy
`x_scales`
into symmetric send buffer
for
(
int
i
=
lane_id
;
i
<
num_scales
;
i
+=
kWarpSize
)
{
for
(
int
i
=
lane_id
;
i
<
num_scales
;
i
+=
kWarpSize
)
{
auto
value
=
ld_nc_global
(
x_scales
+
token_idx
*
num_scales
+
i
);
auto
value
=
ld_nc_global
(
x_scales
+
token_idx
*
num_scales
+
i
);
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
{
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
st_na_global
(
reinterpret_cast
<
float
*>
(
dst_send_buffers
[
j
])
+
i
,
value
);
st_na_global
(
reinterpret_cast
<
float
*>
(
dst_send_buffers
[
j
])
+
i
,
value
);
}
}
for
(
int
i
=
0
;
i
<
num_topk_ranks
;
++
i
)
{
dst_send_buffers
[
i
]
=
reinterpret_cast
<
float
*>
(
dst_send_buffers
[
i
])
+
num_scales
;
}
}
for
(
int
i
=
0
;
i
<
num_topk_ranks
;
++
i
)
dst_send_buffers
[
i
]
=
reinterpret_cast
<
float
*>
(
dst_send_buffers
[
i
])
+
num_scales
;
//
复制
`topk_idx`
和
`topk_weights`
到对称发送缓冲区
//
Copy
`topk_idx`
and
`topk_weights`
into symmetric send buffer
for
(
int
i
=
lane_id
;
i
<
num_topk
*
num_topk_ranks
;
i
+=
kWarpSize
)
{
for
(
int
i
=
lane_id
;
i
<
num_topk
*
num_topk_ranks
;
i
+=
kWarpSize
)
{
auto
rank_idx
=
i
/
num_topk
,
copy_idx
=
i
%
num_topk
;
auto
rank_idx
=
i
/
num_topk
,
copy_idx
=
i
%
num_topk
;
auto
idx_value
=
static_cast
<
int
>
(
ld_nc_global
(
topk_idx
+
token_idx
*
num_topk
+
copy_idx
));
auto
idx_value
=
static_cast
<
int
>
(
ld_nc_global
(
topk_idx
+
token_idx
*
num_topk
+
copy_idx
));
auto
weight_value
=
ld_nc_global
(
topk_weights
+
token_idx
*
num_topk
+
copy_idx
);
auto
weight_value
=
ld_nc_global
(
topk_weights
+
token_idx
*
num_topk
+
copy_idx
);
st_na_global
(
reinterpret_cast
<
int
*>
(
dst_send_buffers
[
rank_idx
])
+
copy_idx
,
idx_value
);
st_na_global
(
reinterpret_cast
<
int
*>
(
dst_send_buffers
[
rank_idx
])
+
copy_idx
,
st_na_global
(
reinterpret_cast
<
float
*>
(
dst_send_buffers
[
rank_idx
])
+
num_topk
+
copy_idx
,
weight_value
);
idx_value
);
st_na_global
(
reinterpret_cast
<
float
*>
(
dst_send_buffers
[
rank_idx
])
+
num_topk
+
copy_idx
,
weight_value
);
}
}
}
}
// 结尾部分
// Epilogue
// 获取顺序锁
// Acquire sequential lock
while
(
lane_id
==
0
&&
rdma_send_next_token_idx
!=
token_idx
)
{
while
(
lane_id
==
0
and
rdma_send_next_token_idx
!=
token_idx
)
// 等待
;
}
syncwarp
();
syncwarp
();
//
更新最后一个令牌尾部
//
Update last token tail
if
(
last_rdma_tail_idx
>=
0
)
{
if
(
last_rdma_tail_idx
>=
0
)
st_release_cta
(
const_cast
<
int
*>
(
rdma_send_channel_tail
+
lane_id
),
last_rdma_tail_idx
+
1
);
st_release_cta
(
const_cast
<
const
int
*>
(
rdma_send_channel_tail
+
lane_id
),
}
last_rdma_tail_idx
+
1
);
// 释放顺序锁
// Release sequential lock
if
(
lane_id
==
0
)
{
lane_id
==
0
?
(
rdma_send_next_token_idx
+=
1
)
:
0
;
rdma_send_next_token_idx
+=
1
;
}
else
if
(
warp_role
==
WarpRole
::
kRDMASenderCoordinator
)
{
}
// NOTES: in case of splitting the issued put at the end of the buffer
}
else
if
(
warp_role
==
WarpRole
::
kRDMASenderCoordinator
)
{
EP_DEVICE_ASSERT
(
/*
num_max_rdma_chunked_recv_tokens
%
num_max_rdma_chunked_send_tokens
==
0
);
这段代码的主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作。
它首先计算每个RDMA秩需要发送的令牌数,然后在所有RDMA秩之间循环,检查是否有令牌需要发送。
如果有,它将计算本次需要发出的令牌数,并发出相应的RDMA发送请求。
最后,它更新相关的尾部位置,以便下次循环时可以正确地计算需要发送的令牌数。
kRDMASenderCoordinator使用了同sm内存一致性(ld.acquire.cta.s32),
dushmem内存一致性(dushmem_fence)和原子操作(dushmemx_signal_op),减少硬同步,提升整体效率。
*/
if
(
warp_id
>
kNumDispatchRDMASenderWarps
)
{
return
;
}
// 确保最大接收令牌数可以被最大发送令牌数整除,以避免缓冲区分割问题
EP_DEVICE_ASSERT
(
num_max_rdma_chunked_recv_tokens
%
num_max_rdma_chunked_send_tokens
==
0
);
// 同步共享内存,确保所有线程在继续之前都达到了这一点
// Synchronize shared memory
// sync_rdma_sender_smem();
sync_rdma_sender_smem
();
__syncthreads
();
//
计算当前通道需要发送的令牌数
//
Get number of tokens to send for each RDMA rank
int
num_tokens_to_send
=
0
;
int
num_tokens_to_send
=
0
;
if
(
lane_id
<
kNumRDMARanks
)
{
if
(
lane_id
<
kNumRDMARanks
)
{
num_tokens_to_send
=
rdma_channel_prefix_matrix
[
lane_id
*
num_channels
+
channel_id
];
num_tokens_to_send
=
rdma_channel_prefix_matrix
[
lane_id
*
num_channels
+
channel_id
];
if
(
channel_id
>
0
)
if
(
channel_id
>
0
)
num_tokens_to_send
-=
rdma_channel_prefix_matrix
[
lane_id
*
num_channels
+
channel_id
-
1
];
num_tokens_to_send
-=
rdma_channel_prefix_matrix
[
lane_id
*
num_channels
+
channel_id
-
1
];
}
}
//
记录上次发出的尾部位置
//
Iterate all RDMA ranks
int
last_issued_tail
=
0
;
int
last_issued_tail
=
0
;
// 当有任何RDMA秩需要发送令牌时,继续循环
while
(
__any_sync
(
kFullWarpMask
,
num_tokens_to_send
>
0
))
{
while
(
__any_sync
(
kFullWarpMask
,
num_tokens_to_send
>
0
))
{
for
(
int
i
=
0
,
synced_num_tokens_to_send
;
i
<
kNumRDMARanks
;
++
i
)
{
for
(
int
i
=
0
,
synced_num_tokens_to_send
;
i
<
kNumRDMARanks
;
++
i
)
{
int
dst_rdma_rank
=
(
i
+
channel_id
)
%
kNumRDMARanks
;
// 计算目标RDMA秩
int
dst_rdma_rank
=
(
i
+
channel_id
)
%
kNumRDMARanks
;
// 获取同步后的需要发送的令牌数
synced_num_tokens_to_send
=
shfl_sync
(
num_tokens_to_send
,
dst_rdma_rank
);
synced_num_tokens_to_send
=
shfl_sync
(
num_tokens_to_send
,
dst_rdma_rank
);
if
(
synced_num_tokens_to_send
==
0
)
continue
;
if
(
synced_num_tokens_to_send
==
0
)
// Read progress
continue
;
// 如果没有令牌需要发送,则跳过
// 读取进度
auto
synced_last_issued_tail
=
shfl_sync
(
last_issued_tail
,
dst_rdma_rank
);
auto
synced_last_issued_tail
=
shfl_sync
(
last_issued_tail
,
dst_rdma_rank
);
auto
processed_tail
=
ld_acquire_cta
(
const_cast
<
const
int
*>
(
rdma_send_channel_tail
+
dst_rdma_rank
));
auto
processed_tail
=
auto
num_tokens_processed
=
proces
sed_
tail
-
synced_last_issued_tail
;
ld_acquire_cta
(
const_cast
<
const
int
*>
(
rdma_
se
n
d_
channel_tail
+
dst_rdma_rank
))
;
auto
num_tokens_processed
=
processed_tail
-
synced_last_issued_tail
;
// 如果处理的令牌数不等于需要发送的令牌数,并且处理的令牌数小于最大发送令牌数,则跳过
if
(
num_tokens_processed
!=
synced_num_tokens_to_send
and
if
(
num_tokens_processed
!=
synced_num_tokens_to_send
&&
num_tokens_processed
<
num_max_rdma_chunked_send_tokens
)
num_tokens_processed
<
num_max_rdma_chunked_send_tokens
)
continue
;
continue
;
//
计算本次需要发出的令牌数
//
Issue RDMA send
auto
num_tokens_to_issue
=
min
(
num_tokens_processed
,
num_max_rdma_chunked_send_tokens
);
auto
num_tokens_to_issue
=
EP_DEVICE_ASSERT
(
num_tokens_to_issue
>=
0
&&
num_tokens_to_issue
<=
synced_num_tokens_to_send
);
min
(
num_tokens_processed
,
num_max_rdma_chunked_send_tokens
);
EP_DEVICE_ASSERT
(
num_tokens_to_issue
>=
0
and
// 发出RDMA发送请求
num_tokens_to_issue
<=
synced_num_tokens_to_send
);
if
(
dst_rdma_rank
!=
rdma_rank
)
{
if
(
dst_rdma_rank
!=
rdma_rank
)
{
auto
dst_slot_idx
=
synced_last_issued_tail
%
num_max_rdma_chunked_recv_tokens
;
auto
dst_slot_idx
=
synced_last_issued_tail
%
num_max_rdma_chunked_recv_tokens
;
EP_DEVICE_ASSERT
(
dst_slot_idx
+
num_tokens_to_issue
<=
num_max_rdma_chunked_recv_tokens
);
EP_DEVICE_ASSERT
(
dst_slot_idx
+
num_tokens_to_issue
<=
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
num_max_rdma_chunked_recv_tokens
);
shmem_ctx_schar_put_nbi_warp
(
ctx
,
rocshmem
::
rocshmem_ctx_schar_put_nbi_wave
(
#else
ctx
,
shmemx_int8_put_nbi_warp
(
#endif
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
dst_slot_idx
*
num_bytes_per_rdma_token
,
dst_slot_idx
*
num_bytes_per_rdma_token
,
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
dst_slot_idx
*
num_bytes_per_rdma_token
,
dst_slot_idx
*
num_bytes_per_rdma_token
,
num_bytes_per_rdma_token
*
num_tokens_to_issue
,
num_bytes_per_rdma_token
*
num_tokens_to_issue
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
rocshmem
::
rocshmem_ctx_quiet
(
ctx
);
shmem_ctx_quiet
(
ctx
);
#else
shmem_fence
();
#endif
}
else
{
}
else
{
//
对于本地RDMA秩,使用较轻的内存屏障
//
Lighter fence for local RDMA rank
memory_fence
();
memory_fence
();
}
}
//
更新尾部位置
//
Update tails
syncwarp
();
syncwarp
();
if
(
lane_id
==
dst_rdma_rank
)
{
if
(
lane_id
==
dst_rdma_rank
)
{
last_issued_tail
+=
num_tokens_to_issue
;
last_issued_tail
+=
num_tokens_to_issue
;
num_tokens_to_send
-=
num_tokens_to_issue
;
num_tokens_to_send
-=
num_tokens_to_issue
;
// 更新远端rdma 己方已发送的token数,用于做发送信息同步。用于与kRDMAAndNVLForwarder互相通信
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
rocshmem
::
rocshmem_ctx_ulong_atomic_add
(
shmem_ctx_ulong_atomic_add
(
ctx
,
ctx
,
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_tokens_to_issue
,
#else
shmem_signal_op_add
(
#endif
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_tokens_to_issue
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
}
}
}
}
}
// while(__any(num_tokens_to_send > 0))
}
}
else
if
(
warp_role
==
WarpRole
::
kRDMAAndNVLForwarder
)
{
}
else
if
(
warp_role
==
WarpRole
::
kRDMAAndNVLForwarder
)
{
/*
// RDMA consumers and NVL producers
这段代码的主要功能是在一个CUDA内核中协调从RDMA消费者到NVL生产者的转发操作。
const
auto
dst_nvl_rank
=
target_rank
;
它首先计算目标NVL秩和目标秩,然后等待相关的计数器到达。
const
auto
dst_rank
=
rdma_rank
*
NUM_MAX_NVL_PEERS
+
dst_nvl_rank
;
接着,它检查目标队列是否为空,或者等待一个缓冲区被释放。
const
auto
dst_rank_expert_begin
=
dst_rank
*
(
num_experts
/
num_ranks
);
然后,它找到下一个源RDMA秩,并遍历RDMA缓冲区中的每一个令牌,复制相关的数据到NVL缓冲区。
const
auto
dst_rank_expert_end
=
dst_rank_expert_begin
+
(
num_experts
/
num_ranks
);
最后,它同步头部和尾部索引,并标记通道为退役状态。
*/
// Wait counters to arrive
// RDMA消费者和NVL生产者
const
auto
dst_nvl_rank
=
target_rank
;
// 目标NVL秩
const
auto
dst_rank
=
rdma_rank
*
NUM_MAX_NVL_PEERS
+
dst_nvl_rank
;
// 目标秩
const
auto
dst_rank_expert_begin
=
dst_rank
*
(
num_experts
/
num_ranks
);
// 目标秩专家开始
const
auto
dst_rank_expert_end
=
dst_rank_expert_begin
+
(
num_experts
/
num_ranks
);
// 目标秩专家结束
// 等待计数器到达
int
num_tokens_to_recv_from_rdma
=
0
,
src_rdma_channel_prefix
=
0
;
int
num_tokens_to_recv_from_rdma
=
0
,
src_rdma_channel_prefix
=
0
;
EP_DEVICE_ASSERT
(
kNumRDMARanks
<=
kWarpSize
);
EP_DEVICE_ASSERT
(
kNumRDMARanks
<=
kWarpSize
);
auto
start_time
=
wall_clock64
();
auto
start_time
=
wall_clock64
();
if
(
lane_id
<
kNumRDMARanks
)
{
if
(
lane_id
<
kNumRDMARanks
)
{
while
(
true
)
{
while
(
true
)
{
// 对应于kRDMASender中的数据写入
auto
meta_0
=
auto
meta_0
=
ld_volatile_global
(
rdma_channel_meta
.
recv_buffer
(
lane_id
)
+
dst_nvl_rank
);
// 是nvl节点的起始地址
ld_volatile_global
(
rdma_channel_meta
.
recv_buffer
(
lane_id
)
+
dst_nvl_rank
);
auto
meta_1
=
ld_volatile_global
(
rdma_channel_meta
.
recv_buffer
(
lane_id
)
+
NUM_MAX_NVL_PEERS
+
dst_nvl_rank
);
// nvl节点的结束地址
auto
meta_1
=
ld_volatile_global
(
rdma_channel_meta
.
recv_buffer
(
lane_id
)
+
auto
meta_2
=
ld_volatile_global
(
rdma_channel_meta
.
recv_buffer
(
lane_id
)
+
NUM_MAX_NVL_PEERS
*
2
);
// 本rdma节点的起始地址
NUM_MAX_NVL_PEERS
+
dst_nvl_rank
);
auto
meta_3
=
ld_volatile_global
(
rdma_channel_meta
.
recv_buffer
(
lane_id
)
+
NUM_MAX_NVL_PEERS
*
2
+
1
);
// 本节点的结束地址
auto
meta_2
=
ld_volatile_global
(
rdma_channel_meta
.
recv_buffer
(
lane_id
)
+
if
(
meta_0
<
0
&&
meta_1
<
0
&&
meta_2
<
0
&&
meta_3
<
0
)
{
NUM_MAX_NVL_PEERS
*
2
);
// 通知NVL秩
auto
meta_3
=
ld_volatile_global
(
rdma_channel_meta
.
recv_buffer
(
lane_id
)
+
NUM_MAX_NVL_PEERS
*
2
+
1
);
if
(
meta_0
<
0
and
meta_1
<
0
and
meta_2
<
0
and
meta_3
<
0
)
{
// Notify NVL ranks
int
start_sum
=
-
meta_0
-
1
,
end_sum
=
-
meta_1
-
1
;
int
start_sum
=
-
meta_0
-
1
,
end_sum
=
-
meta_1
-
1
;
EP_DEVICE_ASSERT
(
start_sum
>=
0
&&
end_sum
>=
0
&&
end_sum
>=
start_sum
);
EP_DEVICE_ASSERT
(
start_sum
>=
0
and
end_sum
>=
0
and
end_sum
>=
start_sum
);
st_relaxed_sys_global
(
nvl_channel_prefix_start
.
buffer
()
+
lane_id
,
-
start_sum
-
1
);
st_relaxed_sys_global
(
nvl_channel_prefix_start
.
buffer
()
+
lane_id
,
-
start_sum
-
1
);
st_relaxed_sys_global
(
nvl_channel_prefix_end
.
buffer
()
+
lane_id
,
-
end_sum
-
1
);
st_relaxed_sys_global
(
nvl_channel_prefix_end
.
buffer
()
+
lane_id
,
-
end_sum
-
1
);
//
保存从RDMA通道接收的令牌计数
//
Save RDMA channel received token count
src_rdma_channel_prefix
=
-
meta_2
-
1
;
src_rdma_channel_prefix
=
-
meta_2
-
1
;
auto
src_rdma_channel_prefix_1
=
-
meta_3
-
1
;
auto
src_rdma_channel_prefix_1
=
-
meta_3
-
1
;
num_tokens_to_recv_from_rdma
=
src_rdma_channel_prefix_1
-
src_rdma_channel_prefix
;
// 是远端 rdma_rank 会发送给当前节点的token数量
num_tokens_to_recv_from_rdma
=
if
(
!
kCachedMode
)
src_rdma_channel_prefix_1
-
src_rdma_channel_prefix
;
recv_rdma_channel_prefix_matrix
[
lane_id
*
num_channels
+
channel_id
]
=
src_rdma_channel_prefix_1
;
if
(
not
kCachedMode
)
recv_rdma_channel_prefix_matrix
[
lane_id
*
num_channels
+
channel_id
]
=
src_rdma_channel_prefix
+=
lane_id
==
0
?
0
:
recv_rdma_rank_prefix_sum
[
lane_id
-
1
];
// 对应的远端 rdma_rank 的起始index, 存在线程0之中
src_rdma_channel_prefix_1
;
src_rdma_channel_prefix
+=
lane_id
==
0
?
0
:
recv_rdma_rank_prefix_sum
[
lane_id
-
1
];
EP_DEVICE_ASSERT
(
num_tokens_to_recv_from_rdma
>=
0
);
EP_DEVICE_ASSERT
(
num_tokens_to_recv_from_rdma
>=
0
);
break
;
break
;
}
}
// 超时检查
// Timeout check
if
(
wall_clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
)
{
long
long
int
elapsed_time
=
printf
(
"DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, nvl: %d, src RDMA lane: %d, dst NVL: %d, meta: %d, %d, %d, %d
\n
"
,
wall_clock64
()
>
start_time
?
wall_clock64
()
-
start_time
:
0
;
channel_id
,
rdma_rank
,
nvl_rank
,
lane_id
,
dst_nvl_rank
,
meta_0
,
meta_1
,
meta_2
,
meta_3
);
if
(
elapsed_time
>
NUM_TIMEOUT_CYCLES
)
{
printf
(
"DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, "
"nvl: %d, src RDMA lane: %d, dst NVL: %d, meta: %d, %d, %d, %d
\n
"
,
channel_id
,
rdma_rank
,
nvl_rank
,
lane_id
,
dst_nvl_rank
,
meta_0
,
meta_1
,
meta_2
,
meta_3
);
trap
();
trap
();
}
}
}
}
}
}
syncwarp
();
syncwarp
();
// Shift cached head
// 移动缓存的头部
send_nvl_head
+=
src_rdma_channel_prefix
*
NUM_MAX_NVL_PEERS
+
dst_nvl_rank
;
send_nvl_head
+=
src_rdma_channel_prefix
*
NUM_MAX_NVL_PEERS
+
dst_nvl_rank
;
// 等待共享内存被清理
// Wait shared memory to be cleaned
// sync_forwarder_smem();
sync_forwarder_smem
();
__syncthreads
();
// 开始准备处理接受数据,直到所有的数据接受完成。
// Forward tokens from RDMA buffer
// 转发从RDMA缓冲区的令牌
// NOTES: always start from the local rank
// 注意:总是从本地秩开始
int
src_rdma_rank
=
sm_id
%
kNumRDMARanks
;
int
src_rdma_rank
=
sm_id
%
kNumRDMARanks
;
int
cached_rdma_channel_head
=
0
,
cached_rdma_channel_tail
=
0
;
int
cached_rdma_channel_head
=
0
,
cached_rdma_channel_tail
=
0
;
int
cached_nvl_channel_head
=
0
,
cached_nvl_channel_tail
=
0
,
rdma_nvl_token_idx
=
0
;
int
cached_nvl_channel_head
=
0
,
cached_nvl_channel_tail
=
0
,
rdma_nvl_token_idx
=
0
;
while
(
__any_sync
(
kFullWarpMask
,
num_tokens_to_recv_from_rdma
>
0
))
{
while
(
__any_sync
(
kFullWarpMask
,
num_tokens_to_recv_from_rdma
>
0
))
{
//
检查nvl目标队列是否为空,或者等待一个缓冲区被释放
//
Check destination queue emptiness, or wait a buffer to be released
start_time
=
wall_clock64
();
start_time
=
wall_clock64
();
while
(
lane_id
==
0
)
{
// 用于给kNVLReceivers进行互动,控制数据的传输速度
while
(
lane_id
==
0
)
{
int
num_used_slots
=
cached_nvl_channel_tail
-
cached_nvl_channel_head
;
int
num_used_slots
=
cached_nvl_channel_tail
-
cached_nvl_channel_head
;
if
(
num_max_nvl_chunked_recv_tokens
-
num_used_slots
>=
num_max_nvl_chunked_send_tokens
)
if
(
num_max_nvl_chunked_recv_tokens
-
num_used_slots
>=
num_max_nvl_chunked_send_tokens
)
break
;
break
;
cached_nvl_channel_head
=
ld_volatile_global
(
nvl_channel_head
.
buffer
());
cached_nvl_channel_head
=
ld_volatile_global
(
nvl_channel_head
.
buffer
());
// 超时检查
// Timeout check
if
(
wall_clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
)
{
long
long
int
elapsed_time
=
printf
(
"DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, head: %d, tail: %d
\n
"
,
wall_clock64
()
>
start_time
?
wall_clock64
()
-
start_time
:
0
;
channel_id
,
rdma_rank
,
nvl_rank
,
dst_nvl_rank
,
ld_volatile_global
(
nvl_channel_head
.
buffer
()),
cached_nvl_channel_tail
);
if
(
elapsed_time
>
NUM_TIMEOUT_CYCLES
)
{
printf
(
"DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, "
"nvl: %d, dst NVL: %d, head: %d, tail: %d
\n
"
,
channel_id
,
rdma_rank
,
nvl_rank
,
dst_nvl_rank
,
ld_volatile_global
(
nvl_channel_head
.
buffer
()),
cached_nvl_channel_tail
);
trap
();
trap
();
}
}
}
}
syncwarp
();
syncwarp
();
//
找到下一个源RDMA秩(轮询)
//
Find next source RDMA rank (round-robin)
start_time
=
wall_clock64
();
start_time
=
wall_clock64
();
while
(
true
)
{
while
(
true
)
{
src_rdma_rank
=
(
src_rdma_rank
+
1
)
%
kNumRDMARanks
;
src_rdma_rank
=
(
src_rdma_rank
+
1
)
%
kNumRDMARanks
;
if
(
shfl_sync
(
num_tokens_to_recv_from_rdma
,
src_rdma_rank
)
>
0
)
{
if
(
shfl_sync
(
num_tokens_to_recv_from_rdma
,
src_rdma_rank
)
>
0
)
{
if
(
lane_id
==
src_rdma_rank
&&
cached_rdma_channel_head
==
cached_rdma_channel_tail
)
if
(
lane_id
==
src_rdma_rank
and
cached_rdma_channel_tail
=
static_cast
<
int
>
(
ld_acquire_sys_global
(
rdma_channel_tail
.
buffer
(
src_rdma_rank
)));
cached_rdma_channel_head
==
cached_rdma_channel_tail
)
cached_rdma_channel_tail
=
static_cast
<
int
>
(
if
(
shfl_sync
(
cached_rdma_channel_tail
>
cached_rdma_channel_head
,
src_rdma_rank
))
{
ld_relaxed_sys_global
(
rdma_channel_tail
.
buffer
(
src_rdma_rank
)));
if
(
shfl_sync
(
cached_rdma_channel_tail
>
cached_rdma_channel_head
,
src_rdma_rank
))
break
;
break
;
}
}
}
// 超时检查
// Timeout check
if
(
wall_clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
and
lane_id
<
kNumRDMARanks
)
{
long
long
int
elapsed_time
=
printf
(
"DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, src RDMA lane: %d, head: %d, tail: %d, expected: %d
\n
"
,
wall_clock64
()
>
start_time
?
wall_clock64
()
-
start_time
:
0
;
channel_id
,
rdma_rank
,
nvl_rank
,
dst_nvl_rank
,
lane_id
,
cached_rdma_channel_head
,
cached_rdma_channel_tail
,
num_tokens_to_recv_from_rdma
);
if
(
elapsed_time
>
NUM_TIMEOUT_CYCLES
and
lane_id
<
kNumRDMARanks
)
{
printf
(
"DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, "
"nvl: %d, dst NVL: %d, src RDMA lane: %d, head: %d, tail: %d, expected: "
"%d
\n
"
,
channel_id
,
rdma_rank
,
nvl_rank
,
dst_nvl_rank
,
lane_id
,
cached_rdma_channel_head
,
cached_rdma_channel_tail
,
num_tokens_to_recv_from_rdma
);
trap
();
trap
();
}
}
}
}
auto
src_rdma_head
=
shfl_sync
(
cached_rdma_channel_head
,
src_rdma_rank
);
auto
src_rdma_head
=
shfl_sync
(
cached_rdma_channel_head
,
src_rdma_rank
);
auto
src_rdma_tail
=
shfl_sync
(
cached_rdma_channel_tail
,
src_rdma_rank
);
auto
src_rdma_tail
=
shfl_sync
(
cached_rdma_channel_tail
,
src_rdma_rank
);
// 遍历RDMA缓冲区中的每一个令牌
// Iterate over every token from the RDMA buffer
for
(
int
i
=
src_rdma_head
,
num_tokens_sent
=
0
;
i
<
src_rdma_tail
;
++
i
)
{
for
(
int
i
=
src_rdma_head
,
num_tokens_sent
=
0
;
i
<
src_rdma_tail
;
++
i
)
{
auto
rdma_slot_idx
=
i
%
num_max_rdma_chunked_recv_tokens
;
auto
rdma_slot_idx
=
i
%
num_max_rdma_chunked_recv_tokens
;
// 首先读取SourceMeta,对应到kRDMASenderCoordinator中 kRDMASender 的数据远程写入
void
*
shifted
=
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
void
*
shifted
=
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
;
rdma_slot_idx
*
num_bytes_per_rdma_token
;
auto
src_meta
=
ld_nc_global
(
reinterpret_cast
<
SourceMeta
*>
(
reinterpret_cast
<
int8_t
*>
(
shifted
)));
auto
src_meta
=
ld_nc_global
(
reinterpret_cast
<
SourceMeta
*>
(
if
(
lane_id
==
src_rdma_rank
)
{
reinterpret_cast
<
int8_t
*>
(
shifted
)
+
hidden_bytes
));
num_tokens_to_recv_from_rdma
-=
1
;
lane_id
==
src_rdma_rank
?
(
num_tokens_to_recv_from_rdma
-=
1
)
:
0
;
}
bool
is_in_dst_nvl_rank
=
src_meta
.
is_token_in_nvl_rank
(
dst_nvl_rank
);
bool
is_in_dst_nvl_rank
=
src_meta
.
is_token_in_nvl_rank
(
dst_nvl_rank
);
if
(
lane_id
==
src_rdma_rank
)
{
if
(
lane_id
==
src_rdma_rank
)
{
auto
cached_head
=
is_in_dst_nvl_rank
?
rdma_nvl_token_idx
:
-
1
;
auto
cached_head
=
is_in_dst_nvl_rank
?
rdma_nvl_token_idx
:
-
1
;
rdma_nvl_token_idx
+=
is_in_dst_nvl_rank
;
rdma_nvl_token_idx
+=
is_in_dst_nvl_rank
;
if
(
!
kCachedMode
)
if
(
not
kCachedMode
)
send_nvl_head
[
i
*
NUM_MAX_NVL_PEERS
]
=
cached_head
;
send_nvl_head
[
i
*
NUM_MAX_NVL_PEERS
]
=
cached_head
;
}
}
if
(
not
is_in_dst_nvl_rank
)
if
(
!
is_in_dst_nvl_rank
)
continue
;
continue
;
//
获取一个空闲槽位
//
Get an empty slot
int
dst_slot_idx
=
(
cached_nvl_channel_tail
++
)
%
num_max_nvl_chunked_recv_tokens
;
int
dst_slot_idx
=
(
cached_nvl_channel_tail
++
)
%
num_max_nvl_chunked_recv_tokens
;
// 设置 src和dst 位置
// Copy data
auto
src_gpu_buffer_x
=
reinterpret_cast
<
int4
*>
(
reinterpret_cast
<
int8_t
*>
(
shifted
)
+
sizeof
(
SourceMeta
));
auto
src_gpu_buffer_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
int8_t
*>
(
src_gpu_buffer_x
)
+
hidden_bytes
);
auto
src_gpu_buffer_topk_idx
=
reinterpret_cast
<
int
*>
(
reinterpret_cast
<
int8_t
*>
(
src_gpu_buffer_scales
)
+
num_scales
*
sizeof
(
float
));
auto
src_gpu_buffer_topk_weights
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
int8_t
*>
(
src_gpu_buffer_topk_idx
)
+
num_topk
*
sizeof
(
int
));
auto
dst_gpu_buffer_x
=
nvl_channel_x
.
buffer
()
+
dst_slot_idx
*
hidden_int4
;
auto
dst_gpu_buffer_scales
=
nvl_channel_x_scales
.
buffer
()
+
dst_slot_idx
*
num_scales
;
auto
dst_gpu_buffer_topk_idx
=
nvl_channel_topk_idx
.
buffer
()
+
dst_slot_idx
*
num_topk
;
auto
dst_gpu_buffer_topk_weights
=
nvl_channel_topk_weights
.
buffer
()
+
dst_slot_idx
*
num_topk
;
if
(
lane_id
==
0
)
{
st_na_global
(
reinterpret_cast
<
int64_t
*>
(
nvl_channel_src_meta
.
buffer
()
+
dst_slot_idx
),
*
reinterpret_cast
<
int64_t
*>
(
&
src_meta
));
}
UNROLLED_WARP_COPY
(
5
,
lane_id
,
hidden_int4
,
UNROLLED_WARP_COPY
(
5
,
lane_id
,
hidden_int4
,
dst_gpu_buffer_x
,
nvl_channel_x
.
buffer
()
+
dst_slot_idx
*
hidden_int4
,
src_gpu_buffer_x
,
reinterpret_cast
<
int4
*>
(
shifted
),
ld_nc_global
,
st_na_global
);
ld_direct_global
,
st_na_global
);
shifted
=
reinterpret_cast
<
int4
*>
(
shifted
)
+
hidden_int4
;
// Copy source meta
if
(
lane_id
==
0
)
st_na_global
(
nvl_channel_src_meta
.
buffer
()
+
dst_slot_idx
,
src_meta
);
shifted
=
reinterpret_cast
<
SourceMeta
*>
(
shifted
)
+
1
;
// Copy `x_scales`
UNROLLED_WARP_COPY
(
1
,
lane_id
,
num_scales
,
UNROLLED_WARP_COPY
(
1
,
lane_id
,
num_scales
,
dst_gpu_buffer_scales
,
nvl_channel_x_scales
.
buffer
()
+
dst_slot_idx
*
num_scales
,
src_gpu_buffer_scales
,
reinterpret_cast
<
float
*>
(
shifted
),
ld_nc_global
,
st_na_global
);
ld_direct_global
,
st_na_global
);
shifted
=
reinterpret_cast
<
float
*>
(
shifted
)
+
num_scales
;
for
(
int
t
=
lane_id
;
t
<
num_topk
;
t
+=
kWarpSize
)
{
// Copy `topk_idx` and `topk_weights`
int
idx_val
=
ld_direct_global
(
reinterpret_cast
<
int
*>
(
src_gpu_buffer_topk_idx
)
+
t
);
// NOTES: do not use `shifted` after this `if`, because only several lanes are
float
w_val
=
ld_direct_global
(
reinterpret_cast
<
float
*>
(
src_gpu_buffer_topk_weights
)
+
t
);
// shifted
int
new_idx
=
(
idx_val
>=
dst_rank_expert_begin
&&
idx_val
<
dst_rank_expert_end
)
if
(
lane_id
<
num_topk
)
{
?
(
idx_val
-
dst_rank_expert_begin
)
:
-
1
;
// Read
float
new_w
=
(
new_idx
!=
-
1
)
?
w_val
:
0.0
f
;
auto
idx_value
=
ld_nc_global
(
reinterpret_cast
<
int
*>
(
shifted
)
+
lane_id
);
dst_gpu_buffer_topk_idx
[
t
]
=
new_idx
;
shifted
=
reinterpret_cast
<
int
*>
(
shifted
)
+
num_topk
;
dst_gpu_buffer_topk_weights
[
t
]
=
new_w
;
auto
weight_value
=
ld_nc_global
(
reinterpret_cast
<
float
*>
(
shifted
)
+
lane_id
);
// Transform and write
idx_value
=
(
idx_value
>=
dst_rank_expert_begin
and
idx_value
<
dst_rank_expert_end
)
?
idx_value
-
dst_rank_expert_begin
:
-
1
;
st_na_global
(
nvl_channel_topk_idx
.
buffer
()
+
dst_slot_idx
*
num_topk
+
lane_id
,
idx_value
);
weight_value
=
idx_value
>=
0
?
weight_value
:
0.0
f
;
st_na_global
(
nvl_channel_topk_weights
.
buffer
()
+
dst_slot_idx
*
num_topk
+
lane_id
,
weight_value
);
}
}
//
在NVL缓冲区不足的情况下,提前停止
//
In case of insufficient NVL buffers, early stopping
if
((
++
num_tokens_sent
)
==
num_max_nvl_chunked_send_tokens
)
if
((
++
num_tokens_sent
)
==
num_max_nvl_chunked_send_tokens
)
src_rdma_tail
=
i
+
1
;
src_rdma_tail
=
i
+
1
;
}
}
// 同步头部索引
// Sync head index
if
(
lane_id
==
src_rdma_rank
)
if
(
lane_id
==
src_rdma_rank
)
forward_channel_head
[
dst_nvl_rank
][
src_rdma_rank
]
=
(
cached_rdma_channel_head
=
src_rdma_tail
);
forward_channel_head
[
dst_nvl_rank
][
src_rdma_rank
]
=
(
cached_rdma_channel_head
=
src_rdma_tail
);
//
移动尾部索引,与kNVLReceivers互相通信使用
//
Move tail index
syncwarp
();
syncwarp
();
if
(
lane_id
==
0
)
{
if
(
lane_id
==
0
)
st_release_sys_global
(
nvl_channel_tail
.
buffer
(),
cached_nvl_channel_tail
);
st_relaxed_sys_global
(
nvl_channel_tail
.
buffer
(),
cached_nvl_channel_tail
);
}
}
}
// Retired
// Retired
syncwarp
();
syncwarp
();
if
(
lane_id
==
0
)
{
if
(
lane_id
==
0
)
forward_channel_retired
[
dst_nvl_rank
]
=
true
;
forward_channel_retired
[
dst_nvl_rank
]
=
true
;
}
}
else
if
(
warp_role
==
WarpRole
::
kForwarderCoordinator
)
{
}
else
if
(
warp_role
==
WarpRole
::
kForwarderCoordinator
)
{
/*
这段代码的主要功能是在一个CUDA内核中协调转发器的逻辑。
它首先检查当前warp是否是额外的转发器协调warp,如果是,则直接退出。
然后,它清理共享内存,并初始化转发通道的头部和退役状态。
接着,它进入一个无限循环,在循环中,它找到最小的头部,如果所有的通道都已退役,则退出循环。
否则,它更新远程头部,并进行纳秒级睡眠,以让其他warp工作。
*/
// Extra warps for forwarder coordinator should exit directly
// Extra warps for forwarder coordinator should exit directly
if
(
w
ar
p_id
>
NUM_MAX_NVL_PEERS
)
if
(
t
ar
get_rank
>
0
)
return
;
return
;
//
转发warp协调器
//
Forward warp coordinator
EP_STATIC_ASSERT
(
kNumRDMARanks
<=
kWarpSize
,
"
无效的RDMA对等体数量
"
);
EP_STATIC_ASSERT
(
kNumRDMARanks
<=
kWarpSize
,
"
Invalid number of RDMA peers
"
);
// 清理共享内存
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
<=
kWarpSize
,
"无效的NVL对等体数量"
);
// Clean shared memory
#pragma unroll
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
<=
kWarpSize
,
"Invalid number of NVL peers"
);
for
(
int
i
=
lane_id
;
i
<
kNumRDMARanks
*
NUM_MAX_NVL_PEERS
;
i
+=
kWarpSize
)
for
(
int
i
=
lane_id
;
i
<
kNumRDMARanks
*
NUM_MAX_NVL_PEERS
;
i
+=
kWarpSize
)
forward_channel_head
[
i
%
NUM_MAX_NVL_PEERS
][
i
/
NUM_MAX_NVL_PEERS
]
=
0
;
forward_channel_head
[
i
%
NUM_MAX_NVL_PEERS
][
i
/
NUM_MAX_NVL_PEERS
]
=
0
;
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
forward_channel_retired
[
lane_id
]
=
false
;
forward_channel_retired
[
lane_id
]
=
false
;
// sync_forwarder_smem();
sync_forwarder_smem
();
__syncthreads
();
int
last_head
=
0
,
target_rdma
=
lane_id
<
kNumRDMARanks
?
lane_id
:
0
;
int
last_head
=
0
,
target_rdma
=
lane_id
<
kNumRDMARanks
?
lane_id
:
0
;
while
(
true
)
{
while
(
true
)
{
// Find minimum head
// 找到最小的头部
int
min_head
=
std
::
numeric_limits
<
int
>::
max
();
int
min_head
=
std
::
numeric_limits
<
int
>::
max
();
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_MAX_NVL_PEERS
;
++
i
)
for
(
int
i
=
0
;
i
<
NUM_MAX_NVL_PEERS
;
++
i
)
if
(
!
forward_channel_retired
[
i
])
if
(
not
forward_channel_retired
[
i
])
min_head
=
min
(
min_head
,
forward_channel_head
[
i
][
target_rdma
]);
min_head
=
min
(
min_head
,
forward_channel_head
[
i
][
target_rdma
]);
if
(
__all_sync
(
kFullWarpMask
,
min_head
==
std
::
numeric_limits
<
int
>::
max
()))
if
(
__all_sync
(
kFullWarpMask
,
min_head
==
std
::
numeric_limits
<
int
>::
max
()))
{
break
;
break
;
}
// 更新远程头部
// Update remote head
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
&&
min_head
>=
last_head
+
num_max_rdma_chunked_send_tokens
&&
lane_id
<
kNumRDMARanks
){
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
and
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
min_head
>=
last_head
+
num_max_rdma_chunked_send_tokens
and
shmem_ctx_ulong_atomic_add
(
ctx
,
lane_id
<
kNumRDMARanks
)
{
#else
rocshmem
::
rocshmem_ctx_ulong_atomic_add
(
shmem_signal_op_add
(
ctx
,
rdma_channel_head
.
buffer
(
rdma_rank
),
min_head
-
last_head
,
#endif
rdma_channel_head
.
buffer
(
rdma_rank
),
min_head
-
last_head
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
lane_id
,
nvl_rank
));
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
lane_id
,
nvl_rank
));
last_head
=
min_head
;
last_head
=
min_head
;
}
}
//
纳秒级睡眠并让其他warp工作 //
Nanosleep and let other warps work
// Nanosleep and let other warps work
__builtin_amdgcn_s_sleep
(
NUM_WAIT_CYCLES_TIMES_64
);
__builtin_amdgcn_s_sleep
(
NUM_WAIT_CYCLES_TIMES_64
);
}
}
}
else
if
(
warp_role
==
WarpRole
::
kNVLReceivers
)
{
}
else
{
if
(
warp_id
>=
NUM_MAX_NVL_PEERS
)
{
// NVL consumers
return
;
// Retrieve rank offset from barrier results (each lane's register stores an RDMA rank)
}
// Place the main logic of your kernel here, using the parameters above.
// NVL消费者
// 从屏障结果中检索秩偏移(每个通道的寄存器存储一个RDMA秩)
int
src_nvl_rank
=
target_rank
,
total_offset
=
0
;
int
src_nvl_rank
=
target_rank
,
total_offset
=
0
;
EP_STATIC_ASSERT
(
kNumRDMARanks
<=
kWarpSize
,
"
无效的RDMA对等体数量
"
);
EP_STATIC_ASSERT
(
kNumRDMARanks
<=
kWarpSize
,
"
Invalid number of RDMA peers
"
);
if
(
lane_id
<
kNumRDMARanks
&&
lane_id
*
NUM_MAX_NVL_PEERS
+
src_nvl_rank
>
0
)
if
(
lane_id
<
kNumRDMARanks
and
lane_id
*
NUM_MAX_NVL_PEERS
+
src_nvl_rank
>
0
)
total_offset
=
recv_gbl_rank_prefix_sum
[
lane_id
*
NUM_MAX_NVL_PEERS
+
src_nvl_rank
-
1
];
total_offset
=
recv_gbl_rank_prefix_sum
[
lane_id
*
NUM_MAX_NVL_PEERS
+
src_nvl_rank
-
1
];
//
接收通道偏移
//
Receive channel offsets
int
start_offset
=
0
,
end_offset
=
0
,
num_tokens_to_recv
;
int
start_offset
=
0
,
end_offset
=
0
,
num_tokens_to_recv
;
auto
start_time
=
wall_clock64
();
auto
start_time
=
wall_clock64
();
while
(
lane_id
<
kNumRDMARanks
)
{
while
(
lane_id
<
kNumRDMARanks
)
{
start_offset
=
ld_volatile_global
(
nvl_channel_prefix_start
.
buffer
()
+
lane_id
);
start_offset
=
ld_volatile_global
(
nvl_channel_prefix_start
.
buffer
()
+
lane_id
);
end_offset
=
ld_volatile_global
(
nvl_channel_prefix_end
.
buffer
()
+
lane_id
);
end_offset
=
ld_volatile_global
(
nvl_channel_prefix_end
.
buffer
()
+
lane_id
);
if
(
start_offset
<
0
&&
end_offset
<
0
)
{
if
(
start_offset
<
0
and
end_offset
<
0
)
{
start_offset
=
-
start_offset
-
1
,
end_offset
=
-
end_offset
-
1
;
start_offset
=
-
start_offset
-
1
,
end_offset
=
-
end_offset
-
1
;
total_offset
+=
start_offset
;
total_offset
+=
start_offset
;
break
;
break
;
}
}
// 超时检查
if
(
wall_clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
)
{
// Timeout check
printf
(
"DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, src nvl: %d, start: %d, end: %d
\n
"
,
long
long
int
elapsed_time
=
channel_id
,
rdma_rank
,
nvl_rank
,
lane_id
,
src_nvl_rank
,
start_offset
,
end_offset
);
wall_clock64
()
>
start_time
?
wall_clock64
()
-
start_time
:
0
;
if
(
elapsed_time
>
NUM_TIMEOUT_CYCLES
)
{
printf
(
"DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src "
"RDMA: %d, src nvl: %d, start: %d, end: %d
\n
"
,
channel_id
,
rdma_rank
,
nvl_rank
,
lane_id
,
src_nvl_rank
,
start_offset
,
end_offset
);
trap
();
trap
();
}
}
}
}
num_tokens_to_recv
=
warp_reduce_sum
(
end_offset
-
start_offset
);
num_tokens_to_recv
=
warp_reduce_sum
(
end_offset
-
start_offset
);
// 保存以供合并使用
// Save for combine usage
if
(
lane_id
<
kNumRDMARanks
&&
!
kCachedMode
)
if
(
lane_id
<
kNumRDMARanks
and
not
kCachedMode
)
recv_gbl_channel_prefix_matrix
[(
lane_id
*
NUM_MAX_NVL_PEERS
+
src_nvl_rank
)
*
num_channels
+
channel_id
]
=
total_offset
;
recv_gbl_channel_prefix_matrix
[(
lane_id
*
NUM_MAX_NVL_PEERS
+
src_nvl_rank
)
*
num_channels
+
channel_id
]
=
total_offset
;
syncwarp
();
syncwarp
();
int
cached_channel_head_idx
=
0
,
cached_channel_tail_idx
=
0
;
int
cached_channel_head_idx
=
0
,
cached_channel_tail_idx
=
0
;
while
(
num_tokens_to_recv
>
0
)
{
while
(
num_tokens_to_recv
>
0
)
{
//
通过通道0检查通道状态
//
Check channel status by lane 0
start_time
=
wall_clock64
();
start_time
=
wall_clock64
();
while
(
lane_id
==
0
)
{
while
(
lane_id
==
0
)
{
//
准备复制
//
Ready to copy
if
(
cached_channel_head_idx
!=
cached_channel_tail_idx
)
if
(
cached_channel_head_idx
!=
cached_channel_tail_idx
)
break
;
break
;
cached_channel_tail_idx
=
ld_acquire_sys_global
(
nvl_channel_tail
.
buffer
());
cached_channel_tail_idx
=
ld_relaxed_sys_global
(
nvl_channel_tail
.
buffer
());
// 超时检查
if
(
wall_clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
)
{
// Timeout check
printf
(
"DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d
\n
"
,
long
long
int
elapsed_time
=
channel_id
,
rdma_rank
,
nvl_rank
,
src_nvl_rank
,
cached_channel_head_idx
,
cached_channel_tail_idx
);
wall_clock64
()
>
start_time
?
wall_clock64
()
-
start_time
:
0
;
if
(
elapsed_time
>
NUM_TIMEOUT_CYCLES
)
{
printf
(
"DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, "
"src NVL: %d, head: %d, tail: %d
\n
"
,
channel_id
,
rdma_rank
,
nvl_rank
,
src_nvl_rank
,
cached_channel_head_idx
,
cached_channel_tail_idx
);
trap
();
trap
();
}
}
}
}
//
同步队列尾部
//
Sync queue tail
cached_channel_tail_idx
=
shfl_sync
(
cached_channel_tail_idx
,
0
);
cached_channel_tail_idx
=
shfl_sync
(
cached_channel_tail_idx
,
0
);
//
复制数据
//
Copy data
int
num_recv_tokens
=
cached_channel_tail_idx
-
cached_channel_head_idx
;
int
num_recv_tokens
=
cached_channel_tail_idx
-
cached_channel_head_idx
;
for
(
int
chunk_idx
=
0
;
chunk_idx
<
num_recv_tokens
;
++
chunk_idx
,
--
num_tokens_to_recv
)
{
for
(
int
chunk_idx
=
0
;
chunk_idx
<
num_recv_tokens
;
int
token_idx_in_buffer
=
(
cached_channel_head_idx
++
)
%
num_max_nvl_chunked_recv_tokens
;
++
chunk_idx
,
--
num_tokens_to_recv
)
{
auto
meta
=
ld_nc_global
(
nvl_channel_src_meta
.
buffer
()
+
token_idx_in_buffer
);
int
token_idx_in_buffer
=
int64_t
recv_token_idx
=
shfl_sync
(
total_offset
,
meta
.
src_rdma_rank
);
(
cached_channel_head_idx
++
)
%
num_max_nvl_chunked_recv_tokens
;
auto
meta
=
ld_nc_global
(
nvl_channel_src_meta
.
buffer
()
+
token_idx_in_buffer
);
int64_t
recv_token_idx
=
shfl_sync
(
total_offset
,
meta
.
src_rdma_rank
);
(
lane_id
==
meta
.
src_rdma_rank
)
?
(
total_offset
+=
1
)
:
0
;
(
lane_id
==
meta
.
src_rdma_rank
)
?
(
total_offset
+=
1
)
:
0
;
// 复制数据
// Copy data
UNROLLED_WARP_COPY
(
5
,
UNROLLED_WARP_COPY
(
5
,
lane_id
,
hidden_int4
,
recv_x
+
recv_token_idx
*
hidden_int4
,
lane_id
,
nvl_channel_x
.
buffer
()
+
token_idx_in_buffer
*
hidden_int4
,
hidden_int4
,
ld_nc_global
,
st_na_global
);
recv_x
+
recv_token_idx
*
hidden_int4
,
nvl_channel_x
.
buffer
()
+
token_idx_in_buffer
*
hidden_int4
,
// Copy source meta
ld_nc_global
,
if
(
lane_id
==
0
and
not
kCachedMode
)
st_na_global
);
// 复制源元数据
if
(
lane_id
==
0
&&
!
kCachedMode
)
st_na_global
(
recv_src_meta
+
recv_token_idx
,
meta
);
st_na_global
(
recv_src_meta
+
recv_token_idx
,
meta
);
// 复制比例
// Copy scales
UNROLLED_WARP_COPY
(
1
,
UNROLLED_WARP_COPY
(
1
,
lane_id
,
num_scales
,
lane_id
,
recv_x_scales
+
recv_token_idx
*
num_scales
,
num_scales
,
nvl_channel_x_scales
.
buffer
()
+
token_idx_in_buffer
*
num_scales
,
recv_x_scales
+
recv_token_idx
*
num_scales
,
ld_nc_global
,
st_na_global
);
nvl_channel_x_scales
.
buffer
()
+
token_idx_in_buffer
*
num_scales
,
ld_nc_global
,
// Copy `topk_idx` and `topk_weights`
st_na_global
);
if
(
lane_id
<
num_topk
)
{
// 复制 `topk_idx` 和 `topk_weights`
if
(
lane_id
<
num_topk
)
{
auto
recv_idx
=
recv_token_idx
*
num_topk
+
lane_id
;
auto
recv_idx
=
recv_token_idx
*
num_topk
+
lane_id
;
auto
buffer_idx
=
token_idx_in_buffer
*
num_topk
+
lane_id
;
auto
buffer_idx
=
token_idx_in_buffer
*
num_topk
+
lane_id
;
st_na_global
(
recv_topk_idx
+
recv_idx
,
static_cast
<
int64_t
>
(
ld_nc_global
(
nvl_channel_topk_idx
.
buffer
()
+
buffer_idx
)));
st_na_global
(
recv_topk_idx
+
recv_idx
,
st_na_global
(
recv_topk_weights
+
recv_idx
,
ld_nc_global
(
nvl_channel_topk_weights
.
buffer
()
+
buffer_idx
));
static_cast
<
int64_t
>
(
ld_nc_global
(
nvl_channel_topk_idx
.
buffer
()
+
buffer_idx
)));
st_na_global
(
recv_topk_weights
+
recv_idx
,
ld_nc_global
(
nvl_channel_topk_weights
.
buffer
()
+
buffer_idx
));
}
}
}
}
//
移动队列
//
Move queue
syncwarp
();
syncwarp
();
if
(
lane_id
==
0
)
{
if
(
lane_id
==
0
)
st_relaxed_sys_global
(
nvl_channel_head
.
buffer
(),
cached_channel_head_idx
);
st_relaxed_sys_global
(
nvl_channel_head
.
buffer
(),
cached_channel_head_idx
);
}
}
}
// while(num_tokens_to_recv > 0)
}
}
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
rocshmem
::
rocshmem_wg_ctx_destroy
(
&
ctx
);
shmem_wg_ctx_destroy
(
&
ctx
);
#endif
}
}
void
dispatch
(
void
*
recv_x
,
float
*
recv_x_scales
,
int64_t
*
recv_topk_idx
,
float
*
recv_topk_weights
,
void
dispatch
(
void
*
recv_x
,
float
*
recv_x_scales
,
int64_t
*
recv_topk_idx
,
float
*
recv_topk_weights
,
...
@@ -1152,8 +1135,6 @@ void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float
...
@@ -1152,8 +1135,6 @@ void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float
int
num_ranks
,
bool
is_cached_dispatch
,
hipStream_t
stream
,
int
num_channels
,
int
num_ranks
,
bool
is_cached_dispatch
,
hipStream_t
stream
,
int
num_channels
,
bool
low_latency_mode
)
{
bool
low_latency_mode
)
{
constexpr
int
kNumDispatchRDMASenderWarps
=
7
;
constexpr
int
kNumDispatchRDMASenderWarps
=
7
;
// Make sure never OOB
EP_HOST_ASSERT
(
static_cast
<
int64_t
>
(
num_scales
)
*
scale_hidden_stride
<
std
::
numeric_limits
<
int
>::
max
());
#define DISPATCH_LAUNCH_CASE(num_rdma_ranks) \
#define DISPATCH_LAUNCH_CASE(num_rdma_ranks) \
{ \
{ \
...
@@ -1181,8 +1162,8 @@ void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float
...
@@ -1181,8 +1162,8 @@ void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float
EP_HOST_ASSERT
((
topk_idx
==
nullptr
)
==
(
topk_weights
==
nullptr
));
EP_HOST_ASSERT
((
topk_idx
==
nullptr
)
==
(
topk_weights
==
nullptr
));
EP_HOST_ASSERT
((
recv_topk_idx
==
nullptr
)
==
(
recv_topk_weights
==
nullptr
));
EP_HOST_ASSERT
((
recv_topk_idx
==
nullptr
)
==
(
recv_topk_weights
==
nullptr
));
SETUP_LAUNCH_CONFIG
(
num_channels
*
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
,
SETUP_LAUNCH_CONFIG
(
num_channels
*
2
,
(
1
+
NUM_MAX_NVL_PEERS
)
*
kWarpSize
,
stream
);
(
kNumDispatchRDMASenderWarps
+
1
+
NUM_MAX_NVL_PEERS
)
*
kWarpSize
,
stream
);
SWITCH_RDMA_RANKS
(
DISPATCH_LAUNCH_CASE
);
SWITCH_RDMA_RANKS
(
DISPATCH_LAUNCH_CASE
);
#undef DISPATCH_LAUNCH_CASE
#undef DISPATCH_LAUNCH_CASE
}
}
...
@@ -1209,7 +1190,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
...
@@ -1209,7 +1190,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
if
(
sm_id
==
0
)
{
if
(
sm_id
==
0
)
{
// Barrier for RDMA
// Barrier for RDMA
if
(
thread_id
==
kWarpSize
)
if
(
thread_id
==
kWarpSize
)
du
shmem_
barrier
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
shmem_
sync
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
// Barrier for NVL
// Barrier for NVL
barrier_block
<
NUM_MAX_NVL_PEERS
>
(
barrier_signal_ptrs
,
nvl_rank
);
barrier_block
<
NUM_MAX_NVL_PEERS
>
(
barrier_signal_ptrs
,
nvl_rank
);
...
@@ -1228,7 +1209,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
...
@@ -1228,7 +1209,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
// Barrier again
// Barrier again
if
(
thread_id
==
kWarpSize
)
if
(
thread_id
==
kWarpSize
)
du
shmem_
barrier
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
shmem_
sync
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
// Barrier again
// Barrier again
barrier_block
<
NUM_MAX_NVL_PEERS
>
(
barrier_signal_ptrs
,
nvl_rank
);
barrier_block
<
NUM_MAX_NVL_PEERS
>
(
barrier_signal_ptrs
,
nvl_rank
);
...
@@ -1236,25 +1217,24 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
...
@@ -1236,25 +1217,24 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
if
(
is_cached_dispatch
)
if
(
is_cached_dispatch
)
return
;
return
;
EP_DEVICE_ASSERT
(
num_warps
>=
num_channels
);
EP_DEVICE_ASSERT
(
num_rdma_ranks
<=
kWarpSize
);
EP_DEVICE_ASSERT
(
num_rdma_ranks
<=
kWarpSize
);
// Iterate in reverse order
// Iterate in reverse order
for
(
int
channel_id
=
warp_id
;
channel_id
<
num_channels
;
channel_id
+=
num_warps
)
{
if
(
lane_id
<
num_rdma_ranks
and
warp_id
<
num_channels
)
{
if
(
lane_id
<
num_rdma_ranks
)
{
int
token_start_idx
,
token_end_idx
;
int
token_start_idx
,
token_end_idx
;
get_channel_task_range
(
num_combined_tokens
,
num_channels
,
warp_id
,
token_start_idx
,
get_channel_task_range
(
num_combined_tokens
,
num_channels
,
channel_id
,
token_start_idx
,
token_end_idx
);
token_end_idx
);
// NOTES: `1 << 25` is a heuristic large number
// NOTES: `1 << 25` is a heuristic large number
int
last_head
=
1
<<
25
;
int
last_head
=
1
<<
25
;
for
(
int
token_idx
=
token_end_idx
-
1
;
token_idx
>=
token_start_idx
;
--
token_idx
)
{
for
(
int
token_idx
=
token_end_idx
-
1
;
token_idx
>=
token_start_idx
;
--
token_idx
)
{
auto
current_head
=
auto
current_head
=
__ldg
(
combined_rdma_head
+
token_idx
*
num_rdma_ranks
+
lane_id
);
__ldg
(
combined_rdma_head
+
token_idx
*
num_rdma_ranks
+
lane_id
);
if
(
current_head
<
0
)
{
if
(
current_head
<
0
)
{
combined_rdma_head
[
token_idx
*
num_rdma_ranks
+
lane_id
]
=
-
last_head
-
1
;
combined_rdma_head
[
token_idx
*
num_rdma_ranks
+
lane_id
]
=
-
last_head
-
1
;
}
else
{
}
else
{
last_head
=
current_head
;
last_head
=
current_head
;
}
}
}
}
}
}
}
...
@@ -1262,34 +1242,34 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
...
@@ -1262,34 +1242,34 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
if
(
is_cached_dispatch
)
if
(
is_cached_dispatch
)
return
;
return
;
EP_DEVICE_ASSERT
(
rdma_channel_prefix_matrix
!=
nullptr
and
rdma_rank_prefix_sum
!=
nullptr
);
EP_DEVICE_ASSERT
(
num_warps
>=
num_channels
);
EP_DEVICE_ASSERT
(
rdma_channel_prefix_matrix
!=
nullptr
and
rdma_rank_prefix_sum
!=
nullptr
);
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
<=
kWarpSize
,
"Too many NVL peers"
);
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
<=
kWarpSize
,
"Too many NVL peers"
);
constexpr
int
num_clean_sms
=
2
;
constexpr
int
num_clean_sms
=
2
;
for
(
int
channel_id
=
warp_id
;
channel_id
<
num_channels
;
channel_id
+=
num_warps
)
{
if
(
lane_id
<
NUM_MAX_NVL_PEERS
and
warp_id
<
num_channels
)
{
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
{
for
(
int
dst_rdma_rank
=
sm_id
-
num_clean_sms
;
dst_rdma_rank
<
num_rdma_ranks
;
for
(
int
dst_rdma_rank
=
sm_id
-
num_clean_sms
;
dst_rdma_rank
<
num_rdma_ranks
;
dst_rdma_rank
+=
num_channels
*
2
-
num_clean_sms
)
{
dst_rdma_rank
+=
num_channels
*
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
-
num_clean_sms
)
{
// Iterate in reverse order
// Iterate in reverse order
int
token_start_idx
=
int
token_start_idx
=
warp_id
==
0
channel_id
==
0
?
0
?
0
:
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
warp_id
-
1
];
:
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel_id
-
1
];
int
token_end_idx
=
int
token_end_idx
=
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
warp_id
];
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel_id
];
int
shift
=
dst_rdma_rank
==
0
?
0
:
rdma_rank_prefix_sum
[
dst_rdma_rank
-
1
];
int
shift
=
dst_rdma_rank
==
0
?
0
:
rdma_rank_prefix_sum
[
dst_rdma_rank
-
1
];
token_start_idx
+=
shift
,
token_end_idx
+=
shift
;
token_start_idx
+=
shift
,
token_end_idx
+=
shift
;
// NOTES: `1 << 25` is a heuristic large number
// NOTES: `1 << 25` is a heuristic large number
int
last_head
=
1
<<
25
;
int
last_head
=
1
<<
25
;
for
(
int
token_idx
=
token_end_idx
-
1
;
token_idx
>=
token_start_idx
;
--
token_idx
)
{
for
(
int
token_idx
=
token_end_idx
-
1
;
token_idx
>=
token_start_idx
;
--
token_idx
)
{
auto
current_head
=
auto
current_head
=
__ldg
(
combined_nvl_head
+
token_idx
*
NUM_MAX_NVL_PEERS
+
lane_id
);
__ldg
(
combined_nvl_head
+
token_idx
*
NUM_MAX_NVL_PEERS
+
lane_id
);
if
(
current_head
<
0
)
{
if
(
current_head
<
0
)
{
combined_nvl_head
[
token_idx
*
NUM_MAX_NVL_PEERS
+
lane_id
]
=
-
last_head
-
1
;
combined_nvl_head
[
token_idx
*
NUM_MAX_NVL_PEERS
+
lane_id
]
=
-
last_head
-
1
;
}
else
{
}
else
{
last_head
=
current_head
;
last_head
=
current_head
;
}
}
}
}
}
}
}
...
@@ -1305,7 +1285,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
...
@@ -1305,7 +1285,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
int
num_max_nvl_chunked_recv_tokens
,
int
**
barrier_signal_ptrs
,
int
rank
,
int
num_max_nvl_chunked_recv_tokens
,
int
**
barrier_signal_ptrs
,
int
rank
,
hipStream_t
stream
,
int64_t
num_rdma_bytes
,
int64_t
num_nvl_bytes
,
hipStream_t
stream
,
int64_t
num_rdma_bytes
,
int64_t
num_nvl_bytes
,
bool
is_cached_dispatch
,
bool
low_latency_mode
)
{
bool
is_cached_dispatch
,
bool
low_latency_mode
)
{
const
int
num_threads
=
::
min
(
1024
,
::
max
(
128
,
kWarpSize
*
num_channels
)
)
;
const
int
num_threads
=
::
max
(
128
,
kWarpSize
*
num_channels
);
const
auto
num_rdma_ranks
=
num_ranks
/
NUM_MAX_NVL_PEERS
;
const
auto
num_rdma_ranks
=
num_ranks
/
NUM_MAX_NVL_PEERS
;
// Get clean meta
// Get clean meta
...
@@ -1321,11 +1301,11 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
...
@@ -1321,11 +1301,11 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
num_nvl_bytes
);
num_nvl_bytes
);
EP_HOST_ASSERT
(
num_rdma_bytes
<
std
::
numeric_limits
<
int
>::
max
());
EP_HOST_ASSERT
(
num_rdma_bytes
<
std
::
numeric_limits
<
int
>::
max
());
EP_HOST_ASSERT
(
num_nvl_bytes
<
std
::
numeric_limits
<
int
>::
max
());
EP_HOST_ASSERT
(
num_nvl_bytes
<
std
::
numeric_limits
<
int
>::
max
());
EP_HOST_ASSERT
(
num_channels
*
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
>
2
);
EP_HOST_ASSERT
(
num_channels
*
2
>
2
);
// Launch kernel
// Launch kernel
auto
cached_notify_func
=
low_latency_mode
?
cached_notify
<
true
>
:
cached_notify
<
false
>
;
auto
cached_notify_func
=
low_latency_mode
?
cached_notify
<
true
>
:
cached_notify
<
false
>
;
SETUP_LAUNCH_CONFIG
(
num_channels
*
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
,
num_threads
,
stream
);
SETUP_LAUNCH_CONFIG
(
num_channels
*
2
,
num_threads
,
stream
);
LAUNCH_KERNEL_NON_COOPERATIVE
(
LAUNCH_KERNEL_NON_COOPERATIVE
(
&
cfg
,
cached_notify_func
,
rdma_clean_meta
.
first
,
rdma_clean_meta
.
second
,
&
cfg
,
cached_notify_func
,
rdma_clean_meta
.
first
,
rdma_clean_meta
.
second
,
nvl_clean_meta
.
first
,
nvl_clean_meta
.
second
,
combined_rdma_head
,
num_combined_tokens
,
nvl_clean_meta
.
first
,
nvl_clean_meta
.
second
,
combined_rdma_head
,
num_combined_tokens
,
...
@@ -1334,45 +1314,49 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
...
@@ -1334,45 +1314,49 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
cpu_rdma_team
);
cpu_rdma_team
);
}
}
template
<
int
kNumRanks
,
typename
dtype_t
,
int
kMaxNumRanks
,
bool
kUseMLS
,
typename
GetAddr
Fn
,
typename
ReceiveTWFn
>
template
<
int
kNumRanks
,
typename
dtype_t
,
int
kMaxNumRanks
,
int
kWidth
,
typename
Receive
Fn
,
typename
ReceiveTWFn
>
__device__
int
combine_token
(
bool
is_token_in_rank
,
int
head_idx
,
__device__
int
combine_token
(
bool
is_token_in_rank
,
int
head_idx
,
int
lane_id
,
int
hidden_int4
,
int
num_topk
,
int
lane_id
,
int
hidden_int4
,
int
num_topk
,
int4
*
combined_row
,
float
*
combined_topk_weights
,
int4
*
combined_row
,
float
*
combined_topk_weights
,
int
num_max_recv_tokens
,
int
num_max_recv_tokens
,
const
ReceiveFn
&
recv_fn
,
const
ReceiveTWFn
&
recv_tw_fn
)
{
const
GetAddrFn
&
get_addr_fn
,
const
ReceiveTWFn
&
recv_tw_fn
)
{
constexpr
auto
kDtypePerInt4
=
sizeof
(
int4
)
/
sizeof
(
dtype_t
);
constexpr
auto
kDtypePerInt4
=
sizeof
(
int4
)
/
sizeof
(
dtype_t
);
// Broadcast current heads
// Broadcast current heads
// Lane `i` holds the head of rank `i` and `is_token_in_rank`
// Lane `i` holds the head of rank `i` and `is_token_in_rank`
EP_STATIC_ASSERT
(
kMaxNumRanks
<=
kW
arpSize
,
"Too many ranks"
);
EP_STATIC_ASSERT
(
kMaxNumRanks
<=
kW
idth
,
"Too many ranks"
);
int
num_topk_ranks
=
0
,
topk_ranks
[
kMaxNumRanks
],
slot_indices
[
kMaxNumRanks
];
int
num_topk_ranks
=
0
,
topk_ranks
[
kMaxNumRanks
],
slot_indices
[
kMaxNumRanks
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
kNumRanks
;
++
i
)
if
(
shfl_sync
(
is_token_in_rank
,
i
))
{
for
(
int
i
=
0
;
i
<
kNumRanks
;
++
i
)
if
(
shfl_sync
(
is_token_in_rank
,
i
,
kWidth
))
{
slot_indices
[
num_topk_ranks
]
=
shfl_sync
(
head_idx
,
i
)
%
num_max_recv_tokens
;
slot_indices
[
num_topk_ranks
]
=
shfl_sync
(
head_idx
,
i
,
kWidth
)
%
num_max_recv_tokens
;
topk_ranks
[
num_topk_ranks
++
]
=
i
;
topk_ranks
[
num_topk_ranks
++
]
=
i
;
}
}
EP_DEVICE_ASSERT
(
num_topk_ranks
<=
kMaxNumRanks
);
EP_DEVICE_ASSERT
(
num_topk_ranks
<=
kMaxNumRanks
);
// Reduce data
// Reduce data
#pragma unroll
#pragma unroll
for
(
int
i
=
lane_id
;
i
<
hidden_int4
;
i
+=
kWarpSize
)
{
for
(
int
i
=
lane_id
;
i
<
hidden_int4
;
i
+=
kWidth
)
{
// Read buffers
float
values
[
kDtypePerInt4
]
=
{
0
};
float
values
[
kDtypePerInt4
]
=
{
0
};
// 8 × 4B = 32B
// Temporary buffer
int4
temp
;
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
{
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
{
int4
recv_value
=
ld_nc_global
(
get_addr_fn
(
topk_ranks
[
j
],
slot_indices
[
j
],
i
));
temp
=
recv_fn
(
topk_ranks
[
j
],
slot_indices
[
j
],
i
);
auto
recv_dtypes
=
reinterpret_cast
<
const
dtype_t
*>
(
&
recv_value
);
const
dtype_t
*
d
=
reinterpret_cast
<
const
dtype_t
*>
(
&
temp
);
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
kDtypePerInt4
;
++
k
)
for
(
int
k
=
0
;
k
<
kDtypePerInt4
;
++
k
)
values
[
k
]
+=
static_cast
<
float
>
(
recv_dtypes
[
k
]);
values
[
k
]
+=
static_cast
<
float
>
(
d
[
k
]);
}
}
// Cast back to `dtype_t` and write
int4
out_int4
;
int4
out_int4
;
auto
out_dtypes
=
reinterpret_cast
<
dtype_t
*>
(
&
out_int4
);
dtype_t
*
out_dtypes
=
reinterpret_cast
<
dtype_t
*>
(
&
out_int4
);
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
kDtypePerInt4
;
++
j
)
for
(
int
j
=
0
;
j
<
kDtypePerInt4
;
++
j
)
out_dtypes
[
j
]
=
static_cast
<
dtype_t
>
(
values
[
j
]);
out_dtypes
[
j
]
=
static_cast
<
dtype_t
>
(
values
[
j
]);
st_na_global
(
combined_row
+
i
,
out_int4
);
st_na_global
(
combined_row
+
i
,
out_int4
);
}
}
...
@@ -1389,87 +1373,98 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
...
@@ -1389,87 +1373,98 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
return
topk_ranks
[
0
];
return
topk_ranks
[
0
];
}
}
template
<
bool
kLowLatencyMode
,
template
<
bool
kLowLatencyMode
,
int
kNumRDMARanks
,
int
kNumRDMARanks
,
typename
dtype_t
,
typename
dtype_t
,
int
kNumCombineForwarderWarps
,
int
kNum
CombineForwarderWarps
,
int
kNum
TopkRDMARanks
=
get_num_topk_rdma_ranks
(
kNumRDMARanks
)
,
int
kNum
TopkRDMARanks
=
get_num_topk_rdma_ranks
(
kNumRDMARanks
)
,
int
kNum
WarpsPerForwarder
=
(
kNumCombineForwarderWarps
/
kNumRDMARanks
>
0
)
?
kNumCombineForwarderWarps
/
kNumRDMARanks
:
1
,
int
kNum
WarpsPer
Forwarder
=
(
kNumCombineForwarderWarp
s
/
kNumRDMARanks
>
0
)
?
kNumCombineForwarderWarps
/
kNumRDMARanks
:
1
,
int
kNumForwarders
=
kNumRDMARanks
*
kNumWarpsPerForwarder
,
int
kNum
Forwarders
=
kNumRDMARanks
*
kNumWarpsPerForwarder
,
int
kNum
RDMAReceivers
=
kNumRDMARanks
<=
8
?
kNumForwarders
+
NUM_MAX_NVL_PEERS
/
2
:
kNumForwarders
+
NUM_MAX_NVL_PEERS
,
int
k
NumRDMAReceivers
=
kNumForwarders
>
int
k
BlockThreads
=
(
kNumRDMARanks
>
8
)
?
((
NUM_MAX_NVL_PEERS
+
kNumForwarders
)
*
kEmulatedWarpSize
+
kWarpSize
)
:
((
NUM_MAX_NVL_PEERS
/
2
+
1
+
kNumForwarders
)
*
kWarpSize
)
>
__global__
void
__launch_bounds__
(
(
1
+
NUM_MAX_NVL_PEERS
)
*
kWarpSize
,
1
)
__global__
void
__launch_bounds__
(
kBlockThreads
,
1
)
combine
(
int4
*
combined_x
,
float
*
combined_topk_weights
,
const
bool
*
is_combined_token_in_rank
,
combine
(
int4
*
combined_x
,
float
*
combined_topk_weights
,
const
int4
*
x
,
const
float
*
topk_weights
,
const
int4
*
bias_0
,
const
int4
*
bias_1
,
const
bool
*
is_combined_token_in_rank
,
const
int
*
combined_rdma_head
,
const
int
*
combined_nvl_head
,
const
SourceMeta
*
src_meta
,
const
int
4
*
x
,
const
float
*
topk_weights
,
const
int
4
*
bias_0
,
const
int4
*
bias_1
,
const
int
*
rdma_channel_prefix_matrix
,
const
int
*
rdma_rank_prefix_sum
,
const
int
*
combined_rdma_head
,
const
int
*
combined_nvl_head
,
const
int
*
gbl
_channel_prefix_matrix
,
int
num_toke
ns
,
int
num_combined_tokens
,
const
SourceMeta
*
src_meta
,
const
int
*
rdma
_channel_prefix_matrix
,
const
int
*
rdma_rank_prefix_sum
,
co
ns
t
int
*
gbl_channel_prefix_matrix
,
int
hidd
en
,
int
num_
topk
,
void
*
rdma_buffer_ptr
,
int
num_max_rdma_chunked_send_tokens
,
int
num_tok
en
s
,
int
num_
combined_tokens
,
int
hidden
,
int
num_topk
,
int
num_max_rdma_chunked_
recv
_tokens
,
void
**
buffer_ptr
s
,
void
*
rdma_buffer_ptr
,
int
num_max_rdma_chunked_
send
_tokens
,
int
num_max_rdma_chunked_recv_token
s
,
int
num_max_nvl_chunked_send_tokens
,
int
num_max_nvl_chunked_recv_tokens
,
int
rank
,
void
**
buffer_ptrs
,
int
num_max_nvl_chunked_send_tokens
,
int
num_max_nvl_chunked_recv_tokens
,
int
num_ranks
)
{
int
rank
,
int
num_ranks
)
{
enum
class
WarpRole
{
enum
class
WarpRole
{
kNVLSender
,
kNVLSender
,
kNVLAndRDMAForwarder
,
kNVLAndRDMAForwarder
,
kRDMAReceiver
,
kRDMAReceiver
,
kRDMACoordinator
,
kCoordinator
kNVLCoordinator
};
};
constexpr
auto
kNVLPeersHyb
=
(
kNumRDMARanks
>
8
)
?
NUM_MAX_NVL_PEERS
:
NUM_MAX_NVL_PEERS
/
2
;
constexpr
auto
kWarpHyb
=
kNumRDMARanks
>
8
?
kEmulatedWarpSize
:
kWarpSize
;
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
num_threads
=
static_cast
<
int
>
(
blockDim
.
x
);
const
int
num_warps
=
kNumRDMARanks
>
8
?
(
num_threads
/
kEmulatedWarpSize
-
1
)
:
(
num_threads
/
kWarpSize
);
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
const
auto
num_channels
=
static_cast
<
int
>
(
gridDim
.
x
)
/
2
,
channel_id
=
sm_id
/
2
;
const
bool
is_rdma_receiver_sm
=
sm_id
%
2
==
1
;
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
__shared__
shmem_ctx_t
ctx
;
__shared__
shmem_ctx_t
ctx
;
shmem_wg_ctx_create
(
&
ctx
);
shmem_wg_ctx_create
(
&
ctx
);
#endif
EP_STATIC_ASSERT
(
kNumCombineForwarderWarps
<=
kWarpSize
,
"Invalid number of forwarder warps"
);
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
#endif
const
auto
num_threads
=
static_cast
<
int
>
(
blockDim
.
x
),
num_warps
=
num_threads
/
kWarpSize
;
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
),
warp_id
=
thread_id
/
kWarpSize
,
lane_id
=
get_lane_id
();
const
auto
num_channels
=
static_cast
<
int
>
(
gridDim
.
x
)
/
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
,
channel_id
=
sm_id
/
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
;
EP_DEVICE_ASSERT
(
num_topk
<=
kEmulatedWarpSize
);
EP_DEVICE_ASSERT
(
hidden
%
(
sizeof
(
int4
)
/
sizeof
(
dtype_t
))
==
0
);
const
auto
hidden_int4
=
hidden
/
(
sizeof
(
int4
)
/
sizeof
(
dtype_t
));
const
auto
hidden_int4
=
hidden
/
(
sizeof
(
int4
)
/
sizeof
(
dtype_t
));
// NOTES: we decouple a channel into 2 SMs
// NOTES: we decouple a channel into 2 SMs
const
auto
rdma_rank
=
rank
/
NUM_MAX_NVL_PEERS
,
nvl_rank
=
rank
%
NUM_MAX_NVL_PEERS
;
const
auto
rdma_rank
=
rank
/
NUM_MAX_NVL_PEERS
,
nvl_rank
=
rank
%
NUM_MAX_NVL_PEERS
;
auto
role_meta
=
[
=
]()
->
std
::
pair
<
WarpRole
,
int
>
{
const
auto
role_meta
=
[
=
]()
->
std
::
pair
<
WarpRole
,
int
>
{
auto
warp_id
=
thread_id
/
kWarpHyb
;
if
(
sm_id
%
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
==
1
)
{
if
(
not
is_rdma_receiver_sm
)
{
return
{
WarpRole
::
kNVLSender
,
(
warp_id
+
channel_id
)
%
NUM_MAX_NVL_PEERS
};
if
(
warp_id
<
kNVLPeersHyb
)
{
}
else
if
(
sm_id
%
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
==
0
)
{
auto
shuffled_warp_id
=
warp_id
;
if
(
warp_id
<
kNumForwarders
)
{
shuffled_warp_id
=
(
shuffled_warp_id
+
channel_id
)
%
kNVLPeersHyb
;
return
{
WarpRole
::
kNVLAndRDMAForwarder
,
(
warp_id
+
channel_id
)
%
kNumForwarders
};
return
{
WarpRole
::
kNVLSender
,
shuffled_warp_id
};
}
else
if
(
warp_id
<
kNVLPeersHyb
+
kNumForwarders
)
{
auto
shuffled_warp_id
=
warp_id
-
kNVLPeersHyb
;
shuffled_warp_id
=
(
shuffled_warp_id
+
channel_id
)
%
kNumForwarders
;
return
{
WarpRole
::
kNVLAndRDMAForwarder
,
shuffled_warp_id
};
}
else
{
}
else
{
return
{
WarpRole
::
k
RDMA
Coordinator
,
0
};
return
{
WarpRole
::
kCoordinator
,
0
};
}
}
}
else
{
}
else
{
if
(
warp_id
<
kNumForwarders
)
{
if
(
warp_id
<
kNVLPeersHyb
+
kNumForwarders
)
{
return
{
WarpRole
::
kRDMAReceiver
,
warp_id
};
return
{
WarpRole
::
kRDMAReceiver
,
warp_id
};
}
else
{
}
else
{
return
{
WarpRole
::
k
NVL
Coordinator
,
0
};
return
{
WarpRole
::
kCoordinator
,
0
};
}
}
}
}
}();
}();
auto
warp_role
=
role_meta
.
first
;
auto
warp_role
=
role_meta
.
first
;
auto
t
ar
get_rank
=
role_meta
.
second
;
// Not applicable for RDMA senders
auto
w
ar
p_id
=
role_meta
.
second
;
EP_DEVICE_ASSERT
(
num_warps
==
NUM_MAX_NVL_PEERS
+
1
);
auto
num_max_nvl_chunked_recv_tokens_per_rdma
=
num_max_nvl_chunked_recv_tokens
/
kNumRDMARanks
;
auto
num_max_nvl_chunked_recv_tokens_per_rdma
=
num_max_nvl_chunked_recv_tokens
/
kNumRDMARanks
;
// This approach is designed to sync multiple warps in a loop
// This approach is designed to sync multiple warps in a loop
constexpr
int
num_sync_large_iteration
=
64
;
constexpr
int
num_sync_large_iteration
=
64
;
constexpr
int
rdma_warp_counters
=
kNumRDMARanks
*
num_sync_large_iteration
;
__shared__
volatile
int
rdma_receiver_counter
[
1
];
__shared__
volatile
int
sync_large_warp_counters
[
2
*
rdma_warp_counters
];
__shared__
volatile
int
rdma_forwarder_counter
[
1
];
for
(
int
i
=
thread_id
;
i
<
2
*
rdma_warp_counters
;
i
+=
num_threads
)
{
__shared__
volatile
uint8_t
sync_large_warp_counters
[
2
*
kNumRDMARanks
*
num_sync_large_iteration
];
if
(
threadIdx
.
x
==
0
){
rdma_receiver_counter
[
0
]
=
0
;
rdma_forwarder_counter
[
0
]
=
0
;
}
for
(
int
i
=
thread_id
;
i
<
2
*
kNumRDMARanks
*
num_sync_large_iteration
;
i
+=
num_threads
)
{
sync_large_warp_counters
[
i
]
=
0
;
sync_large_warp_counters
[
i
]
=
0
;
}
}
__syncthreads
();
__syncthreads
();
if
(
warp_role
==
WarpRole
::
kNVLSender
)
{
if
(
warp_role
==
WarpRole
::
kNVLSender
)
{
if
(
warp_id
>=
NUM_MAX_NVL_PEERS
)
{
// NVL producers
return
;
const
int
dst_nvl_rank
=
kNumRDMARanks
<=
8
?
(
warp_id
*
2
+
(
thread_id
%
kWarpSize
)
/
kEmulatedWarpSize
)
:
warp_id
;
}
auto
lane_id
=
get_lane_id
()
%
kEmulatedWarpSize
;
const
auto
dst_nvl_rank
=
target_rank
;
// NVL layouts
// NVL layouts
// NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources
// NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources
auto
dst_buffer_ptr
=
buffer_ptrs
[
dst_nvl_rank
],
local_buffer_ptr
=
buffer_ptrs
[
nvl_rank
];
auto
dst_buffer_ptr
=
buffer_ptrs
[
dst_nvl_rank
],
local_buffer_ptr
=
buffer_ptrs
[
nvl_rank
];
...
@@ -1481,103 +1476,90 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
...
@@ -1481,103 +1476,90 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
// Get tasks for each RDMA lane
// Get tasks for each RDMA lane
int
token_start_idx
=
0
,
token_end_idx
=
0
;
int
token_start_idx
=
0
,
token_end_idx
=
0
;
if
(
lane_id
<
kNumRDMARanks
)
{
if
(
lane_id
<
kNumRDMARanks
)
{
int
prefix_idx
=
(
lane_id
*
NUM_MAX_NVL_PEERS
+
dst_nvl_rank
)
*
num_channels
+
channel_id
;
int
prefix_idx
=
(
lane_id
*
NUM_MAX_NVL_PEERS
+
dst_nvl_rank
)
*
num_channels
+
channel_id
;
token_start_idx
=
gbl_channel_prefix_matrix
[
prefix_idx
];
token_start_idx
=
gbl_channel_prefix_matrix
[
prefix_idx
];
token_end_idx
=
(
prefix_idx
==
num_channels
*
num_ranks
-
1
)
?
num_tokens
:
gbl_channel_prefix_matrix
[
prefix_idx
+
1
];
token_end_idx
=
(
prefix_idx
==
num_channels
*
num_ranks
-
1
)
?
num_tokens
:
gbl_channel_prefix_matrix
[
prefix_idx
+
1
];
}
}
syncwarp
();
syncwarp
();
// NOTES: here the cached value of each lane is only responsible for a single RDMA buffer
// NOTES: here the cached value of each lane is only responsible for a single RDMA buffer
int
cached_channel_head_idx
=
0
,
cached_channel_tail_idx
=
0
;
int
cached_channel_head_idx
=
0
,
cached_channel_tail_idx
=
0
;
EP_STATIC_ASSERT
(
kNumRDMARanks
<=
kWarpSize
,
"Invalid number of RDMA peers"
);
EP_STATIC_ASSERT
(
kNumRDMARanks
<=
k
Emulated
WarpSize
,
"Invalid number of RDMA peers"
);
// Iterate over all tokens and send by chunks
// Iterate over all tokens and send by chunks
while
(
true
)
{
while
(
true
)
{
// Exit if possible
// Exit if possible
if
(
__all_sync
(
kFullWarpMask
,
token_start_idx
>=
token_end_idx
))
if
(
__all_sync
(
kFullWarpMask
,
token_start_idx
>=
token_end_idx
))
break
;
break
;
// Decide next RDMA buffer to send
// Decide next RDMA buffer to send
bool
is_lane_ready
=
false
;
bool
is_lane_ready
=
false
;
auto
start_time
=
wall_clock64
();
auto
start_time
=
wall_clock64
();
while
(
true
)
{
while
(
true
)
{
int
num_used_slots
=
cached_channel_tail_idx
-
cached_channel_head_idx
;
int
num_used_slots
=
cached_channel_tail_idx
-
cached_channel_head_idx
;
is_lane_ready
=
lane_id
<
kNumRDMARanks
and
token_start_idx
<
token_end_idx
and
is_lane_ready
=
lane_id
<
kNumRDMARanks
and
token_start_idx
<
token_end_idx
and
num_max_nvl_chunked_recv_tokens_per_rdma
-
num_used_slots
>=
num_max_nvl_chunked_send_tokens
;
num_max_nvl_chunked_recv_tokens_per_rdma
-
num_used_slots
>=
num_max_nvl_chunked_send_tokens
;
if
(
__any_sync
(
kFirstHalfMask
,
is_lane_ready
))
break
;
if
(
__any_sync
(
k
FullWarp
Mask
,
is_lane_ready
))
if
(
__any_sync
(
k
SecondHalf
Mask
,
is_lane_ready
))
break
;
break
;
// Retry
// Retry
if
(
lane_id
<
kNumRDMARanks
and
token_start_idx
<
token_end_idx
)
if
(
lane_id
<
kNumRDMARanks
and
token_start_idx
<
token_end_idx
)
cached_channel_head_idx
=
ld_volatile_global
(
nvl_channel_head
.
buffer
()
+
lane_id
);
cached_channel_head_idx
=
ld_volatile_global
(
nvl_channel_head
.
buffer
()
+
lane_id
);
// Timeout check
// Timeout check
if
(
wall_clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
and
lane_id
<
kNumRDMARanks
)
{
long
long
int
elapsed_time
=
wall_clock64
()
>
start_time
?
wall_clock64
()
-
start_time
:
0
;
printf
(
"DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, "
if
(
elapsed_time
>
NUM_TIMEOUT_CYCLES
and
lane_id
<
kNumRDMARanks
)
{
"RDMA lane: %d, head: %d, tail: %d, start: %d, end: %d
\n
"
,
printf
(
"DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, RDMA lane: %d, head: %d, tail: %d, start: %d, end: %d
\n
"
,
channel_id
,
channel_id
,
rdma_rank
,
nvl_rank
,
dst_nvl_rank
,
lane_id
,
ld_volatile_global
(
nvl_channel_head
.
buffer
()
+
lane_id
),
cached_channel_tail_idx
,
rdma_rank
,
token_start_idx
,
token_end_idx
);
nvl_rank
,
dst_nvl_rank
,
lane_id
,
ld_volatile_global
(
nvl_channel_head
.
buffer
()
+
lane_id
),
cached_channel_tail_idx
,
token_start_idx
,
token_end_idx
);
trap
();
trap
();
}
}
__builtin_amdgcn_s_sleep
(
1
);
}
}
// Sync token start index and count
// Sync token start index and count
for
(
int
current_rdma_idx
=
0
;
current_rdma_idx
<
kNumRDMARanks
;
++
current_rdma_idx
)
{
for
(
int
current_rdma_idx
=
0
;
current_rdma_idx
<
kNumRDMARanks
;
++
current_rdma_idx
)
{
if
(
shfl_sync
((
token_start_idx
>=
token_end_idx
)
or
(
not
is_lane_ready
),
current_rdma_idx
))
if
(
shfl_sync
((
token_start_idx
>=
token_end_idx
)
or
(
not
is_lane_ready
),
current_rdma_idx
,
kEmulatedWarpSize
))
continue
;
continue
;
// Sync token start index
// Sync token start index
auto
token_idx
=
static_cast
<
int64_t
>
(
shfl_sync
(
token_start_idx
,
current_rdma_idx
));
auto
token_idx
=
static_cast
<
int64_t
>
(
shfl_sync
(
token_start_idx
,
current_rdma_idx
,
kEmulatedWarpSize
));
int
num_tokens_in_chunk
=
shfl_sync
(
min
(
num_max_nvl_chunked_send_tokens
,
token_end_idx
-
token_start_idx
),
current_rdma_idx
);
int
num_tokens_in_chunk
=
shfl_sync
(
min
(
num_max_nvl_chunked_send_tokens
,
token_end_idx
-
token_start_idx
),
current_rdma_idx
,
kEmulatedWarpSize
);
// Send by chunk
// Send by chunk
for
(
int
chunk_idx
=
0
;
chunk_idx
<
num_tokens_in_chunk
;
++
chunk_idx
,
++
token_idx
)
{
for
(
int
chunk_idx
=
0
;
chunk_idx
<
num_tokens_in_chunk
;
++
chunk_idx
,
++
token_idx
)
{
// Get an empty slot
// Get an empty slot
int
dst_slot_idx
=
0
;
int
dst_slot_idx
=
0
;
if
(
lane_id
==
current_rdma_idx
)
{
if
(
lane_id
==
current_rdma_idx
)
{
dst_slot_idx
=
(
cached_channel_tail_idx
++
)
%
num_max_nvl_chunked_recv_tokens_per_rdma
;
dst_slot_idx
=
(
cached_channel_tail_idx
++
)
%
num_max_nvl_chunked_recv_tokens_per_rdma
;
dst_slot_idx
=
current_rdma_idx
*
num_max_nvl_chunked_recv_tokens_per_rdma
+
dst_slot_idx
;
dst_slot_idx
=
current_rdma_idx
*
num_max_nvl_chunked_recv_tokens_per_rdma
+
dst_slot_idx
;
}
}
dst_slot_idx
=
shfl_sync
(
dst_slot_idx
,
current_rdma_idx
);
dst_slot_idx
=
shfl_sync
(
dst_slot_idx
,
current_rdma_idx
,
kEmulatedWarpSize
);
// Copy data
// Copy data
auto
shifted_x_buffers
=
nvl_channel_x
.
buffer
()
+
dst_slot_idx
*
hidden_int4
;
auto
shifted_x_buffers
=
nvl_channel_x
.
buffer
()
+
dst_slot_idx
*
hidden_int4
;
auto
shifted_x
=
x
+
token_idx
*
hidden_int4
;
auto
shifted_x
=
x
+
token_idx
*
hidden_int4
;
UNROLLED_WARP_COPY
(
5
,
lane_id
,
hidden_int4
,
shifted_x_buffers
,
shifted_x
,
ld_nc_global
,
st_na_global
);
UNROLLED_WARP_COPY
_EMULATED
(
5
,
lane_id
,
hidden_int4
,
shifted_x_buffers
,
shifted_x
,
ld_nc_global
,
st_na_global
);
// Copy source meta
// Copy source meta
if
(
lane_id
==
0
)
if
(
lane_id
==
num_topk
)
st_na_global
(
nvl_channel_src_meta
.
buffer
()
+
dst_slot_idx
,
ld_nc_global
(
src_meta
+
token_idx
));
st_na_global
(
nvl_channel_src_meta
.
buffer
()
+
dst_slot_idx
,
ld_nc_global
(
src_meta
+
token_idx
));
// Copy `topk_weights`
// Copy `topk_weights`
if
(
lane_id
<
num_topk
)
if
(
lane_id
<
num_topk
)
st_na_global
(
nvl_channel_topk_weights
.
buffer
()
+
dst_slot_idx
*
num_topk
+
lane_id
,
st_na_global
(
nvl_channel_topk_weights
.
buffer
()
+
dst_slot_idx
*
num_topk
+
lane_id
,
ld_nc_global
(
topk_weights
+
token_idx
*
num_topk
+
lane_id
));
ld_nc_global
(
topk_weights
+
token_idx
*
num_topk
+
lane_id
));
}
}
lane_id
==
current_rdma_idx
?
(
token_start_idx
=
static_cast
<
int
>
(
token_idx
))
:
0
;
lane_id
==
current_rdma_idx
?
(
token_start_idx
=
static_cast
<
int
>
(
token_idx
))
:
0
;
}
}
// Move queue tail
// Move queue tail
syncwarp
();
syncwarp
();
if
(
lane_id
<
kNumRDMARanks
and
is_lane_ready
)
{
if
(
lane_id
<
kNumRDMARanks
and
is_lane_ready
)
st_release_sys_global
(
nvl_channel_tail
.
buffer
()
+
lane_id
,
cached_channel_tail_idx
);
st_relaxed_sys_global
(
nvl_channel_tail
.
buffer
()
+
lane_id
,
cached_channel_tail_idx
);
}
}
}
}
else
{
}
else
{
if
(
warp_id
>
kNumForwarders
)
{
auto
lane_id
=
get_lane_id
()
%
kWarpHyb
;
return
;
}
// Combiners and coordinators
// Combiners and coordinators
// RDMA symmetric layout
// RDMA symmetric layout
auto
hidden_bytes
=
hidden_int4
*
sizeof
(
int4
);
auto
hidden_bytes
=
hidden_int4
*
sizeof
(
int4
);
...
@@ -1604,53 +1586,65 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
...
@@ -1604,53 +1586,65 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
__shared__
volatile
int
rdma_receiver_rdma_head
[
kNumRDMAReceivers
][
kNumRDMARanks
];
__shared__
volatile
int
rdma_receiver_rdma_head
[
kNumRDMAReceivers
][
kNumRDMARanks
];
__shared__
volatile
bool
rdma_receiver_retired
[
kNumRDMAReceivers
];
__shared__
volatile
bool
rdma_receiver_retired
[
kNumRDMAReceivers
];
auto
sync_forwarder_smem
=
[
&
]()
{
if
(
lane_id
==
0
)
{
volatile
int
ret
=
__hip_atomic_fetch_add
(
&
rdma_forwarder_counter
[
0
],
1
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
}
syncwarp
();
while
(
rdma_forwarder_counter
[
0
]
<
(
kNumForwarders
+
1
)){}
};
auto
sync_rdma_receiver_smem
=
[
&
]()
{
if
(
lane_id
==
0
)
{
volatile
int
ret
=
__hip_atomic_fetch_add
(
&
rdma_receiver_counter
[
0
],
1
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
}
syncwarp
();
while
(
rdma_receiver_counter
[
0
]
<
(
kNumRDMAReceivers
+
1
)){}
};
if
(
warp_role
==
WarpRole
::
kNVLAndRDMAForwarder
)
{
if
(
warp_role
==
WarpRole
::
kNVLAndRDMAForwarder
)
{
// Receive from NVL ranks and forward to RDMA ranks
// Receive from NVL ranks and forward to RDMA ranks
// NOTES: this part is using "large warps" for each RDMA ranks
// NOTES: this part is using "large warps" for each RDMA ranks
const
auto
dst_rdma_rank
=
target_rank
/
kNumWarpsPerForwarder
;
const
auto
dst_rdma_rank
=
warp_id
/
kNumWarpsPerForwarder
;
const
auto
sub_warp_id
=
target_rank
%
kNumWarpsPerForwarder
;
const
auto
sub_warp_id
=
warp_id
%
kNumWarpsPerForwarder
;
auto
send_buffer
=
dst_rdma_rank
==
rdma_rank
?
rdma_channel_data
.
recv_buffer
(
dst_rdma_rank
)
:
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
);
auto
send_buffer
=
dst_rdma_rank
==
rdma_rank
?
rdma_channel_data
.
recv_buffer
(
dst_rdma_rank
)
:
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
);
// auto sync_large_warp = [=]() {
// if(kNumWarpsPerForwarder == 1) {
// syncwarp();
// } else {
// // asm volatile("bar.sync %0, %1;" ::"r"(dst_rdma_rank + 2), "r"(kNumWarpsPerForwarder * kWarpSize));
// // __syncthreads();
// syncwarp();
// }
// };
auto
sync_large_warp
=
[
=
](
const
int
iter
,
const
int
mode
)
{
auto
sync_large_warp
=
[
=
](
const
int
iter
,
const
int
mode
)
{
if
(
kNumWarpsPerForwarder
==
1
)
{
if
(
kNumWarpsPerForwarder
==
1
)
{
syncwarp
();
syncwarp
();
}
else
{
}
else
{
// LDS index to store for sync
// LDS index to store for sync
int
lds_dst_rdma_rank
=
dst_rdma_rank
+
(
iter
%
num_sync_large_iteration
)
*
kNumRDMARanks
+
mode
*
rdma_warp_counters
;
int
lds_dst_rdma_rank
=
dst_rdma_rank
+
(
iter
%
num_sync_large_iteration
)
*
kNumRDMARanks
+
mode
*
kNumRDMARanks
*
num_sync_large_iteration
;
//reset index in the LDS to avoid race condition due to warp scheduling
//reset index in the LDS to avoid race condition due to warp scheduling
int
reset_idx
=
dst_rdma_rank
+
((
iter
+
num_sync_large_iteration
/
2
)
%
num_sync_large_iteration
)
*
kNumRDMARanks
+
mode
*
rdma_warp_counters
;
int
reset_idx
=
dst_rdma_rank
+
((
iter
+
num_sync_large_iteration
/
2
)
%
num_sync_large_iteration
)
*
kNumRDMARanks
+
mode
*
kNumRDMARanks
*
num_sync_large_iteration
;
auto
start_time
=
wall_clock64
();
// if (lane_id==0)
if
(
lane_id
==
0
){
// printf("rank %d dst_rdma_rank %d iter %d warp_id %d val %d\n", rank, dst_rdma_rank, iter, warp_id, sync_large_warp_counters[lds_dst_rdma_rank]);
volatile
int
ret
=
atomicAdd
((
int
*
)
&
sync_large_warp_counters
[
lds_dst_rdma_rank
],
1
);
auto
start_time
=
clock64
();
}
if
(
lane_id
==
0
){
syncwarp
();
volatile
int
ret
=
__hip_atomic_fetch_add
(
//The while(...) loop polls the counter until all warps have arrived
&
sync_large_warp_counters
[
lds_dst_rdma_rank
],
1
,
if
(
lane_id
==
0
){
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
while
(
sync_large_warp_counters
[
lds_dst_rdma_rank
]
<
(
kNumWarpsPerForwarder
)){
}
if
(
wall_clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
)
{
syncwarp
();
printf
(
"DeepEP combine sync timeout. current num_sync_large_iteration %d. double it.
\n
"
,
num_sync_large_iteration
);
//The while(...) loop polls the counter until all warps have arrived
trap
();
if
(
lane_id
==
0
){
}
while
(
sync_large_warp_counters
[
lds_dst_rdma_rank
]
<
(
kNumWarpsPerForwarder
)){
if
(
clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
)
{
printf
(
"DeepEP combine sync timeout. current num_sync_large_iteration %d. double it.
\n
"
,
num_sync_large_iteration
);
trap
();
}
}
}
}
syncwarp
();
}
if
(
lane_id
==
0
&&
sync_large_warp_counters
[
reset_idx
]
==
kNumWarpsPerForwarder
){
syncwarp
();
sync_large_warp_counters
[
reset_idx
]
=
0
;
if
(
lane_id
==
0
&&
sync_large_warp_counters
[
reset_idx
]
==
kNumWarpsPerForwarder
){
}
sync_large_warp_counters
[
reset_idx
]
=
0
;
syncwarp
();
}
syncwarp
();
}
}
};
};
EP_STATIC_ASSERT
(
kNumWarpsPerForwarder
==
1
or
kNumRDMARanks
+
2
<=
kNumCombineForwarderWarps
,
"Barriers are not enough"
);
EP_STATIC_ASSERT
(
kNumWarpsPerForwarder
==
1
or
kNumRDMARanks
+
2
<=
16
,
"Barriers are not enough"
);
// Advance to the corresponding NVL buffer, 基于原本指针进行的地址偏移
// In case of running less than 8 nodes
constexpr
bool
kUseWave
=
(
kNumRDMARanks
<=
8
);
// Advance to the corresponding NVL buffer
nvl_channel_x
.
advance
(
dst_rdma_rank
*
num_max_nvl_chunked_recv_tokens_per_rdma
*
hidden_int4
);
nvl_channel_x
.
advance
(
dst_rdma_rank
*
num_max_nvl_chunked_recv_tokens_per_rdma
*
hidden_int4
);
nvl_channel_src_meta
.
advance
(
dst_rdma_rank
*
num_max_nvl_chunked_recv_tokens_per_rdma
);
nvl_channel_src_meta
.
advance
(
dst_rdma_rank
*
num_max_nvl_chunked_recv_tokens_per_rdma
);
nvl_channel_topk_weights
.
advance
(
dst_rdma_rank
*
num_max_nvl_chunked_recv_tokens_per_rdma
*
num_topk
);
nvl_channel_topk_weights
.
advance
(
dst_rdma_rank
*
num_max_nvl_chunked_recv_tokens_per_rdma
*
num_topk
);
...
@@ -1659,10 +1653,9 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
...
@@ -1659,10 +1653,9 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
// Clean shared memory and sync
// Clean shared memory and sync
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
<=
kWarpSize
,
"Invalid number of NVL peers"
);
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
<=
kWarpSize
,
"Invalid number of NVL peers"
);
lane_id
<
NUM_MAX_NVL_PEERS
?
(
forwarder_nvl_head
[
target_rank
][
lane_id
]
=
0
)
:
0
;
lane_id
<
NUM_MAX_NVL_PEERS
?
(
forwarder_nvl_head
[
warp_id
][
lane_id
]
=
0
)
:
0
;
lane_id
==
0
?
(
forwarder_retired
[
target_rank
]
=
false
)
:
false
;
lane_id
==
0
?
(
forwarder_retired
[
warp_id
]
=
false
)
:
false
;
// sync_forwarder_smem();
sync_forwarder_smem
();
__syncthreads
();
// Get count and cached head
// Get count and cached head
int
cached_nvl_channel_tail_idx
=
0
;
int
cached_nvl_channel_tail_idx
=
0
;
...
@@ -1673,89 +1666,106 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
...
@@ -1673,89 +1666,106 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
combined_nvl_head
+=
num_tokens_prefix
*
NUM_MAX_NVL_PEERS
;
combined_nvl_head
+=
num_tokens_prefix
*
NUM_MAX_NVL_PEERS
;
// Iterate over all tokens and combine by chunks
// Iterate over all tokens and combine by chunks
for
(
int
token_start_idx
=
0
;
token_start_idx
<
num_tokens_to_combine
;
token_start_idx
+=
num_max_rdma_chunked_send_tokens
)
{
for
(
int
token_start_idx
=
0
;
token_start_idx
<
num_tokens_to_combine
;
token_start_idx
+=
num_max_rdma_chunked_send_tokens
)
{
// Check destination queue emptiness, or wait a buffer to be released
// Check destination queue emptiness, or wait a buffer to be released
auto
token_end_idx
=
min
(
token_start_idx
+
num_max_rdma_chunked_send_tokens
,
num_tokens_to_combine
);
auto
token_end_idx
=
min
(
token_start_idx
+
num_max_rdma_chunked_send_tokens
,
num_tokens_to_combine
);
auto
num_chunked_tokens
=
token_end_idx
-
token_start_idx
;
auto
num_chunked_tokens
=
token_end_idx
-
token_start_idx
;
auto
start_time
=
wall_clock64
();
auto
start_time
=
wall_clock64
();
while
(
sub_warp_id
==
0
and
lane_id
==
0
)
{
while
(
sub_warp_id
==
0
and
lane_id
==
0
)
{
// Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens`
// Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens`
// Here, `token_start_idx` is the actual tail
// Here, `token_start_idx` is the actual tail
int
num_used_slots
=
token_start_idx
-
ld_volatile_global
(
rdma_channel_head
.
buffer
(
dst_rdma_rank
));
int
num_used_slots
=
token_start_idx
-
ld_volatile_global
(
rdma_channel_head
.
buffer
(
dst_rdma_rank
));
if
(
num_max_rdma_chunked_recv_tokens
-
num_used_slots
>=
num_chunked_tokens
)
if
(
num_max_rdma_chunked_recv_tokens
-
num_used_slots
>=
num_chunked_tokens
)
break
;
break
;
// Timeout check
// Timeout check
if
(
wall_clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
)
{
long
long
int
elapsed_time
=
wall_clock64
()
>
start_time
?
wall_clock64
()
-
start_time
:
0
;
if
(
elapsed_time
>
NUM_TIMEOUT_CYCLES
)
{
printf
(
"DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA: %d, head: %ld, tail: %d, chunked: %d
\n
"
,
printf
(
"DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA: %d, head: %ld, tail: %d, chunked: %d
\n
"
,
channel_id
,
rdma_rank
,
nvl_rank
,
dst_rdma_rank
,
ld_volatile_global
(
rdma_channel_head
.
buffer
(
dst_rdma_rank
)),
token_start_idx
,
num_chunked_tokens
);
channel_id
,
rdma_rank
,
nvl_rank
,
dst_rdma_rank
,
ld_volatile_global
(
rdma_channel_head
.
buffer
(
dst_rdma_rank
)),
token_start_idx
,
num_chunked_tokens
);
trap
();
trap
();
}
}
}
}
// sync_large_warp();
sync_large_warp
(
token_start_idx
,
0
);
sync_large_warp
(
token_start_idx
,
0
);
// Combine and write to the RDMA buffer
// Combine and write to the RDMA buffer
for
(
int
token_idx
=
token_start_idx
+
sub_warp_id
;
token_idx
<
token_end_idx
;
token_idx
+=
kNumWarpsPerForwarder
)
{
for
(
int
token_idx
=
token_start_idx
+
sub_warp_id
;
token_idx
<
token_end_idx
;
token_idx
+=
kNumWarpsPerForwarder
)
{
// Read expected head
// Read expected head
EP_STATIC_ASSERT
(
kNumRDMARanks
<=
kWarpSize
,
"Invalid number of RDMA peers"
);
EP_STATIC_ASSERT
(
kNumRDMARanks
<=
kWarpSize
,
"Invalid number of RDMA peers"
);
int
expected_head
=
-
1
;
int
expected_head
=
-
1
;
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
expected_head
=
ld_nc_global
(
combined_nvl_head
+
token_idx
*
NUM_MAX_NVL_PEERS
+
lane_id
);
expected_head
=
ld_nc_global
(
combined_nvl_head
+
token_idx
*
NUM_MAX_NVL_PEERS
+
lane_id
);
// Wait lanes to be ready
// Wait lanes to be ready
start_time
=
wall_clock64
();
start_time
=
wall_clock64
();
while
(
cached_nvl_channel_tail_idx
<=
expected_head
)
{
while
(
cached_nvl_channel_tail_idx
<=
expected_head
)
{
cached_nvl_channel_tail_idx
=
ld_acquire_sys_global
(
nvl_channel_tail
.
buffer
(
lane_id
));
cached_nvl_channel_tail_idx
=
ld_relaxed_sys_global
(
nvl_channel_tail
.
buffer
(
lane_id
));
// Timeout check
// Timeout check
if
(
wall_clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
and
lane_id
<
NUM_MAX_NVL_PEERS
)
{
long
long
int
elapsed_time
=
wall_clock64
()
>
start_time
?
wall_clock64
()
-
start_time
:
0
;
if
(
elapsed_time
>
NUM_TIMEOUT_CYCLES
and
lane_id
<
NUM_MAX_NVL_PEERS
)
{
printf
(
"DeepEP combine forwarder (NVL check) timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, dst RDMA: %d, tail: %d, waiting: %d, total: %d, sub: %d, large: %d, expected: %d
\n
"
,
printf
(
"DeepEP combine forwarder (NVL check) timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, dst RDMA: %d, tail: %d, waiting: %d, total: %d, sub: %d, large: %d, expected: %d
\n
"
,
channel_id
,
rdma_rank
,
nvl_rank
,
lane_id
,
dst_rdma_rank
,
cached_nvl_channel_tail_idx
,
token_idx
,
num_tokens_to_combine
,
sub_warp_id
,
kNumWarpsPerForwarder
,
expected_head
);
channel_id
,
rdma_rank
,
nvl_rank
,
lane_id
,
dst_rdma_rank
,
cached_nvl_channel_tail_idx
,
token_idx
,
num_tokens_to_combine
,
sub_warp_id
,
kNumWarpsPerForwarder
,
expected_head
);
trap
();
trap
();
}
}
__builtin_amdgcn_s_sleep
(
1
);
}
}
// Combine current token
// Combine current token
auto
rdma_slot_idx
=
token_idx
%
num_max_rdma_chunked_recv_tokens
;
auto
rdma_slot_idx
=
token_idx
%
num_max_rdma_chunked_recv_tokens
;
void
*
shifted
=
send_buffer
+
rdma_slot_idx
*
num_bytes_per_rdma_token
;
void
*
shifted
=
send_buffer
+
rdma_slot_idx
*
num_bytes_per_rdma_token
;
auto
get_addr
_fn
=
[
&
](
int
src_nvl_rank
,
int
slot_idx
,
int
hidden_int4_idx
)
->
int4
*
{
return
reinterpret_cast
<
int4
*>
(
nvl_channel_x
.
buffer
(
src_nvl_rank
)
+
slot_idx
*
hidden_int4
)
+
hidden_int4_idx
;
};
auto
recv
_fn
=
[
&
](
int
src_nvl_rank
,
int
slot_idx
,
int
hidden_int4_idx
)
->
int4
{
return
ld_nc_global
(
nvl_channel_x
.
buffer
(
src_nvl_rank
)
+
slot_idx
*
hidden_int4
+
hidden_int4_idx
)
;
};
auto
recv_tw_fn
=
[
&
](
int
src_nvl_rank
,
int
slot_idx
,
int
topk_idx
)
->
float
{
return
ld_nc_global
(
nvl_channel_topk_weights
.
buffer
(
src_nvl_rank
)
+
slot_idx
*
num_topk
+
topk_idx
);
};
auto
recv_tw_fn
=
[
&
](
int
src_nvl_rank
,
int
slot_idx
,
int
topk_idx
)
->
float
{
return
ld_nc_global
(
nvl_channel_topk_weights
.
buffer
(
src_nvl_rank
)
+
slot_idx
*
num_topk
+
topk_idx
);
};
combine_token
<
NUM_MAX_NVL_PEERS
,
dtype_t
,
NUM_MAX_NVL_PEERS
,
true
>
(
expected_head
>=
0
,
combine_token
<
NUM_MAX_NVL_PEERS
,
dtype_t
,
NUM_MAX_NVL_PEERS
,
kWarpHyb
>
(
expected_head
>=
0
,
expected_head
,
lane_id
,
expected_head
,
lane_id
,
hidden_int4
,
num_topk
,
hidden_int4
,
num_topk
,
reinterpret_cast
<
int4
*>
(
shifted
),
reinterpret_cast
<
int4
*>
(
shifted
),
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
int8_t
*>
(
shifted
)
+
hidden_bytes
+
sizeof
(
SourceMeta
)),
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
int8_t
*>
(
shifted
)
+
hidden_bytes
+
sizeof
(
SourceMeta
)),
num_max_nvl_chunked_recv_tokens_per_rdma
,
num_max_nvl_chunked_recv_tokens_per_rdma
,
recv_fn
,
recv_tw_fn
);
get_addr_fn
,
recv_tw_fn
);
// Update head
// Update head
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
{
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
expected_head
<
0
?
(
forwarder_nvl_head
[
target_rank
][
lane_id
]
=
-
expected_head
-
1
)
expected_head
<
0
?
(
forwarder_nvl_head
[
warp_id
][
lane_id
]
=
-
expected_head
-
1
)
:
(
forwarder_nvl_head
[
warp_id
][
lane_id
]
=
expected_head
+
1
);
:
(
forwarder_nvl_head
[
target_rank
][
lane_id
]
=
expected_head
+
1
);
}
}
}
// sync_large_warp();
sync_large_warp
(
token_start_idx
,
1
);
sync_large_warp
(
token_start_idx
,
1
);
// Issue RDMA send
// Issue RDMA send
if
(
sub_warp_id
==
kNumWarpsPerForwarder
-
1
)
{
// TODO: Switch back to put_nbi_wave function
if
(
dst_rdma_rank
!=
rdma_rank
)
{
if
(
sub_warp_id
==
kNumWarpsPerForwarder
-
1
)
{
if
(
dst_rdma_rank
!=
rdma_rank
)
{
auto
rdma_slot_idx
=
token_start_idx
%
num_max_rdma_chunked_recv_tokens
;
auto
rdma_slot_idx
=
token_start_idx
%
num_max_rdma_chunked_recv_tokens
;
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
#ifdef FORCE_DUSHMEM_API
shmem_ctx_schar_put_nbi_warp
(
ctx
,
#else
shmemx_int8_put_nbi_warp
(
shmemx_int8_put_nbi_warp
(
#endif
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
,
rdma_slot_idx
*
num_bytes_per_rdma_token
,
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
,
rdma_slot_idx
*
num_bytes_per_rdma_token
,
num_chunked_tokens
*
num_bytes_per_rdma_token
,
num_chunked_tokens
*
num_bytes_per_rdma_token
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
#else
#if !defined(ROCM_DISABLE_CTX)
if
constexpr
(
kUseWave
){
shmem_ctx_schar_put_nbi_warp
(
ctx
,
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
,
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
,
num_chunked_tokens
*
num_bytes_per_rdma_token
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
}
#else
if
constexpr
(
kUseWave
){
shmemx_int8_put_nbi_warp
(
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
,
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
,
num_chunked_tokens
*
num_bytes_per_rdma_token
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
}
else
{
if
(
lane_id
==
0
)
shmemx_int8_put_nbi
(
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
,
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
,
num_chunked_tokens
*
num_bytes_per_rdma_token
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
}
#endif
#endif
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_quiet
(
ctx
);
shmem_ctx_quiet
(
ctx
);
#else
#else
shmem_fence
();
shmem_fence
();
#endif
#endif
...
@@ -1765,167 +1775,100 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
...
@@ -1765,167 +1775,100 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
// Write new RDMA tail
// Write new RDMA tail
syncwarp
();
syncwarp
();
if
(
lane_id
==
0
)
{
if
(
lane_id
==
0
)
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add
(
ctx
,
shmem_ctx_ulong_atomic_add
(
ctx
,
#else
#else
shmem_signal_op_add
(
shmem_signal_op_add
(
#endif
#endif
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_chunked_tokens
,
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_chunked_tokens
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
}
}
}
}
}
// Retired
// Retired
syncwarp
();
syncwarp
();
if
(
lane_id
==
0
)
{
if
(
lane_id
==
0
)
forwarder_retired
[
target_rank
]
=
true
;
forwarder_retired
[
warp_id
]
=
true
;
}
}
else
if
(
warp_role
==
WarpRole
::
kRDMAReceiver
)
{
}
else
if
(
warp_role
==
WarpRole
::
kRDMACoordinator
)
{
// Coordinator
// Sync shared memory status
// sync_forwarder_smem();
__syncthreads
();
constexpr
int
num_warps_per_rdma_rank
=
kNumForwarders
/
kNumRDMARanks
;
int
last_nvl_head
[
kNumRDMARanks
]
=
{
0
};
int
dst_nvl_rank
=
lane_id
<
NUM_MAX_NVL_PEERS
?
lane_id
:
0
;
while
(
true
)
{
// Retired
if
(
__all_sync
(
kFullWarpMask
,
lane_id
>=
kNumForwarders
or
forwarder_retired
[
lane_id
]))
break
;
{
// Find minimum head for NVL ranks
#pragma unroll
for
(
int
i
=
0
;
i
<
kNumRDMARanks
;
++
i
)
{
int
min_head
=
std
::
numeric_limits
<
int
>::
max
();
#pragma unroll
for
(
int
j
=
0
;
j
<
num_warps_per_rdma_rank
;
++
j
)
if
(
not
forwarder_retired
[
i
*
num_warps_per_rdma_rank
+
j
])
min_head
=
min
(
min_head
,
forwarder_nvl_head
[
i
*
num_warps_per_rdma_rank
+
j
][
dst_nvl_rank
]);
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
and
min_head
>
last_nvl_head
[
i
]
and
lane_id
<
NUM_MAX_NVL_PEERS
)
{
st_relaxed_sys_global
(
nvl_channel_head
.
buffer_by
(
dst_nvl_rank
)
+
i
,
last_nvl_head
[
i
]
=
min_head
);
}
}
}
// Nanosleep and let other warps work
__builtin_amdgcn_s_sleep
(
NUM_WAIT_CYCLES_TIMES_64
);
}
}
else
if
(
warp_role
==
WarpRole
::
kRDMAReceiver
)
{
// Receive from RDMA ranks and write to the output tensor
// Receive from RDMA ranks and write to the output tensor
// Clean shared memory and sync
// Clean shared memory and sync
EP_DEVICE_ASSERT
(
kNumRDMARanks
<=
kWarpSize
);
EP_DEVICE_ASSERT
(
kNumRDMARanks
<=
kWarpSize
);
lane_id
<
kNumRDMARanks
?
(
rdma_receiver_rdma_head
[
target_rank
][
lane_id
]
=
0
)
:
0
;
lane_id
<
kNumRDMARanks
?
(
rdma_receiver_rdma_head
[
warp_id
][
lane_id
]
=
0
)
:
0
;
lane_id
==
0
?
(
rdma_receiver_retired
[
target_rank
]
=
false
)
:
0
;
lane_id
==
0
?
(
rdma_receiver_retired
[
warp_id
]
=
false
)
:
0
;
// sync_rdma_receiver_smem();
sync_rdma_receiver_smem
();
__syncthreads
();
// The same tokens as the dispatch process
// The same tokens as the dispatch process
int
token_start_idx
,
token_end_idx
;
int
token_start_idx
,
token_end_idx
;
get_channel_task_range
(
num_combined_tokens
,
num_channels
,
channel_id
,
token_start_idx
,
token_end_idx
);
get_channel_task_range
(
num_combined_tokens
,
num_channels
,
channel_id
,
token_start_idx
,
token_end_idx
);
// ==================== Token 级展开 x4 ====================
// Iterate over all tokens and combine
constexpr
int
kTokenUnroll
=
4
;
int
cached_channel_tail_idx
=
0
;
int
cached_channel_tail_idx
=
0
;
for
(
int64_t
token_idx
=
token_start_idx
+
warp_id
;
token_idx
<
token_end_idx
;
token_idx
+=
kNumRDMAReceivers
)
{
for
(
int64_t
base
=
token_start_idx
+
target_rank
;
// Read expected head
base
<
token_end_idx
;
EP_STATIC_ASSERT
(
kNumRDMARanks
<=
kWarpSize
,
"Invalid number of RDMA peers"
);
base
+=
(
int64_t
)
kNumRDMAReceivers
*
kTokenUnroll
)
{
int
expected_head
=
-
1
;
if
(
lane_id
<
kNumRDMARanks
)
{
// ---- Phase 1: 批量预取所有 token 的 expected_head ----
expected_head
=
ld_nc_global
(
combined_rdma_head
+
token_idx
*
kNumRDMARanks
+
lane_id
);
int
cached_expected_head
[
kTokenUnroll
];
(
expected_head
<
0
)
?
(
rdma_receiver_rdma_head
[
warp_id
][
lane_id
]
=
-
expected_head
-
1
)
:
(
rdma_receiver_rdma_head
[
warp_id
][
lane_id
]
=
expected_head
);
int
max_expected_head
=
-
1
;
#pragma unroll
for
(
int
u
=
0
;
u
<
kTokenUnroll
;
++
u
)
{
int64_t
tidx
=
base
+
(
int64_t
)
u
*
kNumRDMAReceivers
;
cached_expected_head
[
u
]
=
-
1
;
if
(
tidx
<
token_end_idx
&&
lane_id
<
kNumRDMARanks
)
{
int
expected_head
=
ld_nc_global
(
combined_rdma_head
+
tidx
*
kNumRDMARanks
+
lane_id
);
cached_expected_head
[
u
]
=
expected_head
;
if
(
expected_head
>
max_expected_head
)
max_expected_head
=
expected_head
;
}
}
}
// ---- Phase 2: 一次等待,覆盖所有 token ----
// Wait lanes to be ready
if
(
max_expected_head
>=
0
)
{
auto
start_time
=
wall_clock64
();
auto
start_time
=
wall_clock64
();
while
(
cached_channel_tail_idx
<=
expected_head
)
{
while
(
cached_channel_tail_idx
<=
max_expected_head
)
{
cached_channel_tail_idx
=
static_cast
<
int
>
(
ld_acquire_sys_global
(
rdma_channel_tail
.
buffer
(
lane_id
)));
cached_channel_tail_idx
=
static_cast
<
int
>
(
ld_acquire_sys_global
(
rdma_channel_tail
.
buffer
(
lane_id
)));
// Timeout check
if
(
wall_clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
)
{
long
long
int
elapsed_time
=
wall_clock64
()
>
start_time
?
wall_clock64
()
-
start_time
:
0
;
printf
(
"DeepEP combine RDMA receiver timeout (unroll x%d), "
if
(
elapsed_time
>
NUM_TIMEOUT_CYCLES
)
{
"ch: %d, rdma: %d, nvl: %d, lane: %d, "
printf
(
"DeepEP combine RDMA receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, tail: %d, waiting: %ld, expect: %d
\n
"
,
"tail: %d, wait: %d
\n
"
,
channel_id
,
rdma_rank
,
nvl_rank
,
lane_id
,
cached_channel_tail_idx
,
token_idx
,
expected_head
);
kTokenUnroll
,
channel_id
,
rdma_rank
,
nvl_rank
,
trap
();
lane_id
,
cached_channel_tail_idx
,
max_expected_head
);
trap
();
}
}
}
__builtin_amdgcn_s_sleep
(
1
);
}
}
syncwarp
();
syncwarp
();
// ---- Phase 3: 批量处理所有就绪 token ----
// Combine current token
#pragma unroll
auto
recv_fn
=
[
&
](
int
src_rdma_rank
,
int
slot_idx
,
int
hidden_int4_idx
)
->
int4
{
return
ld_nc_global
(
reinterpret_cast
<
const
int4
*>
(
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
slot_idx
*
num_bytes_per_rdma_token
)
+
hidden_int4_idx
);};
for
(
int
u
=
0
;
u
<
kTokenUnroll
;
++
u
)
{
auto
recv_tw_fn
=
[
&
](
int
src_rdma_rank
,
int
slot_idx
,
int
topk_idx
)
->
float
{
return
ld_nc_global
(
reinterpret_cast
<
const
float
*>
(
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
slot_idx
*
num_bytes_per_rdma_token
+
hidden_bytes
+
sizeof
(
SourceMeta
))
+
topk_idx
);};
int64_t
tidx
=
base
+
(
int64_t
)
u
*
kNumRDMAReceivers
;
combine_token
<
kNumRDMARanks
,
dtype_t
,
kNumTopkRDMARanks
,
kWarpHyb
>
(
expected_head
>=
0
,
if
(
tidx
<
token_end_idx
)
{
expected_head
,
lane_id
,
int
expected_head
=
cached_expected_head
[
u
];
hidden_int4
,
num_topk
,
// Combine current token
combined_x
+
token_idx
*
hidden_int4
,
auto
get_addr_fn
=
[
&
](
int
src_rdma_rank
,
int
slot_idx
,
int
hidden_int4_idx
)
->
int4
*
{
return
reinterpret_cast
<
int4
*>
(
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
slot_idx
*
num_bytes_per_rdma_token
)
+
hidden_int4_idx
;
};
combined_topk_weights
+
token_idx
*
num_topk
,
auto
recv_tw_fn
=
[
&
](
int
src_rdma_rank
,
int
slot_idx
,
int
topk_idx
)
->
float
{
return
ld_nc_global
(
reinterpret_cast
<
const
float
*>
(
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
slot_idx
*
num_bytes_per_rdma_token
+
hidden_bytes
+
sizeof
(
SourceMeta
))
+
topk_idx
);};
num_max_rdma_chunked_recv_tokens
,
recv_fn
,
recv_tw_fn
);
combine_token
<
kNumRDMARanks
,
dtype_t
,
kNumTopkRDMARanks
,
false
>
(
expected_head
>=
0
,
expected_head
,
lane_id
,
hidden_int4
,
num_topk
,
combined_x
+
tidx
*
hidden_int4
,
combined_topk_weights
+
tidx
*
num_topk
,
num_max_rdma_chunked_recv_tokens
,
get_addr_fn
,
recv_tw_fn
);
if
(
lane_id
<
kNumRDMARanks
)
{
rdma_receiver_rdma_head
[
target_rank
][
lane_id
]
=
expected_head
<
0
?
-
expected_head
-
1
:
expected_head
;
}
}
}
}
}
// Retired
// Retired
syncwarp
();
syncwarp
();
if
(
lane_id
==
0
)
{
if
(
lane_id
==
0
)
rdma_receiver_retired
[
t
ar
get_rank
]
=
true
;
rdma_receiver_retired
[
w
ar
p_id
]
=
true
;
}
}
else
{
}
else
if
(
warp_role
==
WarpRole
::
kNVLCoordinator
)
{
auto
lane_id
=
get_lane_id
();
// Coordinator
// Coordinator
// Sync shared memory status
// Sync shared memory status
// sync_rdma_receiver_smem();
is_rdma_receiver_sm
?
sync_rdma_receiver_smem
()
:
sync_forwarder_smem
();
__syncthreads
();
const
auto
num_warps_per_rdma_rank
=
kNumForwarders
/
kNumRDMARanks
;
const
auto
num_warps_per_rdma_rank
=
kNumForwarders
/
kNumRDMARanks
;
int
last_rdma_head
=
0
;
int
last_rdma_head
=
0
;
int
last_nvl_head
[
kNumRDMARanks
]
=
{
0
};
int
last_nvl_head
[
kNumRDMARanks
]
=
{
0
};
int
dst_rdma_rank
=
lane_id
<
kNumRDMARanks
?
lane_id
:
0
;
int
dst_rdma_rank
=
lane_id
<
kNumRDMARanks
?
lane_id
:
0
;
int
dst_nvl_rank
=
lane_id
<
NUM_MAX_NVL_PEERS
?
lane_id
:
0
;
int
dst_nvl_rank
=
lane_id
<
NUM_MAX_NVL_PEERS
?
lane_id
:
0
;
EP_STATIC_ASSERT
(
kNumCombineForwarderWarps
<=
kWarpSize
,
"Invalid number of forwarder warps"
);
while
(
true
)
{
while
(
true
)
{
// Retired
// Retired
if
(
__all_sync
(
kFullWarpMask
,
lane_id
>=
kNumRDMAReceivers
or
rdma_receiver_retired
[
lane_id
]))
if
(
is_rdma_receiver_sm
and
__all_sync
(
kFullWarpMask
,
lane_id
>=
kNumRDMAReceivers
or
rdma_receiver_retired
[
lane_id
]))
break
;
if
(
not
is_rdma_receiver_sm
and
__all_sync
(
kFullWarpMask
,
lane_id
>=
kNumForwarders
or
forwarder_retired
[
lane_id
]))
break
;
break
;
// Find minimum head for RDMA ranks
// Find minimum head for RDMA ranks
{
if
(
is_rdma_receiver_sm
)
{
int
min_head
=
std
::
numeric_limits
<
int
>::
max
();
int
min_head
=
std
::
numeric_limits
<
int
>::
max
();
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
kNumRDMAReceivers
;
++
i
)
for
(
int
i
=
0
;
i
<
kNumRDMAReceivers
;
++
i
)
if
(
not
rdma_receiver_retired
[
i
])
if
(
not
rdma_receiver_retired
[
i
])
min_head
=
min
(
min_head
,
rdma_receiver_rdma_head
[
i
][
dst_rdma_rank
]);
min_head
=
min
(
min_head
,
rdma_receiver_rdma_head
[
i
][
dst_rdma_rank
]);
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
and
min_head
>=
last_rdma_head
+
num_max_rdma_chunked_send_tokens
and
lane_id
<
kNumRDMARanks
)
{
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
and
min_head
>=
last_rdma_head
+
num_max_rdma_chunked_send_tokens
and
lane_id
<
kNumRDMARanks
)
{
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add
(
ctx
,
shmem_ctx_ulong_atomic_add
(
ctx
,
...
@@ -1933,10 +1876,21 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
...
@@ -1933,10 +1876,21 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
shmem_signal_op_add
(
shmem_signal_op_add
(
#endif
#endif
rdma_channel_head
.
buffer
(
rdma_rank
),
min_head
-
last_rdma_head
,
rdma_channel_head
.
buffer
(
rdma_rank
),
min_head
-
last_rdma_head
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
last_rdma_head
=
min_head
;
last_rdma_head
=
min_head
;
}
}
}
else
{
// Find minimum head for NVL ranks
#pragma unroll
for
(
int
i
=
0
;
i
<
kNumRDMARanks
;
++
i
)
{
int
min_head
=
std
::
numeric_limits
<
int
>::
max
();
#pragma unroll
for
(
int
j
=
0
;
j
<
num_warps_per_rdma_rank
;
++
j
)
if
(
not
forwarder_retired
[
i
*
num_warps_per_rdma_rank
+
j
])
min_head
=
min
(
min_head
,
forwarder_nvl_head
[
i
*
num_warps_per_rdma_rank
+
j
][
dst_nvl_rank
]);
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
and
min_head
>
last_nvl_head
[
i
]
and
lane_id
<
NUM_MAX_NVL_PEERS
)
st_relaxed_sys_global
(
nvl_channel_head
.
buffer_by
(
dst_nvl_rank
)
+
i
,
last_nvl_head
[
i
]
=
min_head
);
}
}
}
// Nanosleep and let other warps work
// Nanosleep and let other warps work
...
@@ -1949,46 +1903,80 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
...
@@ -1949,46 +1903,80 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
#endif
#endif
}
}
void
combine
(
hipDataType
type
,
void
*
combined_x
,
float
*
combined_topk_weights
,
const
bool
*
is_combined_token_in_rank
,
const
void
*
x
,
const
float
*
topk_weights
,
const
void
*
bias_0
,
const
void
*
bias_1
,
const
int
*
combined_rdma_head
,
const
int
*
combined_nvl_head
,
const
void
*
src_meta
,
const
int
*
rdma_channel_prefix_matrix
,
const
int
*
rdma_rank_prefix_sum
,
const
int
*
gbl_channel_prefix_matrix
,
int
num_tokens
,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
void
*
rdma_buffer_ptr
,
int
num_max_rdma_chunked_send_tokens
,
int
num_max_rdma_chunked_recv_tokens
,
void
**
buffer_ptrs
,
int
num_max_nvl_chunked_send_tokens
,
int
num_max_nvl_chunked_recv_tokens
,
int
rank
,
int
num_ranks
,
hipStream_t
stream
,
int
num_channels
,
bool
low_latency_mode
)
{
constexpr
int
kNumCombineForwarderWarps
=
8
;
#define COMBINE_LAUNCH_CASE(num_rdma_ranks) \
void
combine
(
hipDataType
type
,
{ \
void
*
combined_x
,
auto combine_func = \
float
*
combined_topk_weights
,
low_latency_mode \
const
bool
*
is_combined_token_in_rank
,
? combine<true, num_rdma_ranks, hip_bfloat16, kNumCombineForwarderWarps> \
const
void
*
x
,
: combine<false, num_rdma_ranks, hip_bfloat16, kNumCombineForwarderWarps>; \
const
float
*
topk_weights
,
LAUNCH_KERNEL_NON_COOPERATIVE( \
const
void
*
bias_0
,
&cfg, combine_func, reinterpret_cast<int4 *>(combined_x), combined_topk_weights, \
const
void
*
bias_1
,
is_combined_token_in_rank, reinterpret_cast<const int4 *>(x), topk_weights, \
const
int
*
combined_rdma_head
,
reinterpret_cast<const int4 *>(bias_0), reinterpret_cast<const int4 *>(bias_1), \
const
int
*
combined_nvl_head
,
combined_rdma_head, combined_nvl_head, reinterpret_cast<const SourceMeta *>(src_meta), \
const
void
*
src_meta
,
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \
const
int
*
rdma_channel_prefix_matrix
,
num_tokens, num_combined_tokens, hidden, num_topk, rdma_buffer_ptr, \
const
int
*
rdma_rank_prefix_sum
,
num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, buffer_ptrs, \
const
int
*
gbl_channel_prefix_matrix
,
num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, rank, num_ranks); \
int
num_tokens
,
} \
int
num_combined_tokens
,
break
int
hidden
,
int
num_topk
,
int
num_rdma_ranks
=
num_ranks
/
NUM_MAX_NVL_PEERS
;
void
*
rdma_buffer_ptr
,
auto
num_warps_per_forwarder
=
std
::
max
(
kNumCombineForwarderWarps
/
num_rdma_ranks
,
1
);
int
num_max_rdma_chunked_send_tokens
,
int
num_forwarder_warps
=
num_rdma_ranks
*
num_warps_per_forwarder
;
int
num_max_rdma_chunked_recv_tokens
,
EP_HOST_ASSERT
(
num_forwarder_warps
>=
NUM_MAX_NVL_PEERS
);
void
**
buffer_ptrs
,
EP_HOST_ASSERT
(
num_forwarder_warps
>
0
and
num_forwarder_warps
%
num_rdma_ranks
==
0
);
int
num_max_nvl_chunked_send_tokens
,
int
num_max_nvl_chunked_recv_tokens
,
int
rank
,
int
num_ranks
,
hipStream_t
stream
,
int
num_channels
,
bool
low_latency_mode
)
{
const
int
num_rdma_ranks
=
num_ranks
/
NUM_MAX_NVL_PEERS
;
EP_HOST_ASSERT
(
num_rdma_ranks
>
0
);
EP_HOST_ASSERT
(
num_max_nvl_chunked_recv_tokens
%
num_rdma_ranks
==
0
);
EP_HOST_ASSERT
(
num_max_nvl_chunked_recv_tokens
%
num_rdma_ranks
==
0
);
EP_HOST_ASSERT
(
num_max_nvl_chunked_recv_tokens
/
num_rdma_ranks
>
std
::
max
(
num_max_rdma_chunked_send_tokens
,
num_max_nvl_chunked_send_tokens
));
EP_HOST_ASSERT
(
num_max_nvl_chunked_recv_tokens
/
num_rdma_ranks
>
std
::
max
(
num_max_rdma_chunked_send_tokens
,
num_max_nvl_chunked_send_tokens
));
EP_HOST_ASSERT
(
type
==
HIP_R_16BF
);
EP_HOST_ASSERT
(
type
==
HIP_R_16BF
);
SETUP_LAUNCH_CONFIG
(
num_channels
*
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
,
(
NUM_MAX_NVL_PEERS
+
1
)
*
kWarpSize
,
stream
);
// One case per compile-time NR specialization.
#define COMBINE_LAUNCH_CASE(NR) { \
/* Per-case compile-time constants */
\
constexpr int kNumCombineForwarderWarps = (NR < 9) ? 10 : 16; \
constexpr int kWarpsPerForwarder = (kNumCombineForwarderWarps/NR) > 0 \
? (kNumCombineForwarderWarps/NR) : 1; \
constexpr int kNumForwarders = NR * kWarpsPerForwarder; \
constexpr int kBlockThreads = (NR > 8) \
? ((NUM_MAX_NVL_PEERS + kNumForwarders) * kEmulatedWarpSize + kWarpSize) \
: ((NUM_MAX_NVL_PEERS/2 + 1 + kNumForwarders) * kWarpSize); \
\
SETUP_LAUNCH_CONFIG(num_channels * 2, kBlockThreads, stream); \
\
using scalar_t = hip_bfloat16; \
auto fn = low_latency_mode \
? combine<true, NR, scalar_t, kNumCombineForwarderWarps> \
: combine<false, NR, scalar_t, kNumCombineForwarderWarps>; \
\
/* Launch (backend-specific) */
\
\
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, fn, \
reinterpret_cast<int4*>(combined_x), combined_topk_weights, is_combined_token_in_rank, \
reinterpret_cast<const int4*>(x), topk_weights, \
reinterpret_cast<const int4 *>(bias_0), reinterpret_cast<const int4 *>(bias_1), \
combined_rdma_head, combined_nvl_head, \
reinterpret_cast<const SourceMeta*>(src_meta), rdma_channel_prefix_matrix, \
rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \
num_tokens, num_combined_tokens, hidden, num_topk, \
rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, \
num_max_rdma_chunked_recv_tokens, \
buffer_ptrs, num_max_nvl_chunked_send_tokens, \
num_max_nvl_chunked_recv_tokens, \
rank, num_ranks); \
} break
// Dispatch on the runtime num_rdma_ranks, but each case is compile-time specialized.
SWITCH_RDMA_RANKS
(
COMBINE_LAUNCH_CASE
);
SWITCH_RDMA_RANKS
(
COMBINE_LAUNCH_CASE
);
#undef COMBINE_LAUNCH_CASE
#undef COMBINE_LAUNCH_CASE
}
}
...
@@ -1997,8 +1985,4 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
...
@@ -1997,8 +1985,4 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
}
// namespace deep_ep
}
// namespace deep_ep
// #ifdef __clang__
// #pragma clang diagnostic pop
// #endif // __clang__
#endif
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment