Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
bc11ea32
Commit
bc11ea32
authored
Dec 26, 2025
by
lishen
Browse files
ROCSHMEM加入multiqp接口,编译时添加选项
parent
1b00b9d8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
92 additions
and
83 deletions
+92
-83
build.sh
build.sh
+27
-15
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+53
-68
csrc/kernels/shmem_wrapper.cuh
csrc/kernels/shmem_wrapper.cuh
+12
-0
No files found.
build.sh
View file @
bc11ea32
...
...
@@ -31,24 +31,33 @@ PYTHON_PLATLIB=$(python3 -c "from sysconfig import get_paths; print(get_paths()[
USE_NVSHMEM
=
OFF
USE_ROCSHMEM
=
OFF
ROCM_DISABLE_CTX
=
OFF
case
"
$1
"
in
rocshmem
)
USE_ROCSHMEM
=
ON
;;
nvshmem|dushmem
)
USE_NVSHMEM
=
ON
;;
*
)
echo
"Usage: ./build.sh rocshmem [ROCM_DISABLE_CTX] / ./build.sh nvshmem"
exit
1
;;
esac
if
[
"
${
2
:-}
"
=
"ROCM_DISABLE_CTX"
]
;
then
ROCM_DISABLE_CTX
=
ON
fi
ROCM_USE_MULTIQP
=
OFF
# 解析命令行参数
for
arg
in
"
$@
"
;
do
case
$arg
in
rocshmem
)
USE_ROCSHMEM
=
ON
;;
nvshmem|dushmem
)
USE_NVSHMEM
=
ON
;;
ROCM_DISABLE_CTX
=
ON
)
ROCM_DISABLE_CTX
=
ON
;;
ROCM_USE_MULTIQP
=
ON
)
ROCM_USE_MULTIQP
=
ON
;;
*
)
echo
"Usage: ./build.sh rocshmem [ROCM_DISABLE_CTX=ON] [ROCM_USE_MULTIQP=ON] / ./build.sh nvshmem"
exit
1
;;
esac
done
echo
"USE_NVSHMEM=
$USE_NVSHMEM
"
echo
"USE_ROCSHMEM=
$USE_ROCSHMEM
"
echo
"ROCM_DISABLE_CTX=
$ROCM_DISABLE_CTX
"
echo
"ROCM_USE_MULTIQP=
$ROCM_USE_MULTIQP
"
# -------------------------- With rocSHMEM -------------------------- #
build_rocshmem
()
...
...
@@ -84,6 +93,9 @@ if [ "$USE_ROCSHMEM" == "ON" ]; then
if
[
"
$ROCM_DISABLE_CTX
"
==
"ON"
]
;
then
COMPILE_OPTIONS
=
"-DROCM_DISABLE_CTX
$COMPILE_OPTIONS
"
fi
if
[
"
$ROCM_USE_MULTIQP
"
==
"ON"
]
;
then
COMPILE_OPTIONS
=
"-DROCM_USE_MULTIQP
$COMPILE_OPTIONS
"
fi
SHMEM_LINK_OPTIONS
=
${
SHMEM_LINK_OPTIONS
:
=
"-Wl,-rpath,
${
SHMEM_INSTALL_PREFIX
}
/lib/ -l:librocshmem.a"
}
fi
# -------------------------- rocSHMEM END -------------------------- #
...
...
csrc/kernels/internode_ll.cu
View file @
bc11ea32
...
...
@@ -8,9 +8,6 @@
#include "hip/hip_runtime.h"
// low latency+RocSHMEM has issue with CTX.
#define ROCM_DISABLE_CTX
#include "shmem_wrapper.cuh"
namespace
deep_ep
{
...
...
@@ -133,11 +130,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
if
((
phases
&
LOW_LATENCY_SEND_PHASE
)
==
0
)
goto
LOW_LATENCY_DISPATCH_RECV
;
#if !defined(ROCM_DISABLE_CTX)
__shared__
internode
::
shmem_ctx_t
ctx
;
internode
::
shmem_wg_ctx_create
(
&
ctx
);
#endif
// There are 2 kinds of warps in this part:
// 1. The first-kind warps for FP8 cast and sending top-k tokens
// 2. The last warp for reading `topk_idx` and count for per-expert information
...
...
@@ -265,24 +257,29 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
slot_idx
*
num_bytes_per_msg
;
if
(
dst_rank
!=
rank
)
{
#if defined(FORCE_NVSHMEM_API)
#if defined(FORCE_NVSHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
char
*
req_rptr_actual
=
(
char
*
)(
peer_base_addr
)
+
((
char
*
)
dst_ptr
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
));
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
const
auto
*
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
req_rptr_actual
);
UNROLLED_WARP_COPY
(
8
,
lane_id
,
num_int4_per_msg
,
dst_int4_ptr
,
src_int4_ptr
,
ld_nc_global
,
st_na_global
);
}
else
#endif
{
#if !defined(ROCM_DISABLE_CTX)
internode
::
shmem_ctx_schar_put_nbi_warp
(
ctx
,
}
else
{
internode
::
shmemx_int8_put_nbi_warp
(
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
num_bytes_per_msg
,
dst_rank
);
}
#else
#if !defined(ROCM_USE_MULTIQP)
internode
::
shmemx_int8_put_nbi_warp
(
#endif
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
num_bytes_per_msg
,
dst_rank
);
}
#else
internode
::
shmemx_int8_put_nbi_warp_dp
(
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
num_bytes_per_msg
,
(
dst_expert_local_idx
+
1
)
*
num_ranks
+
dst_rank
,
dst_rank
);
#endif
#endif // defined(FORCE_NVSHMEM_API)
}
else
{
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
...
...
@@ -345,22 +342,26 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Wait local sends issued and send expert counts
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
if
(
dst_rank
!=
rank
)
{
#if defined(FORCE_NVSHMEM_API)
#if defined(FORCE_NVSHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
// P2P enabled
int
*
rptr_actual
=
(
int
*
)((
char
*
)(
peer_base_addr
)
+
((
char
*
)(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
)
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
)));
st_na_release
(
rptr_actual
,
-
num_tokens_sent
-
1
);
}
else
#endif
{
#if !defined(ROCM_DISABLE_CTX)
internode
::
shmem_ctx_long_atomic_add
(
ctx
,
}
else
{
internode
::
shmem_long_atomic_add
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
dst_rank
);
}
#else
#if !defined(ROCM_USE_MULTIQP)
internode
::
shmem_long_atomic_add
(
#endif
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
dst_rank
);
}
#else
internode
::
shmem_long_atomic_add_dp
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
(
dst_expert_local_idx
+
1
)
*
num_ranks
+
dst_rank
,
dst_rank
);
#endif
#endif // defined(FORCE_NVSHMEM_API)
}
else
{
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
),
-
num_tokens_sent
-
1
);
}
...
...
@@ -375,10 +376,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
}
syncwarp
();
#if !defined(ROCM_DISABLE_CTX)
internode
::
shmem_wg_ctx_destroy
(
&
ctx
);
#endif
// Receiving phase
LOW_LATENCY_DISPATCH_RECV:
if
((
phases
&
LOW_LATENCY_RECV_PHASE
)
==
0
)
...
...
@@ -591,11 +588,6 @@ combine(void* combined_x,
if
((
phases
&
LOW_LATENCY_SEND_PHASE
)
==
0
)
goto
LOW_LATENCY_COMBINE_RECV
;
#if !defined(ROCM_DISABLE_CTX)
__shared__
internode
::
shmem_ctx_t
ctx
;
internode
::
shmem_wg_ctx_create
(
&
ctx
);
#endif
// Clean up next buffer
if
(
sm_id
==
0
and
warp_group_id
==
0
and
sub_warp_id
==
0
)
{
#pragma unroll
...
...
@@ -642,23 +634,28 @@ combine(void* combined_x,
if
(
not
zero_copy
)
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_bf16_int4
,
buf_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
#if defined(FORCE_NVSHMEM_API)
#if defined(FORCE_NVSHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
char
*
req_rptr_actual
=
(
char
*
)(
peer_base_addr
)
+
((
char
*
)
dst_ptr
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
));
const
auto
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
req_rptr_actual
);
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
}
else
#endif
{
#if !defined(ROCM_DISABLE_CTX)
internode
::
shmem_ctx_schar_put_nbi_warp
(
ctx
,
}
else
{
internode
::
shmemx_int8_put_nbi_warp
(
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
buf_ptr
),
hidden
*
sizeof
(
hip_bfloat16
),
dst_rank
);
}
#else
#if !defined(ROCM_USE_MULTIQP)
internode
::
shmemx_int8_put_nbi_warp
(
#endif
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
buf_ptr
),
hidden
*
sizeof
(
hip_bfloat16
),
dst_rank
);
}
#else
internode
::
shmemx_int8_put_nbi_warp_dp
(
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
buf_ptr
),
hidden
*
sizeof
(
hip_bfloat16
),
(
local_expert_idx
+
1
)
*
num_ranks
+
dst_rank
,
dst_rank
);
#endif
#endif // defined(FORCE_NVSHMEM_API)
}
}
...
...
@@ -673,22 +670,26 @@ combine(void* combined_x,
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
while
(
ld_acquire_global
(
atomic_clean_flag
)
==
0
);
if
(
dst_rank
!=
rank
)
{
#if defined(FORCE_NVSHMEM_API)
#if defined(FORCE_NVSHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
int
*
req_rptr_actual
=
(
int
*
)((
char
*
)(
peer_base_addr
)
+
((
char
*
)(
rdma_recv_flag
+
global_expert_idx
)
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
)));
st_na_release
(
req_rptr_actual
,
1
);
}
else
#endif
{
#if !defined(ROCM_DISABLE_CTX)
internode
::
shmem_ctx_long_atomic_add
(
ctx
,
}
else
{
internode
::
shmem_long_atomic_add
(
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
);
}
#else
#if !defined(ROCM_USE_MULTIQP)
internode
::
shmem_long_atomic_add
(
#endif
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
);
}
#else
internode
::
shmem_long_atomic_add_dp
(
rdma_recv_flag
+
global_expert_idx
,
1
,
(
local_expert_idx
+
1
)
*
num_ranks
+
dst_rank
,
dst_rank
);
#endif
#endif // defined(FORCE_NVSHMEM_API)
}
else
{
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_flag
+
global_expert_idx
),
1
);
}
...
...
@@ -696,32 +697,16 @@ combine(void* combined_x,
}
syncwarp
();
if
(
num_ranks
>
8
){
#if !defined(ROCM_DISABLE_CTX)
internode
::
shmem_ctx_quiet
(
ctx
);
#else
internode
::
shmem_fence
();
#endif
}
// if (num_ranks > 8){
// internode::shmem_fence();
// }
}
#if !defined(ROCM_DISABLE_CTX)
internode
::
shmem_wg_ctx_destroy
(
&
ctx
);
#endif
// Receiving phase
LOW_LATENCY_COMBINE_RECV:
if
((
phases
&
LOW_LATENCY_RECV_PHASE
)
==
0
)
return
;
// if (num_ranks > 8){
// #if !defined(ROCM_DISABLE_CTX)
// internode::shmem_ctx_quiet(ctx);
// #else
// internode::shmem_fence();
// #endif
// }
// Wait all ranks to arrive and notify PCIe usage
if
(
responsible_expert_idx
<
num_experts
)
{
EP_DEVICE_ASSERT
(
num_warps_per_group
>
1
);
...
...
csrc/kernels/shmem_wrapper.cuh
View file @
bc11ea32
...
...
@@ -116,6 +116,18 @@ __device__ inline void shmem_long_atomic_add(
rocshmem
::
rocshmem_long_atomic_add
(
dest
,
value
,
pe
);
}
#if defined(ROCM_USE_MULTIQP)
__device__
inline
void
shmemx_int8_put_nbi_warp_dp
(
signed
char
*
dest
,
const
signed
char
*
source
,
size_t
nelems
,
int
qp_idx
,
int
pe
)
{
rocshmem
::
rocshmem_schar_put_nbi_wave_dp
(
dest
,
source
,
nelems
,
qp_idx
,
pe
);
}
__device__
inline
void
shmem_long_atomic_add_dp
(
long
*
dest
,
long
value
,
int
qp_idx
,
int
pe
)
{
rocshmem
::
rocshmem_long_atomic_add_dp
(
dest
,
value
,
qp_idx
,
pe
);
}
#endif
#if !defined(ROCM_DISABLE_CTX)
using
shmem_ctx_t
=
rocshmem
::
rocshmem_ctx_t
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment