Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
007fcfcf
Unverified
Commit
007fcfcf
authored
Apr 22, 2025
by
Chenggang Zhao
Committed by
GitHub
Apr 22, 2025
Browse files
Merge pull request #130 from deepseek-ai/trmt/internode_multi_qp
Support multi-QP for normal kernels
parents
a84a2480
e255d57b
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
82 additions
and
62 deletions
+82
-62
README.md
README.md
+4
-2
csrc/kernels/ibgda_device.cuh
csrc/kernels/ibgda_device.cuh
+17
-11
csrc/kernels/internode.cu
csrc/kernels/internode.cu
+38
-26
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+1
-1
csrc/kernels/runtime.cu
csrc/kernels/runtime.cu
+5
-7
deep_ep/buffer.py
deep_ep/buffer.py
+11
-12
tests/test_internode.py
tests/test_internode.py
+6
-3
No files found.
README.md
View file @
007fcfcf
...
...
@@ -18,8 +18,10 @@ We test normal kernels on H800 (~160 GB/s NVLink maximum bandwidth), with each c
|:---------:|:------------:|:--------------------:|:-----------:|:--------------------:|
| Intranode | 8 | 153 GB/s (NVLink) | 8 | 158 GB/s (NVLink) |
| Internode | 16 | 43 GB/s (RDMA) | 16 | 43 GB/s (RDMA) |
| Internode | 32 | 44 GB/s (RDMA) | 32 | 47 GB/s (RDMA) |
| Internode | 64 | 46 GB/s (RDMA) | 64 | 45 GB/s (RDMA) |
| Internode | 32 | 58 GB/s (RDMA) | 32 | 57 GB/s (RDMA) |
| Internode | 64 | 51 GB/s (RDMA) | 64 | 50 GB/s (RDMA) |
**News (2025.04.22)**
: with optimizations from Tencent Network Platform Department, performance was enhanced by up to 30%, see
[
#130
](
https://github.com/deepseek-ai/DeepEP/pull/130
)
for more details. Thanks for the contribution!
### Low-latency kernels with pure RDMA
...
...
csrc/kernels/ibgda_device.cuh
View file @
007fcfcf
...
...
@@ -325,6 +325,7 @@ ibgda_write_empty_recv_wqe(void *out_wqe) {
st_na_relaxed
(
reinterpret_cast
<
int4
*>
(
data_seg_ptr
),
*
reinterpret_cast
<
const
int4
*>
(
&
data_seg
));
}
template
<
bool
kAlwaysDoPostSend
=
false
>
__device__
static
__forceinline__
void
nvshmemi_ibgda_put_nbi_warp
(
uint64_t
req_rptr
,
uint64_t
req_lptr
,
size_t
bytes
,
int
dst_pe
,
int
qp_id
,
int
lane_id
,
int
message_idx
)
{
// Get lkey and rkey, store them into lanes
...
...
@@ -365,7 +366,7 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
// Submit
if
(
lane_id
==
0
)
ibgda_submit_requests
<
false
>
(
qp
,
base_wqe_idx
,
num_wqes
,
message_idx
);
ibgda_submit_requests
<
kAlwaysDoPostSend
>
(
qp
,
base_wqe_idx
,
num_wqes
,
message_idx
);
__syncwarp
();
}
...
...
@@ -410,7 +411,11 @@ __device__ static __forceinline__ void ibgda_write_amo_add_wqe(
st_na_relaxed
(
reinterpret_cast
<
int4
*>
(
data_seg_ptr
),
*
reinterpret_cast
<
int4
*>
(
&
data_seg
));
}
__device__
__forceinline__
void
nvshmemi_ibgda_amo_nonfetch_add
(
void
*
rptr
,
const
int
&
value
,
int
pe
,
int
qp_id
)
{
__device__
__forceinline__
void
nvshmemi_ibgda_amo_nonfetch_add
(
void
*
rptr
,
const
int
&
value
,
int
pe
,
int
qp_id
,
bool
is_local_copy
=
false
)
{
if
(
is_local_copy
)
{
// Fallback to NVSHMEM legacy API
nvshmemx_signal_op
(
static_cast
<
uint64_t
*>
(
rptr
),
value
,
NVSHMEM_SIGNAL_ADD
,
pe
);
}
else
{
nvshmemi_ibgda_device_qp_t
*
qp
=
ibgda_get_rc
(
pe
,
qp_id
);
__be32
rkey
;
...
...
@@ -424,6 +429,7 @@ __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, cons
qp
->
ibuf
.
lkey
,
raddr
,
rkey
,
my_wqe_idx
,
&
wqe_ptrs
);
ibgda_submit_requests
<
true
>
(
qp
,
my_wqe_idx
,
1
);
}
}
}
// namespace deep_ep
csrc/kernels/internode.cu
View file @
007fcfcf
...
...
@@ -3,6 +3,7 @@
#include "exception.cuh"
#include "launch.cuh"
#include "utils.cuh"
#include "ibgda_device.cuh"
namespace
deep_ep
{
...
...
@@ -479,6 +480,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
const
bool
is_forwarder
=
sm_id
%
2
==
0
;
const
auto
rdma_rank
=
rank
/
NUM_MAX_NVL_PEERS
,
nvl_rank
=
rank
%
NUM_MAX_NVL_PEERS
;
EP_DEVICE_ASSERT
(
ibgda_get_state
()
->
num_rc_per_pe
>=
num_channels
);
const
auto
role_meta
=
[
=
]()
->
std
::
pair
<
WarpRole
,
int
>
{
if
(
is_forwarder
)
{
if
(
warp_id
<
NUM_MAX_NVL_PEERS
)
{
...
...
@@ -555,19 +558,27 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Send number of tokens in this channel by `-value - 1`
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
*
2
+
2
<=
32
,
"Invalid number of NVL peers"
);
for
(
int
dst_rdma_rank
=
warp_id
;
dst_rdma_rank
<
kNumRDMARanks
;
dst_rdma_rank
+=
kNumDispatchRDMASenderWarps
)
{
auto
dst_ptr
=
dst_rdma_rank
==
rdma_rank
?
rdma_channel_meta
.
recv_buffer
(
dst_rdma_rank
)
:
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
);
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
{
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
)
[
lane_id
]
=
-
(
channel_id
==
0
?
0
:
gbl_channel_prefix_matrix
[(
dst_rdma_rank
*
NUM_MAX_NVL_PEERS
+
lane_id
)
*
num_channels
+
channel_id
-
1
])
-
1
;
dst_ptr
[
lane_id
]
=
-
(
channel_id
==
0
?
0
:
gbl_channel_prefix_matrix
[(
dst_rdma_rank
*
NUM_MAX_NVL_PEERS
+
lane_id
)
*
num_channels
+
channel_id
-
1
])
-
1
;
}
else
if
(
lane_id
<
NUM_MAX_NVL_PEERS
*
2
)
{
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
)
[
lane_id
]
=
-
gbl_channel_prefix_matrix
[(
dst_rdma_rank
*
NUM_MAX_NVL_PEERS
+
lane_id
-
NUM_MAX_NVL_PEERS
)
*
num_channels
+
channel_id
]
-
1
;
dst_ptr
[
lane_id
]
=
-
gbl_channel_prefix_matrix
[(
dst_rdma_rank
*
NUM_MAX_NVL_PEERS
+
lane_id
-
NUM_MAX_NVL_PEERS
)
*
num_channels
+
channel_id
]
-
1
;
}
else
if
(
lane_id
==
NUM_MAX_NVL_PEERS
*
2
)
{
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
)
[
lane_id
]
=
-
(
channel_id
==
0
?
0
:
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel_id
-
1
])
-
1
;
dst_ptr
[
lane_id
]
=
-
(
channel_id
==
0
?
0
:
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel_id
-
1
])
-
1
;
}
else
if
(
lane_id
==
NUM_MAX_NVL_PEERS
*
2
+
1
)
{
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
)[
lane_id
]
=
-
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel_id
]
-
1
;
dst_ptr
[
lane_id
]
=
-
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel_id
]
-
1
;
}
__syncwarp
();
// Issue RDMA for non-local ranks
if
(
dst_rdma_rank
!=
rdma_rank
)
{
nvshmemi_ibgda_put_nbi_warp
<
true
>
(
reinterpret_cast
<
uint64_t
>
(
rdma_channel_meta
.
recv_buffer
(
rdma_rank
)),
reinterpret_cast
<
uint64_t
>
(
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
)),
sizeof
(
int
)
*
(
NUM_MAX_NVL_PEERS
*
2
+
2
),
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
,
lane_id
,
0
);
}
nvshmemx_int_put_nbi_warp
(
rdma_channel_meta
.
recv_buffer
(
rdma_rank
),
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
),
NUM_MAX_NVL_PEERS
*
2
+
2
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
}
nvshmem_fence
();
sync_rdma_sender_smem
();
// Iterate over tokens and copy into buffer
...
...
@@ -710,11 +721,11 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if
(
dst_rdma_rank
!=
rdma_rank
)
{
auto
dst_slot_idx
=
synced_last_issued_tail
%
num_max_rdma_chunked_recv_tokens
;
EP_DEVICE_ASSERT
(
dst_slot_idx
+
num_tokens_to_issue
<=
num_max_rdma_chunked_recv_tokens
);
nvshmemx_int8_put_nbi_warp
(
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
dst_slot_idx
*
num_bytes_per_rdma_token
,
rdma_channel_data
.
send
_buffer
(
dst_
rdma_rank
)
+
dst_slot_idx
*
num_bytes_per_rdma_token
,
num_bytes_per_rdma_token
*
num_tokens_to_issue
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
nvshmem_fence
(
);
const
size_t
num_bytes_per_msg
=
num_bytes_per_rdma_token
*
num_tokens_to_issue
;
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_channel_data
.
recv
_buffer
(
rdma_rank
)
+
dst_slot_idx
*
num_bytes_per_rdma_token
);
const
auto
src_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
dst_slot_idx
*
num_bytes_per_rdma_token
);
nvshmemi_ibgda_put_nbi_warp
<
true
>
(
dst_ptr
,
src_ptr
,
num_bytes_per_msg
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
,
lane_id
,
0
);
}
else
{
// Lighter fence for local RDMA rank
memory_fence
();
...
...
@@ -725,8 +736,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if
(
lane_id
==
dst_rdma_rank
)
{
last_issued_tail
+=
num_tokens_to_issue
;
num_tokens_to_send
-=
num_tokens_to_issue
;
nvshmem
x_signal_op
(
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_tokens_to_issue
,
NVSHMEM_SIGNAL_ADD
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
nvshmem
i_ibgda_amo_nonfetch_add
(
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_tokens_to_issue
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
)
,
channel_id
,
dst_rdma_rank
==
rdma_rank
);
}
}
}
...
...
@@ -926,8 +937,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Update remote head
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
and
min_head
>=
last_head
+
num_max_rdma_chunked_send_tokens
and
lane_id
<
kNumRDMARanks
)
{
nvshmem
x_signal_op
(
rdma_channel_head
.
buffer
(
rdma_rank
),
min_head
-
last_head
,
NVSHMEM_SIGNAL_ADD
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
lane_id
,
nvl_rank
));
nvshmem
i_ibgda_amo_nonfetch_add
(
rdma_channel_head
.
buffer
(
rdma_rank
),
min_head
-
last_head
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
lane_id
,
nvl_rank
)
,
channel_id
,
lane_id
==
rdma_rank
);
last_head
=
min_head
;
}
...
...
@@ -1558,20 +1569,21 @@ combine(int4* combined_x, float* combined_topk_weights,
if
(
sub_warp_id
==
kNumWarpsPerForwarder
-
1
)
{
if
(
dst_rdma_rank
!=
rdma_rank
)
{
auto
rdma_slot_idx
=
token_start_idx
%
num_max_rdma_chunked_recv_tokens
;
nvshmemx_int8_put_nbi_warp
(
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
,
rdma_channel_data
.
send
_buffer
(
dst_
rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
,
num_chunked_tokens
*
num_bytes_per_rdma_token
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
nvshmem_fence
(
);
const
size_t
num_bytes_per_msg
=
num_chunked_tokens
*
num_bytes_per_rdma_token
;
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_channel_data
.
recv
_buffer
(
rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
);
const
auto
src_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
);
nvshmemi_ibgda_put_nbi_warp
<
true
>
(
dst_ptr
,
src_ptr
,
num_bytes_per_msg
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
,
lane_id
,
0
);
}
else
{
memory_fence
();
}
// Write new RDMA tail
__syncwarp
();
if
(
lane_id
==
0
)
nvshmemx_signal_op
(
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_chunked_tokens
,
NVSHMEM_SIGNAL_ADD
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
if
(
lane_id
==
0
)
{
nvshmemi_ibgda_amo_nonfetch_add
(
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_chunked_tokens
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
,
dst_rdma_rank
==
rdma_rank
);
}
}
}
...
...
@@ -1656,8 +1668,8 @@ combine(int4* combined_x, float* combined_topk_weights,
for
(
int
i
=
0
;
i
<
kNumRDMAReceivers
;
++
i
)
if
(
not
rdma_receiver_retired
[
i
])
min_head
=
min
(
min_head
,
rdma_receiver_rdma_head
[
i
][
dst_rdma_rank
]);
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
and
min_head
>=
last_rdma_head
+
num_max_rdma_chunked_send_tokens
and
lane_id
<
kNumRDMARanks
)
{
nvshmem
x_signal_op
(
rdma_channel_head
.
buffer
(
rdma_rank
),
min_head
-
last_rdma_head
,
NVSHMEM_SIGNAL_ADD
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
nvshmem
i_ibgda_amo_nonfetch_add
(
rdma_channel_head
.
buffer
(
rdma_rank
),
min_head
-
last_rdma_head
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
)
,
channel_id
,
dst_rdma_rank
==
rdma_rank
);
last_rdma_head
=
min_head
;
}
}
else
{
...
...
csrc/kernels/internode_ll.cu
View file @
007fcfcf
...
...
@@ -167,7 +167,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
EP_DEVICE_ASSERT
(
num_sms
>
1
);
if
(
sm_id
==
0
)
{
// The first SM is also responsible for checking QPs
EP_DEVICE_ASSERT
(
ibgda_get_state
()
->
num_rc_per_pe
=
=
num_local_experts
);
EP_DEVICE_ASSERT
(
ibgda_get_state
()
->
num_rc_per_pe
>
=
num_local_experts
);
// The first SM is also responsible for cleaning the next buffer
#pragma unroll
...
...
csrc/kernels/runtime.cu
View file @
007fcfcf
...
...
@@ -58,14 +58,12 @@ int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks
EP_HOST_ASSERT
(
cpu_rdma_team
!=
NVSHMEM_TEAM_INVALID
);
}
// Normal operations use IBRC, while low-latency operations use IBGDA
if
(
low_latency_mode
)
{
// TODO: we still use `nvshmem_barrier` under IBRC mode, which should be switch to IBGDA mode later
nvshmemi_device_host_state_t
*
dev_state_ptr
=
nullptr
;
CUDA_CHECK
(
cudaGetSymbolAddress
(
reinterpret_cast
<
void
**>
(
&
dev_state_ptr
),
nvshmemi_device_state_d
));
bool
ibgda_is_initialized
=
false
;
CUDA_CHECK
(
cudaMemcpy
(
&
dev_state_ptr
->
ibgda_is_initialized
,
&
ibgda_is_initialized
,
sizeof
(
bool
),
cudaMemcpyHostToDevice
));
}
nvshmem_barrier_all
();
return
nvshmem_my_pe
();
}
...
...
deep_ep/buffer.py
View file @
007fcfcf
...
...
@@ -31,7 +31,7 @@ class Buffer:
def
__init__
(
self
,
group
:
dist
.
ProcessGroup
,
num_nvl_bytes
:
int
=
0
,
num_rdma_bytes
:
int
=
0
,
low_latency_mode
:
bool
=
False
,
num_qps_per_rank
:
int
=
1
)
->
None
:
low_latency_mode
:
bool
=
False
,
num_qps_per_rank
:
int
=
1
2
)
->
None
:
"""
Initialize the communication buffer.
...
...
@@ -66,8 +66,7 @@ class Buffer:
# Synchronize NVSHMEM unique IDs
root_unique_id
=
None
if
self
.
runtime
.
get_num_rdma_ranks
()
>
1
or
low_latency_mode
:
# Enable IBGDA for the low latency mode, which refers to "no package forwarding between NVLink and RDMA"
if
low_latency_mode
:
# Enable IBGDA
assert
num_qps_per_rank
>
0
os
.
environ
[
'NVSHMEM_DISABLE_P2P'
]
=
'1'
os
.
environ
[
'NVSHMEM_IB_ENABLE_IBGDA'
]
=
'1'
...
...
tests/test_internode.py
View file @
007fcfcf
...
...
@@ -219,16 +219,19 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
def
test_loop
(
local_rank
:
int
,
num_local_ranks
:
int
):
num_nodes
=
int
(
os
.
getenv
(
'WORLD_SIZE'
,
1
))
rank
,
num_ranks
,
group
=
init_dist
(
local_rank
,
num_local_ranks
)
test_ll_compatibility
=
Fals
e
test_ll_compatibility
=
Tru
e
if
test_ll_compatibility
:
ll_num_tokens
,
ll_hidden
,
ll_num_experts
,
ll_num_topk
=
16
,
5120
,
256
,
9
num_sms
=
24
num_qps_per_rank
=
max
(
num_sms
//
2
,
ll_num_experts
//
num_ranks
if
test_ll_compatibility
else
0
)
buffer
=
deep_ep
.
Buffer
(
group
,
int
(
1e9
),
int
(
1e9
),
low_latency_mode
=
test_ll_compatibility
,
num_qps_per_rank
=
(
ll_num_experts
//
num_ranks
if
test_ll_compatibility
else
1
)
)
num_qps_per_rank
=
num_qps_per_rank
)
assert
num_local_ranks
==
8
and
num_ranks
>
8
torch
.
manual_seed
(
rank
)
for
i
in
(
24
,
):
for
i
in
(
num_sms
,
):
test_main
(
i
,
local_rank
,
num_local_ranks
,
num_ranks
,
num_nodes
,
rank
,
buffer
,
group
)
if
local_rank
==
0
:
print
(
''
,
flush
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment