Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
9fe9021f
"...ssh:/git@developer.sourcefind.cn:2222/tsoc/openmm.git" did not exist on "ce9fcace1c7c3835ab3970781a2bceebb3e563e7"
Unverified
Commit
9fe9021f
authored
May 28, 2025
by
Shangyan Zhou
Committed by
GitHub
May 28, 2025
Browse files
Use IBGDA only (#177)
parent
aae9fa9a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
77 additions
and
20 deletions
+77
-20
csrc/kernels/ibgda_device.cuh
csrc/kernels/ibgda_device.cuh
+48
-2
csrc/kernels/internode.cu
csrc/kernels/internode.cu
+29
-12
csrc/kernels/runtime.cu
csrc/kernels/runtime.cu
+0
-6
No files found.
csrc/kernels/ibgda_device.cuh
View file @
9fe9021f
...
@@ -413,8 +413,7 @@ __device__ static __forceinline__ void ibgda_write_amo_add_wqe(
...
@@ -413,8 +413,7 @@ __device__ static __forceinline__ void ibgda_write_amo_add_wqe(
__device__
__forceinline__
void
nvshmemi_ibgda_amo_nonfetch_add
(
void
*
rptr
,
const
int
&
value
,
int
pe
,
int
qp_id
,
bool
is_local_copy
=
false
)
{
__device__
__forceinline__
void
nvshmemi_ibgda_amo_nonfetch_add
(
void
*
rptr
,
const
int
&
value
,
int
pe
,
int
qp_id
,
bool
is_local_copy
=
false
)
{
if
(
is_local_copy
)
{
if
(
is_local_copy
)
{
// Fallback to NVSHMEM legacy API
atomicAdd
(
static_cast
<
unsigned
long
long
*>
(
rptr
),
value
);
nvshmemx_signal_op
(
static_cast
<
uint64_t
*>
(
rptr
),
value
,
NVSHMEM_SIGNAL_ADD
,
pe
);
}
else
{
}
else
{
nvshmemi_ibgda_device_qp_t
*
qp
=
ibgda_get_rc
(
pe
,
qp_id
);
nvshmemi_ibgda_device_qp_t
*
qp
=
ibgda_get_rc
(
pe
,
qp_id
);
...
@@ -446,4 +445,51 @@ __device__ __forceinline__ uint64_t nvshmemi_get_p2p_ptr(const uint64_t& ptr, co
...
@@ -446,4 +445,51 @@ __device__ __forceinline__ uint64_t nvshmemi_get_p2p_ptr(const uint64_t& ptr, co
return
peer_base
+
(
ptr
-
reinterpret_cast
<
uint64_t
>
(
nvshmemi_device_state_d
.
heap_base
));
return
peer_base
+
(
ptr
-
reinterpret_cast
<
uint64_t
>
(
nvshmemi_device_state_d
.
heap_base
));
}
}
// This is a simplified version of NVSHMEM's `ibgda_poll_cq`.
// Note that this implementation does not guarantee thread safety,
// so we must ensure that no other threads are concurrently using the same QP.
__device__
static
__forceinline__
int
ibgda_poll_cq
(
nvshmemi_ibgda_device_cq_t
*
cq
,
uint64_t
idx
)
{
int
status
=
0
;
struct
mlx5_cqe64
*
cqe64
=
(
struct
mlx5_cqe64
*
)
cq
->
cqe
;
const
uint32_t
ncqes
=
cq
->
ncqes
;
uint16_t
wqe_counter
;
uint16_t
new_wqe_counter
;
memory_fence_cta
();
do
{
new_wqe_counter
=
ld_na_relaxed
(
&
cqe64
->
wqe_counter
);
new_wqe_counter
=
HtoBE16
(
new_wqe_counter
);
wqe_counter
=
new_wqe_counter
;
}
// NOTE: This while loop is part of do while above.
// wqe_counter is the HW consumer index. However, we always maintain index
// + 1 in SW. To be able to compare with idx, we need to use wqe_counter +
// 1. Because wqe_counter is uint16_t, it may wraparound. Still we know for
// sure that if idx - wqe_counter - 1 < ncqes, wqe_counter + 1 is less than
// idx, and thus we need to wait. We don't need to wait when idx ==
// wqe_counter + 1. That's why we use - (uint16_t)2 here to make this case
// wraparound.
// Example:
// if idx = 10, we wait until wqe_counter = 9, idx - wqe_counter - 2 = 65535 > ncqes.
while
(((
uint16_t
)((
uint16_t
)
idx
-
wqe_counter
-
(
uint16_t
)
2
)
<
ncqes
));
*
cq
->
cons_idx
=
idx
;
// Prevent reordering of this function and subsequent instructions
memory_fence_cta
();
return
status
;
}
// Wait until wqe `idx - 1` is completed.
__device__
static
__forceinline__
void
nvshmemi_ibgda_quiet
(
int
dst_pe
,
int
qp_id
)
{
auto
qp
=
ibgda_get_rc
(
dst_pe
,
qp_id
);
uint64_t
prod_idx
=
ld_na_relaxed
(
qp
->
tx_wq
.
prod_idx
);
ibgda_poll_cq
(
qp
->
tx_wq
.
cq
,
prod_idx
);
}
}
// namespace deep_ep
}
// namespace deep_ep
csrc/kernels/internode.cu
View file @
9fe9021f
...
@@ -193,8 +193,8 @@ __forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank,
...
@@ -193,8 +193,8 @@ __forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank,
}
}
template
<
bool
kLowLatencyMode
>
template
<
bool
kLowLatencyMode
>
__forceinline__
__device__
void
nvshmem_
barrier
_with_same_gpu_idx
(
const
nvshmem_team_t
&
rdma_team
)
{
__forceinline__
__device__
void
nvshmem_
sync
_with_same_gpu_idx
(
const
nvshmem_team_t
&
rdma_team
)
{
kLowLatencyMode
?
void
(
nvshmem_
barrier
(
rdma_team
))
:
nvshmem_
barrier
_all
();
kLowLatencyMode
?
void
(
nvshmem_
sync
(
rdma_team
))
:
nvshmem_
sync
_all
();
}
}
template
<
bool
kLowLatencyMode
,
int
kNumRDMARanks
>
template
<
bool
kLowLatencyMode
,
int
kNumRDMARanks
>
...
@@ -223,7 +223,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
...
@@ -223,7 +223,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
EP_DEVICE_ASSERT
(
num_warps
>
1
);
EP_DEVICE_ASSERT
(
num_warps
>
1
);
EP_DEVICE_ASSERT
(
kNumRDMARanks
<=
num_threads
);
EP_DEVICE_ASSERT
(
kNumRDMARanks
<=
num_threads
);
if
(
thread_id
==
32
)
if
(
thread_id
==
32
)
nvshmem_
barrier
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
nvshmem_
sync
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
barrier_device
<
NUM_MAX_NVL_PEERS
>
(
task_fifo_ptrs
,
head
,
nvl_rank
);
barrier_device
<
NUM_MAX_NVL_PEERS
>
(
task_fifo_ptrs
,
head
,
nvl_rank
);
move_fifo_slots
<
NUM_MAX_NVL_PEERS
>
(
head
);
move_fifo_slots
<
NUM_MAX_NVL_PEERS
>
(
head
);
__syncthreads
();
__syncthreads
();
...
@@ -252,14 +252,25 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
...
@@ -252,14 +252,25 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
// Issue send
// Issue send
// TODO: more light fence or barrier or signaling
// TODO: more light fence or barrier or signaling
// TODO: overlap EP barrier and NVL cleaning
// TODO: overlap EP barrier and NVL cleaning
if
(
thread_id
<
kNumRDMARanks
)
{
for
(
int
i
=
0
;
i
<
kNumRDMARanks
;
++
i
)
{
nvshmem_int_put_nbi
(
rdma_recv_num_tokens_mixed
.
recv_buffer
(
rdma_rank
),
rdma_recv_num_tokens_mixed
.
send_buffer
(
thread_id
),
if
(
i
!=
rdma_rank
)
{
NUM_MAX_NVL_PEERS
+
num_rdma_experts
+
1
,
nvshmemi_ibgda_put_nbi_warp
<
true
>
(
reinterpret_cast
<
uint64_t
>
(
rdma_recv_num_tokens_mixed
.
recv_buffer
(
rdma_rank
)),
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
thread_id
,
nvl_rank
));
reinterpret_cast
<
uint64_t
>
(
rdma_recv_num_tokens_mixed
.
send_buffer
(
i
)),
(
NUM_MAX_NVL_PEERS
+
num_rdma_experts
+
1
)
*
sizeof
(
int
),
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
i
,
nvl_rank
),
0
,
lane_id
,
0
);
}
else
{
UNROLLED_WARP_COPY
(
1
,
lane_id
,
NUM_MAX_NVL_PEERS
+
num_rdma_experts
+
1
,
rdma_recv_num_tokens_mixed
.
recv_buffer
(
rdma_rank
),
rdma_recv_num_tokens_mixed
.
send_buffer
(
i
),
ld_volatile_global
,
st_na_global
);
}
}
}
if
(
thread_id
<
kNumRDMARanks
and
thread_id
!=
rdma_rank
)
nvshmemi_ibgda_quiet
(
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
thread_id
,
nvl_rank
),
0
);
__syncthreads
();
__syncthreads
();
if
(
thread_id
==
0
)
if
(
thread_id
==
0
)
nvshmem_
barrier
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
nvshmem_
sync
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
__syncthreads
();
__syncthreads
();
// NVL buffers
// NVL buffers
...
@@ -345,7 +356,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
...
@@ -345,7 +356,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
// Finally barrier
// Finally barrier
__syncthreads
();
__syncthreads
();
if
(
thread_id
==
32
)
if
(
thread_id
==
32
)
nvshmem_
barrier
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
nvshmem_
sync
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
barrier_device
<
NUM_MAX_NVL_PEERS
>
(
task_fifo_ptrs
,
head
,
nvl_rank
);
barrier_device
<
NUM_MAX_NVL_PEERS
>
(
task_fifo_ptrs
,
head
,
nvl_rank
);
move_fifo_slots
<
NUM_MAX_NVL_PEERS
>
(
head
);
move_fifo_slots
<
NUM_MAX_NVL_PEERS
>
(
head
);
}
else
{
}
else
{
...
@@ -701,7 +712,14 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
...
@@ -701,7 +712,14 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Iterate all RDMA ranks
// Iterate all RDMA ranks
int
last_issued_tail
=
0
;
int
last_issued_tail
=
0
;
auto
start_time
=
clock64
();
while
(
__any_sync
(
0xffffffff
,
num_tokens_to_send
>
0
))
{
while
(
__any_sync
(
0xffffffff
,
num_tokens_to_send
>
0
))
{
// Timeout check
if
(
clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
and
lane_id
<
kNumRDMARanks
)
{
printf
(
"DeepEP RDMA sender coordinator timeout, channel: %d, IB: %d, nvl %d, dst IB: %d, tail %d, num_tokens_to_send %d
\n
"
,
channel_id
,
rdma_rank
,
nvl_rank
,
lane_id
,
last_issued_tail
,
num_tokens_to_send
);
trap
();
}
for
(
int
i
=
0
,
synced_num_tokens_to_send
;
i
<
kNumRDMARanks
;
++
i
)
{
for
(
int
i
=
0
,
synced_num_tokens_to_send
;
i
<
kNumRDMARanks
;
++
i
)
{
// To mitigate incast congestion, shuffle the starting index of target rank for different ranks and channels
// To mitigate incast congestion, shuffle the starting index of target rank for different ranks and channels
int
dst_rdma_rank
=
(
i
+
channel_id
+
rdma_rank
)
%
kNumRDMARanks
;
int
dst_rdma_rank
=
(
i
+
channel_id
+
rdma_rank
)
%
kNumRDMARanks
;
...
@@ -1103,7 +1121,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
...
@@ -1103,7 +1121,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
if
(
sm_id
==
0
)
{
if
(
sm_id
==
0
)
{
// Barrier for RDMA
// Barrier for RDMA
if
(
thread_id
==
0
)
if
(
thread_id
==
0
)
nvshmem_
barrier
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
nvshmem_
sync
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
__syncthreads
();
__syncthreads
();
// Clean
// Clean
...
@@ -1111,12 +1129,11 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
...
@@ -1111,12 +1129,11 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
#pragma unroll
#pragma unroll
for
(
int
i
=
thread_id
;
i
<
rdma_num_int_clean
;
i
+=
num_threads
)
for
(
int
i
=
thread_id
;
i
<
rdma_num_int_clean
;
i
+=
num_threads
)
rdma_buffer_ptr_int
[
rdma_clean_offset
+
i
]
=
0
;
rdma_buffer_ptr_int
[
rdma_clean_offset
+
i
]
=
0
;
nvshmem_fence
();
__syncthreads
();
__syncthreads
();
// Barrier again
// Barrier again
if
(
thread_id
==
0
)
if
(
thread_id
==
0
)
nvshmem_
barrier
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
nvshmem_
sync
_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
}
else
if
(
sm_id
==
1
)
{
}
else
if
(
sm_id
==
1
)
{
// Barrier for NVL
// Barrier for NVL
barrier_device
<
NUM_MAX_NVL_PEERS
>
(
task_fifo_ptrs
,
head
,
nvl_rank
);
barrier_device
<
NUM_MAX_NVL_PEERS
>
(
task_fifo_ptrs
,
head
,
nvl_rank
);
...
...
csrc/kernels/runtime.cu
View file @
9fe9021f
...
@@ -58,12 +58,6 @@ int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks
...
@@ -58,12 +58,6 @@ int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks
EP_HOST_ASSERT
(
cpu_rdma_team
!=
NVSHMEM_TEAM_INVALID
);
EP_HOST_ASSERT
(
cpu_rdma_team
!=
NVSHMEM_TEAM_INVALID
);
}
}
// TODO: we still use `nvshmem_barrier` under IBRC mode, which should be switch to IBGDA mode later
nvshmemi_device_host_state_t
*
dev_state_ptr
=
nullptr
;
CUDA_CHECK
(
cudaGetSymbolAddress
(
reinterpret_cast
<
void
**>
(
&
dev_state_ptr
),
nvshmemi_device_state_d
));
bool
ibgda_is_initialized
=
false
;
CUDA_CHECK
(
cudaMemcpy
(
&
dev_state_ptr
->
ibgda_is_initialized
,
&
ibgda_is_initialized
,
sizeof
(
bool
),
cudaMemcpyHostToDevice
));
nvshmem_barrier_all
();
nvshmem_barrier_all
();
return
nvshmem_my_pe
();
return
nvshmem_my_pe
();
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment