Commit ee3551ab authored by lishen's avatar lishen
Browse files

修改为兼容rocSHMEM和nvSHMEM的代码

parent e18f726a
......@@ -8,10 +8,21 @@ fi
PYTHON_INCLUDE=$(python3 -c "from sysconfig import get_paths; print(get_paths()['include'])")
PYTHON_PLATLIB=$(python3 -c "from sysconfig import get_paths; print(get_paths()['platlib'])")
# --------------------------------------------------------------------- #
USE_NVSHMEM=${USE_NVSHMEM:=OFF}
ROCSHMEM_INSTALL_PREFIX=${ROCSHMEM_INSTALL_PREFIX:=$(pwd)/rocshmem_dir}
COMPILE_OPTIONS=${COMPILE_OPTIONS:= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 -Wno-return-type}
SHMEM_LINK_OPTIONS=${SHMEM_LINK_OPTIONS:="-Wl,-rpath,${ROCSHMEM_INSTALL_PREFIX}/lib/ -l:librocshmem.a"}
####
# 检查是否设置了USE_NVSHMEM环境变量
if [ "$USE_NVSHMEM" == "ON" ]; then
COMPILE_OPTIONS+=" -DFORCE_NVSHMEM_API"
ROCSHMEM_INSTALL_PREFIX=???/dushmem_dir
SHMEM_LINK_OPTIONS="-Wl,-rpath,${ROCSHMEM_INSTALL_PREFIX}/lib/ -l:libnvshmem_device.a -lnvshmem_host"
fi
INCLUDE_PATHS=${INCLUDE_PATHS:=-Icsrc/ -I${ROCSHMEM_INSTALL_PREFIX}/include/ -I/opt/mpi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE}}
COMPILE_OPTIONS=${COMPILE_OPTIONS:= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 -Wno-return-type}
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/runtime.cu -o build_/runtime.o ${COMPILE_OPTIONS}
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/layout.cu -o build_/layout.o ${COMPILE_OPTIONS}
......@@ -20,7 +31,7 @@ hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode.cu -o build_/internode.o
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode_ll.cu -o build_/internode_ll.o ${COMPILE_OPTIONS}
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/deep_ep.cu -o build_/deep_ep.o ${COMPILE_OPTIONS}
hipcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -shared -Wl,-O1 -Wl,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o build_/internode_ll.o -L${ROCSHMEM_INSTALL_PREFIX}/lib/ -L/opt/mpi/lib -L/opt/dtk/hip/lib -L/usr/lib/x86_64-linux-gnu -lhipblaslt -lamdhip64 -o deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,/opt/dtk/lib -fgpu-rdc --hip-link --offload-arch=gfx936 -shared -Wl,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,$${ROCSHMEM_INSTALL_PREFIX}/lib/ -L"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux" -lclang_rt.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0 -L${PYTHON_PLATLIB}/torch/lib -L/opt/dtk/lib -L/opt/dtk/hip/lib -L/usr/local/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lamdhip64 -lc10_hip -ltorch_hip -lrocm-core -lrocm_smi64 -l:librocshmem.a -fgpu-rdc --hip-link -lamdhip64 -lhsa-runtime64 -l:libmpi.so -Wl,-rpath,/opt/mpi/lib/ -libverbs -lmlx5
hipcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -shared -Wl,-O1 -Wl,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o build_/internode_ll.o -L${ROCSHMEM_INSTALL_PREFIX}/lib/ -L/opt/mpi/lib -L/opt/dtk/hip/lib -L/usr/lib/x86_64-linux-gnu -lhipblaslt -lamdhip64 -o deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,/opt/dtk/lib -fgpu-rdc --hip-link --offload-arch=gfx936 -shared -Wl,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -L"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux" -lclang_rt.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0 -L${PYTHON_PLATLIB}/torch/lib -L/opt/dtk/lib -L/opt/dtk/hip/lib -L/usr/local/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lamdhip64 -lc10_hip -ltorch_hip -lrocm-core -lrocm_smi64 ${SHMEM_LINK_OPTIONS} -fgpu-rdc --hip-link -lamdhip64 -lhsa-runtime64 -l:libmpi.so -Wl,-rpath,/opt/mpi/lib/ -libverbs -lmlx5
# build whl
echo "Using Python: $(which python3)"
......
......@@ -3,11 +3,10 @@
#include "configs.cuh"
#include "launch.cuh"
#include "utils.cuh"
#include "shmem_wrapper.cuh"
#ifndef DISABLE_ROCSHMEM
#include <rocshmem/rocshmem.hpp>
// TODO: fix unroll warnings
// #ifdef __clang__
// #pragma clang diagnostic push
......@@ -19,7 +18,7 @@ namespace deep_ep {
namespace internode {
extern rocshmem::rocshmem_team_t cpu_rdma_team;
extern shmem_team_t cpu_rdma_team;
struct SourceMeta {
int src_rdma_rank, is_token_in_nvl_rank_bits;
......@@ -51,9 +50,8 @@ __host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_
int num_topk_idx,
int num_topk_weights) {
return static_cast<int>(ALIGN(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) +
num_scales * sizeof(float) + num_topk_idx * sizeof(int) +
num_topk_weights * sizeof(float),
sizeof(int4)));
num_scales * sizeof(float) + num_topk_idx * sizeof(int) +
num_topk_weights * sizeof(float), sizeof(int4)));
}
__host__ __device__ __forceinline__ std::pair<int, int>
......@@ -61,9 +59,8 @@ get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_t
int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_sms) {
// Return `int32_t` offset and count to clean
return {(get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) *
num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_sms) /
sizeof(int),
(NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_sms};
num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_sms) / sizeof(int),
(NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_sms};
}
__host__ __device__ __forceinline__ std::pair<int, int>
get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
......@@ -74,10 +71,9 @@ get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_to
"Invalid size of `SourceMeta`");
return {
(num_nvl_recv_buffer_tokens *
(hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) +
num_topk_weights * sizeof(float) + sizeof(SourceMeta)) *
num_nvl_ranks * num_sms) /
sizeof(int),
(hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) +
num_topk_weights * sizeof(float) + sizeof(SourceMeta)) *
num_nvl_ranks * num_sms) / sizeof(int),
num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_sms,
};
}
......@@ -90,12 +86,10 @@ __forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank,
template <bool kLowLatencyMode>
__forceinline__ __device__ void
nvshmem_barrier_with_same_gpu_idx(const rocshmem::rocshmem_team_t &rdma_team) {
nvshmem_barrier_with_same_gpu_idx(const shmem_team_t &rdma_team) {
// NOTE: shmem_device_barrier_all() might be an issue as
// it doesn't follow OpenSHMEM specification on ROCm
kLowLatencyMode
? void(rocshmem::rocshmem_ctx_barrier(rocshmem::ROCSHMEM_CTX_DEFAULT, rdma_team))
: rocshmem::rocshmem_barrier_all();
kLowLatencyMode ? shmem_barrier(rdma_team) : shmem_device_barrier_all();
}
template <bool kLowLatencyMode, int kNumRDMARanks>
......@@ -109,7 +103,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
int *rdma_channel_prefix_matrix, int *recv_rdma_rank_prefix_sum,
int *gbl_channel_prefix_matrix, int *recv_gbl_rank_prefix_sum,
void *rdma_buffer_ptr, void **buffer_ptrs, int **barrier_signal_ptrs, int rank,
const rocshmem::rocshmem_team_t rdma_team) {
const shmem_team_t rdma_team) {
auto sm_id = static_cast<int>(blockIdx.x);
auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize,
lane_id = get_lane_id();
......@@ -159,7 +153,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
// TODO: more light fence or barrier or signaling
// TODO: overlap EP barrier and NVL cleaning
if (thread_id < kNumRDMARanks) {
rocshmem::rocshmem_int_put_nbi(
shmem_int_put_nbi(
rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank),
rdma_recv_num_tokens_mixed.send_buffer(thread_id),
NUM_MAX_NVL_PEERS + num_rdma_experts + 1,
......@@ -405,9 +399,10 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
kForwarderCoordinator, // 向远端RDMA确认接收
kNVLReceivers // 从nvl缓存写入到recv_x
};
__shared__ rocshmem::rocshmem_ctx_t ctx;
rocshmem::rocshmem_wg_ctx_create(0, &ctx);
#ifndef FORCE_NVSHMEM_API
__shared__ shmem_ctx_t ctx;
shmem_wg_ctx_create(&ctx);
#endif
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
......@@ -521,13 +516,23 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
syncwarp();
if (dst_rdma_rank != rdma_rank) {
rocshmem::rocshmem_ctx_int_put_nbi_wave(
ctx, rdma_channel_meta.recv_buffer(rdma_rank),
#ifndef FORCE_NVSHMEM_API
shmem_ctx_int_put_nbi_warp(ctx,
#else
shmemx_int_put_nbi_warp(
#endif
rdma_channel_meta.recv_buffer(rdma_rank),
rdma_channel_meta.send_buffer(dst_rdma_rank), NUM_MAX_NVL_PEERS * 2 + 2,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
}
rocshmem::rocshmem_ctx_quiet(ctx);
#ifndef FORCE_NVSHMEM_API
shmem_ctx_quiet(ctx);
#else
shmem_fence();
#endif
// sync_rdma_sender_smem();
__syncthreads();
......@@ -736,15 +741,22 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
if(dst_rdma_rank != rdma_rank) {
auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);
rocshmem::rocshmem_ctx_schar_put_nbi_wave(
ctx,
#ifndef FORCE_NVSHMEM_API
shmem_ctx_schar_put_nbi_warp(ctx,
#else
shmemx_int8_put_nbi_warp(
#endif
rdma_channel_data.recv_buffer(rdma_rank) +
dst_slot_idx * num_bytes_per_rdma_token,
rdma_channel_data.send_buffer(dst_rdma_rank) +
dst_slot_idx * num_bytes_per_rdma_token,
num_bytes_per_rdma_token * num_tokens_to_issue,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
rocshmem::rocshmem_ctx_quiet(ctx);
#ifndef FORCE_NVSHMEM_API
shmem_ctx_quiet(ctx);
#else
shmem_fence();
#endif
} else {
// 对于本地RDMA秩,使用较轻的内存屏障
memory_fence();
......@@ -756,8 +768,12 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
last_issued_tail += num_tokens_to_issue;
num_tokens_to_send -= num_tokens_to_issue;
// 更新远端rdma 己方已发送的token数,用于做发送信息同步。用于与kRDMAAndNVLForwarder互相通信
rocshmem::rocshmem_ctx_ulong_atomic_add(
ctx, rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue,
#ifndef FORCE_NVSHMEM_API
shmem_ctx_ulong_atomic_add(ctx,
#else
shmem_signal_op_add(
#endif
rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
}
......@@ -992,8 +1008,12 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
// 更新远程头部
if(min_head != std::numeric_limits<int>::max() && min_head >= last_head + num_max_rdma_chunked_send_tokens && lane_id < kNumRDMARanks){
rocshmem::rocshmem_ctx_ulong_atomic_add(
ctx, rdma_channel_head.buffer(rdma_rank), min_head - last_head,
#ifndef FORCE_NVSHMEM_API
shmem_ctx_ulong_atomic_add(ctx,
#else
shmem_signal_op_add(
#endif
rdma_channel_head.buffer(rdma_rank), min_head - last_head,
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
last_head = min_head;
}
......@@ -1107,7 +1127,9 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
}
} // while(num_tokens_to_recv > 0)
}
rocshmem::rocshmem_wg_ctx_destroy(&ctx);
#ifndef FORCE_NVSHMEM_API
shmem_wg_ctx_destroy(&ctx);
#endif
}
void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights,
......@@ -1166,7 +1188,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
int num_channels, const int *rdma_channel_prefix_matrix,
const int *rdma_rank_prefix_sum, int *combined_nvl_head, void *rdma_buffer_ptr,
void **buffer_ptrs, int **barrier_signal_ptrs, int rank, int num_ranks,
bool is_cached_dispatch, const rocshmem::rocshmem_team_t rdma_team) {
bool is_cached_dispatch, const shmem_team_t rdma_team) {
auto sm_id = static_cast<int>(blockIdx.x);
auto thread_id = static_cast<int>(threadIdx.x);
auto num_threads = static_cast<int>(blockDim.x);
......@@ -1189,7 +1211,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
auto rdma_buffer_ptr_int = reinterpret_cast<int *>(rdma_buffer_ptr);
for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)
rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;
rocshmem::rocshmem_fence();
shmem_fence();
__syncthreads();
// Barrier again
......@@ -1395,9 +1417,10 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
kRDMACoordinator,
kNVLCoordinator
};
__shared__ rocshmem::rocshmem_ctx_t ctx;
rocshmem::rocshmem_wg_ctx_create(0, &ctx);
#ifndef FORCE_NVSHMEM_API
__shared__ shmem_ctx_t ctx;
shmem_wg_ctx_create(&ctx);
#endif
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
......@@ -1721,16 +1744,22 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
if(sub_warp_id == kNumWarpsPerForwarder - 1) {
if(dst_rdma_rank != rdma_rank) {
auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
rocshmem::rocshmem_ctx_schar_put_nbi_wave(
ctx,
#ifndef FORCE_NVSHMEM_API
shmem_ctx_schar_put_nbi_warp(ctx,
#else
shmemx_int8_put_nbi_warp(
#endif
rdma_channel_data.recv_buffer(rdma_rank) +
rdma_slot_idx * num_bytes_per_rdma_token,
rdma_channel_data.send_buffer(dst_rdma_rank) +
rdma_slot_idx * num_bytes_per_rdma_token,
num_chunked_tokens * num_bytes_per_rdma_token,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
rocshmem::rocshmem_ctx_quiet(ctx);
#ifndef FORCE_NVSHMEM_API
shmem_ctx_quiet(ctx);
#else
shmem_fence();
#endif
} else {
memory_fence();
}
......@@ -1738,8 +1767,12 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
// Write new RDMA tail
syncwarp();
if(lane_id == 0) {
rocshmem::rocshmem_ctx_ulong_atomic_add(
ctx, rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens,
#ifndef FORCE_NVSHMEM_API
shmem_ctx_ulong_atomic_add(ctx,
#else
shmem_signal_op_add(
#endif
rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
}
......@@ -1867,8 +1900,12 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
rocshmem::rocshmem_ctx_ulong_atomic_add(
ctx, rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head,
#ifndef FORCE_NVSHMEM_API
shmem_ctx_ulong_atomic_add(ctx,
#else
shmem_signal_op_add(
#endif
rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
last_rdma_head = min_head;
......@@ -1880,7 +1917,9 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
}
}
}
rocshmem::rocshmem_wg_ctx_destroy(&ctx);
#ifndef FORCE_NVSHMEM_API
shmem_wg_ctx_destroy(&ctx);
#endif
}
void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
......
......@@ -11,9 +11,8 @@
// low latency+RocSHMEM has issue with CTX.
#define ROCM_DISABLE_CTX
#include <rocshmem/rocshmem.hpp>
#include "shmem_wrapper.cuh"
using namespace rocshmem;
namespace deep_ep {
namespace internode_ll {
......@@ -59,7 +58,7 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
int64_t* clean_1, int num_clean_int_1) {
// Barrier before cleaning (in case of unfinished chunked EP)
if (threadIdx.x == 0)
rocshmem::rocshmem_barrier_all();
internode::shmem_device_barrier_all();
// Clean
auto thread_id = static_cast<int>(threadIdx.x);
......@@ -72,7 +71,7 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
// Barrier after cleaning (make sure low-latency mode work
if (threadIdx.x == 0)
rocshmem::rocshmem_barrier_all();
internode::shmem_device_barrier_all();
}
void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
......@@ -100,8 +99,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int num_warp_groups, int num_warps_per_group,
bool round_scale, int phases) {
#if !defined(ROCM_DISABLE_CTX)
__shared__ rocshmem::rocshmem_ctx_t ctx;
rocshmem::rocshmem_wg_ctx_create(0, &ctx);
__shared__ internode::shmem_ctx_t ctx;
internode::shmem_wg_ctx_create(&ctx);
#endif
const auto sm_id = static_cast<int>(blockIdx.x);
......@@ -221,9 +220,9 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
slot_idx * num_bytes_per_msg;
if (dst_rank != rank) {
rocshmem::rocshmem_schar_put_nbi_wave(reinterpret_cast<signed char*>(dst_ptr),
internode::shmemx_int8_put_nbi_warp(reinterpret_cast<signed char*>(dst_ptr),
reinterpret_cast<signed char*>(src_ptr), num_bytes_per_msg, dst_rank);
rocshmem::rocshmem_fence();
internode::shmem_fence();
} else {
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
......@@ -288,7 +287,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Wait local sends issued and send expert counts
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
if (dst_rank != rank) {
rocshmem::rocshmem_long_atomic_add( rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank);
internode::shmem_long_atomic_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank);
} else {
st_na_release(reinterpret_cast<int *>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1);
}
......@@ -396,7 +395,7 @@ LOW_LATENCY_DISPATCH_RECV:
}
#if !defined(ROCM_DISABLE_CTX)
rocshmem::rocshmem_wg_ctx_destroy(&ctx);
internode::shmem_wg_ctx_destroy(&ctx);
#endif
}
......@@ -467,8 +466,8 @@ combine(void* combined_x,
int phases, bool zero_copy) {
#if !defined(ROCM_DISABLE_CTX)
__shared__ rocshmem::rocshmem_ctx_t ctx;
rocshmem::rocshmem_wg_ctx_create(0, &ctx);
__shared__ internode::shmem_ctx_t ctx;
internode::shmem_wg_ctx_create(&ctx);
#endif
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_sms = static_cast<int>(gridDim.x);
......@@ -539,7 +538,7 @@ combine(void* combined_x,
const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row + 4);
// Copy directly to local rank, or copy to buffer and issue RDMA
const auto src_idx = shfl_sync(__ldg(local_src_info + token_idx), 0);
const auto src_idx = __ldg(local_src_info + token_idx);
const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot + sizeof(int4);
if (dst_rank == rank) {
......@@ -552,16 +551,16 @@ combine(void* combined_x,
//nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(hip_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
#if defined(ROCM_DISABLE_CTX)
rocshmem::rocshmem_schar_put_nbi_wave(
internode::shmemx_int8_put_nbi_warp(
#else
rocshmem::rocshmem_ctx_schar_put_nbi_wave(ctx,
internode::shmem_ctx_schar_put_nbi_warp(ctx,
#endif
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr), hidden * sizeof(hip_bfloat16), dst_rank);
#if defined(ROCM_DISABLE_CTX)
rocshmem::rocshmem_fence();
internode::shmem_fence();
#else
rocshmem::rocshmem_ctx_quiet(ctx);
internode::shmem_ctx_quiet(ctx);
#endif
}
}
......@@ -578,9 +577,9 @@ combine(void* combined_x,
while (ld_acquire_global(atomic_clean_flag) == 0);
if (dst_rank != rank) {
#if defined(ROCM_DISABLE_CTX)
rocshmem::rocshmem_long_atomic_add(rdma_recv_flag + global_expert_idx, 1, dst_rank);
internode::shmem_long_atomic_add(rdma_recv_flag + global_expert_idx, 1, dst_rank);
#else
rocshmem::rocshmem_ctx_long_atomic_add(ctx, rdma_recv_flag + global_expert_idx, 1, dst_rank);
internode::shmem_ctx_long_atomic_add(ctx, rdma_recv_flag + global_expert_idx, 1, dst_rank);
#endif
} else {
st_na_release(reinterpret_cast<int*>(rdma_recv_flag + global_expert_idx), 1);
......@@ -643,7 +642,7 @@ combine(void* combined_x,
}
}
#if !defined(ROCM_DISABLE_CTX)
rocshmem::rocshmem_wg_ctx_destroy(&ctx);
internode::shmem_wg_ctx_destroy(&ctx);
#endif
}
......
......@@ -5,10 +5,8 @@
#include "exception.cuh"
#include "launch.cuh"
#include "utils.cuh"
#include "shmem_wrapper.cuh"
#ifndef DISABLE_ROCSHMEM
#include <rocshmem/rocshmem.hpp>
#endif
namespace deep_ep {
namespace intranode {
......@@ -33,60 +31,66 @@ void barrier(int **barrier_signal_ptrs, int rank, int num_ranks, hipStream_t str
namespace internode {
#ifndef DISABLE_ROCSHMEM
rocshmem::rocshmem_team_t cpu_rdma_team = rocshmem::ROCSHMEM_TEAM_INVALID;
rocshmem::rocshmem_team_config_t cpu_rdma_team_config;
shmem_team_t cpu_rdma_team = EP_SHMEM_TEAM_INVALID;
shmem_team_config_t cpu_rdma_team_config;
std::vector<uint8_t> get_unique_id() {
rocshmem::rocshmem_uniqueid_t unique_id;
rocshmem::rocshmem_get_uniqueid(&unique_id);
std::vector<uint8_t> result(sizeof(rocshmem::rocshmem_uniqueid_t));
std::memcpy(result.data(), &unique_id, sizeof(rocshmem::rocshmem_uniqueid_t));
shmemx_uniqueid_t unique_id;
shmemx_get_uniqueid(&unique_id);
std::vector<uint8_t> result(sizeof(shmemx_uniqueid_t));
std::memcpy(result.data(), &unique_id, sizeof(shmemx_uniqueid_t));
return result;
}
int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks,
bool low_latency_mode) {
rocshmem::rocshmem_uniqueid_t root_unique_id;
rocshmem::rocshmem_init_attr_t attr;
std::memcpy(&root_unique_id, root_unique_id_val.data(), sizeof(rocshmem::rocshmem_uniqueid_t));
rocshmem::rocshmem_set_attr_uniqueid_args(rank, num_ranks, &root_unique_id, &attr);
rocshmem::rocshmem_init_attr(rocshmem::ROCSHMEM_INIT_WITH_UNIQUEID, &attr);
int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks, bool low_latency_mode) {
shmemx_uniqueid_t root_unique_id;
shmemx_init_attr_t attr;
std::memcpy(&root_unique_id, root_unique_id_val.data(), sizeof(shmemx_uniqueid_t));
shmemx_set_attr_uniqueid_args(rank, num_ranks, &root_unique_id, &attr);
shmemx_init_attr(EP_SHMEMX_INIT_WITH_UNIQUEID, &attr);
// Create sub-RDMA teams
// NOTES: if `num_ranks <= NUM_MAX_NVL_PEERS` then only low-latency kernels are used
if (low_latency_mode and num_ranks > NUM_MAX_NVL_PEERS) {
EP_HOST_ASSERT(cpu_rdma_team == rocshmem::ROCSHMEM_TEAM_INVALID);
shmem_barrier_all();
EP_HOST_ASSERT(cpu_rdma_team == EP_SHMEM_TEAM_INVALID);
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(rocshmem::rocshmem_team_split_strided(
rocshmem::ROCSHMEM_TEAM_WORLD, rank % NUM_MAX_NVL_PEERS,
EP_HOST_ASSERT(shmem_team_split_strided(
EP_SHMEM_TEAM_WORLD, rank % NUM_MAX_NVL_PEERS,
NUM_MAX_NVL_PEERS, num_ranks / NUM_MAX_NVL_PEERS,
&cpu_rdma_team_config, 0, &cpu_rdma_team) == 0);
EP_HOST_ASSERT(cpu_rdma_team != rocshmem::ROCSHMEM_TEAM_INVALID);
EP_HOST_ASSERT(cpu_rdma_team != EP_SHMEM_TEAM_INVALID);
#ifdef FORCE_NVSHMEM_API
nvshmemi_device_host_state_t* dev_state_ptr = nullptr;
CUDA_CHECK(hipGetSymbolAddress(reinterpret_cast<void**>(&dev_state_ptr), nvshmemi_device_state_d));
bool ibgda_is_initialized = false;
CUDA_CHECK(hipMemcpy(&dev_state_ptr->ibgda_is_initialized, &ibgda_is_initialized, sizeof(bool), hipMemcpyHostToDevice));
#endif
}
rocshmem::rocshmem_barrier_all();
return rocshmem::rocshmem_my_pe();
shmem_barrier_all();
return shmem_my_pe();
}
void *alloc(size_t size, size_t alignment) {
auto alloc_size = ALIGN(size, alignment);
return rocshmem::rocshmem_malloc(alloc_size);
return shmem_align(size, alignment);
}
void free(void *ptr) {
rocshmem::rocshmem_free(ptr);
shmem_free(ptr);
}
void barrier() {
rocshmem::rocshmem_barrier_all();
shmem_barrier_all();
}
void finalize() {
if (cpu_rdma_team != rocshmem::ROCSHMEM_TEAM_INVALID) {
rocshmem::rocshmem_team_destroy(cpu_rdma_team);
cpu_rdma_team = rocshmem::ROCSHMEM_TEAM_INVALID;
if (cpu_rdma_team != EP_SHMEM_TEAM_INVALID) {
shmem_team_destroy(cpu_rdma_team);
cpu_rdma_team = EP_SHMEM_TEAM_INVALID;
}
rocshmem::rocshmem_finalize();
shmem_finalize();
}
#endif
......
#pragma once
/*
* Temporary wrapper for for platform specific NVSHMEM and rocSHMEM functions.
* Once hipify or hipify-torch fully supports this mapping, this file has to be
* removed and according nvshmem* functions restored.
*/
#ifndef DISABLE_ROCSHMEM
#include "configs.cuh"
#ifndef FORCE_NVSHMEM_API
#include <hip/hip_bfloat16.h>
#include <hip/hip_fp8.h>
#include <hip/hip_runtime.h>
#include <rocshmem/rocshmem.hpp>
#else
#include <device_host_transport/nvshmem_common_ibgda.h>
#include <infiniband/mlx5dv.h>
#include <nvshmem.h>
#include <nvshmemx.h>
#include <non_abi/device/threadgroup/nvshmemi_common_device_defines.cuh>
#endif
namespace deep_ep::internode {
// rocSHMEM wrapper
#ifndef FORCE_NVSHMEM_API
using shmem_team_t = rocshmem::rocshmem_team_t;
using shmem_team_config_t = rocshmem::rocshmem_team_config_t;
const shmem_team_t EP_SHMEM_TEAM_INVALID = rocshmem::ROCSHMEM_TEAM_INVALID;
inline shmem_team_t& EP_SHMEM_TEAM_WORLD = rocshmem::ROCSHMEM_TEAM_WORLD;
using shmemx_uniqueid_t = rocshmem::rocshmem_uniqueid_t;
using shmemx_init_attr_t = rocshmem::rocshmem_init_attr_t;
constexpr auto EP_SHMEMX_INIT_WITH_UNIQUEID = rocshmem::ROCSHMEM_INIT_WITH_UNIQUEID;
__host__ inline int shmemx_get_uniqueid(shmemx_uniqueid_t *uid) {
return rocshmem::rocshmem_get_uniqueid(uid);
}
__host__ inline int shmemx_set_attr_uniqueid_args(int rank, int nranks,
shmemx_uniqueid_t *uid, shmemx_init_attr_t *attr) {
return rocshmem::rocshmem_set_attr_uniqueid_args(rank, nranks, uid, attr);
}
__host__ inline int shmemx_init_attr(unsigned int flags, shmemx_init_attr_t *attr) {
return rocshmem::rocshmem_init_attr(flags, attr);
}
__host__ inline int shmem_team_split_strided(shmem_team_t parent_team,
int start, int stride, int size,
const shmem_team_config_t *config,
long config_mask, shmem_team_t *new_team) {
return rocshmem::rocshmem_team_split_strided(parent_team, start, stride, size, config, config_mask, new_team);
}
__host__ inline void shmem_barrier_all() {
rocshmem::rocshmem_barrier_all();
}
__device__ inline void shmem_device_barrier_all() {
rocshmem::rocshmem_barrier_all();
}
__device__ inline void shmem_barrier(shmem_team_t team) {
rocshmem::rocshmem_ctx_barrier(rocshmem::ROCSHMEM_CTX_DEFAULT, team);
}
__host__ inline int shmem_my_pe(){
return rocshmem::rocshmem_my_pe();
}
__host__ inline void shmem_free(void *ptr){
rocshmem::rocshmem_free(ptr);
}
__host__ inline void* shmem_align(const size_t alignment, const size_t size) {
auto alloc_size = ALIGN(size, alignment);
return rocshmem::rocshmem_malloc(alloc_size);
}
__host__ inline void shmem_finalize() {
rocshmem::rocshmem_finalize();
}
__host__ inline void shmem_team_destroy(shmem_team_t team) {
rocshmem::rocshmem_team_destroy(team);
}
__device__ inline void shmem_fence() {
rocshmem::rocshmem_fence();
}
__device__ inline void shmem_int_put_nbi(
int *dest, const int *source, size_t nelems, int pe) {
rocshmem::rocshmem_int_put_nbi(dest, source, nelems, pe);
}
__device__ inline void shmemx_int_put_nbi_warp(
int *dest, const int *source, size_t nelems, int pe) {
rocshmem::rocshmem_int_put_nbi_wave(dest, source, nelems, pe);
}
__device__ inline void shmemx_int8_put_nbi_warp(
signed char *dest, const signed char *source, size_t nelems, int pe) {
rocshmem::rocshmem_schar_put_nbi_wave(dest, source, nelems, pe);
}
__device__ inline void shmem_long_atomic_add(
long *dest, long value, int pe) {
rocshmem::rocshmem_long_atomic_add(dest, value, pe);
}
#if !defined(ROCM_DISABLE_CTX)
using shmem_ctx_t = rocshmem::rocshmem_ctx_t;
__device__ inline int shmem_wg_ctx_create(shmem_ctx_t *ctx) {
return rocshmem::rocshmem_wg_ctx_create(0, ctx);
}
__device__ inline void shmem_wg_ctx_destroy(shmem_ctx_t *ctx) {
rocshmem::rocshmem_wg_ctx_destroy(ctx);
}
__device__ inline void shmem_ctx_quiet(shmem_ctx_t ctx) {
rocshmem::rocshmem_ctx_quiet(ctx);
}
__device__ inline void shmem_ctx_ulong_atomic_add(
shmem_ctx_t ctx, uint64_t *dest, uint64_t value, int pe) {
rocshmem::rocshmem_ctx_ulong_atomic_add(ctx, dest, value, pe);
}
__device__ inline void shmem_ctx_long_atomic_add(
shmem_ctx_t ctx, long *dest, long value, int pe) {
rocshmem::rocshmem_ctx_long_atomic_add(ctx, dest, value, pe);
}
__device__ inline void shmem_ctx_schar_put_nbi_warp(
shmem_ctx_t ctx, signed char *dest, const signed char *source, size_t nelems, int pe) {
rocshmem::rocshmem_ctx_schar_put_nbi_wave(ctx, dest, source, nelems, pe);
}
__device__ inline void shmem_ctx_int_put_nbi_warp(
shmem_ctx_t ctx, int *dest, const int *source, size_t nelems, int pe) {
rocshmem::rocshmem_ctx_int_put_nbi_wave(ctx, dest, source, nelems, pe);
}
#endif
#else
// NVSHMEM wrapper
#ifndef ROCM_DISABLE_CTX
#define ROCM_DISABLE_CTX
#endif
using shmem_team_t = nvshmem_team_t;
using shmem_team_config_t = nvshmem_team_config_t;
using shmemx_uniqueid_t = nvshmemx_uniqueid_t;
using shmemx_init_attr_t = nvshmemx_init_attr_t;
const shmem_team_t EP_SHMEM_TEAM_INVALID = NVSHMEM_TEAM_INVALID;
const shmem_team_t EP_SHMEM_TEAM_WORLD = NVSHMEM_TEAM_WORLD;
constexpr auto EP_SHMEMX_INIT_WITH_UNIQUEID = NVSHMEMX_INIT_WITH_UNIQUEID;
__host__ inline int shmemx_get_uniqueid(shmemx_uniqueid_t *uid) {
return nvshmemx_get_uniqueid(uid);
}
__host__ inline int shmemx_set_attr_uniqueid_args(int rank, int nranks,
shmemx_uniqueid_t *uid, shmemx_init_attr_t *attr) {
return nvshmemx_set_attr_uniqueid_args(rank, nranks, uid, attr);
}
__host__ inline int shmemx_init_attr(unsigned int flags, shmemx_init_attr_t *attr) {
return nvshmemx_init_attr(flags, attr);
}
__host__ inline int shmem_team_split_strided(shmem_team_t parent_team,
int start, int stride, int size,
const shmem_team_config_t *config,
long config_mask, shmem_team_t *new_team) {
return nvshmem_team_split_strided(parent_team, start, stride, size, config, config_mask, new_team);
}
__host__ inline void shmem_barrier_all() {
nvshmem_barrier_all();
}
__device__ inline void shmem_device_barrier_all() {
nvshmem_barrier_all();
}
__device__ inline void shmem_barrier(shmem_team_t team) {
void(nvshmem_barrier(team));
}
__host__ inline int shmem_my_pe(){
return nvshmem_my_pe();
}
__host__ inline void shmem_free(void *ptr){
nvshmem_free(ptr);
}
__host__ inline void* shmem_align(const size_t alignment, const size_t size) {
return nvshmem_align(size, alignment);
}
__host__ inline void shmem_finalize() {
nvshmem_finalize();
}
__host__ inline void shmem_team_destroy(shmem_team_t team) {
nvshmem_team_destroy(team);
}
__device__ inline void shmem_fence() {
nvshmem_fence();
}
__device__ inline void shmem_int_put_nbi(
int *dest, const int *source, size_t nelems, int pe) {
nvshmem_int_put_nbi(dest, source, nelems, pe);
}
__device__ inline void shmemx_int_put_nbi_warp(
int *dest, const int *source, size_t nelems, int pe) {
nvshmemx_int_put_nbi_warp(dest, source, nelems, pe);
}
__device__ inline void shmemx_int8_put_nbi_warp(
signed char *dest, const signed char *source, size_t nelems, int pe) {
nvshmemx_int8_put_nbi_warp(dest, source, nelems, pe);
}
__device__ inline void shmem_signal_op_add(
uint64_t *dest, uint64_t value, int pe) {
nvshmemx_signal_op(dest, value, NVSHMEM_SIGNAL_ADD, pe);
}
__device__ inline void shmem_ulong_atomic_add(
uint64_t *dest, uint64_t value, int pe) {
nvshmem_ulong_atomic_add(dest, value, pe);
}
__device__ inline void shmem_long_atomic_add(
long *dest, long value, int pe) {
// nvshmem_##Name##_atomic_add(dest, value, pe);
nvshmem_long_atomic_add(dest, value, pe);
}
#endif
} // namespace deep_ep::internode
#endif
......@@ -342,7 +342,7 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
return *reinterpret_cast<dtype_t *>(recv_int_values);
}
#ifdef USE_ROCM
#ifndef FORCE_NVSHMEM_API
constexpr float kFP8Margin = 1e-4;
constexpr float kFinfoAmaxE4M3 = 240.0f;
constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment