Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
007fcfcf
Unverified
Commit
007fcfcf
authored
Apr 22, 2025
by
Chenggang Zhao
Committed by
GitHub
Apr 22, 2025
Browse files
Merge pull request #130 from deepseek-ai/trmt/internode_multi_qp
Support multi-QP for normal kernels
parents
a84a2480
e255d57b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
82 additions
and
62 deletions
+82
-62
README.md
README.md
+4
-2
csrc/kernels/ibgda_device.cuh
csrc/kernels/ibgda_device.cuh
+17
-11
csrc/kernels/internode.cu
csrc/kernels/internode.cu
+38
-26
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+1
-1
csrc/kernels/runtime.cu
csrc/kernels/runtime.cu
+5
-7
deep_ep/buffer.py
deep_ep/buffer.py
+11
-12
tests/test_internode.py
tests/test_internode.py
+6
-3
No files found.
README.md
View file @
007fcfcf
...
@@ -18,8 +18,10 @@ We test normal kernels on H800 (~160 GB/s NVLink maximum bandwidth), with each c
...
@@ -18,8 +18,10 @@ We test normal kernels on H800 (~160 GB/s NVLink maximum bandwidth), with each c
|:---------:|:------------:|:--------------------:|:-----------:|:--------------------:|
|:---------:|:------------:|:--------------------:|:-----------:|:--------------------:|
| Intranode | 8 | 153 GB/s (NVLink) | 8 | 158 GB/s (NVLink) |
| Intranode | 8 | 153 GB/s (NVLink) | 8 | 158 GB/s (NVLink) |
| Internode | 16 | 43 GB/s (RDMA) | 16 | 43 GB/s (RDMA) |
| Internode | 16 | 43 GB/s (RDMA) | 16 | 43 GB/s (RDMA) |
| Internode | 32 | 44 GB/s (RDMA) | 32 | 47 GB/s (RDMA) |
| Internode | 32 | 58 GB/s (RDMA) | 32 | 57 GB/s (RDMA) |
| Internode | 64 | 46 GB/s (RDMA) | 64 | 45 GB/s (RDMA) |
| Internode | 64 | 51 GB/s (RDMA) | 64 | 50 GB/s (RDMA) |
**News (2025.04.22)**
: with optimizations from Tencent Network Platform Department, performance was enhanced by up to 30%, see
[
#130
](
https://github.com/deepseek-ai/DeepEP/pull/130
)
for more details. Thanks for the contribution!
### Low-latency kernels with pure RDMA
### Low-latency kernels with pure RDMA
...
...
csrc/kernels/ibgda_device.cuh
View file @
007fcfcf
...
@@ -325,6 +325,7 @@ ibgda_write_empty_recv_wqe(void *out_wqe) {
...
@@ -325,6 +325,7 @@ ibgda_write_empty_recv_wqe(void *out_wqe) {
st_na_relaxed
(
reinterpret_cast
<
int4
*>
(
data_seg_ptr
),
*
reinterpret_cast
<
const
int4
*>
(
&
data_seg
));
st_na_relaxed
(
reinterpret_cast
<
int4
*>
(
data_seg_ptr
),
*
reinterpret_cast
<
const
int4
*>
(
&
data_seg
));
}
}
template
<
bool
kAlwaysDoPostSend
=
false
>
__device__
static
__forceinline__
void
__device__
static
__forceinline__
void
nvshmemi_ibgda_put_nbi_warp
(
uint64_t
req_rptr
,
uint64_t
req_lptr
,
size_t
bytes
,
int
dst_pe
,
int
qp_id
,
int
lane_id
,
int
message_idx
)
{
nvshmemi_ibgda_put_nbi_warp
(
uint64_t
req_rptr
,
uint64_t
req_lptr
,
size_t
bytes
,
int
dst_pe
,
int
qp_id
,
int
lane_id
,
int
message_idx
)
{
// Get lkey and rkey, store them into lanes
// Get lkey and rkey, store them into lanes
...
@@ -365,7 +366,7 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
...
@@ -365,7 +366,7 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
// Submit
// Submit
if
(
lane_id
==
0
)
if
(
lane_id
==
0
)
ibgda_submit_requests
<
false
>
(
qp
,
base_wqe_idx
,
num_wqes
,
message_idx
);
ibgda_submit_requests
<
kAlwaysDoPostSend
>
(
qp
,
base_wqe_idx
,
num_wqes
,
message_idx
);
__syncwarp
();
__syncwarp
();
}
}
...
@@ -410,20 +411,25 @@ __device__ static __forceinline__ void ibgda_write_amo_add_wqe(
...
@@ -410,20 +411,25 @@ __device__ static __forceinline__ void ibgda_write_amo_add_wqe(
st_na_relaxed
(
reinterpret_cast
<
int4
*>
(
data_seg_ptr
),
*
reinterpret_cast
<
int4
*>
(
&
data_seg
));
st_na_relaxed
(
reinterpret_cast
<
int4
*>
(
data_seg_ptr
),
*
reinterpret_cast
<
int4
*>
(
&
data_seg
));
}
}
__device__
__forceinline__
void
nvshmemi_ibgda_amo_nonfetch_add
(
void
*
rptr
,
const
int
&
value
,
int
pe
,
int
qp_id
)
{
__device__
__forceinline__
void
nvshmemi_ibgda_amo_nonfetch_add
(
void
*
rptr
,
const
int
&
value
,
int
pe
,
int
qp_id
,
bool
is_local_copy
=
false
)
{
nvshmemi_ibgda_device_qp_t
*
qp
=
ibgda_get_rc
(
pe
,
qp_id
);
if
(
is_local_copy
)
{
// Fallback to NVSHMEM legacy API
nvshmemx_signal_op
(
static_cast
<
uint64_t
*>
(
rptr
),
value
,
NVSHMEM_SIGNAL_ADD
,
pe
);
}
else
{
nvshmemi_ibgda_device_qp_t
*
qp
=
ibgda_get_rc
(
pe
,
qp_id
);
__be32
rkey
;
__be32
rkey
;
uint64_t
raddr
;
uint64_t
raddr
;
ibgda_get_rkey
(
reinterpret_cast
<
uint64_t
>
(
rptr
),
pe
,
&
raddr
,
&
rkey
);
ibgda_get_rkey
(
reinterpret_cast
<
uint64_t
>
(
rptr
),
pe
,
&
raddr
,
&
rkey
);
uint64_t
my_wqe_idx
=
ibgda_reserve_wqe_slots
(
qp
,
1
);
uint64_t
my_wqe_idx
=
ibgda_reserve_wqe_slots
(
qp
,
1
);
void
*
wqe_ptrs
=
ibgda_get_wqe_ptr
(
qp
,
my_wqe_idx
);
void
*
wqe_ptrs
=
ibgda_get_wqe_ptr
(
qp
,
my_wqe_idx
);
ibgda_write_amo_add_wqe
(
qp
,
value
,
reinterpret_cast
<
uint64_t
>
(
qp
->
ibuf
.
buf
),
ibgda_write_amo_add_wqe
(
qp
,
value
,
reinterpret_cast
<
uint64_t
>
(
qp
->
ibuf
.
buf
),
qp
->
ibuf
.
lkey
,
raddr
,
rkey
,
my_wqe_idx
,
&
wqe_ptrs
);
qp
->
ibuf
.
lkey
,
raddr
,
rkey
,
my_wqe_idx
,
&
wqe_ptrs
);
ibgda_submit_requests
<
true
>
(
qp
,
my_wqe_idx
,
1
);
ibgda_submit_requests
<
true
>
(
qp
,
my_wqe_idx
,
1
);
}
}
}
}
// namespace deep_ep
}
// namespace deep_ep
csrc/kernels/internode.cu
View file @
007fcfcf
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include "exception.cuh"
#include "exception.cuh"
#include "launch.cuh"
#include "launch.cuh"
#include "utils.cuh"
#include "utils.cuh"
#include "ibgda_device.cuh"
namespace
deep_ep
{
namespace
deep_ep
{
...
@@ -479,6 +480,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
...
@@ -479,6 +480,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
const
bool
is_forwarder
=
sm_id
%
2
==
0
;
const
bool
is_forwarder
=
sm_id
%
2
==
0
;
const
auto
rdma_rank
=
rank
/
NUM_MAX_NVL_PEERS
,
nvl_rank
=
rank
%
NUM_MAX_NVL_PEERS
;
const
auto
rdma_rank
=
rank
/
NUM_MAX_NVL_PEERS
,
nvl_rank
=
rank
%
NUM_MAX_NVL_PEERS
;
EP_DEVICE_ASSERT
(
ibgda_get_state
()
->
num_rc_per_pe
>=
num_channels
);
const
auto
role_meta
=
[
=
]()
->
std
::
pair
<
WarpRole
,
int
>
{
const
auto
role_meta
=
[
=
]()
->
std
::
pair
<
WarpRole
,
int
>
{
if
(
is_forwarder
)
{
if
(
is_forwarder
)
{
if
(
warp_id
<
NUM_MAX_NVL_PEERS
)
{
if
(
warp_id
<
NUM_MAX_NVL_PEERS
)
{
...
@@ -555,19 +558,27 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
...
@@ -555,19 +558,27 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Send number of tokens in this channel by `-value - 1`
// Send number of tokens in this channel by `-value - 1`
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
*
2
+
2
<=
32
,
"Invalid number of NVL peers"
);
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
*
2
+
2
<=
32
,
"Invalid number of NVL peers"
);
for
(
int
dst_rdma_rank
=
warp_id
;
dst_rdma_rank
<
kNumRDMARanks
;
dst_rdma_rank
+=
kNumDispatchRDMASenderWarps
)
{
for
(
int
dst_rdma_rank
=
warp_id
;
dst_rdma_rank
<
kNumRDMARanks
;
dst_rdma_rank
+=
kNumDispatchRDMASenderWarps
)
{
auto
dst_ptr
=
dst_rdma_rank
==
rdma_rank
?
rdma_channel_meta
.
recv_buffer
(
dst_rdma_rank
)
:
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
);
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
{
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
{
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
)
[
lane_id
]
=
-
(
channel_id
==
0
?
0
:
gbl_channel_prefix_matrix
[(
dst_rdma_rank
*
NUM_MAX_NVL_PEERS
+
lane_id
)
*
num_channels
+
channel_id
-
1
])
-
1
;
dst_ptr
[
lane_id
]
=
-
(
channel_id
==
0
?
0
:
gbl_channel_prefix_matrix
[(
dst_rdma_rank
*
NUM_MAX_NVL_PEERS
+
lane_id
)
*
num_channels
+
channel_id
-
1
])
-
1
;
}
else
if
(
lane_id
<
NUM_MAX_NVL_PEERS
*
2
)
{
}
else
if
(
lane_id
<
NUM_MAX_NVL_PEERS
*
2
)
{
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
)
[
lane_id
]
=
-
gbl_channel_prefix_matrix
[(
dst_rdma_rank
*
NUM_MAX_NVL_PEERS
+
lane_id
-
NUM_MAX_NVL_PEERS
)
*
num_channels
+
channel_id
]
-
1
;
dst_ptr
[
lane_id
]
=
-
gbl_channel_prefix_matrix
[(
dst_rdma_rank
*
NUM_MAX_NVL_PEERS
+
lane_id
-
NUM_MAX_NVL_PEERS
)
*
num_channels
+
channel_id
]
-
1
;
}
else
if
(
lane_id
==
NUM_MAX_NVL_PEERS
*
2
)
{
}
else
if
(
lane_id
==
NUM_MAX_NVL_PEERS
*
2
)
{
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
)
[
lane_id
]
=
-
(
channel_id
==
0
?
0
:
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel_id
-
1
])
-
1
;
dst_ptr
[
lane_id
]
=
-
(
channel_id
==
0
?
0
:
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel_id
-
1
])
-
1
;
}
else
if
(
lane_id
==
NUM_MAX_NVL_PEERS
*
2
+
1
)
{
}
else
if
(
lane_id
==
NUM_MAX_NVL_PEERS
*
2
+
1
)
{
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
)[
lane_id
]
=
-
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel_id
]
-
1
;
dst_ptr
[
lane_id
]
=
-
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel_id
]
-
1
;
}
__syncwarp
();
// Issue RDMA for non-local ranks
if
(
dst_rdma_rank
!=
rdma_rank
)
{
nvshmemi_ibgda_put_nbi_warp
<
true
>
(
reinterpret_cast
<
uint64_t
>
(
rdma_channel_meta
.
recv_buffer
(
rdma_rank
)),
reinterpret_cast
<
uint64_t
>
(
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
)),
sizeof
(
int
)
*
(
NUM_MAX_NVL_PEERS
*
2
+
2
),
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
,
lane_id
,
0
);
}
}
nvshmemx_int_put_nbi_warp
(
rdma_channel_meta
.
recv_buffer
(
rdma_rank
),
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
),
NUM_MAX_NVL_PEERS
*
2
+
2
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
}
}
nvshmem_fence
();
sync_rdma_sender_smem
();
sync_rdma_sender_smem
();
// Iterate over tokens and copy into buffer
// Iterate over tokens and copy into buffer
...
@@ -710,11 +721,11 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
...
@@ -710,11 +721,11 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if
(
dst_rdma_rank
!=
rdma_rank
)
{
if
(
dst_rdma_rank
!=
rdma_rank
)
{
auto
dst_slot_idx
=
synced_last_issued_tail
%
num_max_rdma_chunked_recv_tokens
;
auto
dst_slot_idx
=
synced_last_issued_tail
%
num_max_rdma_chunked_recv_tokens
;
EP_DEVICE_ASSERT
(
dst_slot_idx
+
num_tokens_to_issue
<=
num_max_rdma_chunked_recv_tokens
);
EP_DEVICE_ASSERT
(
dst_slot_idx
+
num_tokens_to_issue
<=
num_max_rdma_chunked_recv_tokens
);
nvshmemx_int8_put_nbi_warp
(
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
dst_slot_idx
*
num_bytes_per_rdma_token
,
const
size_t
num_bytes_per_msg
=
num_bytes_per_rdma_token
*
num_tokens_to_issue
;
rdma_channel_data
.
send
_buffer
(
dst_
rdma_rank
)
+
dst_slot_idx
*
num_bytes_per_rdma_token
,
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_channel_data
.
recv
_buffer
(
rdma_rank
)
+
dst_slot_idx
*
num_bytes_per_rdma_token
);
num_bytes_per_rdma_token
*
num_tokens_to_issue
,
const
auto
src_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
dst_slot_idx
*
num_bytes_per_rdma_token
);
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
nvshmemi_ibgda_put_nbi_warp
<
true
>
(
dst_ptr
,
src_ptr
,
num_bytes_per_msg
,
nvshmem_fence
(
);
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
,
lane_id
,
0
);
}
else
{
}
else
{
// Lighter fence for local RDMA rank
// Lighter fence for local RDMA rank
memory_fence
();
memory_fence
();
...
@@ -725,8 +736,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
...
@@ -725,8 +736,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if
(
lane_id
==
dst_rdma_rank
)
{
if
(
lane_id
==
dst_rdma_rank
)
{
last_issued_tail
+=
num_tokens_to_issue
;
last_issued_tail
+=
num_tokens_to_issue
;
num_tokens_to_send
-=
num_tokens_to_issue
;
num_tokens_to_send
-=
num_tokens_to_issue
;
nvshmem
x_signal_op
(
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_tokens_to_issue
,
NVSHMEM_SIGNAL_ADD
,
nvshmem
i_ibgda_amo_nonfetch_add
(
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_tokens_to_issue
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
)
,
channel_id
,
dst_rdma_rank
==
rdma_rank
);
}
}
}
}
}
}
...
@@ -926,8 +937,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
...
@@ -926,8 +937,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Update remote head
// Update remote head
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
and
min_head
>=
last_head
+
num_max_rdma_chunked_send_tokens
and
lane_id
<
kNumRDMARanks
)
{
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
and
min_head
>=
last_head
+
num_max_rdma_chunked_send_tokens
and
lane_id
<
kNumRDMARanks
)
{
nvshmem
x_signal_op
(
rdma_channel_head
.
buffer
(
rdma_rank
),
min_head
-
last_head
,
NVSHMEM_SIGNAL_ADD
,
nvshmem
i_ibgda_amo_nonfetch_add
(
rdma_channel_head
.
buffer
(
rdma_rank
),
min_head
-
last_head
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
lane_id
,
nvl_rank
));
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
lane_id
,
nvl_rank
)
,
channel_id
,
lane_id
==
rdma_rank
);
last_head
=
min_head
;
last_head
=
min_head
;
}
}
...
@@ -1558,20 +1569,21 @@ combine(int4* combined_x, float* combined_topk_weights,
...
@@ -1558,20 +1569,21 @@ combine(int4* combined_x, float* combined_topk_weights,
if
(
sub_warp_id
==
kNumWarpsPerForwarder
-
1
)
{
if
(
sub_warp_id
==
kNumWarpsPerForwarder
-
1
)
{
if
(
dst_rdma_rank
!=
rdma_rank
)
{
if
(
dst_rdma_rank
!=
rdma_rank
)
{
auto
rdma_slot_idx
=
token_start_idx
%
num_max_rdma_chunked_recv_tokens
;
auto
rdma_slot_idx
=
token_start_idx
%
num_max_rdma_chunked_recv_tokens
;
nvshmemx_int8_put_nbi_warp
(
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
,
const
size_t
num_bytes_per_msg
=
num_chunked_tokens
*
num_bytes_per_rdma_token
;
rdma_channel_data
.
send
_buffer
(
dst_
rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
,
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_channel_data
.
recv
_buffer
(
rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
);
num_chunked_tokens
*
num_bytes_per_rdma_token
,
const
auto
src_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
);
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
nvshmemi_ibgda_put_nbi_warp
<
true
>
(
dst_ptr
,
src_ptr
,
num_bytes_per_msg
,
nvshmem_fence
(
);
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
,
lane_id
,
0
);
}
else
{
}
else
{
memory_fence
();
memory_fence
();
}
}
// Write new RDMA tail
// Write new RDMA tail
__syncwarp
();
__syncwarp
();
if
(
lane_id
==
0
)
if
(
lane_id
==
0
)
{
nvshmemx_signal_op
(
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_chunked_tokens
,
NVSHMEM_SIGNAL_ADD
,
nvshmemi_ibgda_amo_nonfetch_add
(
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_chunked_tokens
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
,
dst_rdma_rank
==
rdma_rank
);
}
}
}
}
}
...
@@ -1656,8 +1668,8 @@ combine(int4* combined_x, float* combined_topk_weights,
...
@@ -1656,8 +1668,8 @@ combine(int4* combined_x, float* combined_topk_weights,
for
(
int
i
=
0
;
i
<
kNumRDMAReceivers
;
++
i
)
if
(
not
rdma_receiver_retired
[
i
])
for
(
int
i
=
0
;
i
<
kNumRDMAReceivers
;
++
i
)
if
(
not
rdma_receiver_retired
[
i
])
min_head
=
min
(
min_head
,
rdma_receiver_rdma_head
[
i
][
dst_rdma_rank
]);
min_head
=
min
(
min_head
,
rdma_receiver_rdma_head
[
i
][
dst_rdma_rank
]);
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
and
min_head
>=
last_rdma_head
+
num_max_rdma_chunked_send_tokens
and
lane_id
<
kNumRDMARanks
)
{
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
and
min_head
>=
last_rdma_head
+
num_max_rdma_chunked_send_tokens
and
lane_id
<
kNumRDMARanks
)
{
nvshmem
x_signal_op
(
rdma_channel_head
.
buffer
(
rdma_rank
),
min_head
-
last_rdma_head
,
NVSHMEM_SIGNAL_ADD
,
nvshmem
i_ibgda_amo_nonfetch_add
(
rdma_channel_head
.
buffer
(
rdma_rank
),
min_head
-
last_rdma_head
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
)
,
channel_id
,
dst_rdma_rank
==
rdma_rank
);
last_rdma_head
=
min_head
;
last_rdma_head
=
min_head
;
}
}
}
else
{
}
else
{
...
...
csrc/kernels/internode_ll.cu
View file @
007fcfcf
...
@@ -167,7 +167,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -167,7 +167,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
EP_DEVICE_ASSERT
(
num_sms
>
1
);
EP_DEVICE_ASSERT
(
num_sms
>
1
);
if
(
sm_id
==
0
)
{
if
(
sm_id
==
0
)
{
// The first SM is also responsible for checking QPs
// The first SM is also responsible for checking QPs
EP_DEVICE_ASSERT
(
ibgda_get_state
()
->
num_rc_per_pe
=
=
num_local_experts
);
EP_DEVICE_ASSERT
(
ibgda_get_state
()
->
num_rc_per_pe
>
=
num_local_experts
);
// The first SM is also responsible for cleaning the next buffer
// The first SM is also responsible for cleaning the next buffer
#pragma unroll
#pragma unroll
...
...
csrc/kernels/runtime.cu
View file @
007fcfcf
...
@@ -58,14 +58,12 @@ int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks
...
@@ -58,14 +58,12 @@ int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks
EP_HOST_ASSERT
(
cpu_rdma_team
!=
NVSHMEM_TEAM_INVALID
);
EP_HOST_ASSERT
(
cpu_rdma_team
!=
NVSHMEM_TEAM_INVALID
);
}
}
// Normal operations use IBRC, while low-latency operations use IBGDA
// TODO: we still use `nvshmem_barrier` under IBRC mode, which should be switch to IBGDA mode later
if
(
low_latency_mode
)
{
nvshmemi_device_host_state_t
*
dev_state_ptr
=
nullptr
;
nvshmemi_device_host_state_t
*
dev_state_ptr
=
nullptr
;
CUDA_CHECK
(
cudaGetSymbolAddress
(
reinterpret_cast
<
void
**>
(
&
dev_state_ptr
),
nvshmemi_device_state_d
));
CUDA_CHECK
(
cudaGetSymbolAddress
(
reinterpret_cast
<
void
**>
(
&
dev_state_ptr
),
nvshmemi_device_state_d
));
bool
ibgda_is_initialized
=
false
;
bool
ibgda_is_initialized
=
false
;
CUDA_CHECK
(
cudaMemcpy
(
&
dev_state_ptr
->
ibgda_is_initialized
,
&
ibgda_is_initialized
,
sizeof
(
bool
),
cudaMemcpyHostToDevice
));
CUDA_CHECK
(
cudaMemcpy
(
&
dev_state_ptr
->
ibgda_is_initialized
,
&
ibgda_is_initialized
,
sizeof
(
bool
),
cudaMemcpyHostToDevice
));
}
nvshmem_barrier_all
();
nvshmem_barrier_all
();
return
nvshmem_my_pe
();
return
nvshmem_my_pe
();
}
}
...
...
deep_ep/buffer.py
View file @
007fcfcf
...
@@ -31,7 +31,7 @@ class Buffer:
...
@@ -31,7 +31,7 @@ class Buffer:
def
__init__
(
self
,
group
:
dist
.
ProcessGroup
,
def
__init__
(
self
,
group
:
dist
.
ProcessGroup
,
num_nvl_bytes
:
int
=
0
,
num_rdma_bytes
:
int
=
0
,
num_nvl_bytes
:
int
=
0
,
num_rdma_bytes
:
int
=
0
,
low_latency_mode
:
bool
=
False
,
num_qps_per_rank
:
int
=
1
)
->
None
:
low_latency_mode
:
bool
=
False
,
num_qps_per_rank
:
int
=
1
2
)
->
None
:
"""
"""
Initialize the communication buffer.
Initialize the communication buffer.
...
@@ -66,17 +66,16 @@ class Buffer:
...
@@ -66,17 +66,16 @@ class Buffer:
# Synchronize NVSHMEM unique IDs
# Synchronize NVSHMEM unique IDs
root_unique_id
=
None
root_unique_id
=
None
if
self
.
runtime
.
get_num_rdma_ranks
()
>
1
or
low_latency_mode
:
if
self
.
runtime
.
get_num_rdma_ranks
()
>
1
or
low_latency_mode
:
# Enable IBGDA for the low latency mode, which refers to "no package forwarding between NVLink and RDMA"
# Enable IBGDA
if
low_latency_mode
:
assert
num_qps_per_rank
>
0
assert
num_qps_per_rank
>
0
os
.
environ
[
'NVSHMEM_DISABLE_P2P'
]
=
'1'
os
.
environ
[
'NVSHMEM_DISABLE_P2P'
]
=
'1'
os
.
environ
[
'NVSHMEM_IB_ENABLE_IBGDA'
]
=
'1'
os
.
environ
[
'NVSHMEM_IB_ENABLE_IBGDA'
]
=
'1'
os
.
environ
[
'NVSHMEM_IBGDA_NIC_HANDLER'
]
=
'gpu'
os
.
environ
[
'NVSHMEM_IBGDA_NIC_HANDLER'
]
=
'gpu'
os
.
environ
[
'NVSHMEM_IBGDA_NUM_RC_PER_PE'
]
=
f
'
{
num_qps_per_rank
}
'
os
.
environ
[
'NVSHMEM_IBGDA_NUM_RC_PER_PE'
]
=
f
'
{
num_qps_per_rank
}
'
# Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check
# Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check
os
.
environ
[
'NVSHMEM_QP_DEPTH'
]
=
'1024'
os
.
environ
[
'NVSHMEM_QP_DEPTH'
]
=
'1024'
# NOTES: NVSHMEM initialization requires at least 256 MiB
# NOTES: NVSHMEM initialization requires at least 256 MiB
os
.
environ
[
'NVSHMEM_CUMEM_GRANULARITY'
]
=
f
'
{
2
**
29
}
'
os
.
environ
[
'NVSHMEM_CUMEM_GRANULARITY'
]
=
f
'
{
2
**
29
}
'
# Synchronize using the root ID
# Synchronize using the root ID
nvshmem_unique_ids
=
[
None
,
]
*
self
.
group_size
nvshmem_unique_ids
=
[
None
,
]
*
self
.
group_size
...
...
tests/test_internode.py
View file @
007fcfcf
...
@@ -219,16 +219,19 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
...
@@ -219,16 +219,19 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
def
test_loop
(
local_rank
:
int
,
num_local_ranks
:
int
):
def
test_loop
(
local_rank
:
int
,
num_local_ranks
:
int
):
num_nodes
=
int
(
os
.
getenv
(
'WORLD_SIZE'
,
1
))
num_nodes
=
int
(
os
.
getenv
(
'WORLD_SIZE'
,
1
))
rank
,
num_ranks
,
group
=
init_dist
(
local_rank
,
num_local_ranks
)
rank
,
num_ranks
,
group
=
init_dist
(
local_rank
,
num_local_ranks
)
test_ll_compatibility
=
Fals
e
test_ll_compatibility
=
Tru
e
if
test_ll_compatibility
:
if
test_ll_compatibility
:
ll_num_tokens
,
ll_hidden
,
ll_num_experts
,
ll_num_topk
=
16
,
5120
,
256
,
9
ll_num_tokens
,
ll_hidden
,
ll_num_experts
,
ll_num_topk
=
16
,
5120
,
256
,
9
num_sms
=
24
num_qps_per_rank
=
max
(
num_sms
//
2
,
ll_num_experts
//
num_ranks
if
test_ll_compatibility
else
0
)
buffer
=
deep_ep
.
Buffer
(
group
,
int
(
1e9
),
int
(
1e9
),
low_latency_mode
=
test_ll_compatibility
,
buffer
=
deep_ep
.
Buffer
(
group
,
int
(
1e9
),
int
(
1e9
),
low_latency_mode
=
test_ll_compatibility
,
num_qps_per_rank
=
(
ll_num_experts
//
num_ranks
if
test_ll_compatibility
else
1
)
)
num_qps_per_rank
=
num_qps_per_rank
)
assert
num_local_ranks
==
8
and
num_ranks
>
8
assert
num_local_ranks
==
8
and
num_ranks
>
8
torch
.
manual_seed
(
rank
)
torch
.
manual_seed
(
rank
)
for
i
in
(
24
,
):
for
i
in
(
num_sms
,
):
test_main
(
i
,
local_rank
,
num_local_ranks
,
num_ranks
,
num_nodes
,
rank
,
buffer
,
group
)
test_main
(
i
,
local_rank
,
num_local_ranks
,
num_ranks
,
num_nodes
,
rank
,
buffer
,
group
)
if
local_rank
==
0
:
if
local_rank
==
0
:
print
(
''
,
flush
=
True
)
print
(
''
,
flush
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment