Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
2d0cf41d
"ssh:/git@developer.sourcefind.cn:2222/tsoc/openmm.git" did not exist on "30fa6ecb8106793ce9464f2c0ab0264c1ea6805e"
Commit
2d0cf41d
authored
Mar 14, 2025
by
Shangyan Zhou
Browse files
Low latency kernels use rdma atomic to support AR.
parent
7128ba3e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
70 additions
and
102 deletions
+70
-102
csrc/kernels/ibgda_device.cuh
csrc/kernels/ibgda_device.cuh
+64
-56
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+5
-21
csrc/kernels/runtime.cu
csrc/kernels/runtime.cu
+1
-25
No files found.
csrc/kernels/ibgda_device.cuh
View file @
2d0cf41d
...
@@ -62,6 +62,12 @@ uint16_t HtoBE16(uint16_t x) {
...
@@ -62,6 +62,12 @@ uint16_t HtoBE16(uint16_t x) {
typedef
struct
mlx5_wqe_ctrl_seg
__attribute__
((
__aligned__
(
8
)))
ibgda_ctrl_seg_t
;
typedef
struct
mlx5_wqe_ctrl_seg
__attribute__
((
__aligned__
(
8
)))
ibgda_ctrl_seg_t
;
typedef
struct
{
uint32_t
add_data
;
uint32_t
field_boundary
;
uint64_t
reserved
;
}
__attribute__
((
__packed__
))
ibgda_atomic_32_masked_fa_seg_t
;
__device__
static
__forceinline__
__device__
static
__forceinline__
nvshmemi_ibgda_device_state_t
*
ibgda_get_state
()
{
nvshmemi_ibgda_device_state_t
*
ibgda_get_state
()
{
return
&
nvshmemi_ibgda_device_state_d
;
return
&
nvshmemi_ibgda_device_state_d
;
...
@@ -249,23 +255,6 @@ ibgda_get_wqe_ptr(nvshmemi_ibgda_device_qp_t* qp, uint16_t wqe_idx) {
...
@@ -249,23 +255,6 @@ ibgda_get_wqe_ptr(nvshmemi_ibgda_device_qp_t* qp, uint16_t wqe_idx) {
return
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
uintptr_t
>
(
qp
->
tx_wq
.
wqe
)
+
(
idx
<<
MLX5_SEND_WQE_SHIFT
));
return
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
uintptr_t
>
(
qp
->
tx_wq
.
wqe
)
+
(
idx
<<
MLX5_SEND_WQE_SHIFT
));
}
}
// Wait until wqe `idx - 1` is completed.
// This is a simplified version of NVSHMEM's `ibgda_poll_cq`. It can only be used for polling recv.
// Because we post recv and poll recv in the same thread, so we don't need to maintain queue status.
__device__
static
__forceinline__
void
nvshmemi_ibgda_poll_recv
(
int
dst_pe
,
int
qp_id
)
{
auto
qp
=
ibgda_get_rc
(
dst_pe
,
qp_id
);
auto
cq
=
qp
->
rx_wq
.
cq
;
const
uint32_t
ncqes
=
cq
->
ncqes
;
auto
*
cqe64
=
reinterpret_cast
<
struct
mlx5_cqe64
*>
(
cq
->
cqe
);
auto
old_cons_idx
=
*
cq
->
cons_idx
;
*
cq
->
cons_idx
=
old_cons_idx
+
1
;
// Wait until `wqe_counter >= old_cons_idx`
while
((
static_cast
<
uint16_t
>
(
old_cons_idx
-
HtoBE16
(
ld_na_relaxed
(
&
cqe64
->
wqe_counter
))
-
1
)
<
ncqes
));
}
__device__
static
__forceinline__
void
__device__
static
__forceinline__
void
nvshmemi_ibgda_rma_p
(
int
*
rptr
,
const
int
value
,
int
dst_pe
,
int
qp_id
,
uint32_t
imm
=
std
::
numeric_limits
<
uint32_t
>::
max
())
{
nvshmemi_ibgda_rma_p
(
int
*
rptr
,
const
int
value
,
int
dst_pe
,
int
qp_id
,
uint32_t
imm
=
std
::
numeric_limits
<
uint32_t
>::
max
())
{
// Get rkey
// Get rkey
...
@@ -336,45 +325,6 @@ ibgda_write_empty_recv_wqe(void *out_wqe) {
...
@@ -336,45 +325,6 @@ ibgda_write_empty_recv_wqe(void *out_wqe) {
st_na_relaxed
(
reinterpret_cast
<
int4
*>
(
data_seg_ptr
),
*
reinterpret_cast
<
const
int4
*>
(
&
data_seg
));
st_na_relaxed
(
reinterpret_cast
<
int4
*>
(
data_seg_ptr
),
*
reinterpret_cast
<
const
int4
*>
(
&
data_seg
));
}
}
__device__
static
__forceinline__
uint64_t
nvshmemi_ibgda_allocate_recvs
(
nvshmemi_ibgda_device_qp
*
qp
)
{
auto
mvars
=
&
qp
->
mvars
;
// Allocate if not enough
constexpr
int
kMinIBGDARecvs
=
32
;
auto
resv_head
=
mvars
->
rx_wq
.
resv_head
;
auto
num_valid_slots
=
resv_head
-
mvars
->
rx_wq
.
cons_idx
;
if
(
num_valid_slots
<
kMinIBGDARecvs
)
{
resv_head
=
mvars
->
rx_wq
.
cons_idx
+
qp
->
rx_wq
.
nwqes
;
mvars
->
rx_wq
.
resv_head
=
resv_head
;
// Ensure WQE is written before `dbrec`
__be32
dbrec_val
;
__be32
*
dbrec_ptr
=
qp
->
rx_wq
.
dbrec
;
// Compared to sending, for each QP, we only post recv in a single thread,
// so we don't need to do synchronization here
// This is equivalent to `WRITE_ONCE(dbrec_ptr, HtoBE32(wqe_idx & 0xffff))`
asm
(
"{
\n\t
"
".reg .b32 dbrec_head_16b;
\n\t
"
".reg .b32 ign;
\n\t
"
"and.b32 dbrec_head_16b, %1, 0xffff;
\n\t
"
"prmt.b32 %0, dbrec_head_16b, ign, 0x123;
\n\t
"
"}"
:
"=r"
(
dbrec_val
)
:
"r"
(
static_cast
<
uint32_t
>
(
resv_head
)));
st_na_release
(
dbrec_ptr
,
dbrec_val
);
}
// Return old number of slots
return
num_valid_slots
;
}
__device__
static
__forceinline__
void
nvshmemi_ibgda_prepare_recvs
(
int
dst_rank
,
int
qp_id
)
{
// NOTES: only one thread can run this function
EP_DEVICE_ASSERT
(
nvshmemi_ibgda_allocate_recvs
(
ibgda_get_rc
(
dst_rank
,
qp_id
))
>
16
);
}
__device__
static
__forceinline__
void
__device__
static
__forceinline__
void
nvshmemi_ibgda_put_nbi_warp
(
uint64_t
req_rptr
,
uint64_t
req_lptr
,
size_t
bytes
,
int
dst_pe
,
int
qp_id
,
int
lane_id
,
int
message_idx
)
{
nvshmemi_ibgda_put_nbi_warp
(
uint64_t
req_rptr
,
uint64_t
req_lptr
,
size_t
bytes
,
int
dst_pe
,
int
qp_id
,
int
lane_id
,
int
message_idx
)
{
// Get lkey and rkey, store them into lanes
// Get lkey and rkey, store them into lanes
...
@@ -419,4 +369,62 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
...
@@ -419,4 +369,62 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
__syncwarp
();
__syncwarp
();
}
}
__device__
static
__forceinline__
void
ibgda_write_amo_add_wqe
(
nvshmemi_ibgda_device_qp_t
*
qp
,
const
int
&
value
,
uint64_t
laddr
,
__be32
lkey
,
uint64_t
raddr
,
__be32
rkey
,
uint16_t
wqe_idx
,
void
**
out_wqes
)
{
ibgda_ctrl_seg_t
ctrl_seg
=
{
0
};
struct
mlx5_wqe_raddr_seg
raddr_seg
;
struct
mlx5_wqe_atomic_seg
atomic_seg_1
;
struct
mlx5_wqe_data_seg
data_seg
;
auto
ctrl_seg_ptr
=
reinterpret_cast
<
ibgda_ctrl_seg_t
*>
(
out_wqes
[
0
]);
auto
raddr_seg_ptr
=
reinterpret_cast
<
mlx5_wqe_raddr_seg
*>
(
reinterpret_cast
<
uintptr_t
>
(
ctrl_seg_ptr
)
+
sizeof
(
*
ctrl_seg_ptr
));
auto
atomic_seg_ptr
=
reinterpret_cast
<
mlx5_wqe_atomic_seg
*>
(
reinterpret_cast
<
uintptr_t
>
(
raddr_seg_ptr
)
+
sizeof
(
*
raddr_seg_ptr
));
auto
data_seg_ptr
=
reinterpret_cast
<
mlx5_wqe_data_seg
*>
(
reinterpret_cast
<
uintptr_t
>
(
atomic_seg_ptr
)
+
sizeof
(
*
atomic_seg_ptr
));
raddr_seg
.
raddr
=
HtoBE64
(
raddr
);
raddr_seg
.
rkey
=
rkey
;
raddr_seg
.
reserved
=
0
;
// NOTES: `0x08000000` means `IBGDA_4_BYTE_EXT_AMO_OPMOD`
ctrl_seg
.
opmod_idx_opcode
=
HtoBE32
(
MLX5_OPCODE_ATOMIC_MASKED_FA
|
(
wqe_idx
<<
8
)
|
0x08000000
);
auto
atomic_32_masked_fa_seg
=
reinterpret_cast
<
ibgda_atomic_32_masked_fa_seg_t
*>
(
&
atomic_seg_1
);
atomic_32_masked_fa_seg
->
add_data
=
HtoBE32
(
value
);
atomic_32_masked_fa_seg
->
field_boundary
=
0
;
ctrl_seg
.
qpn_ds
=
HtoBE32
((
qp
->
qpn
<<
8
)
|
4
);
ctrl_seg
.
fm_ce_se
=
MLX5_WQE_CTRL_CQ_UPDATE
;
data_seg
.
byte_count
=
HtoBE32
(
sizeof
(
int
));
data_seg
.
lkey
=
lkey
;
data_seg
.
addr
=
HtoBE64
(
laddr
);
EP_STATIC_ASSERT
(
sizeof
(
*
ctrl_seg_ptr
)
==
sizeof
(
int4
),
"Invalid vectorization"
);
EP_STATIC_ASSERT
(
sizeof
(
*
raddr_seg_ptr
)
==
sizeof
(
int4
),
"Invalid vectorization"
);
EP_STATIC_ASSERT
(
sizeof
(
*
atomic_seg_ptr
)
==
sizeof
(
int4
),
"Invalid vectorization"
);
EP_STATIC_ASSERT
(
sizeof
(
*
data_seg_ptr
)
==
sizeof
(
int4
),
"Invalid vectorization"
);
st_na_relaxed
(
reinterpret_cast
<
int4
*>
(
ctrl_seg_ptr
),
*
reinterpret_cast
<
int4
*>
(
&
ctrl_seg
));
st_na_relaxed
(
reinterpret_cast
<
int4
*>
(
raddr_seg_ptr
),
*
reinterpret_cast
<
int4
*>
(
&
raddr_seg
));
st_na_relaxed
(
reinterpret_cast
<
int4
*>
(
atomic_seg_ptr
),
*
reinterpret_cast
<
int4
*>
(
&
atomic_seg_1
));
st_na_relaxed
(
reinterpret_cast
<
int4
*>
(
data_seg_ptr
),
*
reinterpret_cast
<
int4
*>
(
&
data_seg
));
}
__device__
__forceinline__
void
nvshmemi_ibgda_amo_nonfetch_add
(
void
*
rptr
,
const
int
&
value
,
int
pe
,
int
qp_id
)
{
nvshmemi_ibgda_device_qp_t
*
qp
=
ibgda_get_rc
(
pe
,
qp_id
);
__be32
rkey
;
uint64_t
raddr
;
ibgda_get_rkey
(
reinterpret_cast
<
uint64_t
>
(
rptr
),
pe
,
&
raddr
,
&
rkey
);
void
*
wqe_ptrs
[
1
];
uint64_t
my_wqe_idx
=
ibgda_reserve_wqe_slots
(
qp
,
1
);
wqe_ptrs
[
0
]
=
ibgda_get_wqe_ptr
(
qp
,
my_wqe_idx
);
ibgda_write_amo_add_wqe
(
qp
,
value
,
reinterpret_cast
<
uint64_t
>
(
qp
->
ibuf
.
buf
),
qp
->
ibuf
.
lkey
,
raddr
,
rkey
,
my_wqe_idx
,
wqe_ptrs
);
ibgda_submit_requests
<
true
>
(
qp
,
my_wqe_idx
,
1
);
}
}
// namespace deep_ep
}
// namespace deep_ep
csrc/kernels/internode_ll.cu
View file @
2d0cf41d
...
@@ -215,9 +215,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -215,9 +215,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
// Wait local sends issued and send expert counts
// Wait local sends issued and send expert counts
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
if
(
dst_rank
!=
rank
)
{
if
(
dst_rank
!=
rank
)
{
nvshmemi_ibgda_rma_p
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
nvshmemi_ibgda_amo_nonfetch_add
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
dst_rank
,
dst_expert_local_idx
);
dst_rank
,
dst_expert_local_idx
,
0
);
nvshmemi_ibgda_prepare_recvs
(
dst_rank
,
dst_expert_local_idx
);
}
else
{
}
else
{
st_na_release
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
);
st_na_release
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
);
}
}
...
@@ -262,13 +260,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -262,13 +260,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int
num_recv_tokens
,
recv_token_begin_idx
;
int
num_recv_tokens
,
recv_token_begin_idx
;
EP_STATIC_ASSERT
(
kNumWarpsPerGroup
>
1
,
"Requires more than one warp per group"
);
EP_STATIC_ASSERT
(
kNumWarpsPerGroup
>
1
,
"Requires more than one warp per group"
);
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
if
(
src_rank
!=
rank
)
{
nvshmemi_ibgda_poll_recv
(
src_rank
,
local_expert_idx
);
num_recv_tokens
=
ld_acquire_sys_global
(
rdma_recv_count
+
local_expert_idx
*
num_ranks
+
src_rank
);
EP_DEVICE_ASSERT
(
num_recv_tokens
!=
0
);
}
else
{
while
((
num_recv_tokens
=
ld_acquire_global
(
rdma_recv_count
+
local_expert_idx
*
num_ranks
+
src_rank
))
==
0
);
while
((
num_recv_tokens
=
ld_acquire_global
(
rdma_recv_count
+
local_expert_idx
*
num_ranks
+
src_rank
))
==
0
);
}
num_recv_tokens
=
-
num_recv_tokens
-
1
;
num_recv_tokens
=
-
num_recv_tokens
-
1
;
recv_token_begin_idx
=
atomicAdd
(
packed_recv_count
+
local_expert_idx
,
num_recv_tokens
);
recv_token_begin_idx
=
atomicAdd
(
packed_recv_count
+
local_expert_idx
,
num_recv_tokens
);
shared_num_recv_tokens
[
warp_group_id
]
=
num_recv_tokens
;
shared_num_recv_tokens
[
warp_group_id
]
=
num_recv_tokens
;
...
@@ -439,7 +431,7 @@ combine(void* combined_x,
...
@@ -439,7 +431,7 @@ combine(void* combined_x,
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
while
(
ld_acquire_global
(
atomic_clean_flag
)
==
0
);
while
(
ld_acquire_global
(
atomic_clean_flag
)
==
0
);
if
(
dst_rank
!=
rank
)
{
if
(
dst_rank
!=
rank
)
{
nvshmemi_ibgda_
rma_p
(
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
,
local_expert_idx
,
0
);
nvshmemi_ibgda_
amo_nonfetch_add
(
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
,
local_expert_idx
);
}
else
{
}
else
{
st_na_release
(
rdma_recv_flag
+
global_expert_idx
,
1
);
st_na_release
(
rdma_recv_flag
+
global_expert_idx
,
1
);
}
}
...
@@ -456,17 +448,9 @@ combine(void* combined_x,
...
@@ -456,17 +448,9 @@ combine(void* combined_x,
// Wait all ranks to arrive and notify PCIe usage
// Wait all ranks to arrive and notify PCIe usage
if
(
responsible_expert_idx
<
num_experts
)
{
if
(
responsible_expert_idx
<
num_experts
)
{
EP_STATIC_ASSERT
(
kNumWarpsPerGroup
>
1
,
"Invalid number of warps per group"
);
EP_STATIC_ASSERT
(
kNumWarpsPerGroup
>
1
,
"Invalid number of warps per group"
);
if
(
sub_warp_id
==
0
and
lane_id
==
0
)
{
if
(
sub_warp_id
==
0
and
lane_id
==
0
)
// TODO: refactor QP indices
auto
src_rank
=
responsible_expert_idx
/
num_local_experts
;
auto
src_expert_idx
=
responsible_expert_idx
%
num_local_experts
;
if
(
src_rank
!=
rank
)
{
nvshmemi_ibgda_poll_recv
(
src_rank
,
src_expert_idx
);
}
else
{
while
(
ld_acquire_global
(
rdma_recv_flag
+
responsible_expert_idx
)
==
0
);
while
(
ld_acquire_global
(
rdma_recv_flag
+
responsible_expert_idx
)
==
0
);
}
}
}
}
cg
::
this_grid
().
sync
();
cg
::
this_grid
().
sync
();
// Reduce tokens with FP8 cast
// Reduce tokens with FP8 cast
...
...
csrc/kernels/runtime.cu
View file @
2d0cf41d
...
@@ -41,27 +41,6 @@ std::vector<uint8_t> get_unique_id() {
...
@@ -41,27 +41,6 @@ std::vector<uint8_t> get_unique_id() {
return
result
;
return
result
;
}
}
__global__
void
ibgda_initialize_recv_queue
(
int
rank
)
{
auto
thread_idx
=
static_cast
<
int
>
(
threadIdx
.
x
);
auto
num_threads
=
static_cast
<
int
>
(
blockDim
.
x
);
auto
dst_rank
=
static_cast
<
int
>
(
blockIdx
.
x
);
if
(
dst_rank
!=
rank
)
{
for
(
int
qp_id
=
thread_idx
;
qp_id
<
ibgda_get_state
()
->
num_rc_per_pe
;
qp_id
+=
num_threads
)
{
auto
qp
=
ibgda_get_rc
(
dst_rank
,
qp_id
);
// Clean some necessary variables
for
(
int
i
=
0
;
i
<
qp
->
rx_wq
.
nwqes
;
++
i
)
ibgda_write_empty_recv_wqe
(
ibgda_get_wqe_ptr
(
qp
,
i
));
qp
->
mvars
.
rx_wq
.
resv_head
=
0
;
qp
->
mvars
.
rx_wq
.
cons_idx
=
0
;
// Allocate receive slots
nvshmemi_ibgda_allocate_recvs
(
qp
);
}
}
}
int
init
(
const
std
::
vector
<
uint8_t
>
&
root_unique_id_val
,
int
rank
,
int
num_ranks
,
bool
low_latency_mode
)
{
int
init
(
const
std
::
vector
<
uint8_t
>
&
root_unique_id_val
,
int
rank
,
int
num_ranks
,
bool
low_latency_mode
)
{
nvshmemx_uniqueid_t
root_unique_id
;
nvshmemx_uniqueid_t
root_unique_id
;
nvshmemx_init_attr_t
attr
;
nvshmemx_init_attr_t
attr
;
...
@@ -85,10 +64,7 @@ int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks
...
@@ -85,10 +64,7 @@ int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks
CUDA_CHECK
(
cudaGetSymbolAddress
(
reinterpret_cast
<
void
**>
(
&
dev_state_ptr
),
nvshmemi_device_state_d
));
CUDA_CHECK
(
cudaGetSymbolAddress
(
reinterpret_cast
<
void
**>
(
&
dev_state_ptr
),
nvshmemi_device_state_d
));
bool
ibgda_is_initialized
=
false
;
bool
ibgda_is_initialized
=
false
;
cudaMemcpy
(
&
dev_state_ptr
->
ibgda_is_initialized
,
&
ibgda_is_initialized
,
sizeof
(
bool
),
cudaMemcpyHostToDevice
);
CUDA_CHECK
(
cudaMemcpy
(
&
dev_state_ptr
->
ibgda_is_initialized
,
&
ibgda_is_initialized
,
sizeof
(
bool
),
cudaMemcpyHostToDevice
));
// Initialize recv queues for low-latency mode AR
ibgda_initialize_recv_queue
<<<
num_ranks
,
128
>>>
(
rank
);
}
}
nvshmem_barrier_all
();
nvshmem_barrier_all
();
return
nvshmem_my_pe
();
return
nvshmem_my_pe
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment