Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
ee3551ab
Commit
ee3551ab
authored
Nov 07, 2025
by
lishen
Browse files
修改为兼容rocSHMEM和nvSHMEM的代码
parent
e18f726a
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
413 additions
and
100 deletions
+413
-100
build.sh
build.sh
+13
-2
csrc/kernels/internode.cu
csrc/kernels/internode.cu
+86
-47
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+19
-20
csrc/kernels/runtime.cu
csrc/kernels/runtime.cu
+34
-30
csrc/kernels/shmem_wrapper.cuh
csrc/kernels/shmem_wrapper.cuh
+260
-0
csrc/kernels/utils.cuh
csrc/kernels/utils.cuh
+1
-1
No files found.
build.sh
View file @
ee3551ab
...
@@ -8,10 +8,21 @@ fi
...
@@ -8,10 +8,21 @@ fi
PYTHON_INCLUDE
=
$(
python3
-c
"from sysconfig import get_paths; print(get_paths()['include'])"
)
PYTHON_INCLUDE
=
$(
python3
-c
"from sysconfig import get_paths; print(get_paths()['include'])"
)
PYTHON_PLATLIB
=
$(
python3
-c
"from sysconfig import get_paths; print(get_paths()['platlib'])"
)
PYTHON_PLATLIB
=
$(
python3
-c
"from sysconfig import get_paths; print(get_paths()['platlib'])"
)
# --------------------------------------------------------------------- #
USE_NVSHMEM
=
${
USE_NVSHMEM
:
=OFF
}
ROCSHMEM_INSTALL_PREFIX
=
${
ROCSHMEM_INSTALL_PREFIX
:
=
$(
pwd
)
/rocshmem_dir
}
ROCSHMEM_INSTALL_PREFIX
=
${
ROCSHMEM_INSTALL_PREFIX
:
=
$(
pwd
)
/rocshmem_dir
}
COMPILE_OPTIONS
=
${
COMPILE_OPTIONS
:
= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 -Wno-return-type
}
SHMEM_LINK_OPTIONS
=
${
SHMEM_LINK_OPTIONS
:
=
"-Wl,-rpath,
${
ROCSHMEM_INSTALL_PREFIX
}
/lib/ -l:librocshmem.a"
}
####
# 检查是否设置了USE_NVSHMEM环境变量
if
[
"
$USE_NVSHMEM
"
==
"ON"
]
;
then
COMPILE_OPTIONS+
=
" -DFORCE_NVSHMEM_API"
ROCSHMEM_INSTALL_PREFIX
=
???/dushmem_dir
SHMEM_LINK_OPTIONS
=
"-Wl,-rpath,
${
ROCSHMEM_INSTALL_PREFIX
}
/lib/ -l:libnvshmem_device.a -lnvshmem_host"
fi
INCLUDE_PATHS
=
${
INCLUDE_PATHS
:
=-Icsrc/ -I
${
ROCSHMEM_INSTALL_PREFIX
}
/include/ -I/opt/mpi/include -I
${
PYTHON_PLATLIB
}
/torch/include -I
${
PYTHON_PLATLIB
}
/torch/include/torch/csrc/api/include -I
${
PYTHON_PLATLIB
}
/torch/include/TH -I
${
PYTHON_PLATLIB
}
/torch/include/THC -I
${
PYTHON_PLATLIB
}
/torch/include/THH -I/opt/dtk/include -I
${
PYTHON_INCLUDE
}}
INCLUDE_PATHS
=
${
INCLUDE_PATHS
:
=-Icsrc/ -I
${
ROCSHMEM_INSTALL_PREFIX
}
/include/ -I/opt/mpi/include -I
${
PYTHON_PLATLIB
}
/torch/include -I
${
PYTHON_PLATLIB
}
/torch/include/torch/csrc/api/include -I
${
PYTHON_PLATLIB
}
/torch/include/TH -I
${
PYTHON_PLATLIB
}
/torch/include/THC -I
${
PYTHON_PLATLIB
}
/torch/include/THH -I/opt/dtk/include -I
${
PYTHON_INCLUDE
}}
COMPILE_OPTIONS
=
${
COMPILE_OPTIONS
:
= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 -Wno-return-type
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/runtime.cu
-o
build_/runtime.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/runtime.cu
-o
build_/runtime.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/layout.cu
-o
build_/layout.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/layout.cu
-o
build_/layout.o
${
COMPILE_OPTIONS
}
...
@@ -20,7 +31,7 @@ hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode.cu -o build_/internode.o
...
@@ -20,7 +31,7 @@ hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode.cu -o build_/internode.o
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/internode_ll.cu
-o
build_/internode_ll.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/internode_ll.cu
-o
build_/internode_ll.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/deep_ep.cu
-o
build_/deep_ep.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/deep_ep.cu
-o
build_/deep_ep.o
${
COMPILE_OPTIONS
}
hipcc
-Wno-unused-result
-Wsign-compare
-DNDEBUG
-g
-fwrapv
-O2
-Wall
-g
-fstack-protector-strong
-Wformat
-Werror
=
format-security
-g
-fwrapv
-O2
-shared
-Wl
,-O1
-Wl
,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o build_/internode_ll.o
-L
${
ROCSHMEM_INSTALL_PREFIX
}
/lib/
-L
/opt/mpi/lib
-L
/opt/dtk/hip/lib
-L
/usr/lib/x86_64-linux-gnu
-lhipblaslt
-lamdhip64
-o
deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-Wl
,-rpath,/opt/dtk/lib
-fgpu-rdc
--hip-link
--offload-arch
=
gfx936
-shared
-Wl
,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-Wl
,-rpath,
$$
{
ROCSHMEM_INSTALL_PREFIX
}
/lib/
-L
"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux"
-lclang_rt
.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0
-L
${
PYTHON_PLATLIB
}
/torch/lib
-L
/opt/dtk/lib
-L
/opt/dtk/hip/lib
-L
/usr/local/lib
-lc10
-ltorch
-ltorch_cpu
-ltorch_python
-lamdhip64
-lc10_hip
-ltorch_hip
-lrocm-core
-lrocm_smi64
-l
:librocshmem.a
-fgpu-rdc
--hip-link
-lamdhip64
-lhsa-runtime64
-l
:libmpi.so
-Wl
,-rpath,/opt/mpi/lib/
-libverbs
-lmlx5
hipcc
-Wno-unused-result
-Wsign-compare
-DNDEBUG
-g
-fwrapv
-O2
-Wall
-g
-fstack-protector-strong
-Wformat
-Werror
=
format-security
-g
-fwrapv
-O2
-shared
-Wl
,-O1
-Wl
,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o build_/internode_ll.o
-L
${
ROCSHMEM_INSTALL_PREFIX
}
/lib/
-L
/opt/mpi/lib
-L
/opt/dtk/hip/lib
-L
/usr/lib/x86_64-linux-gnu
-lhipblaslt
-lamdhip64
-o
deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-Wl
,-rpath,/opt/dtk/lib
-fgpu-rdc
--hip-link
--offload-arch
=
gfx936
-shared
-Wl
,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-L
"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux"
-lclang_rt
.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0
-L
${
PYTHON_PLATLIB
}
/torch/lib
-L
/opt/dtk/lib
-L
/opt/dtk/hip/lib
-L
/usr/local/lib
-lc10
-ltorch
-ltorch_cpu
-ltorch_python
-lamdhip64
-lc10_hip
-ltorch_hip
-lrocm-core
-lrocm_smi64
${
SHMEM_LINK_OPTIONS
}
-fgpu-rdc
--hip-link
-lamdhip64
-lhsa-runtime64
-l
:libmpi.so
-Wl
,-rpath,/opt/mpi/lib/
-libverbs
-lmlx5
# build whl
# build whl
echo
"Using Python:
$(
which python3
)
"
echo
"Using Python:
$(
which python3
)
"
...
...
csrc/kernels/internode.cu
View file @
ee3551ab
...
@@ -3,11 +3,10 @@
...
@@ -3,11 +3,10 @@
#include "configs.cuh"
#include "configs.cuh"
#include "launch.cuh"
#include "launch.cuh"
#include "utils.cuh"
#include "utils.cuh"
#include "shmem_wrapper.cuh"
#ifndef DISABLE_ROCSHMEM
#ifndef DISABLE_ROCSHMEM
#include <rocshmem/rocshmem.hpp>
// TODO: fix unroll warnings
// TODO: fix unroll warnings
// #ifdef __clang__
// #ifdef __clang__
// #pragma clang diagnostic push
// #pragma clang diagnostic push
...
@@ -19,7 +18,7 @@ namespace deep_ep {
...
@@ -19,7 +18,7 @@ namespace deep_ep {
namespace
internode
{
namespace
internode
{
extern
rocshmem
::
roc
shmem_team_t
cpu_rdma_team
;
extern
shmem_team_t
cpu_rdma_team
;
struct
SourceMeta
{
struct
SourceMeta
{
int
src_rdma_rank
,
is_token_in_nvl_rank_bits
;
int
src_rdma_rank
,
is_token_in_nvl_rank_bits
;
...
@@ -51,9 +50,8 @@ __host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_
...
@@ -51,9 +50,8 @@ __host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_
int
num_topk_idx
,
int
num_topk_idx
,
int
num_topk_weights
)
{
int
num_topk_weights
)
{
return
static_cast
<
int
>
(
ALIGN
(
hidden_int4
*
sizeof
(
int4
)
+
sizeof
(
SourceMeta
)
+
return
static_cast
<
int
>
(
ALIGN
(
hidden_int4
*
sizeof
(
int4
)
+
sizeof
(
SourceMeta
)
+
num_scales
*
sizeof
(
float
)
+
num_topk_idx
*
sizeof
(
int
)
+
num_scales
*
sizeof
(
float
)
+
num_topk_idx
*
sizeof
(
int
)
+
num_topk_weights
*
sizeof
(
float
),
num_topk_weights
*
sizeof
(
float
),
sizeof
(
int4
)));
sizeof
(
int4
)));
}
}
__host__
__device__
__forceinline__
std
::
pair
<
int
,
int
>
__host__
__device__
__forceinline__
std
::
pair
<
int
,
int
>
...
@@ -61,9 +59,8 @@ get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_t
...
@@ -61,9 +59,8 @@ get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_t
int
num_rdma_ranks
,
int
num_rdma_recv_buffer_tokens
,
int
num_sms
)
{
int
num_rdma_ranks
,
int
num_rdma_recv_buffer_tokens
,
int
num_sms
)
{
// Return `int32_t` offset and count to clean
// Return `int32_t` offset and count to clean
return
{(
get_num_bytes_per_rdma_token
(
hidden_int4
,
num_scales
,
num_topk_idx
,
num_topk_weights
)
*
return
{(
get_num_bytes_per_rdma_token
(
hidden_int4
,
num_scales
,
num_topk_idx
,
num_topk_weights
)
*
num_rdma_recv_buffer_tokens
*
num_rdma_ranks
*
2
*
num_sms
)
/
num_rdma_recv_buffer_tokens
*
num_rdma_ranks
*
2
*
num_sms
)
/
sizeof
(
int
),
sizeof
(
int
),
(
NUM_MAX_NVL_PEERS
*
2
+
4
)
*
num_rdma_ranks
*
2
*
num_sms
};
(
NUM_MAX_NVL_PEERS
*
2
+
4
)
*
num_rdma_ranks
*
2
*
num_sms
};
}
}
__host__
__device__
__forceinline__
std
::
pair
<
int
,
int
>
__host__
__device__
__forceinline__
std
::
pair
<
int
,
int
>
get_nvl_clean_meta
(
int
hidden_int4
,
int
num_scales
,
int
num_topk_idx
,
int
num_topk_weights
,
get_nvl_clean_meta
(
int
hidden_int4
,
int
num_scales
,
int
num_topk_idx
,
int
num_topk_weights
,
...
@@ -74,10 +71,9 @@ get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_to
...
@@ -74,10 +71,9 @@ get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_to
"Invalid size of `SourceMeta`"
);
"Invalid size of `SourceMeta`"
);
return
{
return
{
(
num_nvl_recv_buffer_tokens
*
(
num_nvl_recv_buffer_tokens
*
(
hidden_int4
*
sizeof
(
int4
)
+
num_scales
*
sizeof
(
float
)
+
num_topk_idx
*
sizeof
(
int
)
+
(
hidden_int4
*
sizeof
(
int4
)
+
num_scales
*
sizeof
(
float
)
+
num_topk_idx
*
sizeof
(
int
)
+
num_topk_weights
*
sizeof
(
float
)
+
sizeof
(
SourceMeta
))
*
num_topk_weights
*
sizeof
(
float
)
+
sizeof
(
SourceMeta
))
*
num_nvl_ranks
*
num_sms
)
/
num_nvl_ranks
*
num_sms
)
/
sizeof
(
int
),
sizeof
(
int
),
num_nvl_ranks
*
(
2
*
num_rdma_ranks
+
2
)
*
num_sms
,
num_nvl_ranks
*
(
2
*
num_rdma_ranks
+
2
)
*
num_sms
,
};
};
}
}
...
@@ -90,12 +86,10 @@ __forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank,
...
@@ -90,12 +86,10 @@ __forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank,
template
<
bool
kLowLatencyMode
>
template
<
bool
kLowLatencyMode
>
__forceinline__
__device__
void
__forceinline__
__device__
void
nvshmem_barrier_with_same_gpu_idx
(
const
rocshmem
::
roc
shmem_team_t
&
rdma_team
)
{
nvshmem_barrier_with_same_gpu_idx
(
const
shmem_team_t
&
rdma_team
)
{
// NOTE: shmem_device_barrier_all() might be an issue as
// NOTE: shmem_device_barrier_all() might be an issue as
// it doesn't follow OpenSHMEM specification on ROCm
// it doesn't follow OpenSHMEM specification on ROCm
kLowLatencyMode
kLowLatencyMode
?
shmem_barrier
(
rdma_team
)
:
shmem_device_barrier_all
();
?
void
(
rocshmem
::
rocshmem_ctx_barrier
(
rocshmem
::
ROCSHMEM_CTX_DEFAULT
,
rdma_team
))
:
rocshmem
::
rocshmem_barrier_all
();
}
}
template
<
bool
kLowLatencyMode
,
int
kNumRDMARanks
>
template
<
bool
kLowLatencyMode
,
int
kNumRDMARanks
>
...
@@ -109,7 +103,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
...
@@ -109,7 +103,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
int
*
rdma_channel_prefix_matrix
,
int
*
recv_rdma_rank_prefix_sum
,
int
*
rdma_channel_prefix_matrix
,
int
*
recv_rdma_rank_prefix_sum
,
int
*
gbl_channel_prefix_matrix
,
int
*
recv_gbl_rank_prefix_sum
,
int
*
gbl_channel_prefix_matrix
,
int
*
recv_gbl_rank_prefix_sum
,
void
*
rdma_buffer_ptr
,
void
**
buffer_ptrs
,
int
**
barrier_signal_ptrs
,
int
rank
,
void
*
rdma_buffer_ptr
,
void
**
buffer_ptrs
,
int
**
barrier_signal_ptrs
,
int
rank
,
const
rocshmem
::
roc
shmem_team_t
rdma_team
)
{
const
shmem_team_t
rdma_team
)
{
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
),
warp_id
=
thread_id
/
kWarpSize
,
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
),
warp_id
=
thread_id
/
kWarpSize
,
lane_id
=
get_lane_id
();
lane_id
=
get_lane_id
();
...
@@ -159,7 +153,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
...
@@ -159,7 +153,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
// TODO: more light fence or barrier or signaling
// TODO: more light fence or barrier or signaling
// TODO: overlap EP barrier and NVL cleaning
// TODO: overlap EP barrier and NVL cleaning
if
(
thread_id
<
kNumRDMARanks
)
{
if
(
thread_id
<
kNumRDMARanks
)
{
rocshmem
::
roc
shmem_int_put_nbi
(
shmem_int_put_nbi
(
rdma_recv_num_tokens_mixed
.
recv_buffer
(
rdma_rank
),
rdma_recv_num_tokens_mixed
.
recv_buffer
(
rdma_rank
),
rdma_recv_num_tokens_mixed
.
send_buffer
(
thread_id
),
rdma_recv_num_tokens_mixed
.
send_buffer
(
thread_id
),
NUM_MAX_NVL_PEERS
+
num_rdma_experts
+
1
,
NUM_MAX_NVL_PEERS
+
num_rdma_experts
+
1
,
...
@@ -405,9 +399,10 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
...
@@ -405,9 +399,10 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
kForwarderCoordinator
,
// 向远端RDMA确认接收
kForwarderCoordinator
,
// 向远端RDMA确认接收
kNVLReceivers
// 从nvl缓存写入到recv_x
kNVLReceivers
// 从nvl缓存写入到recv_x
};
};
#ifndef FORCE_NVSHMEM_API
__shared__
rocshmem
::
rocshmem_ctx_t
ctx
;
__shared__
shmem_ctx_t
ctx
;
rocshmem
::
rocshmem_wg_ctx_create
(
0
,
&
ctx
);
shmem_wg_ctx_create
(
&
ctx
);
#endif
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
num_threads
=
static_cast
<
int
>
(
blockDim
.
x
),
num_warps
=
num_threads
/
kWarpSize
;
const
auto
num_threads
=
static_cast
<
int
>
(
blockDim
.
x
),
num_warps
=
num_threads
/
kWarpSize
;
...
@@ -521,13 +516,23 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
...
@@ -521,13 +516,23 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
syncwarp
();
syncwarp
();
if
(
dst_rdma_rank
!=
rdma_rank
)
{
if
(
dst_rdma_rank
!=
rdma_rank
)
{
rocshmem
::
rocshmem_ctx_int_put_nbi_wave
(
#ifndef FORCE_NVSHMEM_API
ctx
,
rdma_channel_meta
.
recv_buffer
(
rdma_rank
),
shmem_ctx_int_put_nbi_warp
(
ctx
,
#else
shmemx_int_put_nbi_warp
(
#endif
rdma_channel_meta
.
recv_buffer
(
rdma_rank
),
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
),
NUM_MAX_NVL_PEERS
*
2
+
2
,
rdma_channel_meta
.
send_buffer
(
dst_rdma_rank
),
NUM_MAX_NVL_PEERS
*
2
+
2
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
}
}
}
}
rocshmem
::
rocshmem_ctx_quiet
(
ctx
);
#ifndef FORCE_NVSHMEM_API
shmem_ctx_quiet
(
ctx
);
#else
shmem_fence
();
#endif
// sync_rdma_sender_smem();
// sync_rdma_sender_smem();
__syncthreads
();
__syncthreads
();
...
@@ -736,15 +741,22 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
...
@@ -736,15 +741,22 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
if
(
dst_rdma_rank
!=
rdma_rank
)
{
if
(
dst_rdma_rank
!=
rdma_rank
)
{
auto
dst_slot_idx
=
synced_last_issued_tail
%
num_max_rdma_chunked_recv_tokens
;
auto
dst_slot_idx
=
synced_last_issued_tail
%
num_max_rdma_chunked_recv_tokens
;
EP_DEVICE_ASSERT
(
dst_slot_idx
+
num_tokens_to_issue
<=
num_max_rdma_chunked_recv_tokens
);
EP_DEVICE_ASSERT
(
dst_slot_idx
+
num_tokens_to_issue
<=
num_max_rdma_chunked_recv_tokens
);
rocshmem
::
rocshmem_ctx_schar_put_nbi_wave
(
#ifndef FORCE_NVSHMEM_API
ctx
,
shmem_ctx_schar_put_nbi_warp
(
ctx
,
#else
shmemx_int8_put_nbi_warp
(
#endif
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
dst_slot_idx
*
num_bytes_per_rdma_token
,
dst_slot_idx
*
num_bytes_per_rdma_token
,
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
dst_slot_idx
*
num_bytes_per_rdma_token
,
dst_slot_idx
*
num_bytes_per_rdma_token
,
num_bytes_per_rdma_token
*
num_tokens_to_issue
,
num_bytes_per_rdma_token
*
num_tokens_to_issue
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
rocshmem
::
rocshmem_ctx_quiet
(
ctx
);
#ifndef FORCE_NVSHMEM_API
shmem_ctx_quiet
(
ctx
);
#else
shmem_fence
();
#endif
}
else
{
}
else
{
// 对于本地RDMA秩,使用较轻的内存屏障
// 对于本地RDMA秩,使用较轻的内存屏障
memory_fence
();
memory_fence
();
...
@@ -756,8 +768,12 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
...
@@ -756,8 +768,12 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
last_issued_tail
+=
num_tokens_to_issue
;
last_issued_tail
+=
num_tokens_to_issue
;
num_tokens_to_send
-=
num_tokens_to_issue
;
num_tokens_to_send
-=
num_tokens_to_issue
;
// 更新远端rdma 己方已发送的token数,用于做发送信息同步。用于与kRDMAAndNVLForwarder互相通信
// 更新远端rdma 己方已发送的token数,用于做发送信息同步。用于与kRDMAAndNVLForwarder互相通信
rocshmem
::
rocshmem_ctx_ulong_atomic_add
(
#ifndef FORCE_NVSHMEM_API
ctx
,
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_tokens_to_issue
,
shmem_ctx_ulong_atomic_add
(
ctx
,
#else
shmem_signal_op_add
(
#endif
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_tokens_to_issue
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
}
}
}
}
...
@@ -992,8 +1008,12 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
...
@@ -992,8 +1008,12 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
// 更新远程头部
// 更新远程头部
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
&&
min_head
>=
last_head
+
num_max_rdma_chunked_send_tokens
&&
lane_id
<
kNumRDMARanks
){
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
&&
min_head
>=
last_head
+
num_max_rdma_chunked_send_tokens
&&
lane_id
<
kNumRDMARanks
){
rocshmem
::
rocshmem_ctx_ulong_atomic_add
(
#ifndef FORCE_NVSHMEM_API
ctx
,
rdma_channel_head
.
buffer
(
rdma_rank
),
min_head
-
last_head
,
shmem_ctx_ulong_atomic_add
(
ctx
,
#else
shmem_signal_op_add
(
#endif
rdma_channel_head
.
buffer
(
rdma_rank
),
min_head
-
last_head
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
lane_id
,
nvl_rank
));
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
lane_id
,
nvl_rank
));
last_head
=
min_head
;
last_head
=
min_head
;
}
}
...
@@ -1107,7 +1127,9 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
...
@@ -1107,7 +1127,9 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
}
}
}
// while(num_tokens_to_recv > 0)
}
// while(num_tokens_to_recv > 0)
}
}
rocshmem
::
rocshmem_wg_ctx_destroy
(
&
ctx
);
#ifndef FORCE_NVSHMEM_API
shmem_wg_ctx_destroy
(
&
ctx
);
#endif
}
}
void
dispatch
(
void
*
recv_x
,
float
*
recv_x_scales
,
int64_t
*
recv_topk_idx
,
float
*
recv_topk_weights
,
void
dispatch
(
void
*
recv_x
,
float
*
recv_x_scales
,
int64_t
*
recv_topk_idx
,
float
*
recv_topk_weights
,
...
@@ -1166,7 +1188,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
...
@@ -1166,7 +1188,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
int
num_channels
,
const
int
*
rdma_channel_prefix_matrix
,
int
num_channels
,
const
int
*
rdma_channel_prefix_matrix
,
const
int
*
rdma_rank_prefix_sum
,
int
*
combined_nvl_head
,
void
*
rdma_buffer_ptr
,
const
int
*
rdma_rank_prefix_sum
,
int
*
combined_nvl_head
,
void
*
rdma_buffer_ptr
,
void
**
buffer_ptrs
,
int
**
barrier_signal_ptrs
,
int
rank
,
int
num_ranks
,
void
**
buffer_ptrs
,
int
**
barrier_signal_ptrs
,
int
rank
,
int
num_ranks
,
bool
is_cached_dispatch
,
const
rocshmem
::
roc
shmem_team_t
rdma_team
)
{
bool
is_cached_dispatch
,
const
shmem_team_t
rdma_team
)
{
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
auto
num_threads
=
static_cast
<
int
>
(
blockDim
.
x
);
auto
num_threads
=
static_cast
<
int
>
(
blockDim
.
x
);
...
@@ -1189,7 +1211,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
...
@@ -1189,7 +1211,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
auto
rdma_buffer_ptr_int
=
reinterpret_cast
<
int
*>
(
rdma_buffer_ptr
);
auto
rdma_buffer_ptr_int
=
reinterpret_cast
<
int
*>
(
rdma_buffer_ptr
);
for
(
int
i
=
thread_id
;
i
<
rdma_num_int_clean
;
i
+=
num_threads
)
for
(
int
i
=
thread_id
;
i
<
rdma_num_int_clean
;
i
+=
num_threads
)
rdma_buffer_ptr_int
[
rdma_clean_offset
+
i
]
=
0
;
rdma_buffer_ptr_int
[
rdma_clean_offset
+
i
]
=
0
;
rocshmem
::
roc
shmem_fence
();
shmem_fence
();
__syncthreads
();
__syncthreads
();
// Barrier again
// Barrier again
...
@@ -1395,9 +1417,10 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
...
@@ -1395,9 +1417,10 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
kRDMACoordinator
,
kRDMACoordinator
,
kNVLCoordinator
kNVLCoordinator
};
};
#ifndef FORCE_NVSHMEM_API
__shared__
rocshmem
::
rocshmem_ctx_t
ctx
;
__shared__
shmem_ctx_t
ctx
;
rocshmem
::
rocshmem_wg_ctx_create
(
0
,
&
ctx
);
shmem_wg_ctx_create
(
&
ctx
);
#endif
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
num_threads
=
static_cast
<
int
>
(
blockDim
.
x
),
num_warps
=
num_threads
/
kWarpSize
;
const
auto
num_threads
=
static_cast
<
int
>
(
blockDim
.
x
),
num_warps
=
num_threads
/
kWarpSize
;
...
@@ -1721,16 +1744,22 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
...
@@ -1721,16 +1744,22 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
if
(
sub_warp_id
==
kNumWarpsPerForwarder
-
1
)
{
if
(
sub_warp_id
==
kNumWarpsPerForwarder
-
1
)
{
if
(
dst_rdma_rank
!=
rdma_rank
)
{
if
(
dst_rdma_rank
!=
rdma_rank
)
{
auto
rdma_slot_idx
=
token_start_idx
%
num_max_rdma_chunked_recv_tokens
;
auto
rdma_slot_idx
=
token_start_idx
%
num_max_rdma_chunked_recv_tokens
;
rocshmem
::
rocshmem_ctx_schar_put_nbi_wave
(
#ifndef FORCE_NVSHMEM_API
ctx
,
shmem_ctx_schar_put_nbi_warp
(
ctx
,
#else
shmemx_int8_put_nbi_warp
(
#endif
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
,
rdma_slot_idx
*
num_bytes_per_rdma_token
,
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
,
rdma_slot_idx
*
num_bytes_per_rdma_token
,
num_chunked_tokens
*
num_bytes_per_rdma_token
,
num_chunked_tokens
*
num_bytes_per_rdma_token
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
#ifndef FORCE_NVSHMEM_API
rocshmem
::
rocshmem_ctx_quiet
(
ctx
);
shmem_ctx_quiet
(
ctx
);
#else
shmem_fence
();
#endif
}
else
{
}
else
{
memory_fence
();
memory_fence
();
}
}
...
@@ -1738,8 +1767,12 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
...
@@ -1738,8 +1767,12 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
// Write new RDMA tail
// Write new RDMA tail
syncwarp
();
syncwarp
();
if
(
lane_id
==
0
)
{
if
(
lane_id
==
0
)
{
rocshmem
::
rocshmem_ctx_ulong_atomic_add
(
#ifndef FORCE_NVSHMEM_API
ctx
,
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_chunked_tokens
,
shmem_ctx_ulong_atomic_add
(
ctx
,
#else
shmem_signal_op_add
(
#endif
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_chunked_tokens
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
}
}
}
}
...
@@ -1867,8 +1900,12 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
...
@@ -1867,8 +1900,12 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
min_head
=
min
(
min_head
,
rdma_receiver_rdma_head
[
i
][
dst_rdma_rank
]);
min_head
=
min
(
min_head
,
rdma_receiver_rdma_head
[
i
][
dst_rdma_rank
]);
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
and
min_head
>=
last_rdma_head
+
num_max_rdma_chunked_send_tokens
and
lane_id
<
kNumRDMARanks
)
{
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
and
min_head
>=
last_rdma_head
+
num_max_rdma_chunked_send_tokens
and
lane_id
<
kNumRDMARanks
)
{
rocshmem
::
rocshmem_ctx_ulong_atomic_add
(
#ifndef FORCE_NVSHMEM_API
ctx
,
rdma_channel_head
.
buffer
(
rdma_rank
),
min_head
-
last_rdma_head
,
shmem_ctx_ulong_atomic_add
(
ctx
,
#else
shmem_signal_op_add
(
#endif
rdma_channel_head
.
buffer
(
rdma_rank
),
min_head
-
last_rdma_head
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
last_rdma_head
=
min_head
;
last_rdma_head
=
min_head
;
...
@@ -1880,7 +1917,9 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
...
@@ -1880,7 +1917,9 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
}
}
}
}
}
}
rocshmem
::
rocshmem_wg_ctx_destroy
(
&
ctx
);
#ifndef FORCE_NVSHMEM_API
shmem_wg_ctx_destroy
(
&
ctx
);
#endif
}
}
void
combine
(
hipDataType
type
,
void
*
combined_x
,
float
*
combined_topk_weights
,
void
combine
(
hipDataType
type
,
void
*
combined_x
,
float
*
combined_topk_weights
,
...
...
csrc/kernels/internode_ll.cu
View file @
ee3551ab
...
@@ -11,9 +11,8 @@
...
@@ -11,9 +11,8 @@
// low latency+RocSHMEM has issue with CTX.
// low latency+RocSHMEM has issue with CTX.
#define ROCM_DISABLE_CTX
#define ROCM_DISABLE_CTX
#include
<roc
shmem
/rocshmem.hpp>
#include
"
shmem
_wrapper.cuh"
using
namespace
rocshmem
;
namespace
deep_ep
{
namespace
deep_ep
{
namespace
internode_ll
{
namespace
internode_ll
{
...
@@ -59,7 +58,7 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
...
@@ -59,7 +58,7 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
int64_t
*
clean_1
,
int
num_clean_int_1
)
{
int64_t
*
clean_1
,
int
num_clean_int_1
)
{
// Barrier before cleaning (in case of unfinished chunked EP)
// Barrier before cleaning (in case of unfinished chunked EP)
if
(
threadIdx
.
x
==
0
)
if
(
threadIdx
.
x
==
0
)
rocshmem
::
rocshmem
_barrier_all
();
internode
::
shmem_device
_barrier_all
();
// Clean
// Clean
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
...
@@ -72,7 +71,7 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
...
@@ -72,7 +71,7 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
// Barrier after cleaning (make sure low-latency mode work
// Barrier after cleaning (make sure low-latency mode work
if
(
threadIdx
.
x
==
0
)
if
(
threadIdx
.
x
==
0
)
rocshmem
::
rocshmem
_barrier_all
();
internode
::
shmem_device
_barrier_all
();
}
}
void
clean_low_latency_buffer
(
int64_t
*
clean_0
,
int
num_clean_int_0
,
void
clean_low_latency_buffer
(
int64_t
*
clean_0
,
int
num_clean_int_0
,
...
@@ -100,8 +99,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -100,8 +99,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int
num_warp_groups
,
int
num_warps_per_group
,
int
num_warp_groups
,
int
num_warps_per_group
,
bool
round_scale
,
int
phases
)
{
bool
round_scale
,
int
phases
)
{
#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX)
__shared__
rocshmem
::
roc
shmem_ctx_t
ctx
;
__shared__
internode
::
shmem_ctx_t
ctx
;
rocshmem
::
roc
shmem_wg_ctx_create
(
0
,
&
ctx
);
internode
::
shmem_wg_ctx_create
(
&
ctx
);
#endif
#endif
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
...
@@ -221,9 +220,9 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -221,9 +220,9 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
slot_idx
*
num_bytes_per_msg
;
slot_idx
*
num_bytes_per_msg
;
if
(
dst_rank
!=
rank
)
{
if
(
dst_rank
!=
rank
)
{
rocshmem
::
rocshmem_schar
_put_nbi_wa
ve
(
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
internode
::
shmemx_int8
_put_nbi_wa
rp
(
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
num_bytes_per_msg
,
dst_rank
);
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
num_bytes_per_msg
,
dst_rank
);
rocshmem
::
roc
shmem_fence
();
internode
::
shmem_fence
();
}
else
{
}
else
{
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
...
@@ -288,7 +287,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -288,7 +287,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Wait local sends issued and send expert counts
// Wait local sends issued and send expert counts
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
if
(
dst_rank
!=
rank
)
{
if
(
dst_rank
!=
rank
)
{
rocshmem
::
roc
shmem_long_atomic_add
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
dst_rank
);
internode
::
shmem_long_atomic_add
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
dst_rank
);
}
else
{
}
else
{
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
),
-
num_tokens_sent
-
1
);
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
),
-
num_tokens_sent
-
1
);
}
}
...
@@ -396,7 +395,7 @@ LOW_LATENCY_DISPATCH_RECV:
...
@@ -396,7 +395,7 @@ LOW_LATENCY_DISPATCH_RECV:
}
}
#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX)
rocshmem
::
roc
shmem_wg_ctx_destroy
(
&
ctx
);
internode
::
shmem_wg_ctx_destroy
(
&
ctx
);
#endif
#endif
}
}
...
@@ -467,8 +466,8 @@ combine(void* combined_x,
...
@@ -467,8 +466,8 @@ combine(void* combined_x,
int
phases
,
bool
zero_copy
)
{
int
phases
,
bool
zero_copy
)
{
#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX)
__shared__
rocshmem
::
roc
shmem_ctx_t
ctx
;
__shared__
internode
::
shmem_ctx_t
ctx
;
rocshmem
::
roc
shmem_wg_ctx_create
(
0
,
&
ctx
);
internode
::
shmem_wg_ctx_create
(
&
ctx
);
#endif
#endif
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
num_sms
=
static_cast
<
int
>
(
gridDim
.
x
);
const
auto
num_sms
=
static_cast
<
int
>
(
gridDim
.
x
);
...
@@ -539,7 +538,7 @@ combine(void* combined_x,
...
@@ -539,7 +538,7 @@ combine(void* combined_x,
const
auto
rdma_send_x_vec_row
=
reinterpret_cast
<
uint8_t
*>
(
rdma_send_type_row
+
4
);
const
auto
rdma_send_x_vec_row
=
reinterpret_cast
<
uint8_t
*>
(
rdma_send_type_row
+
4
);
// Copy directly to local rank, or copy to buffer and issue RDMA
// Copy directly to local rank, or copy to buffer and issue RDMA
const
auto
src_idx
=
shfl_sync
(
__ldg
(
local_src_info
+
token_idx
)
,
0
)
;
const
auto
src_idx
=
__ldg
(
local_src_info
+
token_idx
);
const
auto
buf_ptr
=
reinterpret_cast
<
int64_t
>
(
rdma_send_x_vec_row
);
const
auto
buf_ptr
=
reinterpret_cast
<
int64_t
>
(
rdma_send_x_vec_row
);
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
(
global_expert_idx
*
num_max_dispatch_tokens_per_rank
+
src_idx
)
*
num_bytes_per_slot
+
sizeof
(
int4
);
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
(
global_expert_idx
*
num_max_dispatch_tokens_per_rank
+
src_idx
)
*
num_bytes_per_slot
+
sizeof
(
int4
);
if
(
dst_rank
==
rank
)
{
if
(
dst_rank
==
rank
)
{
...
@@ -552,16 +551,16 @@ combine(void* combined_x,
...
@@ -552,16 +551,16 @@ combine(void* combined_x,
//nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(hip_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
//nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(hip_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
#if defined(ROCM_DISABLE_CTX)
#if defined(ROCM_DISABLE_CTX)
rocshmem
::
rocshmem_schar
_put_nbi_wa
ve
(
internode
::
shmemx_int8
_put_nbi_wa
rp
(
#else
#else
rocshmem
::
roc
shmem_ctx_schar_put_nbi_wa
ve
(
ctx
,
internode
::
shmem_ctx_schar_put_nbi_wa
rp
(
ctx
,
#endif
#endif
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
buf_ptr
),
hidden
*
sizeof
(
hip_bfloat16
),
dst_rank
);
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
buf_ptr
),
hidden
*
sizeof
(
hip_bfloat16
),
dst_rank
);
#if defined(ROCM_DISABLE_CTX)
#if defined(ROCM_DISABLE_CTX)
rocshmem
::
roc
shmem_fence
();
internode
::
shmem_fence
();
#else
#else
rocshmem
::
roc
shmem_ctx_quiet
(
ctx
);
internode
::
shmem_ctx_quiet
(
ctx
);
#endif
#endif
}
}
}
}
...
@@ -578,9 +577,9 @@ combine(void* combined_x,
...
@@ -578,9 +577,9 @@ combine(void* combined_x,
while
(
ld_acquire_global
(
atomic_clean_flag
)
==
0
);
while
(
ld_acquire_global
(
atomic_clean_flag
)
==
0
);
if
(
dst_rank
!=
rank
)
{
if
(
dst_rank
!=
rank
)
{
#if defined(ROCM_DISABLE_CTX)
#if defined(ROCM_DISABLE_CTX)
rocshmem
::
roc
shmem_long_atomic_add
(
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
);
internode
::
shmem_long_atomic_add
(
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
);
#else
#else
rocshmem
::
roc
shmem_ctx_long_atomic_add
(
ctx
,
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
);
internode
::
shmem_ctx_long_atomic_add
(
ctx
,
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
);
#endif
#endif
}
else
{
}
else
{
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_flag
+
global_expert_idx
),
1
);
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_flag
+
global_expert_idx
),
1
);
...
@@ -643,7 +642,7 @@ combine(void* combined_x,
...
@@ -643,7 +642,7 @@ combine(void* combined_x,
}
}
}
}
#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX)
rocshmem
::
roc
shmem_wg_ctx_destroy
(
&
ctx
);
internode
::
shmem_wg_ctx_destroy
(
&
ctx
);
#endif
#endif
}
}
...
...
csrc/kernels/runtime.cu
View file @
ee3551ab
...
@@ -5,10 +5,8 @@
...
@@ -5,10 +5,8 @@
#include "exception.cuh"
#include "exception.cuh"
#include "launch.cuh"
#include "launch.cuh"
#include "utils.cuh"
#include "utils.cuh"
#include "shmem_wrapper.cuh"
#ifndef DISABLE_ROCSHMEM
#include <rocshmem/rocshmem.hpp>
#endif
namespace
deep_ep
{
namespace
deep_ep
{
namespace
intranode
{
namespace
intranode
{
...
@@ -33,60 +31,66 @@ void barrier(int **barrier_signal_ptrs, int rank, int num_ranks, hipStream_t str
...
@@ -33,60 +31,66 @@ void barrier(int **barrier_signal_ptrs, int rank, int num_ranks, hipStream_t str
namespace
internode
{
namespace
internode
{
#ifndef DISABLE_ROCSHMEM
#ifndef DISABLE_ROCSHMEM
rocshmem
::
roc
shmem_team_t
cpu_rdma_team
=
rocshmem
::
ROC
SHMEM_TEAM_INVALID
;
shmem_team_t
cpu_rdma_team
=
EP_
SHMEM_TEAM_INVALID
;
rocshmem
::
roc
shmem_team_config_t
cpu_rdma_team_config
;
shmem_team_config_t
cpu_rdma_team_config
;
std
::
vector
<
uint8_t
>
get_unique_id
()
{
std
::
vector
<
uint8_t
>
get_unique_id
()
{
roc
shmem
::
rocshmem
_uniqueid_t
unique_id
;
shmem
x
_uniqueid_t
unique_id
;
roc
shmem
::
rocshmem
_get_uniqueid
(
&
unique_id
);
shmem
x
_get_uniqueid
(
&
unique_id
);
std
::
vector
<
uint8_t
>
result
(
sizeof
(
roc
shmem
::
rocshmem
_uniqueid_t
));
std
::
vector
<
uint8_t
>
result
(
sizeof
(
shmem
x
_uniqueid_t
));
std
::
memcpy
(
result
.
data
(),
&
unique_id
,
sizeof
(
roc
shmem
::
rocshmem
_uniqueid_t
));
std
::
memcpy
(
result
.
data
(),
&
unique_id
,
sizeof
(
shmem
x
_uniqueid_t
));
return
result
;
return
result
;
}
}
int
init
(
const
std
::
vector
<
uint8_t
>
&
root_unique_id_val
,
int
rank
,
int
num_ranks
,
int
init
(
const
std
::
vector
<
uint8_t
>
&
root_unique_id_val
,
int
rank
,
int
num_ranks
,
bool
low_latency_mode
)
{
bool
low_latency_mode
)
{
shmemx_uniqueid_t
root_unique_id
;
rocshmem
::
rocshmem_uniqueid_t
root_unique_id
;
shmemx_init_attr_t
attr
;
rocshmem
::
rocshmem_init_attr_t
attr
;
std
::
memcpy
(
&
root_unique_id
,
root_unique_id_val
.
data
(),
sizeof
(
shmemx_uniqueid_t
));
std
::
memcpy
(
&
root_unique_id
,
root_unique_id_val
.
data
(),
sizeof
(
rocshmem
::
rocshmem_uniqueid_t
));
shmemx_set_attr_uniqueid_args
(
rank
,
num_ranks
,
&
root_unique_id
,
&
attr
);
rocshmem
::
rocshmem_set_attr_uniqueid_args
(
rank
,
num_ranks
,
&
root_unique_id
,
&
attr
);
shmemx_init_attr
(
EP_SHMEMX_INIT_WITH_UNIQUEID
,
&
attr
);
rocshmem
::
rocshmem_init_attr
(
rocshmem
::
ROCSHMEM_INIT_WITH_UNIQUEID
,
&
attr
);
// Create sub-RDMA teams
// Create sub-RDMA teams
// NOTES: if `num_ranks <= NUM_MAX_NVL_PEERS` then only low-latency kernels are used
// NOTES: if `num_ranks <= NUM_MAX_NVL_PEERS` then only low-latency kernels are used
if
(
low_latency_mode
and
num_ranks
>
NUM_MAX_NVL_PEERS
)
{
if
(
low_latency_mode
and
num_ranks
>
NUM_MAX_NVL_PEERS
)
{
EP_HOST_ASSERT
(
cpu_rdma_team
==
rocshmem
::
ROCSHMEM_TEAM_INVALID
);
shmem_barrier_all
();
EP_HOST_ASSERT
(
cpu_rdma_team
==
EP_SHMEM_TEAM_INVALID
);
EP_HOST_ASSERT
(
num_ranks
%
NUM_MAX_NVL_PEERS
==
0
);
EP_HOST_ASSERT
(
num_ranks
%
NUM_MAX_NVL_PEERS
==
0
);
EP_HOST_ASSERT
(
rocshmem
::
roc
shmem_team_split_strided
(
EP_HOST_ASSERT
(
shmem_team_split_strided
(
rocshmem
::
ROC
SHMEM_TEAM_WORLD
,
rank
%
NUM_MAX_NVL_PEERS
,
EP_
SHMEM_TEAM_WORLD
,
rank
%
NUM_MAX_NVL_PEERS
,
NUM_MAX_NVL_PEERS
,
num_ranks
/
NUM_MAX_NVL_PEERS
,
NUM_MAX_NVL_PEERS
,
num_ranks
/
NUM_MAX_NVL_PEERS
,
&
cpu_rdma_team_config
,
0
,
&
cpu_rdma_team
)
==
0
);
&
cpu_rdma_team_config
,
0
,
&
cpu_rdma_team
)
==
0
);
EP_HOST_ASSERT
(
cpu_rdma_team
!=
rocshmem
::
ROCSHMEM_TEAM_INVALID
);
EP_HOST_ASSERT
(
cpu_rdma_team
!=
EP_SHMEM_TEAM_INVALID
);
#ifdef FORCE_NVSHMEM_API
nvshmemi_device_host_state_t
*
dev_state_ptr
=
nullptr
;
CUDA_CHECK
(
hipGetSymbolAddress
(
reinterpret_cast
<
void
**>
(
&
dev_state_ptr
),
nvshmemi_device_state_d
));
bool
ibgda_is_initialized
=
false
;
CUDA_CHECK
(
hipMemcpy
(
&
dev_state_ptr
->
ibgda_is_initialized
,
&
ibgda_is_initialized
,
sizeof
(
bool
),
hipMemcpyHostToDevice
));
#endif
}
}
rocshmem
::
roc
shmem_barrier_all
();
shmem_barrier_all
();
return
rocshmem
::
roc
shmem_my_pe
();
return
shmem_my_pe
();
}
}
void
*
alloc
(
size_t
size
,
size_t
alignment
)
{
void
*
alloc
(
size_t
size
,
size_t
alignment
)
{
auto
alloc_size
=
ALIGN
(
size
,
alignment
);
return
shmem_align
(
size
,
alignment
);
return
rocshmem
::
rocshmem_malloc
(
alloc_size
);
}
}
void
free
(
void
*
ptr
)
{
void
free
(
void
*
ptr
)
{
rocshmem
::
roc
shmem_free
(
ptr
);
shmem_free
(
ptr
);
}
}
void
barrier
()
{
void
barrier
()
{
rocshmem
::
roc
shmem_barrier_all
();
shmem_barrier_all
();
}
}
void
finalize
()
{
void
finalize
()
{
if
(
cpu_rdma_team
!=
rocshmem
::
ROC
SHMEM_TEAM_INVALID
)
{
if
(
cpu_rdma_team
!=
EP_
SHMEM_TEAM_INVALID
)
{
rocshmem
::
roc
shmem_team_destroy
(
cpu_rdma_team
);
shmem_team_destroy
(
cpu_rdma_team
);
cpu_rdma_team
=
rocshmem
::
ROC
SHMEM_TEAM_INVALID
;
cpu_rdma_team
=
EP_
SHMEM_TEAM_INVALID
;
}
}
rocshmem
::
roc
shmem_finalize
();
shmem_finalize
();
}
}
#endif
#endif
...
...
csrc/kernels/shmem_wrapper.cuh
0 → 100644
View file @
ee3551ab
#pragma once
/*
* Temporary wrapper for for platform specific NVSHMEM and rocSHMEM functions.
* Once hipify or hipify-torch fully supports this mapping, this file has to be
* removed and according nvshmem* functions restored.
*/
#ifndef DISABLE_ROCSHMEM
#include "configs.cuh"
#ifndef FORCE_NVSHMEM_API
#include <hip/hip_bfloat16.h>
#include <hip/hip_fp8.h>
#include <hip/hip_runtime.h>
#include <rocshmem/rocshmem.hpp>
#else
#include <device_host_transport/nvshmem_common_ibgda.h>
#include <infiniband/mlx5dv.h>
#include <nvshmem.h>
#include <nvshmemx.h>
#include <non_abi/device/threadgroup/nvshmemi_common_device_defines.cuh>
#endif
namespace
deep_ep
::
internode
{
// rocSHMEM wrapper
#ifndef FORCE_NVSHMEM_API
using
shmem_team_t
=
rocshmem
::
rocshmem_team_t
;
using
shmem_team_config_t
=
rocshmem
::
rocshmem_team_config_t
;
const
shmem_team_t
EP_SHMEM_TEAM_INVALID
=
rocshmem
::
ROCSHMEM_TEAM_INVALID
;
inline
shmem_team_t
&
EP_SHMEM_TEAM_WORLD
=
rocshmem
::
ROCSHMEM_TEAM_WORLD
;
using
shmemx_uniqueid_t
=
rocshmem
::
rocshmem_uniqueid_t
;
using
shmemx_init_attr_t
=
rocshmem
::
rocshmem_init_attr_t
;
constexpr
auto
EP_SHMEMX_INIT_WITH_UNIQUEID
=
rocshmem
::
ROCSHMEM_INIT_WITH_UNIQUEID
;
__host__
inline
int
shmemx_get_uniqueid
(
shmemx_uniqueid_t
*
uid
)
{
return
rocshmem
::
rocshmem_get_uniqueid
(
uid
);
}
__host__
inline
int
shmemx_set_attr_uniqueid_args
(
int
rank
,
int
nranks
,
shmemx_uniqueid_t
*
uid
,
shmemx_init_attr_t
*
attr
)
{
return
rocshmem
::
rocshmem_set_attr_uniqueid_args
(
rank
,
nranks
,
uid
,
attr
);
}
__host__
inline
int
shmemx_init_attr
(
unsigned
int
flags
,
shmemx_init_attr_t
*
attr
)
{
return
rocshmem
::
rocshmem_init_attr
(
flags
,
attr
);
}
__host__
inline
int
shmem_team_split_strided
(
shmem_team_t
parent_team
,
int
start
,
int
stride
,
int
size
,
const
shmem_team_config_t
*
config
,
long
config_mask
,
shmem_team_t
*
new_team
)
{
return
rocshmem
::
rocshmem_team_split_strided
(
parent_team
,
start
,
stride
,
size
,
config
,
config_mask
,
new_team
);
}
__host__
inline
void
shmem_barrier_all
()
{
rocshmem
::
rocshmem_barrier_all
();
}
__device__
inline
void
shmem_device_barrier_all
()
{
rocshmem
::
rocshmem_barrier_all
();
}
__device__
inline
void
shmem_barrier
(
shmem_team_t
team
)
{
rocshmem
::
rocshmem_ctx_barrier
(
rocshmem
::
ROCSHMEM_CTX_DEFAULT
,
team
);
}
__host__
inline
int
shmem_my_pe
(){
return
rocshmem
::
rocshmem_my_pe
();
}
__host__
inline
void
shmem_free
(
void
*
ptr
){
rocshmem
::
rocshmem_free
(
ptr
);
}
__host__
inline
void
*
shmem_align
(
const
size_t
alignment
,
const
size_t
size
)
{
auto
alloc_size
=
ALIGN
(
size
,
alignment
);
return
rocshmem
::
rocshmem_malloc
(
alloc_size
);
}
__host__
inline
void
shmem_finalize
()
{
rocshmem
::
rocshmem_finalize
();
}
__host__
inline
void
shmem_team_destroy
(
shmem_team_t
team
)
{
rocshmem
::
rocshmem_team_destroy
(
team
);
}
__device__
inline
void
shmem_fence
()
{
rocshmem
::
rocshmem_fence
();
}
__device__
inline
void
shmem_int_put_nbi
(
int
*
dest
,
const
int
*
source
,
size_t
nelems
,
int
pe
)
{
rocshmem
::
rocshmem_int_put_nbi
(
dest
,
source
,
nelems
,
pe
);
}
__device__
inline
void
shmemx_int_put_nbi_warp
(
int
*
dest
,
const
int
*
source
,
size_t
nelems
,
int
pe
)
{
rocshmem
::
rocshmem_int_put_nbi_wave
(
dest
,
source
,
nelems
,
pe
);
}
__device__
inline
void
shmemx_int8_put_nbi_warp
(
signed
char
*
dest
,
const
signed
char
*
source
,
size_t
nelems
,
int
pe
)
{
rocshmem
::
rocshmem_schar_put_nbi_wave
(
dest
,
source
,
nelems
,
pe
);
}
__device__
inline
void
shmem_long_atomic_add
(
long
*
dest
,
long
value
,
int
pe
)
{
rocshmem
::
rocshmem_long_atomic_add
(
dest
,
value
,
pe
);
}
#if !defined(ROCM_DISABLE_CTX)
using
shmem_ctx_t
=
rocshmem
::
rocshmem_ctx_t
;
__device__
inline
int
shmem_wg_ctx_create
(
shmem_ctx_t
*
ctx
)
{
return
rocshmem
::
rocshmem_wg_ctx_create
(
0
,
ctx
);
}
__device__
inline
void
shmem_wg_ctx_destroy
(
shmem_ctx_t
*
ctx
)
{
rocshmem
::
rocshmem_wg_ctx_destroy
(
ctx
);
}
__device__
inline
void
shmem_ctx_quiet
(
shmem_ctx_t
ctx
)
{
rocshmem
::
rocshmem_ctx_quiet
(
ctx
);
}
__device__
inline
void
shmem_ctx_ulong_atomic_add
(
shmem_ctx_t
ctx
,
uint64_t
*
dest
,
uint64_t
value
,
int
pe
)
{
rocshmem
::
rocshmem_ctx_ulong_atomic_add
(
ctx
,
dest
,
value
,
pe
);
}
__device__
inline
void
shmem_ctx_long_atomic_add
(
shmem_ctx_t
ctx
,
long
*
dest
,
long
value
,
int
pe
)
{
rocshmem
::
rocshmem_ctx_long_atomic_add
(
ctx
,
dest
,
value
,
pe
);
}
__device__
inline
void
shmem_ctx_schar_put_nbi_warp
(
shmem_ctx_t
ctx
,
signed
char
*
dest
,
const
signed
char
*
source
,
size_t
nelems
,
int
pe
)
{
rocshmem
::
rocshmem_ctx_schar_put_nbi_wave
(
ctx
,
dest
,
source
,
nelems
,
pe
);
}
__device__
inline
void
shmem_ctx_int_put_nbi_warp
(
shmem_ctx_t
ctx
,
int
*
dest
,
const
int
*
source
,
size_t
nelems
,
int
pe
)
{
rocshmem
::
rocshmem_ctx_int_put_nbi_wave
(
ctx
,
dest
,
source
,
nelems
,
pe
);
}
#endif
#else
// NVSHMEM wrapper
#ifndef ROCM_DISABLE_CTX
#define ROCM_DISABLE_CTX
#endif
using
shmem_team_t
=
nvshmem_team_t
;
using
shmem_team_config_t
=
nvshmem_team_config_t
;
using
shmemx_uniqueid_t
=
nvshmemx_uniqueid_t
;
using
shmemx_init_attr_t
=
nvshmemx_init_attr_t
;
const
shmem_team_t
EP_SHMEM_TEAM_INVALID
=
NVSHMEM_TEAM_INVALID
;
const
shmem_team_t
EP_SHMEM_TEAM_WORLD
=
NVSHMEM_TEAM_WORLD
;
constexpr
auto
EP_SHMEMX_INIT_WITH_UNIQUEID
=
NVSHMEMX_INIT_WITH_UNIQUEID
;
__host__
inline
int
shmemx_get_uniqueid
(
shmemx_uniqueid_t
*
uid
)
{
return
nvshmemx_get_uniqueid
(
uid
);
}
__host__
inline
int
shmemx_set_attr_uniqueid_args
(
int
rank
,
int
nranks
,
shmemx_uniqueid_t
*
uid
,
shmemx_init_attr_t
*
attr
)
{
return
nvshmemx_set_attr_uniqueid_args
(
rank
,
nranks
,
uid
,
attr
);
}
__host__
inline
int
shmemx_init_attr
(
unsigned
int
flags
,
shmemx_init_attr_t
*
attr
)
{
return
nvshmemx_init_attr
(
flags
,
attr
);
}
__host__
inline
int
shmem_team_split_strided
(
shmem_team_t
parent_team
,
int
start
,
int
stride
,
int
size
,
const
shmem_team_config_t
*
config
,
long
config_mask
,
shmem_team_t
*
new_team
)
{
return
nvshmem_team_split_strided
(
parent_team
,
start
,
stride
,
size
,
config
,
config_mask
,
new_team
);
}
__host__
inline
void
shmem_barrier_all
()
{
nvshmem_barrier_all
();
}
__device__
inline
void
shmem_device_barrier_all
()
{
nvshmem_barrier_all
();
}
__device__
inline
void
shmem_barrier
(
shmem_team_t
team
)
{
void
(
nvshmem_barrier
(
team
));
}
__host__
inline
int
shmem_my_pe
(){
return
nvshmem_my_pe
();
}
__host__
inline
void
shmem_free
(
void
*
ptr
){
nvshmem_free
(
ptr
);
}
__host__
inline
void
*
shmem_align
(
const
size_t
alignment
,
const
size_t
size
)
{
return
nvshmem_align
(
size
,
alignment
);
}
__host__
inline
void
shmem_finalize
()
{
nvshmem_finalize
();
}
__host__
inline
void
shmem_team_destroy
(
shmem_team_t
team
)
{
nvshmem_team_destroy
(
team
);
}
__device__
inline
void
shmem_fence
()
{
nvshmem_fence
();
}
__device__
inline
void
shmem_int_put_nbi
(
int
*
dest
,
const
int
*
source
,
size_t
nelems
,
int
pe
)
{
nvshmem_int_put_nbi
(
dest
,
source
,
nelems
,
pe
);
}
__device__
inline
void
shmemx_int_put_nbi_warp
(
int
*
dest
,
const
int
*
source
,
size_t
nelems
,
int
pe
)
{
nvshmemx_int_put_nbi_warp
(
dest
,
source
,
nelems
,
pe
);
}
__device__
inline
void
shmemx_int8_put_nbi_warp
(
signed
char
*
dest
,
const
signed
char
*
source
,
size_t
nelems
,
int
pe
)
{
nvshmemx_int8_put_nbi_warp
(
dest
,
source
,
nelems
,
pe
);
}
__device__
inline
void
shmem_signal_op_add
(
uint64_t
*
dest
,
uint64_t
value
,
int
pe
)
{
nvshmemx_signal_op
(
dest
,
value
,
NVSHMEM_SIGNAL_ADD
,
pe
);
}
__device__
inline
void
shmem_ulong_atomic_add
(
uint64_t
*
dest
,
uint64_t
value
,
int
pe
)
{
nvshmem_ulong_atomic_add
(
dest
,
value
,
pe
);
}
__device__
inline
void
shmem_long_atomic_add
(
long
*
dest
,
long
value
,
int
pe
)
{
// nvshmem_##Name##_atomic_add(dest, value, pe);
nvshmem_long_atomic_add
(
dest
,
value
,
pe
);
}
#endif
}
// namespace deep_ep::internode
#endif
csrc/kernels/utils.cuh
View file @
ee3551ab
...
@@ -342,7 +342,7 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
...
@@ -342,7 +342,7 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
return
*
reinterpret_cast
<
dtype_t
*>
(
recv_int_values
);
return
*
reinterpret_cast
<
dtype_t
*>
(
recv_int_values
);
}
}
#ifdef
USE_ROCM
#if
n
def
FORCE_NVSHMEM_API
constexpr
float
kFP8Margin
=
1e-4
;
constexpr
float
kFP8Margin
=
1e-4
;
constexpr
float
kFinfoAmaxE4M3
=
240.0
f
;
constexpr
float
kFinfoAmaxE4M3
=
240.0
f
;
constexpr
float
kFinfoAmaxInvE4M3
=
1.0
f
/
kFinfoAmaxE4M3
;
constexpr
float
kFinfoAmaxInvE4M3
=
1.0
f
/
kFinfoAmaxE4M3
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment