Commit 7906c308 authored by lijian6's avatar lijian6
Browse files

Fix compile.


Signed-off-by: lijian6's avatarlijian <lijian6@sugon.com>
parent 4e06dc48
...@@ -80,10 +80,10 @@ if [ "$USE_ROCSHMEM" == "ON" ]; then ...@@ -80,10 +80,10 @@ if [ "$USE_ROCSHMEM" == "ON" ]; then
build_rocshmem build_rocshmem
SHMEM_INSTALL_PREFIX=$(pwd)/third-party/rocshmem_install SHMEM_INSTALL_PREFIX=$(pwd)/third-party/rocshmem_install
COMPILE_OPTIONS=${COMPILE_OPTIONS:= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 -Wno-return-type}
if [ "$ROCM_DISABLE_CTX" == "ON" ]; then if [ "$ROCM_DISABLE_CTX" == "ON" ]; then
COMPILE_OPTIONS=${COMPILE_OPTIONS:= -fPIC -DROCM_DISABLE_CTX -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 -Wno-return-type} COMPILE_OPTIONS="-DROCM_DISABLE_CTX $COMPILE_OPTIONS"
else fi
COMPILE_OPTIONS=${COMPILE_OPTIONS:= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 -Wno-return-type}
SHMEM_LINK_OPTIONS=${SHMEM_LINK_OPTIONS:="-Wl,-rpath,${SHMEM_INSTALL_PREFIX}/lib/ -l:librocshmem.a"} SHMEM_LINK_OPTIONS=${SHMEM_LINK_OPTIONS:="-Wl,-rpath,${SHMEM_INSTALL_PREFIX}/lib/ -l:librocshmem.a"}
fi fi
# -------------------------- rocSHMEM END -------------------------- # # -------------------------- rocSHMEM END -------------------------- #
...@@ -134,7 +134,7 @@ hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode.cu -o build_/internode.o ...@@ -134,7 +134,7 @@ hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode.cu -o build_/internode.o
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode_ll.cu -o build_/internode_ll.o ${COMPILE_OPTIONS} hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode_ll.cu -o build_/internode_ll.o ${COMPILE_OPTIONS}
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/deep_ep.cu -o build_/deep_ep.o ${COMPILE_OPTIONS} hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/deep_ep.cu -o build_/deep_ep.o ${COMPILE_OPTIONS}
hipcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -shared -Wl,-O1 -Wl,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o build_/internode_ll.o -L${SHMEM_INSTALL_PREFIX}/lib/ -L/opt/mpi/lib -L/opt/dtk/hip/lib -L/usr/lib/x86_64-linux-gnu -lhipblaslt -lamdhip64 -o deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,/opt/dtk/lib -fgpu-rdc --hip-link --offload-arch=gfx936 -shared -Wl,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -L"${llvm_path}/include/../lib/linux" -lclang_rt.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so ${llvm_path}/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0 -L${PYTHON_PLATLIB}/torch/lib -L/opt/dtk/lib -L/opt/dtk/hip/lib -L/usr/local/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lamdhip64 -lc10_hip -ltorch_hip -lrocm-core -lrocm_smi64 ${SHMEM_LINK_OPTIONS} -fgpu-rdc --hip-link -lamdhip64 -lhsa-runtime64 -l:libmpi.so -Wl,-rpath,/opt/mpi/lib/ -libverbs -lmlx5 hipcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -shared -Wl,-O1 -Wl,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o build_/internode_ll.o -L${SHMEM_INSTALL_PREFIX}/lib/ -L/opt/mpi/lib -L/opt/dtk/hip/lib -L/usr/lib/x86_64-linux-gnu -lhipblaslt -lamdhip64 -o deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,/opt/dtk/lib -fgpu-rdc --hip-link --offload-arch=gfx936 -shared -Wl,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -L"${llvm_path}/include/../lib/linux" -lclang_rt.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so ${llvm_path}/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so -L${PYTHON_PLATLIB}/torch/lib -L/opt/dtk/lib -L/opt/dtk/hip/lib -L/usr/local/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lamdhip64 -lc10_hip -ltorch_hip -lrocm-core -lrocm_smi64 ${SHMEM_LINK_OPTIONS} -fgpu-rdc --hip-link -lamdhip64 -lhsa-runtime64 -l:libmpi.so -Wl,-rpath,/opt/mpi/lib/ -libverbs -lmlx5
# build whl # build whl
echo "Using Python: $(which python3)" echo "Using Python: $(which python3)"
......
...@@ -399,7 +399,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv ...@@ -399,7 +399,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
kForwarderCoordinator, // 向远端RDMA确认接收 kForwarderCoordinator, // 向远端RDMA确认接收
kNVLReceivers // 从nvl缓存写入到recv_x kNVLReceivers // 从nvl缓存写入到recv_x
}; };
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX) #if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
__shared__ shmem_ctx_t ctx; __shared__ shmem_ctx_t ctx;
shmem_wg_ctx_create(&ctx); shmem_wg_ctx_create(&ctx);
#endif #endif
...@@ -516,7 +516,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv ...@@ -516,7 +516,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
syncwarp(); syncwarp();
if (dst_rdma_rank != rdma_rank) { if (dst_rdma_rank != rdma_rank) {
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX) #if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_int_put_nbi_warp(ctx, shmem_ctx_int_put_nbi_warp(ctx,
#else #else
shmemx_int_put_nbi_warp( shmemx_int_put_nbi_warp(
...@@ -527,7 +527,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv ...@@ -527,7 +527,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
} }
} }
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX) #if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_quiet(ctx); shmem_ctx_quiet(ctx);
#else #else
shmem_fence(); shmem_fence();
...@@ -741,7 +741,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv ...@@ -741,7 +741,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
if(dst_rdma_rank != rdma_rank) { if(dst_rdma_rank != rdma_rank) {
auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens; auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens); EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX) #if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_schar_put_nbi_warp(ctx, shmem_ctx_schar_put_nbi_warp(ctx,
#else #else
shmemx_int8_put_nbi_warp( shmemx_int8_put_nbi_warp(
...@@ -752,7 +752,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv ...@@ -752,7 +752,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
dst_slot_idx * num_bytes_per_rdma_token, dst_slot_idx * num_bytes_per_rdma_token,
num_bytes_per_rdma_token * num_tokens_to_issue, num_bytes_per_rdma_token * num_tokens_to_issue,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank)); translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX) #if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_quiet(ctx); shmem_ctx_quiet(ctx);
#else #else
shmem_fence(); shmem_fence();
...@@ -768,7 +768,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv ...@@ -768,7 +768,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
last_issued_tail += num_tokens_to_issue; last_issued_tail += num_tokens_to_issue;
num_tokens_to_send -= num_tokens_to_issue; num_tokens_to_send -= num_tokens_to_issue;
// 更新远端rdma 己方已发送的token数,用于做发送信息同步。用于与kRDMAAndNVLForwarder互相通信 // 更新远端rdma 己方已发送的token数,用于做发送信息同步。用于与kRDMAAndNVLForwarder互相通信
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX) #if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add(ctx, shmem_ctx_ulong_atomic_add(ctx,
#else #else
shmem_signal_op_add( shmem_signal_op_add(
...@@ -1008,7 +1008,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv ...@@ -1008,7 +1008,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
// 更新远程头部 // 更新远程头部
if(min_head != std::numeric_limits<int>::max() && min_head >= last_head + num_max_rdma_chunked_send_tokens && lane_id < kNumRDMARanks){ if(min_head != std::numeric_limits<int>::max() && min_head >= last_head + num_max_rdma_chunked_send_tokens && lane_id < kNumRDMARanks){
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX) #if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add(ctx, shmem_ctx_ulong_atomic_add(ctx,
#else #else
shmem_signal_op_add( shmem_signal_op_add(
...@@ -1127,7 +1127,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv ...@@ -1127,7 +1127,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
} }
} // while(num_tokens_to_recv > 0) } // while(num_tokens_to_recv > 0)
} }
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX) #if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_wg_ctx_destroy(&ctx); shmem_wg_ctx_destroy(&ctx);
#endif #endif
} }
...@@ -1417,7 +1417,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1417,7 +1417,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
kRDMACoordinator, kRDMACoordinator,
kNVLCoordinator kNVLCoordinator
}; };
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX) #if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
__shared__ shmem_ctx_t ctx; __shared__ shmem_ctx_t ctx;
shmem_wg_ctx_create(&ctx); shmem_wg_ctx_create(&ctx);
#endif #endif
...@@ -1744,7 +1744,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1744,7 +1744,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
if(sub_warp_id == kNumWarpsPerForwarder - 1) { if(sub_warp_id == kNumWarpsPerForwarder - 1) {
if(dst_rdma_rank != rdma_rank) { if(dst_rdma_rank != rdma_rank) {
auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens; auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX) #if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_schar_put_nbi_warp(ctx, shmem_ctx_schar_put_nbi_warp(ctx,
#else #else
shmemx_int8_put_nbi_warp( shmemx_int8_put_nbi_warp(
...@@ -1755,7 +1755,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1755,7 +1755,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
rdma_slot_idx * num_bytes_per_rdma_token, rdma_slot_idx * num_bytes_per_rdma_token,
num_chunked_tokens * num_bytes_per_rdma_token, num_chunked_tokens * num_bytes_per_rdma_token,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank)); translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX) #if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_quiet(ctx); shmem_ctx_quiet(ctx);
#else #else
shmem_fence(); shmem_fence();
...@@ -1767,7 +1767,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1767,7 +1767,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
// Write new RDMA tail // Write new RDMA tail
syncwarp(); syncwarp();
if(lane_id == 0) { if(lane_id == 0) {
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX) #if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add(ctx, shmem_ctx_ulong_atomic_add(ctx,
#else #else
shmem_signal_op_add( shmem_signal_op_add(
...@@ -1900,7 +1900,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1900,7 +1900,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]); min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX) #if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add(ctx, shmem_ctx_ulong_atomic_add(ctx,
#else #else
shmem_signal_op_add( shmem_signal_op_add(
...@@ -1917,7 +1917,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1917,7 +1917,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
} }
} }
} }
#if !defined(FORCE_NVSHMEM_API) || !defined(ROCM_DISABLE_CTX) #if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_wg_ctx_destroy(&ctx); shmem_wg_ctx_destroy(&ctx);
#endif #endif
} }
......
...@@ -107,7 +107,7 @@ __device__ inline void shmemx_int8_put_nbi_warp( ...@@ -107,7 +107,7 @@ __device__ inline void shmemx_int8_put_nbi_warp(
rocshmem::rocshmem_schar_put_nbi_wave(dest, source, nelems, pe); rocshmem::rocshmem_schar_put_nbi_wave(dest, source, nelems, pe);
} }
__device__ inline void shmem_signal_op_add(uint64_t *dest, uint64_t value, int pe) {} __device__ inline void shmem_signal_op_add(uint64_t *dest, uint64_t value, int pe) {
rocshmem::rocshmem_ulong_atomic_add(dest, value, pe); rocshmem::rocshmem_ulong_atomic_add(dest, value, pe);
} }
......
rocshmem @ 98e9363f
Subproject commit 56da12324dbc61989ead10f3063c6f6fdbfacaae Subproject commit 98e9363fa748000cf3010997e50cd21bcdcde342
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment