Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
20b2aaaf
Commit
20b2aaaf
authored
Apr 22, 2025
by
Shangyan Zhou
Browse files
Refactor some code.
parent
c07fdd19
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
90 additions
and
61 deletions
+90
-61
csrc/kernels/ibgda_device.cuh
csrc/kernels/ibgda_device.cuh
+50
-10
csrc/kernels/internode.cu
csrc/kernels/internode.cu
+33
-45
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+2
-2
tests/test_internode.py
tests/test_internode.py
+5
-4
No files found.
csrc/kernels/ibgda_device.cuh
View file @
20b2aaaf
...
...
@@ -410,20 +410,60 @@ __device__ static __forceinline__ void ibgda_write_amo_add_wqe(
st_na_relaxed
(
reinterpret_cast
<
int4
*>
(
data_seg_ptr
),
*
reinterpret_cast
<
int4
*>
(
&
data_seg
));
}
__device__
__forceinline__
void
nvshmemi_ibgda_amo_nonfetch_add
(
void
*
rptr
,
const
int
&
value
,
int
pe
,
int
qp_id
)
{
nvshmemi_ibgda_device_qp_t
*
qp
=
ibgda_get_rc
(
pe
,
qp_id
);
__device__
__forceinline__
void
nvshmemi_ibgda_amo_nonfetch_add
(
void
*
rptr
,
const
int
&
value
,
int
pe
,
int
qp_id
,
bool
is_local_copy
=
false
)
{
if
(
is_local_copy
)
{
// Fallback to NVSHMEM legacy API
nvshmemx_signal_op
(
reinterpret_cast
<
uint64_t
*>
(
rptr
),
value
,
NVSHMEM_SIGNAL_ADD
,
pe
);
}
else
{
nvshmemi_ibgda_device_qp_t
*
qp
=
ibgda_get_rc
(
pe
,
qp_id
);
__be32
rkey
;
uint64_t
raddr
;
ibgda_get_rkey
(
reinterpret_cast
<
uint64_t
>
(
rptr
),
pe
,
&
raddr
,
&
rkey
);
__be32
rkey
;
uint64_t
raddr
;
ibgda_get_rkey
(
reinterpret_cast
<
uint64_t
>
(
rptr
),
pe
,
&
raddr
,
&
rkey
);
uint64_t
my_wqe_idx
=
ibgda_reserve_wqe_slots
(
qp
,
1
);
void
*
wqe_ptrs
=
ibgda_get_wqe_ptr
(
qp
,
my_wqe_idx
);
uint64_t
my_wqe_idx
=
ibgda_reserve_wqe_slots
(
qp
,
1
);
void
*
wqe_ptrs
=
ibgda_get_wqe_ptr
(
qp
,
my_wqe_idx
);
ibgda_write_amo_add_wqe
(
qp
,
value
,
reinterpret_cast
<
uint64_t
>
(
qp
->
ibuf
.
buf
),
qp
->
ibuf
.
lkey
,
raddr
,
rkey
,
my_wqe_idx
,
&
wqe_ptrs
);
ibgda_write_amo_add_wqe
(
qp
,
value
,
reinterpret_cast
<
uint64_t
>
(
qp
->
ibuf
.
buf
),
qp
->
ibuf
.
lkey
,
raddr
,
rkey
,
my_wqe_idx
,
&
wqe_ptrs
);
ibgda_submit_requests
<
true
>
(
qp
,
my_wqe_idx
,
1
);
}
}
ibgda_submit_requests
<
true
>
(
qp
,
my_wqe_idx
,
1
);
__device__
static
__forceinline__
void
nvshmemi_ibgda_put_nbi_thread
(
uint64_t
req_rptr
,
uint64_t
req_lptr
,
size_t
bytes
,
int
dst_pe
,
int
qp_id
,
bool
is_local_copy
)
{
if
(
is_local_copy
)
{
// Fallback to NVSHMEM legacy API
// TODO: rewrite local API copy with unrolling and vectorization
nvshmem_uint8_put_nbi
(
reinterpret_cast
<
uint8_t
*>
(
req_rptr
),
reinterpret_cast
<
uint8_t
*>
(
req_lptr
),
bytes
,
dst_pe
);
}
else
{
uint32_t
num_wqes
=
0
;
uint64_t
base_wqe_idx
=
0
;
auto
qp
=
ibgda_get_rc
(
dst_pe
,
qp_id
);
while
(
bytes
>
0
)
{
__be32
lkey
,
rkey
;
uint64_t
laddr
,
raddr
,
chunk_size
;
chunk_size
=
min
(
bytes
,
ibgda_get_lkey_and_rkey
(
laddr
=
req_lptr
,
&
lkey
,
req_rptr
,
dst_pe
,
&
raddr
,
&
rkey
));
bytes
-=
chunk_size
;
auto
wqe_idx
=
ibgda_reserve_wqe_slots
(
qp
,
1
);
auto
wqe_ptr
=
ibgda_get_wqe_ptr
(
qp
,
wqe_idx
);
// Only the last WQE should send imm
ibgda_write_rdma_write_wqe
(
qp
,
laddr
,
lkey
,
raddr
,
rkey
,
chunk_size
,
wqe_idx
,
&
wqe_ptr
);
req_lptr
+=
chunk_size
;
req_rptr
+=
chunk_size
;
if
((
num_wqes
++
)
==
0
)
base_wqe_idx
=
wqe_idx
;
}
ibgda_submit_requests
<
true
>
(
qp
,
base_wqe_idx
,
num_wqes
);
}
}
}
// namespace deep_ep
csrc/kernels/internode.cu
View file @
20b2aaaf
...
...
@@ -480,6 +480,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
const
bool
is_forwarder
=
sm_id
%
2
==
0
;
const
auto
rdma_rank
=
rank
/
NUM_MAX_NVL_PEERS
,
nvl_rank
=
rank
%
NUM_MAX_NVL_PEERS
;
EP_DEVICE_ASSERT
(
ibgda_get_state
()
->
num_rc_per_pe
>=
num_channels
);
const
auto
role_meta
=
[
=
]()
->
std
::
pair
<
WarpRole
,
int
>
{
if
(
is_forwarder
)
{
if
(
warp_id
<
NUM_MAX_NVL_PEERS
)
{
...
...
@@ -556,19 +558,27 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Send number of tokens in this channel by `-value - 1`
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
*
2
+
2
<=
32
,
"Invalid number of NVL peers"
);
for
(
int
dst_rdma_rank
=
warp_id
;
dst_rdma_rank
<
kNumRDMARanks
;
dst_rdma_rank
+=
kNumDispatchRDMASenderWarps
)
{
auto
dst_ptr
=
dst_rdma_rank
==
rdma_rank
?
rdma_channel_meta
.
recv_buffer
(
dst_rdma_rank
)
:
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
);
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
{
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
)
[
lane_id
]
=
-
(
channel_id
==
0
?
0
:
gbl_channel_prefix_matrix
[(
dst_rdma_rank
*
NUM_MAX_NVL_PEERS
+
lane_id
)
*
num_channels
+
channel_id
-
1
])
-
1
;
dst_ptr
[
lane_id
]
=
-
(
channel_id
==
0
?
0
:
gbl_channel_prefix_matrix
[(
dst_rdma_rank
*
NUM_MAX_NVL_PEERS
+
lane_id
)
*
num_channels
+
channel_id
-
1
])
-
1
;
}
else
if
(
lane_id
<
NUM_MAX_NVL_PEERS
*
2
)
{
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
)
[
lane_id
]
=
-
gbl_channel_prefix_matrix
[(
dst_rdma_rank
*
NUM_MAX_NVL_PEERS
+
lane_id
-
NUM_MAX_NVL_PEERS
)
*
num_channels
+
channel_id
]
-
1
;
dst_ptr
[
lane_id
]
=
-
gbl_channel_prefix_matrix
[(
dst_rdma_rank
*
NUM_MAX_NVL_PEERS
+
lane_id
-
NUM_MAX_NVL_PEERS
)
*
num_channels
+
channel_id
]
-
1
;
}
else
if
(
lane_id
==
NUM_MAX_NVL_PEERS
*
2
)
{
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
)
[
lane_id
]
=
-
(
channel_id
==
0
?
0
:
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel_id
-
1
])
-
1
;
dst_ptr
[
lane_id
]
=
-
(
channel_id
==
0
?
0
:
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel_id
-
1
])
-
1
;
}
else
if
(
lane_id
==
NUM_MAX_NVL_PEERS
*
2
+
1
)
{
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
)[
lane_id
]
=
-
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel_id
]
-
1
;
dst_ptr
[
lane_id
]
=
-
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel_id
]
-
1
;
}
__syncwarp
();
// Issue RDMA for non-local ranks
if
(
dst_rdma_rank
!=
rdma_rank
and
lane_id
==
0
)
{
nvshmemi_ibgda_put_nbi_thread
(
reinterpret_cast
<
uint64_t
>
(
rdma_channel_meta
.
recv_buffer
(
rdma_rank
)),
reinterpret_cast
<
uint64_t
>
(
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
)),
sizeof
(
int
)
*
(
NUM_MAX_NVL_PEERS
*
2
+
2
),
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
,
false
);
}
nvshmemx_int_put_nbi_warp
(
rdma_channel_meta
.
recv_buffer
(
rdma_rank
),
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
),
NUM_MAX_NVL_PEERS
*
2
+
2
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
}
nvshmem_fence
();
sync_rdma_sender_smem
();
// Iterate over tokens and copy into buffer
...
...
@@ -711,12 +721,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if
(
dst_rdma_rank
!=
rdma_rank
)
{
auto
dst_slot_idx
=
synced_last_issued_tail
%
num_max_rdma_chunked_recv_tokens
;
EP_DEVICE_ASSERT
(
dst_slot_idx
+
num_tokens_to_issue
<=
num_max_rdma_chunked_recv_tokens
);
const
size_t
num_bytes_per_msg
=
(
num_bytes_per_rdma_token
*
num_tokens_to_issue
)
*
sizeof
(
int8_t
)
;
const
size_t
num_bytes_per_msg
=
num_bytes_per_rdma_token
*
num_tokens_to_issue
;
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
dst_slot_idx
*
num_bytes_per_rdma_token
);
const
auto
src_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
dst_slot_idx
*
num_bytes_per_rdma_token
);
nvshmemi_ibgda_put_nbi_warp
(
dst_ptr
,
src_ptr
,
num_bytes_per_msg
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
,
lane_id
,
3
);
nvshmem_fence
(
);
if
(
lane_id
==
dst_rdma_rank
)
nvshmemi_ibgda_put_nbi_thread
(
dst_ptr
,
src_ptr
,
num_bytes_per_msg
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
,
false
);
}
else
{
// Lighter fence for local RDMA rank
memory_fence
();
...
...
@@ -727,13 +737,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if
(
lane_id
==
dst_rdma_rank
)
{
last_issued_tail
+=
num_tokens_to_issue
;
num_tokens_to_send
-=
num_tokens_to_issue
;
if
(
dst_rdma_rank
!=
rdma_rank
)
{
nvshmemi_ibgda_amo_nonfetch_add
(
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_tokens_to_issue
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
);
}
else
{
nvshmemx_signal_op
(
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_tokens_to_issue
,
NVSHMEM_SIGNAL_ADD
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
}
nvshmemi_ibgda_amo_nonfetch_add
(
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_tokens_to_issue
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
,
dst_rdma_rank
==
rdma_rank
);
}
}
}
...
...
@@ -933,13 +938,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Update remote head
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
and
min_head
>=
last_head
+
num_max_rdma_chunked_send_tokens
and
lane_id
<
kNumRDMARanks
)
{
if
(
lane_id
!=
rdma_rank
)
{
nvshmemi_ibgda_amo_nonfetch_add
(
rdma_channel_head
.
buffer
(
rdma_rank
),
min_head
-
last_head
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
lane_id
,
nvl_rank
),
channel_id
);
}
else
{
nvshmemx_signal_op
(
rdma_channel_head
.
buffer
(
rdma_rank
),
min_head
-
last_head
,
NVSHMEM_SIGNAL_ADD
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
lane_id
,
nvl_rank
));
}
nvshmemi_ibgda_amo_nonfetch_add
(
rdma_channel_head
.
buffer
(
rdma_rank
),
min_head
-
last_head
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
lane_id
,
nvl_rank
),
channel_id
,
lane_id
==
rdma_rank
);
last_head
=
min_head
;
}
...
...
@@ -1570,12 +1570,12 @@ combine(int4* combined_x, float* combined_topk_weights,
if
(
sub_warp_id
==
kNumWarpsPerForwarder
-
1
)
{
if
(
dst_rdma_rank
!=
rdma_rank
)
{
auto
rdma_slot_idx
=
token_start_idx
%
num_max_rdma_chunked_recv_tokens
;
const
size_t
num_bytes_per_msg
=
(
num_chunked_tokens
*
num_bytes_per_rdma_token
)
*
sizeof
(
int8_t
)
;
const
size_t
num_bytes_per_msg
=
num_chunked_tokens
*
num_bytes_per_rdma_token
;
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
);
const
auto
src_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
);
nvshmemi_ibgda_put_nbi_warp
(
dst_ptr
,
src_ptr
,
num_bytes_per_msg
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
,
lane_id
,
3
);
nvshmem_fence
(
);
if
(
lane_id
==
0
)
nvshmemi_ibgda_put_nbi_thread
(
dst_ptr
,
src_ptr
,
num_bytes_per_msg
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
,
false
);
}
else
{
memory_fence
();
}
...
...
@@ -1583,13 +1583,8 @@ combine(int4* combined_x, float* combined_topk_weights,
// Write new RDMA tail
__syncwarp
();
if
(
lane_id
==
0
)
{
if
(
dst_rdma_rank
!=
rdma_rank
)
{
nvshmemi_ibgda_amo_nonfetch_add
(
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_chunked_tokens
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
);
}
else
{
nvshmemx_signal_op
(
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_chunked_tokens
,
NVSHMEM_SIGNAL_ADD
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
}
nvshmemi_ibgda_amo_nonfetch_add
(
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_chunked_tokens
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
,
dst_rdma_rank
==
rdma_rank
);
}
}
}
...
...
@@ -1675,15 +1670,8 @@ combine(int4* combined_x, float* combined_topk_weights,
for
(
int
i
=
0
;
i
<
kNumRDMAReceivers
;
++
i
)
if
(
not
rdma_receiver_retired
[
i
])
min_head
=
min
(
min_head
,
rdma_receiver_rdma_head
[
i
][
dst_rdma_rank
]);
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
and
min_head
>=
last_rdma_head
+
num_max_rdma_chunked_send_tokens
and
lane_id
<
kNumRDMARanks
)
{
// nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD,
// translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
if
(
dst_rdma_rank
!=
rdma_rank
)
{
nvshmemi_ibgda_amo_nonfetch_add
(
rdma_channel_head
.
buffer
(
rdma_rank
),
min_head
-
last_rdma_head
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
);
}
else
{
nvshmemx_signal_op
(
rdma_channel_head
.
buffer
(
rdma_rank
),
min_head
-
last_rdma_head
,
NVSHMEM_SIGNAL_ADD
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
}
nvshmemi_ibgda_amo_nonfetch_add
(
rdma_channel_head
.
buffer
(
rdma_rank
),
min_head
-
last_rdma_head
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
,
dst_rdma_rank
==
rdma_rank
);
last_rdma_head
=
min_head
;
}
}
else
{
...
...
csrc/kernels/internode_ll.cu
View file @
20b2aaaf
...
...
@@ -167,7 +167,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
EP_DEVICE_ASSERT
(
num_sms
>
1
);
if
(
sm_id
==
0
)
{
// The first SM is also responsible for checking QPs
EP_DEVICE_ASSERT
(
ibgda_get_state
()
->
num_rc_per_pe
=
=
num_local_experts
);
EP_DEVICE_ASSERT
(
ibgda_get_state
()
->
num_rc_per_pe
>
=
num_local_experts
);
// The first SM is also responsible for cleaning the next buffer
#pragma unroll
...
...
@@ -215,7 +215,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
// Wait local sends issued and send expert counts
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
if
(
dst_rank
!=
rank
)
{
nvshmemi_ibgda_amo_nonfetch_add
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
dst_rank
,
dst_expert_local_idx
);
nvshmemi_ibgda_amo_nonfetch_add
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
dst_rank
,
dst_expert_local_idx
,
false
);
}
else
{
st_na_release
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
);
}
...
...
tests/test_internode.py
View file @
20b2aaaf
...
...
@@ -218,15 +218,16 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
# noinspection PyUnboundLocalVariable
def
test_loop
(
local_rank
:
int
,
num_local_ranks
:
int
):
num_nodes
=
int
(
os
.
getenv
(
'WORLD_SIZE'
,
1
))
num_sms
=
24
qp_num
=
num_sms
//
2
rank
,
num_ranks
,
group
=
init_dist
(
local_rank
,
num_local_ranks
)
test_ll_compatibility
=
Fals
e
test_ll_compatibility
=
Tru
e
if
test_ll_compatibility
:
ll_num_tokens
,
ll_hidden
,
ll_num_experts
,
ll_num_topk
=
16
,
5120
,
256
,
9
num_sms
=
24
num_qps_per_rank
=
max
(
num_sms
//
2
,
ll_num_experts
//
num_ranks
if
test_ll_compatibility
else
0
)
buffer
=
deep_ep
.
Buffer
(
group
,
int
(
1e9
),
int
(
1e9
),
low_latency_mode
=
test_ll_compatibility
,
num_qps_per_rank
=
(
ll_num_experts
//
num_ranks
if
test_ll_compatibility
else
qp_num
)
)
num_qps_per_rank
=
num_qps_per_rank
)
assert
num_local_ranks
==
8
and
num_ranks
>
8
torch
.
manual_seed
(
rank
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment