Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
a1382ed7
Commit
a1382ed7
authored
Jan 23, 2026
by
lishen
Browse files
接入ROCSHMEM的multiqp优化
parent
314d9021
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
129 additions
and
173 deletions
+129
-173
build.sh
build.sh
+24
-11
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+84
-119
csrc/kernels/shmem_wrapper.cuh
csrc/kernels/shmem_wrapper.cuh
+19
-1
deep_ep/version.py
deep_ep/version.py
+0
-40
setup.py
setup.py
+2
-2
No files found.
build.sh
View file @
a1382ed7
...
...
@@ -31,7 +31,7 @@ PYTHON_PLATLIB=$(python3 -c "from sysconfig import get_paths; print(get_paths()[
USE_NVSHMEM
=
OFF
USE_ROCSHMEM
=
OFF
ROCM_DISABLE_CTX
=
OFF
ROCM_
US
E_MULTIQP
=
OFF
ROCM_
DISABL
E_MULTIQP
=
OFF
# 解析命令行参数
for
arg
in
"
$@
"
;
do
case
$arg
in
...
...
@@ -44,20 +44,33 @@ for arg in "$@"; do
ROCM_DISABLE_CTX
=
ON
)
ROCM_DISABLE_CTX
=
ON
;;
ROCM_
US
E_MULTIQP
=
ON
)
ROCM_
US
E_MULTIQP
=
ON
ROCM_
DISABL
E_MULTIQP
=
ON
)
ROCM_
DISABL
E_MULTIQP
=
ON
;;
*
)
echo
"Usage: ./build.sh rocshmem [ROCM_DISABLE_CTX=ON] [ROCM_
US
E_MULTIQP=ON] / ./build.sh dushmem"
echo
"Usage: ./build.sh rocshmem [ROCM_DISABLE_CTX=ON] [ROCM_
DISABL
E_MULTIQP=ON] / ./build.sh dushmem"
exit
1
;;
esac
done
detect_offload_arch
()
{
# 尝试使用 rocm_agent_enumerator 获取所有 gfx 架构,并按字典序降序取第一个(即“最新”)
if
command
-v
rocm_agent_enumerator
>
/dev/null 2>&1
;
then
arch
=
$(
rocm_agent_enumerator 2>/dev/null |
grep
-E
'^gfx[0-9]+'
|
sort
-r
|
head
-n1
)
if
[
-n
"
$arch
"
]
;
then
echo
"
$arch
"
return
0
fi
fi
}
DETECTED_ARCH
=
$(
detect_offload_arch
)
echo
"Using --offload-arch=
$DETECTED_ARCH
"
echo
"USE_NVSHMEM=
$USE_NVSHMEM
"
echo
"USE_ROCSHMEM=
$USE_ROCSHMEM
"
echo
"ROCM_DISABLE_CTX=
$ROCM_DISABLE_CTX
"
echo
"ROCM_
US
E_MULTIQP=
$ROCM_
US
E_MULTIQP
"
echo
"ROCM_
DISABL
E_MULTIQP=
$ROCM_
DISABL
E_MULTIQP
"
# -------------------------- With rocSHMEM -------------------------- #
build_rocshmem
()
...
...
@@ -72,7 +85,7 @@ build_rocshmem()
return
1
}
echo
"cd third-party/rocshmem/build"
../scripts/build_configs/gda_mlx5
bash
../scripts/build_configs/gda_mlx5
echo
"编译rocshmem成功"
cd
"
$src_path
"
}
...
...
@@ -89,12 +102,12 @@ if [ "$USE_ROCSHMEM" == "ON" ]; then
build_rocshmem
SHMEM_INSTALL_PREFIX
=
$(
pwd
)
/third-party/rocshmem_install
COMPILE_OPTIONS
=
${
COMPILE_OPTIONS
:
= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=
gfx936 --offload-arch=gfx938
-std=c++17 -Wno-return-type
}
COMPILE_OPTIONS
=
${
COMPILE_OPTIONS
:
= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=
$DETECTED_ARCH
-std=c++17 -Wno-return-type
}
if
[
"
$ROCM_DISABLE_CTX
"
==
"ON"
]
;
then
COMPILE_OPTIONS
=
"-DROCM_DISABLE_CTX
$COMPILE_OPTIONS
"
fi
if
[
"
$ROCM_
US
E_MULTIQP
"
==
"ON"
]
;
then
COMPILE_OPTIONS
=
"-DROCM_
US
E_MULTIQP
$COMPILE_OPTIONS
"
if
[
"
$ROCM_
DISABL
E_MULTIQP
"
==
"ON"
]
;
then
COMPILE_OPTIONS
=
"-DROCM_
DISABL
E_MULTIQP
$COMPILE_OPTIONS
"
fi
SHMEM_LINK_OPTIONS
=
${
SHMEM_LINK_OPTIONS
:
=
"-Wl,-rpath,
${
SHMEM_INSTALL_PREFIX
}
/lib/ -l:librocshmem.a"
}
fi
...
...
@@ -133,7 +146,7 @@ if [ "$USE_NVSHMEM" == "ON" ]; then
# build_dushmem
# SHMEM_INSTALL_PREFIX=$(pwd)/third-party/dushmem_install
SHMEM_INSTALL_PREFIX
=
${
ROCM_PATH
}
/dushmem
COMPILE_OPTIONS
=
${
COMPILE_OPTIONS
:
= -fPIC -DFORCE_DUSHMEM_API -DHIP_ENABLE_WARP_SYNC_BUILTINS -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=
gfx936 --offload-arch=gfx938
-std=c++17 -Wno-return-type
}
COMPILE_OPTIONS
=
${
COMPILE_OPTIONS
:
= -fPIC -DFORCE_DUSHMEM_API -DHIP_ENABLE_WARP_SYNC_BUILTINS -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=
$DETECTED_ARCH
-std=c++17 -Wno-return-type
}
SHMEM_LINK_OPTIONS
=
"-Wl,-rpath,
${
SHMEM_INSTALL_PREFIX
}
/lib/ -l:libdushmem_device.a -ldushmem_host"
fi
# -------------------------- duSHMEM END -------------------------- #
...
...
@@ -147,7 +160,7 @@ hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode.cu -o build_/internode.o
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/internode_ll.cu
-o
build_/internode_ll.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/deep_ep.cu
-o
build_/deep_ep.o
${
COMPILE_OPTIONS
}
hipcc
-Wno-unused-result
-Wsign-compare
-DNDEBUG
-g
-fwrapv
-O2
-Wall
-g
-fstack-protector-strong
-Wformat
-Werror
=
format-security
-g
-fwrapv
-O2
-shared
-Wl
,-O1
-Wl
,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o build_/internode_ll.o
-L
${
SHMEM_INSTALL_PREFIX
}
/lib/
-L
/opt/mpi/lib
-L
/opt/dtk/hip/lib
-L
/usr/lib/x86_64-linux-gnu
-lhipblaslt
-lamdhip64
-o
deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-Wl
,-rpath,/opt/dtk/lib
-fgpu-rdc
--hip-link
--offload-arch
=
gfx936
--offload-arch
=
gfx938
-shared
-Wl
,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-L
"
${
llvm_path
}
/include/../lib/linux"
-lclang_rt
.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so
${
llvm_path
}
/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so
-L
${
PYTHON_PLATLIB
}
/torch/lib
-L
/opt/dtk/lib
-L
/opt/dtk/hip/lib
-L
/usr/local/lib
-lc10
-ltorch
-ltorch_cpu
-ltorch_python
-lamdhip64
-lc10_hip
-ltorch_hip
-lrocm-core
-lrocm_smi64
${
SHMEM_LINK_OPTIONS
}
-fgpu-rdc
--hip-link
-lamdhip64
-lhsa-runtime64
-l
:libmpi.so
-Wl
,-rpath,/opt/mpi/lib/
-libverbs
-lmlx5
hipcc
-Wno-unused-result
-Wsign-compare
-DNDEBUG
-g
-fwrapv
-O2
-Wall
-g
-fstack-protector-strong
-Wformat
-Werror
=
format-security
-g
-fwrapv
-O2
-shared
-Wl
,-O1
-Wl
,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o build_/internode_ll.o
-L
${
SHMEM_INSTALL_PREFIX
}
/lib/
-L
/opt/mpi/lib
-L
/opt/dtk/hip/lib
-L
/usr/lib/x86_64-linux-gnu
-lhipblaslt
-lamdhip64
-o
deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-Wl
,-rpath,/opt/dtk/lib
-fgpu-rdc
--hip-link
--offload-arch
=
$DETECTED_ARCH
-shared
-Wl
,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-L
"
${
llvm_path
}
/include/../lib/linux"
-lclang_rt
.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so
${
llvm_path
}
/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so
-L
${
PYTHON_PLATLIB
}
/torch/lib
-L
/opt/dtk/lib
-L
/opt/dtk/hip/lib
-L
/usr/local/lib
-lc10
-ltorch
-ltorch_cpu
-ltorch_python
-lamdhip64
-lc10_hip
-ltorch_hip
-lrocm-core
-lrocm_smi64
${
SHMEM_LINK_OPTIONS
}
-fgpu-rdc
--hip-link
-lamdhip64
-lhsa-runtime64
-l
:libmpi.so
-Wl
,-rpath,/opt/mpi/lib/
-libverbs
-lmlx5
# build whl
echo
"Using Python:
$(
which python3
)
"
...
...
csrc/kernels/internode_ll.cu
View file @
a1382ed7
...
...
@@ -80,6 +80,42 @@ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
clean_0
,
num_clean_int_0
,
clean_1
,
num_clean_int_1
);
}
__device__
__forceinline__
void
internode_ll_putmem_nbi
(
void
*
dst_ptr
,
void
*
src_ptr
,
int
num_ranks
,
int
dst_rank
,
int
expert_idx
,
int
msg_bytes
)
{
#if defined(FORCE_NVSHMEM_API)
internode
::
shmemx_int8_put_nbi_warp
(
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
msg_bytes
,
dst_rank
);
#else
#if defined(ROCM_DISABLE_MULTIQP)
internode
::
shmemx_int8_put_nbi_warp
(
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
msg_bytes
,
dst_rank
);
#else
internode
::
shmemx_int8_put_nbi_warp_dp
(
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
msg_bytes
,
(
expert_idx
+
1
)
*
num_ranks
+
dst_rank
,
dst_rank
);
#endif
#endif // defined(FORCE_NVSHMEM_API)
}
__device__
__forceinline__
void
internode_ll_long_atomic_add
(
long
*
dest
,
const
long
&
value
,
int
num_ranks
,
int
dst_rank
,
int
expert_idx
)
{
#if defined(FORCE_DUSHMEM_API)
internode
::
shmem_long_atomic_add
(
dest
,
value
,
dst_rank
);
#else
#if defined(ROCM_DISABLE_MULTIQP)
internode
::
shmem_long_atomic_add
(
dest
,
value
,
dst_rank
);
#else
internode
::
shmem_long_atomic_add_dp
(
dest
,
value
,
(
expert_idx
+
1
)
*
num_ranks
+
dst_rank
,
dst_rank
);
#endif
#endif // defined(FORCE_DUSHMEM_API)
}
template
<
bool
kUseFP8
,
bool
kUseUE8M0
,
bool
kUseInt8
,
int
kHidden
>
__global__
__launch_bounds__
(
16
*
kWarpSize
,
1
)
void
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
...
...
@@ -118,9 +154,9 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Message package: hidden data, FP8 scales, index at source
// NOTES: currently we have 3 reserved int fields for future use
using
vec_t
=
typename
std
::
conditional
<
kUseFP8
,
int2
,
int4
>::
type
;
const
size_t
num_bytes_per_msg
=
sizeof
(
int4
)
+
(
kUseFP8
?
(
kHidden
+
kNumScales
*
sizeof
(
float
))
:
(
kHidden
*
sizeof
(
hip_bfloat16
)));
const
size_t
num_int4_per_msg
=
num_bytes_per_msg
/
sizeof
(
int4
);
EP_DEVICE_ASSERT
(
num_bytes_per_msg
%
sizeof
(
int4
)
==
0
)
;
const
expr
size_t
num_bytes_per_msg
=
sizeof
(
int4
)
+
(
kUseFP8
?
(
kHidden
+
kNumScales
*
sizeof
(
float
))
:
(
kHidden
*
sizeof
(
hip_bfloat16
)));
EP_STATIC_ASSERT
(
num_bytes_per_msg
%
sizeof
(
int4
)
==
0
,
"Invalid message size"
)
;
constexpr
size_t
num_int4_per_msg
=
num_bytes_per_msg
/
sizeof
(
int4
);
// Expert counts
constexpr
int
kNumMaxWarpGroups
=
1024
/
kWarpSize
;
...
...
@@ -135,7 +171,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// 2. The last warp for reading `topk_idx` and count for per-expert information
if
(
warp_id
<
num_warps
)
{
constexpr
int
kNumElemsPerRead
=
sizeof
(
int4
)
/
sizeof
(
hip_bfloat16
);
EP_DEVICE_ASSERT
(
kHidden
%
kNumElemsPerRead
==
0
);
//
EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
%
kNumPerChannels
==
0
,
"Invalid vectorization"
);
const
auto
num_threads
=
(
num_warps
-
1
)
*
kWarpSize
;
constexpr
int
hidden_bf16_int4
=
kHidden
/
kNumElemsPerRead
;
...
...
@@ -256,34 +292,17 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
dst_expert_local_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
slot_idx
*
num_bytes_per_msg
;
if
(
dst_rank
!=
rank
)
{
#if defined(FORCE_DUSHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
dushmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
char
*
req_rptr_actual
=
(
char
*
)(
peer_base_addr
)
+
((
char
*
)
dst_ptr
-
(
char
*
)(
dushmemi_device_state_d
.
heap_base
));
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
const
auto
*
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
req_rptr_actual
);
UNROLLED_WARP_COPY
(
8
,
lane_id
,
num_int4_per_msg
,
dst_int4_ptr
,
src_int4_ptr
,
ld_nc_global
,
st_na_global
);
}
else
{
internode
::
shmemx_int8_put_nbi_warp
(
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
num_bytes_per_msg
,
dst_rank
);
}
#else
#if !defined(ROCM_USE_MULTIQP)
internode
::
shmemx_int8_put_nbi_warp
(
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
num_bytes_per_msg
,
dst_rank
);
#else
internode
::
shmemx_int8_put_nbi_warp_dp
(
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
num_bytes_per_msg
,
(
dst_expert_local_idx
+
1
)
*
num_ranks
+
dst_rank
,
dst_rank
);
#endif
#endif // defined(FORCE_DUSHMEM_API)
}
else
{
// 通过 shmem_get_p2p_ptr 获取 当前远程指针能否可达
uint64_t
p2p_ptr
=
internode
::
shmem_get_p2p_ptr
((
void
*
)
dst_ptr
,
rank
,
dst_rank
);
if
(
p2p_ptr
==
0
)
{
// RDMA
internode_ll_putmem_nbi
((
void
*
)
dst_ptr
,
(
void
*
)
src_ptr
,
num_ranks
,
dst_rank
,
dst_expert_local_idx
,
num_bytes_per_msg
);
}
else
{
// 本地 GPU 和 同一计算节点的 其他 GPU 地址
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
const
auto
*
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
dst
_ptr
);
const
auto
*
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
p2p
_ptr
);
UNROLLED_WARP_COPY_LL
(
8
,
lane_id
,
num_int4_per_msg
,
dst_int4_ptr
,
src_int4_ptr
,
ld_nc_global
,
st_na_global
);
}
...
...
@@ -294,7 +313,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
}
}
if
(
warp_id
==
num_warps
-
1
)
{
EP_DEVICE_ASSERT
(
num_sms
>
1
);
//
EP_DEVICE_ASSERT(num_sms > 1);
if
(
sm_id
==
0
)
{
// The first SM is also responsible for checking QPs
// The first SM is also responsible for cleaning the next buffer
...
...
@@ -341,29 +360,15 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Wait local sends issued and send expert counts
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
if
(
dst_rank
!=
rank
)
{
#if defined(FORCE_DUSHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
dushmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
// P2P enabled
int
*
rptr_actual
=
(
int
*
)((
char
*
)(
peer_base_addr
)
+
((
char
*
)(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
)
-
(
char
*
)(
dushmemi_device_state_d
.
heap_base
)));
st_na_release
(
rptr_actual
,
-
num_tokens_sent
-
1
);
}
else
{
internode
::
shmem_long_atomic_add
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
dst_rank
);
}
#else
#if !defined(ROCM_USE_MULTIQP)
internode
::
shmem_long_atomic_add
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
dst_rank
);
#else
internode
::
shmem_long_atomic_add_dp
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
(
dst_expert_local_idx
+
1
)
*
num_ranks
+
dst_rank
,
dst_rank
);
#endif
#endif // defined(FORCE_DUSHMEM_API)
}
else
{
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
),
-
num_tokens_sent
-
1
);
auto
dst_ptr
=
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
;
// 通过 shmem_get_p2p_ptr 获取 当前远程指针能否可达
uint64_t
p2p_ptr
=
internode
::
shmem_get_p2p_ptr
((
void
*
)
dst_ptr
,
rank
,
dst_rank
);
if
(
p2p_ptr
==
0
)
{
// RDMA
internode_ll_long_atomic_add
(
dst_ptr
,
-
num_tokens_sent
-
1
,
num_ranks
,
dst_rank
,
dst_expert_local_idx
);
}
else
{
// 本地 GPU 和 同一计算节点的 其他 GPU 地址
st_na_release
(
reinterpret_cast
<
int
*>
(
p2p_ptr
),
-
num_tokens_sent
-
1
);
}
// Clean workspace for next use
...
...
@@ -419,7 +424,7 @@ LOW_LATENCY_DISPATCH_RECV:
// Wait tokens to arrive
// NOTES: using sub-warp 1 to overlap with sub-warp 0
int
num_recv_tokens
,
recv_token_begin_idx
;
EP_DEVICE_ASSERT
(
num_warps_per_group
>
1
);
//
EP_DEVICE_ASSERT(num_warps_per_group > 1);
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
while
((
num_recv_tokens
=
ld_acquire_global
(
reinterpret_cast
<
int
*>
(
rdma_recv_count
+
local_expert_idx
*
num_ranks
+
src_rank
)))
==
0
);
...
...
@@ -430,12 +435,6 @@ LOW_LATENCY_DISPATCH_RECV:
recv_range
[
src_rank
]
=
pack2
<
int
,
int64_t
>
(
num_recv_tokens
,
recv_token_begin_idx
);
}
#if defined(ROCM_USE_MULTIQP)
if
(
sub_warp_id
==
2
and
lane_id
==
0
)
{
internode
::
shmem_qp_quiet
(
num_ranks
+
responsible_expert_idx
);
}
#endif
// no needs to reset because there is no iteration
if
(
lane_id
==
0
){
volatile
int
ret
=
__hip_atomic_fetch_add
(
&
sync_large_warp_counters
[
warp_group_id
],
1
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
...
...
@@ -447,7 +446,7 @@ LOW_LATENCY_DISPATCH_RECV:
recv_token_begin_idx
=
shared_recv_token_begin_idx
[
warp_group_id
];
// Copy tokens
EP_
DEV
IC
E
_ASSERT
(
kNumScales
<=
64
);
EP_
STAT
IC_ASSERT
(
kNumScales
<=
64
,
"Invalid hidden size"
);
for
(
int
i
=
sub_warp_id
;
i
<
num_recv_tokens
;
i
+=
num_warps_per_group
)
{
// Copy source info
const
auto
src_src_idx
=
reinterpret_cast
<
int
*>
(
rdma_recv_x_uint8
+
i
*
num_bytes_per_msg
);
...
...
@@ -632,41 +631,26 @@ combine(void* combined_x,
const
auto
src_idx
=
__ldg
(
local_src_info
+
token_idx
);
const
auto
buf_ptr
=
reinterpret_cast
<
int64_t
>
(
rdma_send_x_vec_row
);
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
(
global_expert_idx
*
num_max_dispatch_tokens_per_rank
+
src_idx
)
*
num_bytes_per_slot
;
if
(
dst_rank
==
rank
)
{
const
auto
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
dst_ptr
);
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
}
else
{
uint64_t
p2p_ptr
=
internode
::
shmem_get_p2p_ptr
((
void
*
)
dst_ptr
,
rank
,
dst_rank
);
if
(
p2p_ptr
==
0
)
{
// RDMA
const
auto
buf_int4_ptr
=
reinterpret_cast
<
int4
*>
(
buf_ptr
);
if
(
not
zero_copy
)
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_bf16_int4
,
buf_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
#if defined(FORCE_DUSHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
dushmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
char
*
req_rptr_actual
=
(
char
*
)(
peer_base_addr
)
+
((
char
*
)
dst_ptr
-
(
char
*
)(
dushmemi_device_state_d
.
heap_base
));
const
auto
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
req_rptr_actual
);
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
}
else
{
internode
::
shmemx_int8_put_nbi_warp
(
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
buf_ptr
),
hidden
*
sizeof
(
hip_bfloat16
),
dst_rank
);
}
#else
#if !defined(ROCM_USE_MULTIQP)
internode
::
shmemx_int8_put_nbi_warp
(
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
buf_ptr
),
hidden
*
sizeof
(
hip_bfloat16
),
dst_rank
);
#else
internode
::
shmemx_int8_put_nbi_warp_dp
(
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
buf_ptr
),
hidden
*
sizeof
(
hip_bfloat16
),
(
local_expert_idx
+
1
)
*
num_ranks
+
dst_rank
,
dst_rank
);
#endif
#endif // defined(FORCE_DUSHMEM_API)
internode_ll_putmem_nbi
((
void
*
)
dst_ptr
,
(
void
*
)
buf_ptr
,
num_ranks
,
dst_rank
,
local_expert_idx
,
hidden
*
sizeof
(
hip_bfloat16
));
}
else
{
// 本地 GPU 和 同一计算节点的 其他 GPU 地址
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
x_int4
);
const
auto
*
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
p2p_ptr
);
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
src_int4_ptr
,
ld_nc_global
,
st_na_global
);
}
}
// Put finishing flag
EP_DEVICE_ASSERT
(
num_warps_per_group
>
1
);
//
EP_DEVICE_ASSERT(num_warps_per_group > 1);
if
(
lane_id
==
0
){
volatile
int
ret
=
__hip_atomic_fetch_add
(
&
sync_large_warp_counters
[
warp_group_id
],
1
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
}
...
...
@@ -675,30 +659,16 @@ combine(void* combined_x,
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
while
(
ld_acquire_global
(
atomic_clean_flag
)
==
0
);
if
(
dst_rank
!=
rank
)
{
#if defined(FORCE_DUSHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
dushmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
int
*
req_rptr_actual
=
(
int
*
)((
char
*
)(
peer_base_addr
)
+
((
char
*
)(
rdma_recv_flag
+
global_expert_idx
)
-
(
char
*
)(
dushmemi_device_state_d
.
heap_base
)));
st_na_release
(
req_rptr_actual
,
1
);
}
else
{
internode
::
shmem_long_atomic_add
(
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
);
}
#else
#if !defined(ROCM_USE_MULTIQP)
internode
::
shmem_long_atomic_add
(
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
);
#else
internode
::
shmem_long_atomic_add_dp
(
rdma_recv_flag
+
global_expert_idx
,
1
,
(
local_expert_idx
+
1
)
*
num_ranks
+
dst_rank
,
dst_rank
);
#endif
#endif // defined(FORCE_DUSHMEM_API)
}
else
{
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_flag
+
global_expert_idx
),
1
);
auto
dst_ptr
=
rdma_recv_flag
+
global_expert_idx
;
// 通过 shmem_get_p2p_ptr 获取 当前远程指针能否可达
uint64_t
p2p_ptr
=
internode
::
shmem_get_p2p_ptr
((
void
*
)
dst_ptr
,
rank
,
dst_rank
);
if
(
p2p_ptr
==
0
)
{
// RDMA
internode_ll_long_atomic_add
(
dst_ptr
,
1
,
num_ranks
,
dst_rank
,
local_expert_idx
);
}
else
{
// 本地 GPU 和 同一计算节点的 其他 GPU 地址
st_na_release
(
reinterpret_cast
<
int
*>
(
p2p_ptr
),
1
);
}
atomic_add_release_global
(
atomic_clean_flag
,
-
1
);
}
syncwarp
();
...
...
@@ -711,7 +681,7 @@ LOW_LATENCY_COMBINE_RECV:
// Wait all ranks to arrive and notify PCIe usage
if
(
responsible_expert_idx
<
num_experts
)
{
EP_DEVICE_ASSERT
(
num_warps_per_group
>
1
);
//
EP_DEVICE_ASSERT(num_warps_per_group > 1);
if
(
sub_warp_id
==
0
and
lane_id
==
0
)
{
const
auto
src_rank
=
responsible_expert_idx
/
num_local_experts
;
auto
start_time
=
wall_clock64
();
...
...
@@ -730,16 +700,11 @@ LOW_LATENCY_COMBINE_RECV:
atomicAdd
(
reinterpret_cast
<
unsigned
long
long
*>
(
combine_wait_recv_cost_stats
+
src_rank
),
wait_recv_cost
);
}
}
#if defined(ROCM_USE_MULTIQP)
if
(
sub_warp_id
==
2
and
lane_id
==
0
)
{
internode
::
shmem_qp_quiet
(
num_ranks
+
responsible_expert_idx
);
}
#endif
}
grid_barrier
(
global_atomic_counter
,
num_sms
);
// Reduce tokens with FP8 cast
EP_DEVICE_ASSERT
(
num_topk
<=
kWarpSize
and
hidden_bf16_int4
<=
num_threads
);
//
EP_DEVICE_ASSERT(num_topk <= kWarpSize and hidden_bf16_int4 <= num_threads);
EP_STATIC_ASSERT
(
kHidden
%
(
kWarpSize
*
kNumElemsPerInt4
)
==
0
,
"Invalid vectorization"
);
if
(
thread_id
<
hidden_bf16_int4
)
{
for
(
int
token_idx
=
sm_id
;
token_idx
<
num_combined_tokens
;
token_idx
+=
num_sms
)
{
...
...
csrc/kernels/shmem_wrapper.cuh
View file @
a1382ed7
...
...
@@ -116,7 +116,11 @@ __device__ inline void shmem_long_atomic_add(
rocshmem
::
rocshmem_long_atomic_add
(
dest
,
value
,
pe
);
}
#if defined(ROCM_USE_MULTIQP)
__device__
inline
uint64_t
shmem_get_p2p_ptr
(
void
*
dest
,
int
rank
,
int
dst_rank
)
{
return
rocshmem
::
rocshmem_get_p2p_ptr
(
dest
,
rank
,
dst_rank
);
}
#if !defined(ROCM_DISABLE_MULTIQP)
__device__
inline
void
shmem_qp_quiet
(
int
idx_qp
)
{
rocshmem
::
rocshmem_quiet_dp
(
idx_qp
);
}
...
...
@@ -273,6 +277,20 @@ __device__ inline void shmem_long_atomic_add(
dushmem_long_atomic_add
(
dest
,
value
,
pe
);
}
__device__
__forceinline__
uint64_t
shmem_get_p2p_ptr
(
void
*
dest
,
int
rank
,
int
dst_rank
)
{
// Local rank, no need for mapping
if
(
rank
==
dst_rank
)
return
reinterpret_cast
<
uint64_t
>
(
dest
);
auto
peer_base
=
__ldg
(
reinterpret_cast
<
uint64_t
*>
(
dushmemi_device_state_d
.
peer_heap_base_p2p
)
+
dst_rank
);
// RDMA connected
if
(
peer_base
==
0
)
return
0
;
// NVLink P2P is enabled
return
peer_base
+
(
reinterpret_cast
<
uint64_t
>
(
dest
)
-
reinterpret_cast
<
uint64_t
>
(
dushmemi_device_state_d
.
heap_base
));
}
#endif
}
// namespace deep_ep::internode
...
...
deep_ep/version.py
deleted
100644 → 0
View file @
314d9021
try
:
__version__
=
"1.0.0"
__version_tuple__
=
(
1
,
0
,
0
)
__hcu_version__
=
f
'1.0.0+das.opt1.dtk2504'
from
.version
import
__version__
,
__version_tuple__
,
__hcu_version__
except
Exception
as
e
:
import
warnings
warnings
.
warn
(
f
"Failed to read commit hash:
\n
+ str(e)"
,
RuntimeWarning
,
stacklevel
=
2
)
__version__
=
"dev"
__version_tuple__
=
(
0
,
0
,
__version__
)
def
_prev_minor_version_was
(
version_str
):
'''Check whether a given version matches the previous minor version.
Return True if version_str matches the previous minor version.
For example - return True if the current version is 0.7.4 and the
supplied version_str is '0.6'.
Used for --show-hidden-metrics-for-version.
'''
# Match anything if this is a dev tree
if
__version_tuple__
[
0
:
2
]
==
(
0
,
0
):
return
True
# Note - this won't do the right thing when we release 1.0!
# assert __version_tuple__[0] == 0
assert
isinstance
(
__version_tuple__
[
1
],
int
)
return
version_str
==
f
"
{
__version_tuple__
[
0
]
}
.
{
__version_tuple__
[
1
]
-
1
}
"
def
_prev_minor_version
():
'''For the purpose of testing, return a previous minor version number.'''
# In dev tree, this will return "0.-1", but that will work fine"
assert
isinstance
(
__version_tuple__
[
1
],
int
)
return
f
"
{
__version_tuple__
[
0
]
}
.
{
__version_tuple__
[
1
]
-
1
}
"
setup.py
View file @
a1382ed7
...
...
@@ -35,10 +35,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
if
sha
!=
'Unknown'
:
if
sha
is
None
:
sha
=
subprocess
.
check_output
([
'git'
,
'rev-parse'
,
'HEAD'
],
cwd
=
deepep_root
).
decode
(
'ascii'
).
strip
()
if
(
major
,
minor
)
>=
(
'2'
,
'
5
'
):
if
(
major
,
minor
)
>=
(
'2'
,
'
4
'
):
version
=
'das.opt1.'
+
sha
[:
7
]
+
shmem
else
:
if
(
major
,
minor
)
>=
(
'2'
,
'
5
'
):
if
(
major
,
minor
)
>=
(
'2'
,
'
4
'
):
version
=
'das.opt1'
if
os
.
getenv
(
"ROCM_PATH"
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment