Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
043fa5fa
Unverified
Commit
043fa5fa
authored
Mar 14, 2025
by
Chenggang Zhao
Committed by
GitHub
Mar 14, 2025
Browse files
Merge pull request #73 from deepseek-ai/p2p-signal
Low latency kernels use rdma atomic to support AR
parents
7128ba3e
38cdaf39
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
69 additions
and
102 deletions
+69
-102
csrc/kernels/ibgda_device.cuh
csrc/kernels/ibgda_device.cuh
+63
-56
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+5
-21
csrc/kernels/runtime.cu
csrc/kernels/runtime.cu
+1
-25
No files found.
csrc/kernels/ibgda_device.cuh
View file @
043fa5fa
...
@@ -62,6 +62,12 @@ uint16_t HtoBE16(uint16_t x) {
...
@@ -62,6 +62,12 @@ uint16_t HtoBE16(uint16_t x) {
typedef
struct
mlx5_wqe_ctrl_seg
__attribute__
((
__aligned__
(
8
)))
ibgda_ctrl_seg_t
;
typedef
struct
mlx5_wqe_ctrl_seg
__attribute__
((
__aligned__
(
8
)))
ibgda_ctrl_seg_t
;
typedef
struct
{
uint32_t
add_data
;
uint32_t
field_boundary
;
uint64_t
reserved
;
}
__attribute__
((
__packed__
))
ibgda_atomic_32_masked_fa_seg_t
;
__device__
static
__forceinline__
__device__
static
__forceinline__
nvshmemi_ibgda_device_state_t
*
ibgda_get_state
()
{
nvshmemi_ibgda_device_state_t
*
ibgda_get_state
()
{
return
&
nvshmemi_ibgda_device_state_d
;
return
&
nvshmemi_ibgda_device_state_d
;
...
@@ -249,23 +255,6 @@ ibgda_get_wqe_ptr(nvshmemi_ibgda_device_qp_t* qp, uint16_t wqe_idx) {
...
@@ -249,23 +255,6 @@ ibgda_get_wqe_ptr(nvshmemi_ibgda_device_qp_t* qp, uint16_t wqe_idx) {
return
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
uintptr_t
>
(
qp
->
tx_wq
.
wqe
)
+
(
idx
<<
MLX5_SEND_WQE_SHIFT
));
return
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
uintptr_t
>
(
qp
->
tx_wq
.
wqe
)
+
(
idx
<<
MLX5_SEND_WQE_SHIFT
));
}
}
// Wait until wqe `idx - 1` is completed.
// This is a simplified version of NVSHMEM's `ibgda_poll_cq`. It can only be used for polling recv.
// Because we post recv and poll recv in the same thread, so we don't need to maintain queue status.
__device__
static
__forceinline__
void
nvshmemi_ibgda_poll_recv
(
int
dst_pe
,
int
qp_id
)
{
auto
qp
=
ibgda_get_rc
(
dst_pe
,
qp_id
);
auto
cq
=
qp
->
rx_wq
.
cq
;
const
uint32_t
ncqes
=
cq
->
ncqes
;
auto
*
cqe64
=
reinterpret_cast
<
struct
mlx5_cqe64
*>
(
cq
->
cqe
);
auto
old_cons_idx
=
*
cq
->
cons_idx
;
*
cq
->
cons_idx
=
old_cons_idx
+
1
;
// Wait until `wqe_counter >= old_cons_idx`
while
((
static_cast
<
uint16_t
>
(
old_cons_idx
-
HtoBE16
(
ld_na_relaxed
(
&
cqe64
->
wqe_counter
))
-
1
)
<
ncqes
));
}
__device__
static
__forceinline__
void
__device__
static
__forceinline__
void
nvshmemi_ibgda_rma_p
(
int
*
rptr
,
const
int
value
,
int
dst_pe
,
int
qp_id
,
uint32_t
imm
=
std
::
numeric_limits
<
uint32_t
>::
max
())
{
nvshmemi_ibgda_rma_p
(
int
*
rptr
,
const
int
value
,
int
dst_pe
,
int
qp_id
,
uint32_t
imm
=
std
::
numeric_limits
<
uint32_t
>::
max
())
{
// Get rkey
// Get rkey
...
@@ -336,45 +325,6 @@ ibgda_write_empty_recv_wqe(void *out_wqe) {
...
@@ -336,45 +325,6 @@ ibgda_write_empty_recv_wqe(void *out_wqe) {
st_na_relaxed
(
reinterpret_cast
<
int4
*>
(
data_seg_ptr
),
*
reinterpret_cast
<
const
int4
*>
(
&
data_seg
));
st_na_relaxed
(
reinterpret_cast
<
int4
*>
(
data_seg_ptr
),
*
reinterpret_cast
<
const
int4
*>
(
&
data_seg
));
}
}
__device__
static
__forceinline__
uint64_t
nvshmemi_ibgda_allocate_recvs
(
nvshmemi_ibgda_device_qp
*
qp
)
{
auto
mvars
=
&
qp
->
mvars
;
// Allocate if not enough
constexpr
int
kMinIBGDARecvs
=
32
;
auto
resv_head
=
mvars
->
rx_wq
.
resv_head
;
auto
num_valid_slots
=
resv_head
-
mvars
->
rx_wq
.
cons_idx
;
if
(
num_valid_slots
<
kMinIBGDARecvs
)
{
resv_head
=
mvars
->
rx_wq
.
cons_idx
+
qp
->
rx_wq
.
nwqes
;
mvars
->
rx_wq
.
resv_head
=
resv_head
;
// Ensure WQE is written before `dbrec`
__be32
dbrec_val
;
__be32
*
dbrec_ptr
=
qp
->
rx_wq
.
dbrec
;
// Compared to sending, for each QP, we only post recv in a single thread,
// so we don't need to do synchronization here
// This is equivalent to `WRITE_ONCE(dbrec_ptr, HtoBE32(wqe_idx & 0xffff))`
asm
(
"{
\n\t
"
".reg .b32 dbrec_head_16b;
\n\t
"
".reg .b32 ign;
\n\t
"
"and.b32 dbrec_head_16b, %1, 0xffff;
\n\t
"
"prmt.b32 %0, dbrec_head_16b, ign, 0x123;
\n\t
"
"}"
:
"=r"
(
dbrec_val
)
:
"r"
(
static_cast
<
uint32_t
>
(
resv_head
)));
st_na_release
(
dbrec_ptr
,
dbrec_val
);
}
// Return old number of slots
return
num_valid_slots
;
}
__device__
static
__forceinline__
void
nvshmemi_ibgda_prepare_recvs
(
int
dst_rank
,
int
qp_id
)
{
// NOTES: only one thread can run this function
EP_DEVICE_ASSERT
(
nvshmemi_ibgda_allocate_recvs
(
ibgda_get_rc
(
dst_rank
,
qp_id
))
>
16
);
}
__device__
static
__forceinline__
void
__device__
static
__forceinline__
void
nvshmemi_ibgda_put_nbi_warp
(
uint64_t
req_rptr
,
uint64_t
req_lptr
,
size_t
bytes
,
int
dst_pe
,
int
qp_id
,
int
lane_id
,
int
message_idx
)
{
nvshmemi_ibgda_put_nbi_warp
(
uint64_t
req_rptr
,
uint64_t
req_lptr
,
size_t
bytes
,
int
dst_pe
,
int
qp_id
,
int
lane_id
,
int
message_idx
)
{
// Get lkey and rkey, store them into lanes
// Get lkey and rkey, store them into lanes
...
@@ -419,4 +369,61 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
...
@@ -419,4 +369,61 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
__syncwarp
();
__syncwarp
();
}
}
__device__
static
__forceinline__
void
ibgda_write_amo_add_wqe
(
nvshmemi_ibgda_device_qp_t
*
qp
,
const
int
&
value
,
uint64_t
laddr
,
__be32
lkey
,
uint64_t
raddr
,
__be32
rkey
,
uint16_t
wqe_idx
,
void
**
out_wqes
)
{
ibgda_ctrl_seg_t
ctrl_seg
=
{
0
};
struct
mlx5_wqe_raddr_seg
raddr_seg
;
struct
mlx5_wqe_atomic_seg
atomic_seg_1
;
struct
mlx5_wqe_data_seg
data_seg
;
auto
ctrl_seg_ptr
=
reinterpret_cast
<
ibgda_ctrl_seg_t
*>
(
out_wqes
[
0
]);
auto
raddr_seg_ptr
=
reinterpret_cast
<
mlx5_wqe_raddr_seg
*>
(
reinterpret_cast
<
uintptr_t
>
(
ctrl_seg_ptr
)
+
sizeof
(
*
ctrl_seg_ptr
));
auto
atomic_seg_ptr
=
reinterpret_cast
<
mlx5_wqe_atomic_seg
*>
(
reinterpret_cast
<
uintptr_t
>
(
raddr_seg_ptr
)
+
sizeof
(
*
raddr_seg_ptr
));
auto
data_seg_ptr
=
reinterpret_cast
<
mlx5_wqe_data_seg
*>
(
reinterpret_cast
<
uintptr_t
>
(
atomic_seg_ptr
)
+
sizeof
(
*
atomic_seg_ptr
));
raddr_seg
.
raddr
=
HtoBE64
(
raddr
);
raddr_seg
.
rkey
=
rkey
;
raddr_seg
.
reserved
=
0
;
// NOTES: `0x08000000` means `IBGDA_4_BYTE_EXT_AMO_OPMOD`
ctrl_seg
.
opmod_idx_opcode
=
HtoBE32
(
MLX5_OPCODE_ATOMIC_MASKED_FA
|
(
wqe_idx
<<
8
)
|
0x08000000
);
auto
atomic_32_masked_fa_seg
=
reinterpret_cast
<
ibgda_atomic_32_masked_fa_seg_t
*>
(
&
atomic_seg_1
);
atomic_32_masked_fa_seg
->
add_data
=
HtoBE32
(
value
);
atomic_32_masked_fa_seg
->
field_boundary
=
0
;
ctrl_seg
.
qpn_ds
=
HtoBE32
((
qp
->
qpn
<<
8
)
|
4
);
ctrl_seg
.
fm_ce_se
=
MLX5_WQE_CTRL_CQ_UPDATE
;
data_seg
.
byte_count
=
HtoBE32
(
sizeof
(
int
));
data_seg
.
lkey
=
lkey
;
data_seg
.
addr
=
HtoBE64
(
laddr
);
EP_STATIC_ASSERT
(
sizeof
(
*
ctrl_seg_ptr
)
==
sizeof
(
int4
),
"Invalid vectorization"
);
EP_STATIC_ASSERT
(
sizeof
(
*
raddr_seg_ptr
)
==
sizeof
(
int4
),
"Invalid vectorization"
);
EP_STATIC_ASSERT
(
sizeof
(
*
atomic_seg_ptr
)
==
sizeof
(
int4
),
"Invalid vectorization"
);
EP_STATIC_ASSERT
(
sizeof
(
*
data_seg_ptr
)
==
sizeof
(
int4
),
"Invalid vectorization"
);
st_na_relaxed
(
reinterpret_cast
<
int4
*>
(
ctrl_seg_ptr
),
*
reinterpret_cast
<
int4
*>
(
&
ctrl_seg
));
st_na_relaxed
(
reinterpret_cast
<
int4
*>
(
raddr_seg_ptr
),
*
reinterpret_cast
<
int4
*>
(
&
raddr_seg
));
st_na_relaxed
(
reinterpret_cast
<
int4
*>
(
atomic_seg_ptr
),
*
reinterpret_cast
<
int4
*>
(
&
atomic_seg_1
));
st_na_relaxed
(
reinterpret_cast
<
int4
*>
(
data_seg_ptr
),
*
reinterpret_cast
<
int4
*>
(
&
data_seg
));
}
__device__
__forceinline__
void
nvshmemi_ibgda_amo_nonfetch_add
(
void
*
rptr
,
const
int
&
value
,
int
pe
,
int
qp_id
)
{
nvshmemi_ibgda_device_qp_t
*
qp
=
ibgda_get_rc
(
pe
,
qp_id
);
__be32
rkey
;
uint64_t
raddr
;
ibgda_get_rkey
(
reinterpret_cast
<
uint64_t
>
(
rptr
),
pe
,
&
raddr
,
&
rkey
);
uint64_t
my_wqe_idx
=
ibgda_reserve_wqe_slots
(
qp
,
1
);
void
*
wqe_ptrs
=
ibgda_get_wqe_ptr
(
qp
,
my_wqe_idx
);
ibgda_write_amo_add_wqe
(
qp
,
value
,
reinterpret_cast
<
uint64_t
>
(
qp
->
ibuf
.
buf
),
qp
->
ibuf
.
lkey
,
raddr
,
rkey
,
my_wqe_idx
,
&
wqe_ptrs
);
ibgda_submit_requests
<
true
>
(
qp
,
my_wqe_idx
,
1
);
}
}
// namespace deep_ep
}
// namespace deep_ep
csrc/kernels/internode_ll.cu
View file @
043fa5fa
...
@@ -215,9 +215,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -215,9 +215,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
// Wait local sends issued and send expert counts
// Wait local sends issued and send expert counts
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
if
(
dst_rank
!=
rank
)
{
if
(
dst_rank
!=
rank
)
{
nvshmemi_ibgda_rma_p
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
nvshmemi_ibgda_amo_nonfetch_add
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
dst_rank
,
dst_expert_local_idx
);
dst_rank
,
dst_expert_local_idx
,
0
);
nvshmemi_ibgda_prepare_recvs
(
dst_rank
,
dst_expert_local_idx
);
}
else
{
}
else
{
st_na_release
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
);
st_na_release
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
);
}
}
...
@@ -262,13 +260,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -262,13 +260,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int
num_recv_tokens
,
recv_token_begin_idx
;
int
num_recv_tokens
,
recv_token_begin_idx
;
EP_STATIC_ASSERT
(
kNumWarpsPerGroup
>
1
,
"Requires more than one warp per group"
);
EP_STATIC_ASSERT
(
kNumWarpsPerGroup
>
1
,
"Requires more than one warp per group"
);
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
if
(
src_rank
!=
rank
)
{
while
((
num_recv_tokens
=
ld_acquire_global
(
rdma_recv_count
+
local_expert_idx
*
num_ranks
+
src_rank
))
==
0
);
nvshmemi_ibgda_poll_recv
(
src_rank
,
local_expert_idx
);
num_recv_tokens
=
ld_acquire_sys_global
(
rdma_recv_count
+
local_expert_idx
*
num_ranks
+
src_rank
);
EP_DEVICE_ASSERT
(
num_recv_tokens
!=
0
);
}
else
{
while
((
num_recv_tokens
=
ld_acquire_global
(
rdma_recv_count
+
local_expert_idx
*
num_ranks
+
src_rank
))
==
0
);
}
num_recv_tokens
=
-
num_recv_tokens
-
1
;
num_recv_tokens
=
-
num_recv_tokens
-
1
;
recv_token_begin_idx
=
atomicAdd
(
packed_recv_count
+
local_expert_idx
,
num_recv_tokens
);
recv_token_begin_idx
=
atomicAdd
(
packed_recv_count
+
local_expert_idx
,
num_recv_tokens
);
shared_num_recv_tokens
[
warp_group_id
]
=
num_recv_tokens
;
shared_num_recv_tokens
[
warp_group_id
]
=
num_recv_tokens
;
...
@@ -439,7 +431,7 @@ combine(void* combined_x,
...
@@ -439,7 +431,7 @@ combine(void* combined_x,
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
while
(
ld_acquire_global
(
atomic_clean_flag
)
==
0
);
while
(
ld_acquire_global
(
atomic_clean_flag
)
==
0
);
if
(
dst_rank
!=
rank
)
{
if
(
dst_rank
!=
rank
)
{
nvshmemi_ibgda_
rma_p
(
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
,
local_expert_idx
,
0
);
nvshmemi_ibgda_
amo_nonfetch_add
(
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
,
local_expert_idx
);
}
else
{
}
else
{
st_na_release
(
rdma_recv_flag
+
global_expert_idx
,
1
);
st_na_release
(
rdma_recv_flag
+
global_expert_idx
,
1
);
}
}
...
@@ -456,16 +448,8 @@ combine(void* combined_x,
...
@@ -456,16 +448,8 @@ combine(void* combined_x,
// Wait all ranks to arrive and notify PCIe usage
// Wait all ranks to arrive and notify PCIe usage
if
(
responsible_expert_idx
<
num_experts
)
{
if
(
responsible_expert_idx
<
num_experts
)
{
EP_STATIC_ASSERT
(
kNumWarpsPerGroup
>
1
,
"Invalid number of warps per group"
);
EP_STATIC_ASSERT
(
kNumWarpsPerGroup
>
1
,
"Invalid number of warps per group"
);
if
(
sub_warp_id
==
0
and
lane_id
==
0
)
{
if
(
sub_warp_id
==
0
and
lane_id
==
0
)
// TODO: refactor QP indices
while
(
ld_acquire_global
(
rdma_recv_flag
+
responsible_expert_idx
)
==
0
);
auto
src_rank
=
responsible_expert_idx
/
num_local_experts
;
auto
src_expert_idx
=
responsible_expert_idx
%
num_local_experts
;
if
(
src_rank
!=
rank
)
{
nvshmemi_ibgda_poll_recv
(
src_rank
,
src_expert_idx
);
}
else
{
while
(
ld_acquire_global
(
rdma_recv_flag
+
responsible_expert_idx
)
==
0
);
}
}
}
}
cg
::
this_grid
().
sync
();
cg
::
this_grid
().
sync
();
...
...
csrc/kernels/runtime.cu
View file @
043fa5fa
...
@@ -41,27 +41,6 @@ std::vector<uint8_t> get_unique_id() {
...
@@ -41,27 +41,6 @@ std::vector<uint8_t> get_unique_id() {
return
result
;
return
result
;
}
}
__global__
void
ibgda_initialize_recv_queue
(
int
rank
)
{
auto
thread_idx
=
static_cast
<
int
>
(
threadIdx
.
x
);
auto
num_threads
=
static_cast
<
int
>
(
blockDim
.
x
);
auto
dst_rank
=
static_cast
<
int
>
(
blockIdx
.
x
);
if
(
dst_rank
!=
rank
)
{
for
(
int
qp_id
=
thread_idx
;
qp_id
<
ibgda_get_state
()
->
num_rc_per_pe
;
qp_id
+=
num_threads
)
{
auto
qp
=
ibgda_get_rc
(
dst_rank
,
qp_id
);
// Clean some necessary variables
for
(
int
i
=
0
;
i
<
qp
->
rx_wq
.
nwqes
;
++
i
)
ibgda_write_empty_recv_wqe
(
ibgda_get_wqe_ptr
(
qp
,
i
));
qp
->
mvars
.
rx_wq
.
resv_head
=
0
;
qp
->
mvars
.
rx_wq
.
cons_idx
=
0
;
// Allocate receive slots
nvshmemi_ibgda_allocate_recvs
(
qp
);
}
}
}
int
init
(
const
std
::
vector
<
uint8_t
>
&
root_unique_id_val
,
int
rank
,
int
num_ranks
,
bool
low_latency_mode
)
{
int
init
(
const
std
::
vector
<
uint8_t
>
&
root_unique_id_val
,
int
rank
,
int
num_ranks
,
bool
low_latency_mode
)
{
nvshmemx_uniqueid_t
root_unique_id
;
nvshmemx_uniqueid_t
root_unique_id
;
nvshmemx_init_attr_t
attr
;
nvshmemx_init_attr_t
attr
;
...
@@ -85,10 +64,7 @@ int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks
...
@@ -85,10 +64,7 @@ int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks
CUDA_CHECK
(
cudaGetSymbolAddress
(
reinterpret_cast
<
void
**>
(
&
dev_state_ptr
),
nvshmemi_device_state_d
));
CUDA_CHECK
(
cudaGetSymbolAddress
(
reinterpret_cast
<
void
**>
(
&
dev_state_ptr
),
nvshmemi_device_state_d
));
bool
ibgda_is_initialized
=
false
;
bool
ibgda_is_initialized
=
false
;
cudaMemcpy
(
&
dev_state_ptr
->
ibgda_is_initialized
,
&
ibgda_is_initialized
,
sizeof
(
bool
),
cudaMemcpyHostToDevice
);
CUDA_CHECK
(
cudaMemcpy
(
&
dev_state_ptr
->
ibgda_is_initialized
,
&
ibgda_is_initialized
,
sizeof
(
bool
),
cudaMemcpyHostToDevice
));
// Initialize recv queues for low-latency mode AR
ibgda_initialize_recv_queue
<<<
num_ranks
,
128
>>>
(
rank
);
}
}
nvshmem_barrier_all
();
nvshmem_barrier_all
();
return
nvshmem_my_pe
();
return
nvshmem_my_pe
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment