Commit 4e06dc48 authored by lijian6's avatar lijian6
Browse files

1. Add env for internode ctx.


2. Fix rocshmem internode hang on 508.
Signed-off-by: lijian6's avatarlijian <lijian6@sugon.com>
parent eeaf98b0
......@@ -30,6 +30,7 @@ PYTHON_PLATLIB=$(python3 -c "from sysconfig import get_paths; print(get_paths()[
USE_NVSHMEM=OFF
USE_ROCSHMEM=OFF
ROCM_DISABLE_CTX=OFF
case "$1" in
rocshmem)
USE_ROCSHMEM=ON
......@@ -38,12 +39,16 @@ case "$1" in
USE_NVSHMEM=ON
;;
*)
echo "Usage: $0 [rocshmem|nvshmem]"
echo "Usage: ./build.sh rocshmem [ROCM_DISABLE_CTX] / ./build.sh nvshmem"
exit 1
;;
esac
if [ "${2:-}" = "ROCM_DISABLE_CTX" ]; then
ROCM_DISABLE_CTX=ON
fi
echo "USE_NVSHMEM=$USE_NVSHMEM"
echo "USE_ROCSHMEM=$USE_ROCSHMEM"
echo "ROCM_DISABLE_CTX=$ROCM_DISABLE_CTX"
# -------------------------- With rocSHMEM -------------------------- #
build_rocshmem()
......@@ -75,7 +80,10 @@ if [ "$USE_ROCSHMEM" == "ON" ]; then
build_rocshmem
SHMEM_INSTALL_PREFIX=$(pwd)/third-party/rocshmem_install
COMPILE_OPTIONS=${COMPILE_OPTIONS:= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 -Wno-return-type}
if [ "$ROCM_DISABLE_CTX" == "ON" ]; then
COMPILE_OPTIONS=${COMPILE_OPTIONS:= -fPIC -DROCM_DISABLE_CTX -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 -Wno-return-type}
else
COMPILE_OPTIONS=${COMPILE_OPTIONS:= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 -Wno-return-type}
SHMEM_LINK_OPTIONS=${SHMEM_LINK_OPTIONS:="-Wl,-rpath,${SHMEM_INSTALL_PREFIX}/lib/ -l:librocshmem.a"}
fi
# -------------------------- rocSHMEM END -------------------------- #
......@@ -112,7 +120,7 @@ if [ "$USE_NVSHMEM" == "ON" ]; then
fi
build_dushmem
SHMEM_INSTALL_PREFIX=$(pwd)/third-party/dushmem_install
COMPILE_OPTIONS=${COMPILE_OPTIONS:= -fPIC -DFORCE_NVSHMEM_API -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 -Wno-return-type}
COMPILE_OPTIONS=${COMPILE_OPTIONS:= -fPIC -DFORCE_NVSHMEM_API -DHIP_ENABLE_WARP_SYNC_BUILTINS -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 -Wno-return-type}
SHMEM_LINK_OPTIONS="-Wl,-rpath,${SHMEM_INSTALL_PREFIX}/lib/ -l:libnvshmem_device.a -lnvshmem_host"
fi
# -------------------------- duSHMEM END -------------------------- #
......
......@@ -399,7 +399,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
kForwarderCoordinator, // 向远端RDMA确认接收
kNVLReceivers // 从nvl缓存写入到recv_x
};
#ifndef FORCE_NVSHMEM_API
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX)
__shared__ shmem_ctx_t ctx;
shmem_wg_ctx_create(&ctx);
#endif
......@@ -516,7 +516,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
syncwarp();
if (dst_rdma_rank != rdma_rank) {
#ifndef FORCE_NVSHMEM_API
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX)
shmem_ctx_int_put_nbi_warp(ctx,
#else
shmemx_int_put_nbi_warp(
......@@ -527,7 +527,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
}
}
#ifndef FORCE_NVSHMEM_API
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX)
shmem_ctx_quiet(ctx);
#else
shmem_fence();
......@@ -741,7 +741,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
if(dst_rdma_rank != rdma_rank) {
auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);
#ifndef FORCE_NVSHMEM_API
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX)
shmem_ctx_schar_put_nbi_warp(ctx,
#else
shmemx_int8_put_nbi_warp(
......@@ -752,7 +752,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
dst_slot_idx * num_bytes_per_rdma_token,
num_bytes_per_rdma_token * num_tokens_to_issue,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
#ifndef FORCE_NVSHMEM_API
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX)
shmem_ctx_quiet(ctx);
#else
shmem_fence();
......@@ -768,7 +768,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
last_issued_tail += num_tokens_to_issue;
num_tokens_to_send -= num_tokens_to_issue;
// 更新远端rdma 己方已发送的token数,用于做发送信息同步。用于与kRDMAAndNVLForwarder互相通信
#ifndef FORCE_NVSHMEM_API
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add(ctx,
#else
shmem_signal_op_add(
......@@ -1008,7 +1008,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
// 更新远程头部
if(min_head != std::numeric_limits<int>::max() && min_head >= last_head + num_max_rdma_chunked_send_tokens && lane_id < kNumRDMARanks){
#ifndef FORCE_NVSHMEM_API
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add(ctx,
#else
shmem_signal_op_add(
......@@ -1127,7 +1127,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
}
} // while(num_tokens_to_recv > 0)
}
#ifndef FORCE_NVSHMEM_API
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX)
shmem_wg_ctx_destroy(&ctx);
#endif
}
......@@ -1417,7 +1417,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
kRDMACoordinator,
kNVLCoordinator
};
#ifndef FORCE_NVSHMEM_API
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX)
__shared__ shmem_ctx_t ctx;
shmem_wg_ctx_create(&ctx);
#endif
......@@ -1744,7 +1744,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
if(sub_warp_id == kNumWarpsPerForwarder - 1) {
if(dst_rdma_rank != rdma_rank) {
auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
#ifndef FORCE_NVSHMEM_API
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX)
shmem_ctx_schar_put_nbi_warp(ctx,
#else
shmemx_int8_put_nbi_warp(
......@@ -1755,7 +1755,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
rdma_slot_idx * num_bytes_per_rdma_token,
num_chunked_tokens * num_bytes_per_rdma_token,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
#ifndef FORCE_NVSHMEM_API
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX)
shmem_ctx_quiet(ctx);
#else
shmem_fence();
......@@ -1767,7 +1767,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
// Write new RDMA tail
syncwarp();
if(lane_id == 0) {
#ifndef FORCE_NVSHMEM_API
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add(ctx,
#else
shmem_signal_op_add(
......@@ -1900,7 +1900,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
#ifndef FORCE_NVSHMEM_API
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add(ctx,
#else
shmem_signal_op_add(
......@@ -1917,7 +1917,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
}
}
}
#ifndef FORCE_NVSHMEM_API
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX)
shmem_wg_ctx_destroy(&ctx);
#endif
}
......
......@@ -107,6 +107,10 @@ __device__ inline void shmemx_int8_put_nbi_warp(
rocshmem::rocshmem_schar_put_nbi_wave(dest, source, nelems, pe);
}
__device__ inline void shmem_signal_op_add(uint64_t *dest, uint64_t value, int pe) {}
rocshmem::rocshmem_ulong_atomic_add(dest, value, pe);
}
__device__ inline void shmem_long_atomic_add(
long *dest, long value, int pe) {
rocshmem::rocshmem_long_atomic_add(dest, value, pe);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment