Commit d91f2eb9 authored by lishen's avatar lishen
Browse files

Merge branch 'rocshmem_dushmem' into 'main'

Rocshmem dushmem

See merge request dcutoolkit/deeplearing/DeepEP!2
parents 09c4817e d1bf10d3
...@@ -8,10 +8,21 @@ fi ...@@ -8,10 +8,21 @@ fi
PYTHON_INCLUDE=$(python3 -c "from sysconfig import get_paths; print(get_paths()['include'])") PYTHON_INCLUDE=$(python3 -c "from sysconfig import get_paths; print(get_paths()['include'])")
PYTHON_PLATLIB=$(python3 -c "from sysconfig import get_paths; print(get_paths()['platlib'])") PYTHON_PLATLIB=$(python3 -c "from sysconfig import get_paths; print(get_paths()['platlib'])")
# --------------------------------------------------------------------- #
USE_NVSHMEM=${USE_NVSHMEM:=OFF}
ROCSHMEM_INSTALL_PREFIX=${ROCSHMEM_INSTALL_PREFIX:=$(pwd)/rocshmem_dir} ROCSHMEM_INSTALL_PREFIX=${ROCSHMEM_INSTALL_PREFIX:=$(pwd)/rocshmem_dir}
COMPILE_OPTIONS=${COMPILE_OPTIONS:= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 -Wno-return-type}
SHMEM_LINK_OPTIONS=${SHMEM_LINK_OPTIONS:="-Wl,-rpath,${ROCSHMEM_INSTALL_PREFIX}/lib/ -l:librocshmem.a"}
####
# 检查是否设置了USE_NVSHMEM环境变量
if [ "$USE_NVSHMEM" == "ON" ]; then
COMPILE_OPTIONS+=" -DFORCE_NVSHMEM_API"
ROCSHMEM_INSTALL_PREFIX=???/dushmem_dir
SHMEM_LINK_OPTIONS="-Wl,-rpath,${ROCSHMEM_INSTALL_PREFIX}/lib/ -l:libnvshmem_device.a -lnvshmem_host"
fi
INCLUDE_PATHS=${INCLUDE_PATHS:=-Icsrc/ -I${ROCSHMEM_INSTALL_PREFIX}/include/ -I/opt/mpi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE}} INCLUDE_PATHS=${INCLUDE_PATHS:=-Icsrc/ -I${ROCSHMEM_INSTALL_PREFIX}/include/ -I/opt/mpi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE}}
COMPILE_OPTIONS=${COMPILE_OPTIONS:= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 -Wno-return-type}
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/runtime.cu -o build_/runtime.o ${COMPILE_OPTIONS} hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/runtime.cu -o build_/runtime.o ${COMPILE_OPTIONS}
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/layout.cu -o build_/layout.o ${COMPILE_OPTIONS} hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/layout.cu -o build_/layout.o ${COMPILE_OPTIONS}
...@@ -20,7 +31,7 @@ hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode.cu -o build_/internode.o ...@@ -20,7 +31,7 @@ hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode.cu -o build_/internode.o
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode_ll.cu -o build_/internode_ll.o ${COMPILE_OPTIONS} hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode_ll.cu -o build_/internode_ll.o ${COMPILE_OPTIONS}
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/deep_ep.cu -o build_/deep_ep.o ${COMPILE_OPTIONS} hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/deep_ep.cu -o build_/deep_ep.o ${COMPILE_OPTIONS}
hipcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -shared -Wl,-O1 -Wl,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o build_/internode_ll.o -L${ROCSHMEM_INSTALL_PREFIX}/lib/ -L/opt/mpi/lib -L/opt/dtk/hip/lib -L/usr/lib/x86_64-linux-gnu -lhipblaslt -lamdhip64 -o deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,/opt/dtk/lib -fgpu-rdc --hip-link --offload-arch=gfx936 -shared -Wl,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,$${ROCSHMEM_INSTALL_PREFIX}/lib/ -L"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux" -lclang_rt.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0 -L${PYTHON_PLATLIB}/torch/lib -L/opt/dtk/lib -L/opt/dtk/hip/lib -L/usr/local/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lamdhip64 -lc10_hip -ltorch_hip -lrocm-core -lrocm_smi64 -l:librocshmem.a -fgpu-rdc --hip-link -lamdhip64 -lhsa-runtime64 -l:libmpi.so -Wl,-rpath,/opt/mpi/lib/ -libverbs -lmlx5 hipcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -shared -Wl,-O1 -Wl,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o build_/internode_ll.o -L${ROCSHMEM_INSTALL_PREFIX}/lib/ -L/opt/mpi/lib -L/opt/dtk/hip/lib -L/usr/lib/x86_64-linux-gnu -lhipblaslt -lamdhip64 -o deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,/opt/dtk/lib -fgpu-rdc --hip-link --offload-arch=gfx936 -shared -Wl,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -L"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux" -lclang_rt.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0 -L${PYTHON_PLATLIB}/torch/lib -L/opt/dtk/lib -L/opt/dtk/hip/lib -L/usr/local/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lamdhip64 -lc10_hip -ltorch_hip -lrocm-core -lrocm_smi64 ${SHMEM_LINK_OPTIONS} -fgpu-rdc --hip-link -lamdhip64 -lhsa-runtime64 -l:libmpi.so -Wl,-rpath,/opt/mpi/lib/ -libverbs -lmlx5
# build whl # build whl
echo "Using Python: $(which python3)" echo "Using Python: $(which python3)"
......
...@@ -3,11 +3,10 @@ ...@@ -3,11 +3,10 @@
#include "configs.cuh" #include "configs.cuh"
#include "launch.cuh" #include "launch.cuh"
#include "utils.cuh" #include "utils.cuh"
#include "shmem_wrapper.cuh"
#ifndef DISABLE_ROCSHMEM #ifndef DISABLE_ROCSHMEM
#include <rocshmem/rocshmem.hpp>
// TODO: fix unroll warnings // TODO: fix unroll warnings
// #ifdef __clang__ // #ifdef __clang__
// #pragma clang diagnostic push // #pragma clang diagnostic push
...@@ -19,7 +18,7 @@ namespace deep_ep { ...@@ -19,7 +18,7 @@ namespace deep_ep {
namespace internode { namespace internode {
extern rocshmem::rocshmem_team_t cpu_rdma_team; extern shmem_team_t cpu_rdma_team;
struct SourceMeta { struct SourceMeta {
int src_rdma_rank, is_token_in_nvl_rank_bits; int src_rdma_rank, is_token_in_nvl_rank_bits;
...@@ -51,9 +50,8 @@ __host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_ ...@@ -51,9 +50,8 @@ __host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_
int num_topk_idx, int num_topk_idx,
int num_topk_weights) { int num_topk_weights) {
return static_cast<int>(ALIGN(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) + return static_cast<int>(ALIGN(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) +
num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) +
num_topk_weights * sizeof(float), num_topk_weights * sizeof(float), sizeof(int4)));
sizeof(int4)));
} }
__host__ __device__ __forceinline__ std::pair<int, int> __host__ __device__ __forceinline__ std::pair<int, int>
...@@ -61,9 +59,8 @@ get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_t ...@@ -61,9 +59,8 @@ get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_t
int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_sms) { int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_sms) {
// Return `int32_t` offset and count to clean // Return `int32_t` offset and count to clean
return {(get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * return {(get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) *
num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_sms) / num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_sms) / sizeof(int),
sizeof(int), (NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_sms};
(NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_sms};
} }
__host__ __device__ __forceinline__ std::pair<int, int> __host__ __device__ __forceinline__ std::pair<int, int>
get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
...@@ -74,10 +71,9 @@ get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_to ...@@ -74,10 +71,9 @@ get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_to
"Invalid size of `SourceMeta`"); "Invalid size of `SourceMeta`");
return { return {
(num_nvl_recv_buffer_tokens * (num_nvl_recv_buffer_tokens *
(hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + (hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) +
num_topk_weights * sizeof(float) + sizeof(SourceMeta)) * num_topk_weights * sizeof(float) + sizeof(SourceMeta)) *
num_nvl_ranks * num_sms) / num_nvl_ranks * num_sms) / sizeof(int),
sizeof(int),
num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_sms, num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_sms,
}; };
} }
...@@ -90,12 +86,10 @@ __forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank, ...@@ -90,12 +86,10 @@ __forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank,
template <bool kLowLatencyMode> template <bool kLowLatencyMode>
__forceinline__ __device__ void __forceinline__ __device__ void
nvshmem_barrier_with_same_gpu_idx(const rocshmem::rocshmem_team_t &rdma_team) { nvshmem_barrier_with_same_gpu_idx(const shmem_team_t &rdma_team) {
// NOTE: shmem_device_barrier_all() might be an issue as // NOTE: shmem_device_barrier_all() might be an issue as
// it doesn't follow OpenSHMEM specification on ROCm // it doesn't follow OpenSHMEM specification on ROCm
kLowLatencyMode kLowLatencyMode ? shmem_barrier(rdma_team) : shmem_device_barrier_all();
? void(rocshmem::rocshmem_ctx_barrier(rocshmem::ROCSHMEM_CTX_DEFAULT, rdma_team))
: rocshmem::rocshmem_barrier_all();
} }
template <bool kLowLatencyMode, int kNumRDMARanks> template <bool kLowLatencyMode, int kNumRDMARanks>
...@@ -109,7 +103,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in ...@@ -109,7 +103,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
int *rdma_channel_prefix_matrix, int *recv_rdma_rank_prefix_sum, int *rdma_channel_prefix_matrix, int *recv_rdma_rank_prefix_sum,
int *gbl_channel_prefix_matrix, int *recv_gbl_rank_prefix_sum, int *gbl_channel_prefix_matrix, int *recv_gbl_rank_prefix_sum,
void *rdma_buffer_ptr, void **buffer_ptrs, int **barrier_signal_ptrs, int rank, void *rdma_buffer_ptr, void **buffer_ptrs, int **barrier_signal_ptrs, int rank,
const rocshmem::rocshmem_team_t rdma_team) { const shmem_team_t rdma_team) {
auto sm_id = static_cast<int>(blockIdx.x); auto sm_id = static_cast<int>(blockIdx.x);
auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize, auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize,
lane_id = get_lane_id(); lane_id = get_lane_id();
...@@ -159,7 +153,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in ...@@ -159,7 +153,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
// TODO: more light fence or barrier or signaling // TODO: more light fence or barrier or signaling
// TODO: overlap EP barrier and NVL cleaning // TODO: overlap EP barrier and NVL cleaning
if (thread_id < kNumRDMARanks) { if (thread_id < kNumRDMARanks) {
rocshmem::rocshmem_int_put_nbi( shmem_int_put_nbi(
rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank), rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank),
rdma_recv_num_tokens_mixed.send_buffer(thread_id), rdma_recv_num_tokens_mixed.send_buffer(thread_id),
NUM_MAX_NVL_PEERS + num_rdma_experts + 1, NUM_MAX_NVL_PEERS + num_rdma_experts + 1,
...@@ -405,9 +399,10 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv ...@@ -405,9 +399,10 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
kForwarderCoordinator, // 向远端RDMA确认接收 kForwarderCoordinator, // 向远端RDMA确认接收
kNVLReceivers // 从nvl缓存写入到recv_x kNVLReceivers // 从nvl缓存写入到recv_x
}; };
#ifndef FORCE_NVSHMEM_API
__shared__ rocshmem::rocshmem_ctx_t ctx; __shared__ shmem_ctx_t ctx;
rocshmem::rocshmem_wg_ctx_create(0, &ctx); shmem_wg_ctx_create(&ctx);
#endif
const auto sm_id = static_cast<int>(blockIdx.x); const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize; const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
...@@ -521,13 +516,23 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv ...@@ -521,13 +516,23 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
syncwarp(); syncwarp();
if (dst_rdma_rank != rdma_rank) { if (dst_rdma_rank != rdma_rank) {
rocshmem::rocshmem_ctx_int_put_nbi_wave( #ifndef FORCE_NVSHMEM_API
ctx, rdma_channel_meta.recv_buffer(rdma_rank), shmem_ctx_int_put_nbi_warp(ctx,
#else
shmemx_int_put_nbi_warp(
#endif
rdma_channel_meta.recv_buffer(rdma_rank),
rdma_channel_meta.send_buffer(dst_rdma_rank), NUM_MAX_NVL_PEERS * 2 + 2, rdma_channel_meta.send_buffer(dst_rdma_rank), NUM_MAX_NVL_PEERS * 2 + 2,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank)); translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
} }
} }
rocshmem::rocshmem_ctx_quiet(ctx);
#ifndef FORCE_NVSHMEM_API
shmem_ctx_quiet(ctx);
#else
shmem_fence();
#endif
// sync_rdma_sender_smem(); // sync_rdma_sender_smem();
__syncthreads(); __syncthreads();
...@@ -736,15 +741,22 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv ...@@ -736,15 +741,22 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
if(dst_rdma_rank != rdma_rank) { if(dst_rdma_rank != rdma_rank) {
auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens; auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens); EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);
rocshmem::rocshmem_ctx_schar_put_nbi_wave( #ifndef FORCE_NVSHMEM_API
ctx, shmem_ctx_schar_put_nbi_warp(ctx,
#else
shmemx_int8_put_nbi_warp(
#endif
rdma_channel_data.recv_buffer(rdma_rank) + rdma_channel_data.recv_buffer(rdma_rank) +
dst_slot_idx * num_bytes_per_rdma_token, dst_slot_idx * num_bytes_per_rdma_token,
rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_channel_data.send_buffer(dst_rdma_rank) +
dst_slot_idx * num_bytes_per_rdma_token, dst_slot_idx * num_bytes_per_rdma_token,
num_bytes_per_rdma_token * num_tokens_to_issue, num_bytes_per_rdma_token * num_tokens_to_issue,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank)); translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
rocshmem::rocshmem_ctx_quiet(ctx); #ifndef FORCE_NVSHMEM_API
shmem_ctx_quiet(ctx);
#else
shmem_fence();
#endif
} else { } else {
// 对于本地RDMA秩,使用较轻的内存屏障 // 对于本地RDMA秩,使用较轻的内存屏障
memory_fence(); memory_fence();
...@@ -756,8 +768,12 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv ...@@ -756,8 +768,12 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
last_issued_tail += num_tokens_to_issue; last_issued_tail += num_tokens_to_issue;
num_tokens_to_send -= num_tokens_to_issue; num_tokens_to_send -= num_tokens_to_issue;
// 更新远端rdma 己方已发送的token数,用于做发送信息同步。用于与kRDMAAndNVLForwarder互相通信 // 更新远端rdma 己方已发送的token数,用于做发送信息同步。用于与kRDMAAndNVLForwarder互相通信
rocshmem::rocshmem_ctx_ulong_atomic_add( #ifndef FORCE_NVSHMEM_API
ctx, rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, shmem_ctx_ulong_atomic_add(ctx,
#else
shmem_signal_op_add(
#endif
rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank)); translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
} }
} }
...@@ -992,8 +1008,12 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv ...@@ -992,8 +1008,12 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
// 更新远程头部 // 更新远程头部
if(min_head != std::numeric_limits<int>::max() && min_head >= last_head + num_max_rdma_chunked_send_tokens && lane_id < kNumRDMARanks){ if(min_head != std::numeric_limits<int>::max() && min_head >= last_head + num_max_rdma_chunked_send_tokens && lane_id < kNumRDMARanks){
rocshmem::rocshmem_ctx_ulong_atomic_add( #ifndef FORCE_NVSHMEM_API
ctx, rdma_channel_head.buffer(rdma_rank), min_head - last_head, shmem_ctx_ulong_atomic_add(ctx,
#else
shmem_signal_op_add(
#endif
rdma_channel_head.buffer(rdma_rank), min_head - last_head,
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank)); translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
last_head = min_head; last_head = min_head;
} }
...@@ -1107,7 +1127,9 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv ...@@ -1107,7 +1127,9 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
} }
} // while(num_tokens_to_recv > 0) } // while(num_tokens_to_recv > 0)
} }
rocshmem::rocshmem_wg_ctx_destroy(&ctx); #ifndef FORCE_NVSHMEM_API
shmem_wg_ctx_destroy(&ctx);
#endif
} }
void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights, void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights,
...@@ -1166,7 +1188,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i ...@@ -1166,7 +1188,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
int num_channels, const int *rdma_channel_prefix_matrix, int num_channels, const int *rdma_channel_prefix_matrix,
const int *rdma_rank_prefix_sum, int *combined_nvl_head, void *rdma_buffer_ptr, const int *rdma_rank_prefix_sum, int *combined_nvl_head, void *rdma_buffer_ptr,
void **buffer_ptrs, int **barrier_signal_ptrs, int rank, int num_ranks, void **buffer_ptrs, int **barrier_signal_ptrs, int rank, int num_ranks,
bool is_cached_dispatch, const rocshmem::rocshmem_team_t rdma_team) { bool is_cached_dispatch, const shmem_team_t rdma_team) {
auto sm_id = static_cast<int>(blockIdx.x); auto sm_id = static_cast<int>(blockIdx.x);
auto thread_id = static_cast<int>(threadIdx.x); auto thread_id = static_cast<int>(threadIdx.x);
auto num_threads = static_cast<int>(blockDim.x); auto num_threads = static_cast<int>(blockDim.x);
...@@ -1189,7 +1211,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i ...@@ -1189,7 +1211,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
auto rdma_buffer_ptr_int = reinterpret_cast<int *>(rdma_buffer_ptr); auto rdma_buffer_ptr_int = reinterpret_cast<int *>(rdma_buffer_ptr);
for (int i = thread_id; i < rdma_num_int_clean; i += num_threads) for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)
rdma_buffer_ptr_int[rdma_clean_offset + i] = 0; rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;
rocshmem::rocshmem_fence(); shmem_fence();
__syncthreads(); __syncthreads();
// Barrier again // Barrier again
...@@ -1395,9 +1417,10 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1395,9 +1417,10 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
kRDMACoordinator, kRDMACoordinator,
kNVLCoordinator kNVLCoordinator
}; };
#ifndef FORCE_NVSHMEM_API
__shared__ rocshmem::rocshmem_ctx_t ctx; __shared__ shmem_ctx_t ctx;
rocshmem::rocshmem_wg_ctx_create(0, &ctx); shmem_wg_ctx_create(&ctx);
#endif
const auto sm_id = static_cast<int>(blockIdx.x); const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize; const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
...@@ -1721,16 +1744,22 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1721,16 +1744,22 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
if(sub_warp_id == kNumWarpsPerForwarder - 1) { if(sub_warp_id == kNumWarpsPerForwarder - 1) {
if(dst_rdma_rank != rdma_rank) { if(dst_rdma_rank != rdma_rank) {
auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens; auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
rocshmem::rocshmem_ctx_schar_put_nbi_wave( #ifndef FORCE_NVSHMEM_API
ctx, shmem_ctx_schar_put_nbi_warp(ctx,
#else
shmemx_int8_put_nbi_warp(
#endif
rdma_channel_data.recv_buffer(rdma_rank) + rdma_channel_data.recv_buffer(rdma_rank) +
rdma_slot_idx * num_bytes_per_rdma_token, rdma_slot_idx * num_bytes_per_rdma_token,
rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_channel_data.send_buffer(dst_rdma_rank) +
rdma_slot_idx * num_bytes_per_rdma_token, rdma_slot_idx * num_bytes_per_rdma_token,
num_chunked_tokens * num_bytes_per_rdma_token, num_chunked_tokens * num_bytes_per_rdma_token,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank)); translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
#ifndef FORCE_NVSHMEM_API
rocshmem::rocshmem_ctx_quiet(ctx); shmem_ctx_quiet(ctx);
#else
shmem_fence();
#endif
} else { } else {
memory_fence(); memory_fence();
} }
...@@ -1738,8 +1767,12 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1738,8 +1767,12 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
// Write new RDMA tail // Write new RDMA tail
syncwarp(); syncwarp();
if(lane_id == 0) { if(lane_id == 0) {
rocshmem::rocshmem_ctx_ulong_atomic_add( #ifndef FORCE_NVSHMEM_API
ctx, rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, shmem_ctx_ulong_atomic_add(ctx,
#else
shmem_signal_op_add(
#endif
rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank)); translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
} }
} }
...@@ -1867,8 +1900,12 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1867,8 +1900,12 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]); min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
rocshmem::rocshmem_ctx_ulong_atomic_add( #ifndef FORCE_NVSHMEM_API
ctx, rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, shmem_ctx_ulong_atomic_add(ctx,
#else
shmem_signal_op_add(
#endif
rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank)); translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
last_rdma_head = min_head; last_rdma_head = min_head;
...@@ -1880,7 +1917,9 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1880,7 +1917,9 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
} }
} }
} }
rocshmem::rocshmem_wg_ctx_destroy(&ctx); #ifndef FORCE_NVSHMEM_API
shmem_wg_ctx_destroy(&ctx);
#endif
} }
void combine(hipDataType type, void *combined_x, float *combined_topk_weights, void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
......
...@@ -11,9 +11,8 @@ ...@@ -11,9 +11,8 @@
// low latency+RocSHMEM has issue with CTX. // low latency+RocSHMEM has issue with CTX.
#define ROCM_DISABLE_CTX #define ROCM_DISABLE_CTX
#include <rocshmem/rocshmem.hpp> #include "shmem_wrapper.cuh"
using namespace rocshmem;
namespace deep_ep { namespace deep_ep {
namespace internode_ll { namespace internode_ll {
...@@ -37,7 +36,7 @@ __device__ void grid_barrier(int* global_counter, int num_blocks) { ...@@ -37,7 +36,7 @@ __device__ void grid_barrier(int* global_counter, int num_blocks) {
} }
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
while (__hip_atomic_load(global_counter, __ATOMIC_RELAXED,__HIP_MEMORY_SCOPE_AGENT) != num_blocks); while (__hip_atomic_load(global_counter, __ATOMIC_RELAXED,__HIP_MEMORY_SCOPE_AGENT) != num_blocks);
} }
__syncthreads(); __syncthreads();
} }
...@@ -59,7 +58,7 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, ...@@ -59,7 +58,7 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
int64_t* clean_1, int num_clean_int_1) { int64_t* clean_1, int num_clean_int_1) {
// Barrier before cleaning (in case of unfinished chunked EP) // Barrier before cleaning (in case of unfinished chunked EP)
if (threadIdx.x == 0) if (threadIdx.x == 0)
rocshmem::rocshmem_barrier_all(); internode::shmem_device_barrier_all();
// Clean // Clean
auto thread_id = static_cast<int>(threadIdx.x); auto thread_id = static_cast<int>(threadIdx.x);
...@@ -70,9 +69,9 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, ...@@ -70,9 +69,9 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
for (int i = thread_id; i < num_clean_int_1; i += kNumThreads) for (int i = thread_id; i < num_clean_int_1; i += kNumThreads)
clean_1[i] = 0; clean_1[i] = 0;
// Barrier after cleaning (make sure low-latency mode work // Barrier after cleaning (make sure low-latency mode work
if (threadIdx.x == 0) if (threadIdx.x == 0)
rocshmem::rocshmem_barrier_all(); internode::shmem_device_barrier_all();
} }
void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
...@@ -97,13 +96,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -97,13 +96,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int64_t* next_clean, int num_next_clean_int, int64_t* next_clean, int num_next_clean_int,
int num_tokens, int num_max_dispatch_tokens_per_rank, int num_tokens, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
int num_warp_groups, int num_warps_per_group, int num_warp_groups, int num_warps_per_group,
bool round_scale, int phases) { bool round_scale, int phases) {
#if !defined(ROCM_DISABLE_CTX)
__shared__ rocshmem::rocshmem_ctx_t ctx;
rocshmem::rocshmem_wg_ctx_create(0, &ctx);
#endif
const auto sm_id = static_cast<int>(blockIdx.x); const auto sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x); const auto thread_id = static_cast<int>(threadIdx.x);
const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id(); const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
...@@ -132,17 +126,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -132,17 +126,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);
// 16 is the max possible number of warps in AMD GPUs
constexpr int kMaxNumWarps = 1024 / kWarpSize;
constexpr int num_sync_large_iteration = kMaxNumWarps ;
__shared__ volatile int sync_large_warp_counters[num_sync_large_iteration];
#pragma unroll
for (int i = thread_id; i < num_sync_large_iteration; i += blockDim.x) {
sync_large_warp_counters[i] = 0;
}
__syncthreads();
// Expert counts // Expert counts
constexpr int kNumMaxWarpGroups = 1024 / kWarpSize; constexpr int kNumMaxWarpGroups = 1024 / kWarpSize;
__shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups]; __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups];
...@@ -151,6 +134,11 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -151,6 +134,11 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
if ((phases & LOW_LATENCY_SEND_PHASE) == 0) if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
goto LOW_LATENCY_DISPATCH_RECV; goto LOW_LATENCY_DISPATCH_RECV;
#if !defined(ROCM_DISABLE_CTX)
__shared__ internode::shmem_ctx_t ctx;
internode::shmem_wg_ctx_create(&ctx);
#endif
// There are 2 kinds of warps in this part: // There are 2 kinds of warps in this part:
// 1. The first-kind warps for FP8 cast and sending top-k tokens // 1. The first-kind warps for FP8 cast and sending top-k tokens
// 2. The last warp for reading `topk_idx` and count for per-expert information // 2. The last warp for reading `topk_idx` and count for per-expert information
...@@ -221,9 +209,18 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -221,9 +209,18 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
slot_idx * num_bytes_per_msg; slot_idx * num_bytes_per_msg;
if (dst_rank != rank) { if (dst_rank != rank) {
rocshmem::rocshmem_schar_put_nbi_wave(reinterpret_cast<signed char*>(dst_ptr), #if !defined(ROCM_DISABLE_CTX)
reinterpret_cast<signed char*>(src_ptr), num_bytes_per_msg, dst_rank); internode::shmem_ctx_schar_put_nbi_warp(ctx,
rocshmem::rocshmem_fence(); #else
internode::shmemx_int8_put_nbi_warp(
#endif
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
num_bytes_per_msg, dst_rank);
// #if !defined(ROCM_DISABLE_CTX)
// internode::shmem_ctx_quiet(ctx);
// #else
// internode::shmem_fence();
// #endif
} else { } else {
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls // NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr); const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
...@@ -275,8 +272,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -275,8 +272,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
} }
} }
} }
//revert sync_large_warp_counters to 0 for next sync
__syncthreads(); __syncthreads();
// Issue count sends // Issue count sends
...@@ -288,7 +283,12 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -288,7 +283,12 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Wait local sends issued and send expert counts // Wait local sends issued and send expert counts
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2); while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
if (dst_rank != rank) { if (dst_rank != rank) {
rocshmem::rocshmem_long_atomic_add( rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank); #if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_long_atomic_add(ctx,
#else
internode::shmem_long_atomic_add(
#endif
rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank);
} else { } else {
st_na_release(reinterpret_cast<int *>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1); st_na_release(reinterpret_cast<int *>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1);
} }
...@@ -303,6 +303,10 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -303,6 +303,10 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
} }
syncwarp(); syncwarp();
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_wg_ctx_destroy(&ctx);
#endif
// Receiving phase // Receiving phase
LOW_LATENCY_DISPATCH_RECV: LOW_LATENCY_DISPATCH_RECV:
if ((phases & LOW_LATENCY_RECV_PHASE) == 0) if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
...@@ -313,20 +317,31 @@ LOW_LATENCY_DISPATCH_RECV: ...@@ -313,20 +317,31 @@ LOW_LATENCY_DISPATCH_RECV:
grid_barrier(global_atomic_counter, num_sms); grid_barrier(global_atomic_counter, num_sms);
} }
// 16 is the max possible number of warps in AMD GPUs
constexpr int kMaxNumWarps = 1024 / kWarpSize;
constexpr int num_sync_large_iteration = kMaxNumWarps ;
__shared__ volatile int sync_large_warp_counters[num_sync_large_iteration];
#pragma unroll
for (int i = thread_id; i < num_sync_large_iteration; i += blockDim.x) {
sync_large_warp_counters[i] = 0;
}
__syncthreads();
// Receiving and packing // Receiving and packing
if (responsible_expert_idx < num_experts) { if (responsible_expert_idx < num_experts) {
const auto src_rank = responsible_expert_idx / num_local_experts; const auto src_rank = responsible_expert_idx / num_local_experts;
const auto local_expert_idx = responsible_expert_idx % num_local_experts; const auto local_expert_idx = responsible_expert_idx % num_local_experts;
const auto rdma_recv_x_uint8 = reinterpret_cast<uint8_t*>(rdma_recv_x) + const auto rdma_recv_x_uint8 = reinterpret_cast<uint8_t*>(rdma_recv_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg; src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) + const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4; local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks; const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
const auto num_aligned_scales = ALIGN<int>(num_scales, sizeof(float) / sizeof(scale_t)); const auto num_aligned_scales = ALIGN<int>(num_scales, sizeof(float) / sizeof(scale_t));
const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) + const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales; local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales;
// Shared between sub-warps in warp groups // Shared between sub-warps in warp groups
__shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups]; __shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups];
...@@ -394,10 +409,6 @@ LOW_LATENCY_DISPATCH_RECV: ...@@ -394,10 +409,6 @@ LOW_LATENCY_DISPATCH_RECV:
} }
} }
} }
#if !defined(ROCM_DISABLE_CTX)
rocshmem::rocshmem_wg_ctx_destroy(&ctx);
#endif
} }
void dispatch(void* packed_recv_x, void* packed_recv_x_scales, void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...@@ -408,9 +419,9 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -408,9 +419,9 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
const void* x, const int64_t* topk_idx, const void* x, const int64_t* topk_idx,
int64_t* next_clean, int num_next_clean_int, int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_fp8, bool round_scale, bool use_ue8m0,
void* workspace, int num_device_sms, void* workspace, int num_device_sms,
hipStream_t stream, int phases) { hipStream_t stream, int phases) {
constexpr int kNumMaxTopK = 11; constexpr int kNumMaxTopK = 11;
const int num_warp_groups = ceil_div(num_experts, num_device_sms); const int num_warp_groups = ceil_div(num_experts, num_device_sms);
...@@ -465,11 +476,6 @@ combine(void* combined_x, ...@@ -465,11 +476,6 @@ combine(void* combined_x,
int num_experts, int rank, int num_ranks, int num_experts, int rank, int num_ranks,
int num_warp_groups, int num_warps_per_group, int num_warp_groups, int num_warps_per_group,
int phases, bool zero_copy) { int phases, bool zero_copy) {
#if !defined(ROCM_DISABLE_CTX)
__shared__ rocshmem::rocshmem_ctx_t ctx;
rocshmem::rocshmem_wg_ctx_create(0, &ctx);
#endif
const auto sm_id = static_cast<int>(blockIdx.x); const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_sms = static_cast<int>(gridDim.x); const auto num_sms = static_cast<int>(gridDim.x);
const auto thread_id = static_cast<int>(threadIdx.x); const auto thread_id = static_cast<int>(threadIdx.x);
...@@ -489,7 +495,7 @@ combine(void* combined_x, ...@@ -489,7 +495,7 @@ combine(void* combined_x,
constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(hip_bfloat16); constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(hip_bfloat16);
EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization"); EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
// 16 is the max possible number of warps in AMD GPUs // 16 is the max possible number of warps in AMD GPUs
constexpr int kMaxNumWarps = 1024 / kWarpSize; constexpr int kMaxNumWarps = 1024 / kWarpSize;
__shared__ volatile int sync_large_warp_counters[kMaxNumWarps]; __shared__ volatile int sync_large_warp_counters[kMaxNumWarps];
if (threadIdx.x==0){ if (threadIdx.x==0){
...@@ -504,6 +510,11 @@ combine(void* combined_x, ...@@ -504,6 +510,11 @@ combine(void* combined_x,
if ((phases & LOW_LATENCY_SEND_PHASE) == 0) if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
goto LOW_LATENCY_COMBINE_RECV; goto LOW_LATENCY_COMBINE_RECV;
#if !defined(ROCM_DISABLE_CTX)
__shared__ internode::shmem_ctx_t ctx;
internode::shmem_wg_ctx_create(&ctx);
#endif
// Clean up next buffer // Clean up next buffer
if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) { if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
#pragma unroll #pragma unroll
...@@ -523,10 +534,10 @@ combine(void* combined_x, ...@@ -523,10 +534,10 @@ combine(void* combined_x,
const auto global_expert_idx = rank * num_local_experts + local_expert_idx; const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank); const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
const auto local_x = reinterpret_cast<const int4*>(x) + const auto local_x = reinterpret_cast<const int4*>(x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4; local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) + const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot; local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
// Unpack layout // Unpack layout
int offset, num_tokens_to_send; int offset, num_tokens_to_send;
...@@ -539,7 +550,7 @@ combine(void* combined_x, ...@@ -539,7 +550,7 @@ combine(void* combined_x,
const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row + 4); const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row + 4);
// Copy directly to local rank, or copy to buffer and issue RDMA // Copy directly to local rank, or copy to buffer and issue RDMA
const auto src_idx = shfl_sync(__ldg(local_src_info + token_idx), 0); const auto src_idx = __ldg(local_src_info + token_idx);
const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row); const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot + sizeof(int4); const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot + sizeof(int4);
if (dst_rank == rank) { if (dst_rank == rank) {
...@@ -549,21 +560,16 @@ combine(void* combined_x, ...@@ -549,21 +560,16 @@ combine(void* combined_x,
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr); const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
if (not zero_copy) if (not zero_copy)
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global); UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
//nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(hip_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
#if defined(ROCM_DISABLE_CTX)
rocshmem::rocshmem_schar_put_nbi_wave(
#else
rocshmem::rocshmem_ctx_schar_put_nbi_wave(ctx,
#endif
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr), hidden * sizeof(hip_bfloat16), dst_rank);
#if defined(ROCM_DISABLE_CTX) //nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(hip_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
rocshmem::rocshmem_fence(); #if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_schar_put_nbi_warp(ctx,
#else #else
rocshmem::rocshmem_ctx_quiet(ctx); internode::shmemx_int8_put_nbi_warp(
#endif #endif
} reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr),
hidden * sizeof(hip_bfloat16), dst_rank);
}
} }
// Put finishing flag // Put finishing flag
...@@ -574,27 +580,49 @@ combine(void* combined_x, ...@@ -574,27 +580,49 @@ combine(void* combined_x,
} }
syncwarp(); syncwarp();
while (sync_large_warp_counters[warp_group_id] < num_warps_per_group); while (sync_large_warp_counters[warp_group_id] < num_warps_per_group);
if (sub_warp_id == 1 and lane_id == 0) { if (sub_warp_id == 1 and lane_id == 0) {
while (ld_acquire_global(atomic_clean_flag) == 0); while (ld_acquire_global(atomic_clean_flag) == 0);
if (dst_rank != rank) { if (dst_rank != rank) {
#if defined(ROCM_DISABLE_CTX) #if !defined(ROCM_DISABLE_CTX)
rocshmem::rocshmem_long_atomic_add(rdma_recv_flag + global_expert_idx, 1, dst_rank); internode::shmem_ctx_long_atomic_add(ctx,
#else #else
rocshmem::rocshmem_ctx_long_atomic_add(ctx, rdma_recv_flag + global_expert_idx, 1, dst_rank); internode::shmem_long_atomic_add(
#endif #endif
rdma_recv_flag + global_expert_idx, 1, dst_rank);
} else { } else {
st_na_release(reinterpret_cast<int*>(rdma_recv_flag + global_expert_idx), 1); st_na_release(reinterpret_cast<int*>(rdma_recv_flag + global_expert_idx), 1);
} }
atomic_add_release_global(atomic_clean_flag, -1); atomic_add_release_global(atomic_clean_flag, -1);
} }
syncwarp(); syncwarp();
if (num_ranks > 8){
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_quiet(ctx);
#else
internode::shmem_fence();
#endif
}
} }
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_wg_ctx_destroy(&ctx);
#endif
// Receiving phase // Receiving phase
LOW_LATENCY_COMBINE_RECV: LOW_LATENCY_COMBINE_RECV:
if ((phases & LOW_LATENCY_RECV_PHASE) == 0) if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
return; return;
// if (num_ranks > 8){
// #if !defined(ROCM_DISABLE_CTX)
// internode::shmem_ctx_quiet(ctx);
// #else
// internode::shmem_fence();
// #endif
// }
// Wait all ranks to arrive and notify PCIe usage // Wait all ranks to arrive and notify PCIe usage
if (responsible_expert_idx < num_experts) { if (responsible_expert_idx < num_experts) {
EP_DEVICE_ASSERT(num_warps_per_group > 1); EP_DEVICE_ASSERT(num_warps_per_group > 1);
...@@ -642,9 +670,6 @@ combine(void* combined_x, ...@@ -642,9 +670,6 @@ combine(void* combined_x,
(reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4; (reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
} }
} }
#if !defined(ROCM_DISABLE_CTX)
rocshmem::rocshmem_wg_ctx_destroy(&ctx);
#endif
} }
void combine(void* combined_x, void combine(void* combined_x,
......
...@@ -5,10 +5,8 @@ ...@@ -5,10 +5,8 @@
#include "exception.cuh" #include "exception.cuh"
#include "launch.cuh" #include "launch.cuh"
#include "utils.cuh" #include "utils.cuh"
#include "shmem_wrapper.cuh"
#ifndef DISABLE_ROCSHMEM
#include <rocshmem/rocshmem.hpp>
#endif
namespace deep_ep { namespace deep_ep {
namespace intranode { namespace intranode {
...@@ -33,60 +31,66 @@ void barrier(int **barrier_signal_ptrs, int rank, int num_ranks, hipStream_t str ...@@ -33,60 +31,66 @@ void barrier(int **barrier_signal_ptrs, int rank, int num_ranks, hipStream_t str
namespace internode { namespace internode {
#ifndef DISABLE_ROCSHMEM #ifndef DISABLE_ROCSHMEM
rocshmem::rocshmem_team_t cpu_rdma_team = rocshmem::ROCSHMEM_TEAM_INVALID; shmem_team_t cpu_rdma_team = EP_SHMEM_TEAM_INVALID;
rocshmem::rocshmem_team_config_t cpu_rdma_team_config; shmem_team_config_t cpu_rdma_team_config;
std::vector<uint8_t> get_unique_id() { std::vector<uint8_t> get_unique_id() {
rocshmem::rocshmem_uniqueid_t unique_id; shmemx_uniqueid_t unique_id;
rocshmem::rocshmem_get_uniqueid(&unique_id); shmemx_get_uniqueid(&unique_id);
std::vector<uint8_t> result(sizeof(rocshmem::rocshmem_uniqueid_t)); std::vector<uint8_t> result(sizeof(shmemx_uniqueid_t));
std::memcpy(result.data(), &unique_id, sizeof(rocshmem::rocshmem_uniqueid_t)); std::memcpy(result.data(), &unique_id, sizeof(shmemx_uniqueid_t));
return result; return result;
} }
int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks, int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks, bool low_latency_mode) {
bool low_latency_mode) { shmemx_uniqueid_t root_unique_id;
rocshmem::rocshmem_uniqueid_t root_unique_id; shmemx_init_attr_t attr;
rocshmem::rocshmem_init_attr_t attr; std::memcpy(&root_unique_id, root_unique_id_val.data(), sizeof(shmemx_uniqueid_t));
std::memcpy(&root_unique_id, root_unique_id_val.data(), sizeof(rocshmem::rocshmem_uniqueid_t)); shmemx_set_attr_uniqueid_args(rank, num_ranks, &root_unique_id, &attr);
rocshmem::rocshmem_set_attr_uniqueid_args(rank, num_ranks, &root_unique_id, &attr); shmemx_init_attr(EP_SHMEMX_INIT_WITH_UNIQUEID, &attr);
rocshmem::rocshmem_init_attr(rocshmem::ROCSHMEM_INIT_WITH_UNIQUEID, &attr);
// Create sub-RDMA teams // Create sub-RDMA teams
// NOTES: if `num_ranks <= NUM_MAX_NVL_PEERS` then only low-latency kernels are used // NOTES: if `num_ranks <= NUM_MAX_NVL_PEERS` then only low-latency kernels are used
if (low_latency_mode and num_ranks > NUM_MAX_NVL_PEERS) { if (low_latency_mode and num_ranks > NUM_MAX_NVL_PEERS) {
EP_HOST_ASSERT(cpu_rdma_team == rocshmem::ROCSHMEM_TEAM_INVALID); shmem_barrier_all();
EP_HOST_ASSERT(cpu_rdma_team == EP_SHMEM_TEAM_INVALID);
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0); EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(rocshmem::rocshmem_team_split_strided( EP_HOST_ASSERT(shmem_team_split_strided(
rocshmem::ROCSHMEM_TEAM_WORLD, rank % NUM_MAX_NVL_PEERS, EP_SHMEM_TEAM_WORLD, rank % NUM_MAX_NVL_PEERS,
NUM_MAX_NVL_PEERS, num_ranks / NUM_MAX_NVL_PEERS, NUM_MAX_NVL_PEERS, num_ranks / NUM_MAX_NVL_PEERS,
&cpu_rdma_team_config, 0, &cpu_rdma_team) == 0); &cpu_rdma_team_config, 0, &cpu_rdma_team) == 0);
EP_HOST_ASSERT(cpu_rdma_team != rocshmem::ROCSHMEM_TEAM_INVALID); EP_HOST_ASSERT(cpu_rdma_team != EP_SHMEM_TEAM_INVALID);
#ifdef FORCE_NVSHMEM_API
nvshmemi_device_host_state_t* dev_state_ptr = nullptr;
CUDA_CHECK(hipGetSymbolAddress(reinterpret_cast<void**>(&dev_state_ptr), nvshmemi_device_state_d));
bool ibgda_is_initialized = false;
CUDA_CHECK(hipMemcpy(&dev_state_ptr->ibgda_is_initialized, &ibgda_is_initialized, sizeof(bool), hipMemcpyHostToDevice));
#endif
} }
rocshmem::rocshmem_barrier_all(); shmem_barrier_all();
return rocshmem::rocshmem_my_pe(); return shmem_my_pe();
} }
void *alloc(size_t size, size_t alignment) { void *alloc(size_t size, size_t alignment) {
auto alloc_size = ALIGN(size, alignment); return shmem_align(size, alignment);
return rocshmem::rocshmem_malloc(alloc_size);
} }
void free(void *ptr) { void free(void *ptr) {
rocshmem::rocshmem_free(ptr); shmem_free(ptr);
} }
void barrier() { void barrier() {
rocshmem::rocshmem_barrier_all(); shmem_barrier_all();
} }
void finalize() { void finalize() {
if (cpu_rdma_team != rocshmem::ROCSHMEM_TEAM_INVALID) { if (cpu_rdma_team != EP_SHMEM_TEAM_INVALID) {
rocshmem::rocshmem_team_destroy(cpu_rdma_team); shmem_team_destroy(cpu_rdma_team);
cpu_rdma_team = rocshmem::ROCSHMEM_TEAM_INVALID; cpu_rdma_team = EP_SHMEM_TEAM_INVALID;
} }
rocshmem::rocshmem_finalize(); shmem_finalize();
} }
#endif #endif
......
#pragma once
/*
* Temporary wrapper for for platform specific NVSHMEM and rocSHMEM functions.
* Once hipify or hipify-torch fully supports this mapping, this file has to be
* removed and according nvshmem* functions restored.
*/
#ifndef DISABLE_ROCSHMEM
#include "configs.cuh"
#ifndef FORCE_NVSHMEM_API
#include <hip/hip_bfloat16.h>
#include <hip/hip_fp8.h>
#include <hip/hip_runtime.h>
#include <rocshmem/rocshmem.hpp>
#else
#include <device_host_transport/nvshmem_common_ibgda.h>
#include <infiniband/mlx5dv.h>
#include <nvshmem.h>
#include <nvshmemx.h>
#include <non_abi/device/threadgroup/nvshmemi_common_device_defines.cuh>
#endif
namespace deep_ep::internode {
// rocSHMEM wrapper
#ifndef FORCE_NVSHMEM_API
using shmem_team_t = rocshmem::rocshmem_team_t;
using shmem_team_config_t = rocshmem::rocshmem_team_config_t;
const shmem_team_t EP_SHMEM_TEAM_INVALID = rocshmem::ROCSHMEM_TEAM_INVALID;
inline shmem_team_t& EP_SHMEM_TEAM_WORLD = rocshmem::ROCSHMEM_TEAM_WORLD;
using shmemx_uniqueid_t = rocshmem::rocshmem_uniqueid_t;
using shmemx_init_attr_t = rocshmem::rocshmem_init_attr_t;
constexpr auto EP_SHMEMX_INIT_WITH_UNIQUEID = rocshmem::ROCSHMEM_INIT_WITH_UNIQUEID;
__host__ inline int shmemx_get_uniqueid(shmemx_uniqueid_t *uid) {
return rocshmem::rocshmem_get_uniqueid(uid);
}
__host__ inline int shmemx_set_attr_uniqueid_args(int rank, int nranks,
shmemx_uniqueid_t *uid, shmemx_init_attr_t *attr) {
return rocshmem::rocshmem_set_attr_uniqueid_args(rank, nranks, uid, attr);
}
__host__ inline int shmemx_init_attr(unsigned int flags, shmemx_init_attr_t *attr) {
return rocshmem::rocshmem_init_attr(flags, attr);
}
__host__ inline int shmem_team_split_strided(shmem_team_t parent_team,
int start, int stride, int size,
const shmem_team_config_t *config,
long config_mask, shmem_team_t *new_team) {
return rocshmem::rocshmem_team_split_strided(parent_team, start, stride, size, config, config_mask, new_team);
}
__host__ inline void shmem_barrier_all() {
rocshmem::rocshmem_barrier_all();
}
__device__ inline void shmem_device_barrier_all() {
rocshmem::rocshmem_barrier_all();
}
__device__ inline void shmem_barrier(shmem_team_t team) {
rocshmem::rocshmem_ctx_barrier(rocshmem::ROCSHMEM_CTX_DEFAULT, team);
}
__host__ inline int shmem_my_pe(){
return rocshmem::rocshmem_my_pe();
}
__host__ inline void shmem_free(void *ptr){
rocshmem::rocshmem_free(ptr);
}
__host__ inline void* shmem_align(const size_t alignment, const size_t size) {
auto alloc_size = ALIGN(size, alignment);
return rocshmem::rocshmem_malloc(alloc_size);
}
__host__ inline void shmem_finalize() {
rocshmem::rocshmem_finalize();
}
__host__ inline void shmem_team_destroy(shmem_team_t team) {
rocshmem::rocshmem_team_destroy(team);
}
__device__ inline void shmem_fence() {
rocshmem::rocshmem_fence();
}
__device__ inline void shmem_int_put_nbi(
int *dest, const int *source, size_t nelems, int pe) {
rocshmem::rocshmem_int_put_nbi(dest, source, nelems, pe);
}
__device__ inline void shmemx_int_put_nbi_warp(
int *dest, const int *source, size_t nelems, int pe) {
rocshmem::rocshmem_int_put_nbi_wave(dest, source, nelems, pe);
}
__device__ inline void shmemx_int8_put_nbi_warp(
signed char *dest, const signed char *source, size_t nelems, int pe) {
rocshmem::rocshmem_schar_put_nbi_wave(dest, source, nelems, pe);
}
__device__ inline void shmem_long_atomic_add(
long *dest, long value, int pe) {
rocshmem::rocshmem_long_atomic_add(dest, value, pe);
}
#if !defined(ROCM_DISABLE_CTX)
using shmem_ctx_t = rocshmem::rocshmem_ctx_t;
__device__ inline int shmem_wg_ctx_create(shmem_ctx_t *ctx) {
return rocshmem::rocshmem_wg_ctx_create(0, ctx);
}
__device__ inline void shmem_wg_ctx_destroy(shmem_ctx_t *ctx) {
rocshmem::rocshmem_wg_ctx_destroy(ctx);
}
__device__ inline void shmem_ctx_quiet(shmem_ctx_t ctx) {
rocshmem::rocshmem_ctx_quiet(ctx);
}
__device__ inline void shmem_ctx_ulong_atomic_add(
shmem_ctx_t ctx, uint64_t *dest, uint64_t value, int pe) {
rocshmem::rocshmem_ctx_ulong_atomic_add(ctx, dest, value, pe);
}
__device__ inline void shmem_ctx_long_atomic_add(
shmem_ctx_t ctx, long *dest, long value, int pe) {
rocshmem::rocshmem_ctx_long_atomic_add(ctx, dest, value, pe);
}
__device__ inline void shmem_ctx_schar_put_nbi_warp(
shmem_ctx_t ctx, signed char *dest, const signed char *source, size_t nelems, int pe) {
rocshmem::rocshmem_ctx_schar_put_nbi_wave(ctx, dest, source, nelems, pe);
}
__device__ inline void shmem_ctx_int_put_nbi_warp(
shmem_ctx_t ctx, int *dest, const int *source, size_t nelems, int pe) {
rocshmem::rocshmem_ctx_int_put_nbi_wave(ctx, dest, source, nelems, pe);
}
#endif
#else
// NVSHMEM wrapper
#ifndef ROCM_DISABLE_CTX
#define ROCM_DISABLE_CTX
#endif
using shmem_team_t = nvshmem_team_t;
using shmem_team_config_t = nvshmem_team_config_t;
using shmemx_uniqueid_t = nvshmemx_uniqueid_t;
using shmemx_init_attr_t = nvshmemx_init_attr_t;
const shmem_team_t EP_SHMEM_TEAM_INVALID = NVSHMEM_TEAM_INVALID;
const shmem_team_t EP_SHMEM_TEAM_WORLD = NVSHMEM_TEAM_WORLD;
constexpr auto EP_SHMEMX_INIT_WITH_UNIQUEID = NVSHMEMX_INIT_WITH_UNIQUEID;
__host__ inline int shmemx_get_uniqueid(shmemx_uniqueid_t *uid) {
return nvshmemx_get_uniqueid(uid);
}
__host__ inline int shmemx_set_attr_uniqueid_args(int rank, int nranks,
shmemx_uniqueid_t *uid, shmemx_init_attr_t *attr) {
return nvshmemx_set_attr_uniqueid_args(rank, nranks, uid, attr);
}
__host__ inline int shmemx_init_attr(unsigned int flags, shmemx_init_attr_t *attr) {
return nvshmemx_init_attr(flags, attr);
}
__host__ inline int shmem_team_split_strided(shmem_team_t parent_team,
int start, int stride, int size,
const shmem_team_config_t *config,
long config_mask, shmem_team_t *new_team) {
return nvshmem_team_split_strided(parent_team, start, stride, size, config, config_mask, new_team);
}
__host__ inline void shmem_barrier_all() {
nvshmem_barrier_all();
}
__device__ inline void shmem_device_barrier_all() {
nvshmem_barrier_all();
}
__device__ inline void shmem_barrier(shmem_team_t team) {
void(nvshmem_barrier(team));
}
__host__ inline int shmem_my_pe(){
return nvshmem_my_pe();
}
__host__ inline void shmem_free(void *ptr){
nvshmem_free(ptr);
}
__host__ inline void* shmem_align(const size_t alignment, const size_t size) {
return nvshmem_align(size, alignment);
}
__host__ inline void shmem_finalize() {
nvshmem_finalize();
}
__host__ inline void shmem_team_destroy(shmem_team_t team) {
nvshmem_team_destroy(team);
}
__device__ inline void shmem_fence() {
nvshmem_fence();
}
__device__ inline void shmem_int_put_nbi(
int *dest, const int *source, size_t nelems, int pe) {
nvshmem_int_put_nbi(dest, source, nelems, pe);
}
__device__ inline void shmemx_int_put_nbi_warp(
int *dest, const int *source, size_t nelems, int pe) {
nvshmemx_int_put_nbi_warp(dest, source, nelems, pe);
}
__device__ inline void shmemx_int8_put_nbi_warp(
signed char *dest, const signed char *source, size_t nelems, int pe) {
nvshmemx_int8_put_nbi_warp(dest, source, nelems, pe);
}
__device__ inline void shmem_signal_op_add(
uint64_t *dest, uint64_t value, int pe) {
nvshmemx_signal_op(dest, value, NVSHMEM_SIGNAL_ADD, pe);
}
__device__ inline void shmem_ulong_atomic_add(
uint64_t *dest, uint64_t value, int pe) {
nvshmem_ulong_atomic_add(dest, value, pe);
}
__device__ inline void shmem_long_atomic_add(
long *dest, long value, int pe) {
// nvshmem_##Name##_atomic_add(dest, value, pe);
nvshmem_long_atomic_add(dest, value, pe);
}
#endif
} // namespace deep_ep::internode
#endif
...@@ -342,7 +342,7 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) { ...@@ -342,7 +342,7 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
return *reinterpret_cast<dtype_t *>(recv_int_values); return *reinterpret_cast<dtype_t *>(recv_int_values);
} }
#ifdef USE_ROCM #ifndef FORCE_NVSHMEM_API
constexpr float kFP8Margin = 1e-4; constexpr float kFP8Margin = 1e-4;
constexpr float kFinfoAmaxE4M3 = 240.0f; constexpr float kFinfoAmaxE4M3 = 240.0f;
constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3; constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment