Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
9fe9021f
Unverified
Commit
9fe9021f
authored
May 28, 2025
by
Shangyan Zhou
Committed by
GitHub
May 28, 2025
Browse files
Use IBGDA only (#177)
parent
aae9fa9a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
77 additions
and
20 deletions
+77
-20
csrc/kernels/ibgda_device.cuh
csrc/kernels/ibgda_device.cuh
+48
-2
csrc/kernels/internode.cu
csrc/kernels/internode.cu
+29
-12
csrc/kernels/runtime.cu
csrc/kernels/runtime.cu
+0
-6
No files found.
csrc/kernels/ibgda_device.cuh
View file @
9fe9021f
...
@@ -413,8 +413,7 @@ __device__ static __forceinline__ void ibgda_write_amo_add_wqe(
...
@@ -413,8 +413,7 @@ __device__ static __forceinline__ void ibgda_write_amo_add_wqe(
__device__
__forceinline__
void
nvshmemi_ibgda_amo_nonfetch_add
(
void
*
rptr
,
const
int
&
value
,
int
pe
,
int
qp_id
,
bool
is_local_copy
=
false
)
{
__device__
__forceinline__
void
nvshmemi_ibgda_amo_nonfetch_add
(
void
*
rptr
,
const
int
&
value
,
int
pe
,
int
qp_id
,
bool
is_local_copy
=
false
)
{
if
(
is_local_copy
)
{
if
(
is_local_copy
)
{
// Fallback to NVSHMEM legacy API
atomicAdd
(
static_cast
<
unsigned
long
long
*>
(
rptr
),
value
);
nvshmemx_signal_op
(
static_cast
<
uint64_t
*>
(
rptr
),
value
,
NVSHMEM_SIGNAL_ADD
,
pe
);
}
else
{
}
else
{
nvshmemi_ibgda_device_qp_t
*
qp
=
ibgda_get_rc
(
pe
,
qp_id
);
nvshmemi_ibgda_device_qp_t
*
qp
=
ibgda_get_rc
(
pe
,
qp_id
);
...
@@ -446,4 +445,51 @@ __device__ __forceinline__ uint64_t nvshmemi_get_p2p_ptr(const uint64_t& ptr, co
...
@@ -446,4 +445,51 @@ __device__ __forceinline__ uint64_t nvshmemi_get_p2p_ptr(const uint64_t& ptr, co
return
peer_base
+
(
ptr
-
reinterpret_cast
<
uint64_t
>
(
nvshmemi_device_state_d
.
heap_base
));
return
peer_base
+
(
ptr
-
reinterpret_cast
<
uint64_t
>
(
nvshmemi_device_state_d
.
heap_base
));
}
}
// This is a simplified version of NVSHMEM's `ibgda_poll_cq`.
// Note that this implementation does not guarantee thread safety,
// so we must ensure that no other threads are concurrently using the same QP.
__device__
static
__forceinline__
int
ibgda_poll_cq
(
nvshmemi_ibgda_device_cq_t
*
cq
,
uint64_t
idx
)
{
int
status
=
0
;
struct
mlx5_cqe64
*
cqe64
=
(
struct
mlx5_cqe64
*
)
cq
->
cqe
;
const
uint32_t
ncqes
=
cq
->
ncqes
;
uint16_t
wqe_counter
;
uint16_t
new_wqe_counter
;
memory_fence_cta
();
do
{
new_wqe_counter
=
ld_na_relaxed
(
&
cqe64
->
wqe_counter
);
new_wqe_counter
=
HtoBE16
(
new_wqe_counter
);
wqe_counter
=
new_wqe_counter
;
}
// NOTE: This while loop is part of do while above.
// wqe_counter is the HW consumer index. However, we always maintain index
// + 1 in SW. To be able to compare with idx, we need to use wqe_counter +
// 1. Because wqe_counter is uint16_t, it may wraparound. Still we know for
// sure that if idx - wqe_counter - 1 < ncqes, wqe_counter + 1 is less than
// idx, and thus we need to wait. We don't need to wait when idx ==
// wqe_counter + 1. That's why we use - (uint16_t)2 here to make this case
// wraparound.
// Example:
// if idx = 10, we wait until wqe_counter = 9, idx - wqe_counter - 2 = 65535 > ncqes.
while
(((
uint16_t
)((
uint16_t
)
idx
-
wqe_counter
-
(
uint16_t
)
2
)
<
ncqes
));
*
cq
->
cons_idx
=
idx
;
// Prevent reordering of this function and subsequent instructions
memory_fence_cta
();
return
status
;
}
// Wait until wqe `idx - 1` is completed.
__device__
static
__forceinline__
void
nvshmemi_ibgda_quiet
(
int
dst_pe
,
int
qp_id
)
{
auto
qp
=
ibgda_get_rc
(
dst_pe
,
qp_id
);
uint64_t
prod_idx
=
ld_na_relaxed
(
qp
->
tx_wq
.
prod_idx
);
ibgda_poll_cq
(
qp
->
tx_wq
.
cq
,
prod_idx
);
}
}
// namespace deep_ep
}
// namespace deep_ep
csrc/kernels/internode.cu
View file @
9fe9021f
...
@@ -193,8 +193,8 @@ __forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank,
...
@@ -193,8 +193,8 @@ __forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank,
}
}
template
<
bool
kLowLatencyMode
>
template
<
bool
kLowLatencyMode
>
__forceinline__
__device__
void
nvshmem_
barrier
_with_same_gpu_idx
(
const
nvshmem_team_t
&
rdma_team
)
{
__forceinline__
__device__
void
nvshmem_
sync
_with_same_gpu_idx
(
const
nvshmem_team_t
&
rdma_team
)
{
kLowLatencyMode
?
void
(
nvshmem_
barrier
(
rdma_team
))
:
nvshmem_
barrier
_all
();
kLowLatencyMode
?
void
(
nvshmem_
sync
(
rdma_team
))
:
nvshmem_
sync
_all
();
}
}
template
<
bool
kLowLatencyMode
,
int
kNumRDMARanks
>
template
<
bool
kLowLatencyMode
,
int
kNumRDMARanks
>
...
@@ -223,7 +223,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
...
@@ -223,7 +223,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
EP_DEVICE_ASSERT
(
num_warps
>
1
);
EP_DEVICE_ASSERT
(
num_warps
>
1
);
EP_DEVICE_ASSERT
(
kNumRDMARanks
<=
num_threads
);
EP_DEVICE_ASSERT
(
kNumRDMARanks
<=
num_threads
);
if
(
thread_id
==
32
)
if
(
thread_id
==
32
)
nvshmem_
barrier
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
nvshmem_
sync
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
barrier_device
<
NUM_MAX_NVL_PEERS
>
(
task_fifo_ptrs
,
head
,
nvl_rank
);
barrier_device
<
NUM_MAX_NVL_PEERS
>
(
task_fifo_ptrs
,
head
,
nvl_rank
);
move_fifo_slots
<
NUM_MAX_NVL_PEERS
>
(
head
);
move_fifo_slots
<
NUM_MAX_NVL_PEERS
>
(
head
);
__syncthreads
();
__syncthreads
();
...
@@ -252,14 +252,25 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
...
@@ -252,14 +252,25 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
// Issue send
// Issue send
// TODO: more light fence or barrier or signaling
// TODO: more light fence or barrier or signaling
// TODO: overlap EP barrier and NVL cleaning
// TODO: overlap EP barrier and NVL cleaning
if
(
thread_id
<
kNumRDMARanks
)
{
for
(
int
i
=
0
;
i
<
kNumRDMARanks
;
++
i
)
{
nvshmem_int_put_nbi
(
rdma_recv_num_tokens_mixed
.
recv_buffer
(
rdma_rank
),
rdma_recv_num_tokens_mixed
.
send_buffer
(
thread_id
),
if
(
i
!=
rdma_rank
)
{
NUM_MAX_NVL_PEERS
+
num_rdma_experts
+
1
,
nvshmemi_ibgda_put_nbi_warp
<
true
>
(
reinterpret_cast
<
uint64_t
>
(
rdma_recv_num_tokens_mixed
.
recv_buffer
(
rdma_rank
)),
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
thread_id
,
nvl_rank
));
reinterpret_cast
<
uint64_t
>
(
rdma_recv_num_tokens_mixed
.
send_buffer
(
i
)),
(
NUM_MAX_NVL_PEERS
+
num_rdma_experts
+
1
)
*
sizeof
(
int
),
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
i
,
nvl_rank
),
0
,
lane_id
,
0
);
}
else
{
UNROLLED_WARP_COPY
(
1
,
lane_id
,
NUM_MAX_NVL_PEERS
+
num_rdma_experts
+
1
,
rdma_recv_num_tokens_mixed
.
recv_buffer
(
rdma_rank
),
rdma_recv_num_tokens_mixed
.
send_buffer
(
i
),
ld_volatile_global
,
st_na_global
);
}
}
}
if
(
thread_id
<
kNumRDMARanks
and
thread_id
!=
rdma_rank
)
nvshmemi_ibgda_quiet
(
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
thread_id
,
nvl_rank
),
0
);
__syncthreads
();
__syncthreads
();
if
(
thread_id
==
0
)
if
(
thread_id
==
0
)
nvshmem_
barrier
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
nvshmem_
sync
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
__syncthreads
();
__syncthreads
();
// NVL buffers
// NVL buffers
...
@@ -345,7 +356,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
...
@@ -345,7 +356,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
// Finally barrier
// Finally barrier
__syncthreads
();
__syncthreads
();
if
(
thread_id
==
32
)
if
(
thread_id
==
32
)
nvshmem_
barrier
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
nvshmem_
sync
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
barrier_device
<
NUM_MAX_NVL_PEERS
>
(
task_fifo_ptrs
,
head
,
nvl_rank
);
barrier_device
<
NUM_MAX_NVL_PEERS
>
(
task_fifo_ptrs
,
head
,
nvl_rank
);
move_fifo_slots
<
NUM_MAX_NVL_PEERS
>
(
head
);
move_fifo_slots
<
NUM_MAX_NVL_PEERS
>
(
head
);
}
else
{
}
else
{
...
@@ -701,7 +712,14 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
...
@@ -701,7 +712,14 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Iterate all RDMA ranks
// Iterate all RDMA ranks
int
last_issued_tail
=
0
;
int
last_issued_tail
=
0
;
auto
start_time
=
clock64
();
while
(
__any_sync
(
0xffffffff
,
num_tokens_to_send
>
0
))
{
while
(
__any_sync
(
0xffffffff
,
num_tokens_to_send
>
0
))
{
// Timeout check
if
(
clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
and
lane_id
<
kNumRDMARanks
)
{
printf
(
"DeepEP RDMA sender coordinator timeout, channel: %d, IB: %d, nvl %d, dst IB: %d, tail %d, num_tokens_to_send %d
\n
"
,
channel_id
,
rdma_rank
,
nvl_rank
,
lane_id
,
last_issued_tail
,
num_tokens_to_send
);
trap
();
}
for
(
int
i
=
0
,
synced_num_tokens_to_send
;
i
<
kNumRDMARanks
;
++
i
)
{
for
(
int
i
=
0
,
synced_num_tokens_to_send
;
i
<
kNumRDMARanks
;
++
i
)
{
// To mitigate incast congestion, shuffle the starting index of target rank for different ranks and channels
// To mitigate incast congestion, shuffle the starting index of target rank for different ranks and channels
int
dst_rdma_rank
=
(
i
+
channel_id
+
rdma_rank
)
%
kNumRDMARanks
;
int
dst_rdma_rank
=
(
i
+
channel_id
+
rdma_rank
)
%
kNumRDMARanks
;
...
@@ -1103,7 +1121,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
...
@@ -1103,7 +1121,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
if
(
sm_id
==
0
)
{
if
(
sm_id
==
0
)
{
// Barrier for RDMA
// Barrier for RDMA
if
(
thread_id
==
0
)
if
(
thread_id
==
0
)
nvshmem_
barrier
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
nvshmem_
sync
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
__syncthreads
();
__syncthreads
();
// Clean
// Clean
...
@@ -1111,12 +1129,11 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
...
@@ -1111,12 +1129,11 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
#pragma unroll
#pragma unroll
for
(
int
i
=
thread_id
;
i
<
rdma_num_int_clean
;
i
+=
num_threads
)
for
(
int
i
=
thread_id
;
i
<
rdma_num_int_clean
;
i
+=
num_threads
)
rdma_buffer_ptr_int
[
rdma_clean_offset
+
i
]
=
0
;
rdma_buffer_ptr_int
[
rdma_clean_offset
+
i
]
=
0
;
nvshmem_fence
();
__syncthreads
();
__syncthreads
();
// Barrier again
// Barrier again
if
(
thread_id
==
0
)
if
(
thread_id
==
0
)
nvshmem_
barrier
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
nvshmem_
sync
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
}
else
if
(
sm_id
==
1
)
{
}
else
if
(
sm_id
==
1
)
{
// Barrier for NVL
// Barrier for NVL
barrier_device
<
NUM_MAX_NVL_PEERS
>
(
task_fifo_ptrs
,
head
,
nvl_rank
);
barrier_device
<
NUM_MAX_NVL_PEERS
>
(
task_fifo_ptrs
,
head
,
nvl_rank
);
...
...
csrc/kernels/runtime.cu
View file @
9fe9021f
...
@@ -58,12 +58,6 @@ int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks
...
@@ -58,12 +58,6 @@ int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks
EP_HOST_ASSERT
(
cpu_rdma_team
!=
NVSHMEM_TEAM_INVALID
);
EP_HOST_ASSERT
(
cpu_rdma_team
!=
NVSHMEM_TEAM_INVALID
);
}
}
// TODO: we still use `nvshmem_barrier` under IBRC mode, which should be switch to IBGDA mode later
nvshmemi_device_host_state_t
*
dev_state_ptr
=
nullptr
;
CUDA_CHECK
(
cudaGetSymbolAddress
(
reinterpret_cast
<
void
**>
(
&
dev_state_ptr
),
nvshmemi_device_state_d
));
bool
ibgda_is_initialized
=
false
;
CUDA_CHECK
(
cudaMemcpy
(
&
dev_state_ptr
->
ibgda_is_initialized
,
&
ibgda_is_initialized
,
sizeof
(
bool
),
cudaMemcpyHostToDevice
));
nvshmem_barrier_all
();
nvshmem_barrier_all
();
return
nvshmem_my_pe
();
return
nvshmem_my_pe
();
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment