Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
e255d57b
Commit
e255d57b
authored
Apr 22, 2025
by
Shangyan Zhou
Browse files
Use `put_nbi_warp`.
parent
3b1045db
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
52 deletions
+13
-52
csrc/kernels/ibgda_device.cuh
csrc/kernels/ibgda_device.cuh
+2
-36
csrc/kernels/internode.cu
csrc/kernels/internode.cu
+10
-15
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+1
-1
No files found.
csrc/kernels/ibgda_device.cuh
View file @
e255d57b
...
@@ -325,6 +325,7 @@ ibgda_write_empty_recv_wqe(void *out_wqe) {
...
@@ -325,6 +325,7 @@ ibgda_write_empty_recv_wqe(void *out_wqe) {
st_na_relaxed
(
reinterpret_cast
<
int4
*>
(
data_seg_ptr
),
*
reinterpret_cast
<
const
int4
*>
(
&
data_seg
));
st_na_relaxed
(
reinterpret_cast
<
int4
*>
(
data_seg_ptr
),
*
reinterpret_cast
<
const
int4
*>
(
&
data_seg
));
}
}
template
<
bool
kAlwaysDoPostSend
=
false
>
__device__
static
__forceinline__
void
__device__
static
__forceinline__
void
nvshmemi_ibgda_put_nbi_warp
(
uint64_t
req_rptr
,
uint64_t
req_lptr
,
size_t
bytes
,
int
dst_pe
,
int
qp_id
,
int
lane_id
,
int
message_idx
)
{
nvshmemi_ibgda_put_nbi_warp
(
uint64_t
req_rptr
,
uint64_t
req_lptr
,
size_t
bytes
,
int
dst_pe
,
int
qp_id
,
int
lane_id
,
int
message_idx
)
{
// Get lkey and rkey, store them into lanes
// Get lkey and rkey, store them into lanes
...
@@ -365,7 +366,7 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
...
@@ -365,7 +366,7 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
// Submit
// Submit
if
(
lane_id
==
0
)
if
(
lane_id
==
0
)
ibgda_submit_requests
<
false
>
(
qp
,
base_wqe_idx
,
num_wqes
,
message_idx
);
ibgda_submit_requests
<
kAlwaysDoPostSend
>
(
qp
,
base_wqe_idx
,
num_wqes
,
message_idx
);
__syncwarp
();
__syncwarp
();
}
}
...
@@ -431,39 +432,4 @@ __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, cons
...
@@ -431,39 +432,4 @@ __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, cons
}
}
}
}
__device__
static
__forceinline__
void
nvshmemi_ibgda_put_nbi_thread
(
uint64_t
req_rptr
,
uint64_t
req_lptr
,
size_t
bytes
,
int
dst_pe
,
int
qp_id
,
bool
is_local_copy
)
{
if
(
is_local_copy
)
{
// Fallback to NVSHMEM legacy API
// TODO: rewrite local API copy with unrolling and vectorization
nvshmem_uint8_put_nbi
(
reinterpret_cast
<
uint8_t
*>
(
req_rptr
),
reinterpret_cast
<
uint8_t
*>
(
req_lptr
),
bytes
,
dst_pe
);
}
else
{
uint32_t
num_wqes
=
0
;
uint64_t
base_wqe_idx
=
0
;
auto
qp
=
ibgda_get_rc
(
dst_pe
,
qp_id
);
while
(
bytes
>
0
)
{
__be32
lkey
,
rkey
;
uint64_t
laddr
,
raddr
,
chunk_size
;
chunk_size
=
min
(
bytes
,
ibgda_get_lkey_and_rkey
(
laddr
=
req_lptr
,
&
lkey
,
req_rptr
,
dst_pe
,
&
raddr
,
&
rkey
));
bytes
-=
chunk_size
;
auto
wqe_idx
=
ibgda_reserve_wqe_slots
(
qp
,
1
);
auto
wqe_ptr
=
ibgda_get_wqe_ptr
(
qp
,
wqe_idx
);
// Only the last WQE should send imm
ibgda_write_rdma_write_wqe
(
qp
,
laddr
,
lkey
,
raddr
,
rkey
,
chunk_size
,
wqe_idx
,
&
wqe_ptr
);
req_lptr
+=
chunk_size
;
req_rptr
+=
chunk_size
;
if
((
num_wqes
++
)
==
0
)
base_wqe_idx
=
wqe_idx
;
}
ibgda_submit_requests
<
true
>
(
qp
,
base_wqe_idx
,
num_wqes
);
}
}
}
// namespace deep_ep
}
// namespace deep_ep
csrc/kernels/internode.cu
View file @
e255d57b
...
@@ -571,12 +571,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
...
@@ -571,12 +571,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
__syncwarp
();
__syncwarp
();
// Issue RDMA for non-local ranks
// Issue RDMA for non-local ranks
if
(
dst_rdma_rank
!=
rdma_rank
and
lane_id
==
0
)
{
if
(
dst_rdma_rank
!=
rdma_rank
)
{
nvshmemi_ibgda_put_nbi_
thread
(
reinterpret_cast
<
uint64_t
>
(
rdma_channel_meta
.
recv_buffer
(
rdma_rank
)),
nvshmemi_ibgda_put_nbi_
warp
<
true
>
(
reinterpret_cast
<
uint64_t
>
(
rdma_channel_meta
.
recv_buffer
(
rdma_rank
)),
reinterpret_cast
<
uint64_t
>
(
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
)),
reinterpret_cast
<
uint64_t
>
(
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
)),
sizeof
(
int
)
*
(
NUM_MAX_NVL_PEERS
*
2
+
2
),
sizeof
(
int
)
*
(
NUM_MAX_NVL_PEERS
*
2
+
2
),
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
,
false
);
channel_id
,
lane_id
,
0
);
}
}
}
}
sync_rdma_sender_smem
();
sync_rdma_sender_smem
();
...
@@ -724,10 +724,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
...
@@ -724,10 +724,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
const
size_t
num_bytes_per_msg
=
num_bytes_per_rdma_token
*
num_tokens_to_issue
;
const
size_t
num_bytes_per_msg
=
num_bytes_per_rdma_token
*
num_tokens_to_issue
;
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
dst_slot_idx
*
num_bytes_per_rdma_token
);
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
dst_slot_idx
*
num_bytes_per_rdma_token
);
const
auto
src_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
dst_slot_idx
*
num_bytes_per_rdma_token
);
const
auto
src_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
dst_slot_idx
*
num_bytes_per_rdma_token
);
if
(
lane_id
==
dst_rdma_rank
)
{
nvshmemi_ibgda_put_nbi_warp
<
true
>
(
dst_ptr
,
src_ptr
,
num_bytes_per_msg
,
nvshmemi_ibgda_put_nbi_thread
(
dst_ptr
,
src_ptr
,
num_bytes_per_msg
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
,
lane_id
,
0
);
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
,
false
);
}
}
else
{
}
else
{
// Lighter fence for local RDMA rank
// Lighter fence for local RDMA rank
memory_fence
();
memory_fence
();
...
@@ -1574,11 +1572,8 @@ combine(int4* combined_x, float* combined_topk_weights,
...
@@ -1574,11 +1572,8 @@ combine(int4* combined_x, float* combined_topk_weights,
const
size_t
num_bytes_per_msg
=
num_chunked_tokens
*
num_bytes_per_rdma_token
;
const
size_t
num_bytes_per_msg
=
num_chunked_tokens
*
num_bytes_per_rdma_token
;
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
);
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
);
const
auto
src_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
);
const
auto
src_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
);
if
(
lane_id
==
0
)
{
nvshmemi_ibgda_put_nbi_warp
<
true
>
(
dst_ptr
,
src_ptr
,
num_bytes_per_msg
,
// TODO: use the full warp to do this
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
,
lane_id
,
0
);
nvshmemi_ibgda_put_nbi_thread
(
dst_ptr
,
src_ptr
,
num_bytes_per_msg
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
,
false
);
}
}
else
{
}
else
{
memory_fence
();
memory_fence
();
}
}
...
...
csrc/kernels/internode_ll.cu
View file @
e255d57b
...
@@ -215,7 +215,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -215,7 +215,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
// Wait local sends issued and send expert counts
// Wait local sends issued and send expert counts
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
if
(
dst_rank
!=
rank
)
{
if
(
dst_rank
!=
rank
)
{
nvshmemi_ibgda_amo_nonfetch_add
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
dst_rank
,
dst_expert_local_idx
,
false
);
nvshmemi_ibgda_amo_nonfetch_add
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
dst_rank
,
dst_expert_local_idx
);
}
else
{
}
else
{
st_na_release
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
);
st_na_release
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment