Commit a1382ed7 authored by lishen's avatar lishen
Browse files

接入ROCSHMEM的multiqp优化

parent 314d9021
......@@ -31,7 +31,7 @@ PYTHON_PLATLIB=$(python3 -c "from sysconfig import get_paths; print(get_paths()[
USE_NVSHMEM=OFF
USE_ROCSHMEM=OFF
ROCM_DISABLE_CTX=OFF
ROCM_USE_MULTIQP=OFF
ROCM_DISABLE_MULTIQP=OFF
# 解析命令行参数
for arg in "$@"; do
case $arg in
......@@ -44,20 +44,33 @@ for arg in "$@"; do
ROCM_DISABLE_CTX=ON)
ROCM_DISABLE_CTX=ON
;;
ROCM_USE_MULTIQP=ON)
ROCM_USE_MULTIQP=ON
ROCM_DISABLE_MULTIQP=ON)
ROCM_DISABLE_MULTIQP=ON
;;
*)
echo "Usage: ./build.sh rocshmem [ROCM_DISABLE_CTX=ON] [ROCM_USE_MULTIQP=ON] / ./build.sh dushmem"
echo "Usage: ./build.sh rocshmem [ROCM_DISABLE_CTX=ON] [ROCM_DISABLE_MULTIQP=ON] / ./build.sh dushmem"
exit 1
;;
esac
done
detect_offload_arch() {
# 尝试使用 rocm_agent_enumerator 获取所有 gfx 架构,并按字典序降序取第一个(即“最新”)
if command -v rocm_agent_enumerator >/dev/null 2>&1; then
arch=$(rocm_agent_enumerator 2>/dev/null | grep -E '^gfx[0-9]+' | sort -r | head -n1)
if [ -n "$arch" ]; then
echo "$arch"
return 0
fi
fi
}
DETECTED_ARCH=$(detect_offload_arch)
echo "Using --offload-arch=$DETECTED_ARCH"
echo "USE_NVSHMEM=$USE_NVSHMEM"
echo "USE_ROCSHMEM=$USE_ROCSHMEM"
echo "ROCM_DISABLE_CTX=$ROCM_DISABLE_CTX"
echo "ROCM_USE_MULTIQP=$ROCM_USE_MULTIQP"
echo "ROCM_DISABLE_MULTIQP=$ROCM_DISABLE_MULTIQP"
# -------------------------- With rocSHMEM -------------------------- #
build_rocshmem()
......@@ -72,7 +85,7 @@ build_rocshmem()
return 1
}
echo "cd third-party/rocshmem/build"
../scripts/build_configs/gda_mlx5
bash ../scripts/build_configs/gda_mlx5
echo "编译rocshmem成功"
cd "$src_path"
}
......@@ -89,12 +102,12 @@ if [ "$USE_ROCSHMEM" == "ON" ]; then
build_rocshmem
SHMEM_INSTALL_PREFIX=$(pwd)/third-party/rocshmem_install
COMPILE_OPTIONS=${COMPILE_OPTIONS:= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 --offload-arch=gfx938 -std=c++17 -Wno-return-type}
COMPILE_OPTIONS=${COMPILE_OPTIONS:= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=$DETECTED_ARCH -std=c++17 -Wno-return-type}
if [ "$ROCM_DISABLE_CTX" == "ON" ]; then
COMPILE_OPTIONS="-DROCM_DISABLE_CTX $COMPILE_OPTIONS"
fi
if [ "$ROCM_USE_MULTIQP" == "ON" ]; then
COMPILE_OPTIONS="-DROCM_USE_MULTIQP $COMPILE_OPTIONS"
if [ "$ROCM_DISABLE_MULTIQP" == "ON" ]; then
COMPILE_OPTIONS="-DROCM_DISABLE_MULTIQP $COMPILE_OPTIONS"
fi
SHMEM_LINK_OPTIONS=${SHMEM_LINK_OPTIONS:="-Wl,-rpath,${SHMEM_INSTALL_PREFIX}/lib/ -l:librocshmem.a"}
fi
......@@ -133,7 +146,7 @@ if [ "$USE_NVSHMEM" == "ON" ]; then
# build_dushmem
# SHMEM_INSTALL_PREFIX=$(pwd)/third-party/dushmem_install
SHMEM_INSTALL_PREFIX=${ROCM_PATH}/dushmem
COMPILE_OPTIONS=${COMPILE_OPTIONS:= -fPIC -DFORCE_DUSHMEM_API -DHIP_ENABLE_WARP_SYNC_BUILTINS -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 --offload-arch=gfx938 -std=c++17 -Wno-return-type}
COMPILE_OPTIONS=${COMPILE_OPTIONS:= -fPIC -DFORCE_DUSHMEM_API -DHIP_ENABLE_WARP_SYNC_BUILTINS -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=$DETECTED_ARCH -std=c++17 -Wno-return-type}
SHMEM_LINK_OPTIONS="-Wl,-rpath,${SHMEM_INSTALL_PREFIX}/lib/ -l:libdushmem_device.a -ldushmem_host"
fi
# -------------------------- duSHMEM END -------------------------- #
......@@ -147,7 +160,7 @@ hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode.cu -o build_/internode.o
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode_ll.cu -o build_/internode_ll.o ${COMPILE_OPTIONS}
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/deep_ep.cu -o build_/deep_ep.o ${COMPILE_OPTIONS}
hipcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -shared -Wl,-O1 -Wl,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o build_/internode_ll.o -L${SHMEM_INSTALL_PREFIX}/lib/ -L/opt/mpi/lib -L/opt/dtk/hip/lib -L/usr/lib/x86_64-linux-gnu -lhipblaslt -lamdhip64 -o deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,/opt/dtk/lib -fgpu-rdc --hip-link --offload-arch=gfx936 --offload-arch=gfx938 -shared -Wl,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -L"${llvm_path}/include/../lib/linux" -lclang_rt.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so ${llvm_path}/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so -L${PYTHON_PLATLIB}/torch/lib -L/opt/dtk/lib -L/opt/dtk/hip/lib -L/usr/local/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lamdhip64 -lc10_hip -ltorch_hip -lrocm-core -lrocm_smi64 ${SHMEM_LINK_OPTIONS} -fgpu-rdc --hip-link -lamdhip64 -lhsa-runtime64 -l:libmpi.so -Wl,-rpath,/opt/mpi/lib/ -libverbs -lmlx5
hipcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -shared -Wl,-O1 -Wl,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o build_/internode_ll.o -L${SHMEM_INSTALL_PREFIX}/lib/ -L/opt/mpi/lib -L/opt/dtk/hip/lib -L/usr/lib/x86_64-linux-gnu -lhipblaslt -lamdhip64 -o deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,/opt/dtk/lib -fgpu-rdc --hip-link --offload-arch=$DETECTED_ARCH -shared -Wl,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -L"${llvm_path}/include/../lib/linux" -lclang_rt.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so ${llvm_path}/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so -L${PYTHON_PLATLIB}/torch/lib -L/opt/dtk/lib -L/opt/dtk/hip/lib -L/usr/local/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lamdhip64 -lc10_hip -ltorch_hip -lrocm-core -lrocm_smi64 ${SHMEM_LINK_OPTIONS} -fgpu-rdc --hip-link -lamdhip64 -lhsa-runtime64 -l:libmpi.so -Wl,-rpath,/opt/mpi/lib/ -libverbs -lmlx5
# build whl
echo "Using Python: $(which python3)"
......
......@@ -80,6 +80,42 @@ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
clean_0, num_clean_int_0, clean_1, num_clean_int_1);
}
__device__ __forceinline__ void
internode_ll_putmem_nbi(void* dst_ptr, void* src_ptr,
int num_ranks, int dst_rank, int expert_idx,
int msg_bytes) {
#if defined(FORCE_NVSHMEM_API)
internode::shmemx_int8_put_nbi_warp(
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
msg_bytes, dst_rank);
#else
#if defined(ROCM_DISABLE_MULTIQP)
internode::shmemx_int8_put_nbi_warp(
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
msg_bytes, dst_rank);
#else
internode::shmemx_int8_put_nbi_warp_dp(
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
msg_bytes, (expert_idx + 1) * num_ranks + dst_rank, dst_rank);
#endif
#endif // defined(FORCE_NVSHMEM_API)
}
__device__ __forceinline__ void
internode_ll_long_atomic_add(long* dest, const long &value,
int num_ranks, int dst_rank, int expert_idx) {
#if defined(FORCE_DUSHMEM_API)
internode::shmem_long_atomic_add(dest, value, dst_rank);
#else
#if defined(ROCM_DISABLE_MULTIQP)
internode::shmem_long_atomic_add(dest, value, dst_rank);
#else
internode::shmem_long_atomic_add_dp(dest, value,
(expert_idx + 1) * num_ranks + dst_rank, dst_rank);
#endif
#endif // defined(FORCE_DUSHMEM_API)
}
template <bool kUseFP8, bool kUseUE8M0, bool kUseInt8, int kHidden>
__global__ __launch_bounds__(16 * kWarpSize, 1) void
dispatch(void* packed_recv_x, void* packed_recv_x_scales,
......@@ -118,9 +154,9 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Message package: hidden data, FP8 scales, index at source
// NOTES: currently we have 3 reserved int fields for future use
using vec_t = typename std::conditional<kUseFP8, int2, int4>::type;
const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + kNumScales * sizeof(float)) : (kHidden * sizeof(hip_bfloat16)));
const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);
constexpr size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + kNumScales * sizeof(float)) : (kHidden * sizeof(hip_bfloat16)));
EP_STATIC_ASSERT(num_bytes_per_msg % sizeof(int4) == 0, "Invalid message size");
constexpr size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
// Expert counts
constexpr int kNumMaxWarpGroups = 1024 / kWarpSize;
......@@ -135,7 +171,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// 2. The last warp for reading `topk_idx` and count for per-expert information
if (warp_id < num_warps) {
constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(hip_bfloat16);
EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
// EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize % kNumPerChannels == 0, "Invalid vectorization");
const auto num_threads = (num_warps - 1) * kWarpSize;
constexpr int hidden_bf16_int4 = kHidden / kNumElemsPerRead;
......@@ -256,34 +292,17 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
slot_idx * num_bytes_per_msg;
if (dst_rank != rank) {
#if defined(FORCE_DUSHMEM_API)
void *peer_base_addr = (void *)__ldg((const long long unsigned *)dushmemi_device_state_d.peer_heap_base_p2p + dst_rank);
if (peer_base_addr) {
char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(dushmemi_device_state_d.heap_base));
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
} else {
internode::shmemx_int8_put_nbi_warp(
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
num_bytes_per_msg, dst_rank);
}
#else
#if !defined(ROCM_USE_MULTIQP)
internode::shmemx_int8_put_nbi_warp(
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
num_bytes_per_msg, dst_rank);
#else
internode::shmemx_int8_put_nbi_warp_dp(
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
num_bytes_per_msg, (dst_expert_local_idx + 1) * num_ranks + dst_rank, dst_rank);
#endif
#endif // defined(FORCE_DUSHMEM_API)
} else {
// 通过 shmem_get_p2p_ptr 获取 当前远程指针能否可达
uint64_t p2p_ptr = internode::shmem_get_p2p_ptr((void*)dst_ptr, rank, dst_rank);
if (p2p_ptr == 0) { // RDMA
internode_ll_putmem_nbi((void*)dst_ptr, (void*)src_ptr,
num_ranks, dst_rank, dst_expert_local_idx,
num_bytes_per_msg);
} else { // 本地 GPU 和 同一计算节点的 其他 GPU 地址
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(p2p_ptr);
UNROLLED_WARP_COPY_LL(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
}
......@@ -294,7 +313,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
}
}
if (warp_id == num_warps - 1) {
EP_DEVICE_ASSERT(num_sms > 1);
// EP_DEVICE_ASSERT(num_sms > 1);
if (sm_id == 0) {
// The first SM is also responsible for checking QPs
// The first SM is also responsible for cleaning the next buffer
......@@ -341,29 +360,15 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Wait local sends issued and send expert counts
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
if (dst_rank != rank) {
#if defined(FORCE_DUSHMEM_API)
void *peer_base_addr = (void *)__ldg((const long long unsigned *)dushmemi_device_state_d.peer_heap_base_p2p + dst_rank);
if (peer_base_addr) { // P2P enabled
int *rptr_actual = (int *)((char *)(peer_base_addr) +
((char *)(rdma_recv_count + dst_expert_local_idx * num_ranks + rank) - (char *)(dushmemi_device_state_d.heap_base)));
st_na_release(rptr_actual, -num_tokens_sent - 1);
} else {
internode::shmem_long_atomic_add(
rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank);
}
#else
#if !defined(ROCM_USE_MULTIQP)
internode::shmem_long_atomic_add(
rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank);
#else
internode::shmem_long_atomic_add_dp(
rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1,
(dst_expert_local_idx + 1) * num_ranks + dst_rank, dst_rank);
#endif
#endif // defined(FORCE_DUSHMEM_API)
} else {
st_na_release(reinterpret_cast<int *>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1);
auto dst_ptr = rdma_recv_count + dst_expert_local_idx * num_ranks + rank;
// 通过 shmem_get_p2p_ptr 获取 当前远程指针能否可达
uint64_t p2p_ptr = internode::shmem_get_p2p_ptr((void*)dst_ptr, rank, dst_rank);
if (p2p_ptr == 0) { // RDMA
internode_ll_long_atomic_add(dst_ptr, -num_tokens_sent - 1,
num_ranks, dst_rank, dst_expert_local_idx);
} else { // 本地 GPU 和 同一计算节点的 其他 GPU 地址
st_na_release(reinterpret_cast<int *>(p2p_ptr), -num_tokens_sent - 1);
}
// Clean workspace for next use
......@@ -419,7 +424,7 @@ LOW_LATENCY_DISPATCH_RECV:
// Wait tokens to arrive
// NOTES: using sub-warp 1 to overlap with sub-warp 0
int num_recv_tokens, recv_token_begin_idx;
EP_DEVICE_ASSERT(num_warps_per_group > 1);
// EP_DEVICE_ASSERT(num_warps_per_group > 1);
if (sub_warp_id == 1 and lane_id == 0) {
while ((num_recv_tokens = ld_acquire_global(reinterpret_cast<int*>(rdma_recv_count + local_expert_idx * num_ranks + src_rank))) == 0);
......@@ -430,12 +435,6 @@ LOW_LATENCY_DISPATCH_RECV:
recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
}
#if defined(ROCM_USE_MULTIQP)
if (sub_warp_id == 2 and lane_id == 0) {
internode::shmem_qp_quiet(num_ranks + responsible_expert_idx);
}
#endif
// no needs to reset because there is no iteration
if (lane_id == 0){
volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
......@@ -447,7 +446,7 @@ LOW_LATENCY_DISPATCH_RECV:
recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id];
// Copy tokens
EP_DEVICE_ASSERT(kNumScales <= 64);
EP_STATIC_ASSERT(kNumScales <= 64, "Invalid hidden size");
for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) {
// Copy source info
const auto src_src_idx = reinterpret_cast<int*>(rdma_recv_x_uint8 + i * num_bytes_per_msg);
......@@ -632,41 +631,26 @@ combine(void* combined_x,
const auto src_idx = __ldg(local_src_info + token_idx);
const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot;
if (dst_rank == rank) {
const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
} else {
uint64_t p2p_ptr = internode::shmem_get_p2p_ptr((void*)dst_ptr, rank, dst_rank);
if (p2p_ptr == 0) { // RDMA
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
if (not zero_copy)
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
#if defined(FORCE_DUSHMEM_API)
void *peer_base_addr = (void *)__ldg((const long long unsigned *)dushmemi_device_state_d.peer_heap_base_p2p + dst_rank);
if (peer_base_addr) {
char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(dushmemi_device_state_d.heap_base));
const auto dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
} else {
internode::shmemx_int8_put_nbi_warp(
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr),
hidden * sizeof(hip_bfloat16), dst_rank);
}
#else
#if !defined(ROCM_USE_MULTIQP)
internode::shmemx_int8_put_nbi_warp(
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr),
hidden * sizeof(hip_bfloat16), dst_rank);
#else
internode::shmemx_int8_put_nbi_warp_dp(
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr),
hidden * sizeof(hip_bfloat16), (local_expert_idx + 1) * num_ranks + dst_rank, dst_rank);
#endif
#endif // defined(FORCE_DUSHMEM_API)
internode_ll_putmem_nbi((void*)dst_ptr, (void*)buf_ptr,
num_ranks, dst_rank, local_expert_idx,
hidden * sizeof(hip_bfloat16));
} else { // 本地 GPU 和 同一计算节点的 其他 GPU 地址
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(x_int4);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(p2p_ptr);
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
}
}
// Put finishing flag
EP_DEVICE_ASSERT(num_warps_per_group > 1);
// EP_DEVICE_ASSERT(num_warps_per_group > 1);
if (lane_id == 0){
volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1,__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
}
......@@ -675,30 +659,16 @@ combine(void* combined_x,
if (sub_warp_id == 1 and lane_id == 0) {
while (ld_acquire_global(atomic_clean_flag) == 0);
if (dst_rank != rank) {
#if defined(FORCE_DUSHMEM_API)
void *peer_base_addr = (void *)__ldg((const long long unsigned *)dushmemi_device_state_d.peer_heap_base_p2p + dst_rank);
if (peer_base_addr) {
int *req_rptr_actual = (int *)((char *)(peer_base_addr) +
((char *)(rdma_recv_flag + global_expert_idx) - (char *)(dushmemi_device_state_d.heap_base)));
st_na_release(req_rptr_actual, 1);
} else {
internode::shmem_long_atomic_add(
rdma_recv_flag + global_expert_idx, 1, dst_rank);
}
#else
#if !defined(ROCM_USE_MULTIQP)
internode::shmem_long_atomic_add(
rdma_recv_flag + global_expert_idx, 1, dst_rank);
#else
internode::shmem_long_atomic_add_dp(
rdma_recv_flag + global_expert_idx, 1,
(local_expert_idx + 1) * num_ranks + dst_rank, dst_rank);
#endif
#endif // defined(FORCE_DUSHMEM_API)
} else {
st_na_release(reinterpret_cast<int*>(rdma_recv_flag + global_expert_idx), 1);
auto dst_ptr = rdma_recv_flag + global_expert_idx;
// 通过 shmem_get_p2p_ptr 获取 当前远程指针能否可达
uint64_t p2p_ptr = internode::shmem_get_p2p_ptr((void*)dst_ptr, rank, dst_rank);
if (p2p_ptr == 0) { // RDMA
internode_ll_long_atomic_add(dst_ptr, 1, num_ranks, dst_rank, local_expert_idx);
} else { // 本地 GPU 和 同一计算节点的 其他 GPU 地址
st_na_release(reinterpret_cast<int *>(p2p_ptr), 1);
}
atomic_add_release_global(atomic_clean_flag, -1);
}
syncwarp();
......@@ -711,7 +681,7 @@ LOW_LATENCY_COMBINE_RECV:
// Wait all ranks to arrive and notify PCIe usage
if (responsible_expert_idx < num_experts) {
EP_DEVICE_ASSERT(num_warps_per_group > 1);
// EP_DEVICE_ASSERT(num_warps_per_group > 1);
if (sub_warp_id == 0 and lane_id == 0) {
const auto src_rank = responsible_expert_idx / num_local_experts;
auto start_time = wall_clock64();
......@@ -730,16 +700,11 @@ LOW_LATENCY_COMBINE_RECV:
atomicAdd(reinterpret_cast<unsigned long long*>(combine_wait_recv_cost_stats + src_rank), wait_recv_cost);
}
}
#if defined(ROCM_USE_MULTIQP)
if (sub_warp_id == 2 and lane_id == 0) {
internode::shmem_qp_quiet(num_ranks + responsible_expert_idx);
}
#endif
}
grid_barrier(global_atomic_counter, num_sms);
// Reduce tokens with FP8 cast
EP_DEVICE_ASSERT(num_topk <= kWarpSize and hidden_bf16_int4 <= num_threads);
// EP_DEVICE_ASSERT(num_topk <= kWarpSize and hidden_bf16_int4 <= num_threads);
EP_STATIC_ASSERT(kHidden % (kWarpSize * kNumElemsPerInt4) == 0, "Invalid vectorization");
if (thread_id < hidden_bf16_int4) {
for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
......
......@@ -116,7 +116,11 @@ __device__ inline void shmem_long_atomic_add(
rocshmem::rocshmem_long_atomic_add(dest, value, pe);
}
#if defined(ROCM_USE_MULTIQP)
__device__ inline uint64_t shmem_get_p2p_ptr(void *dest, int rank, int dst_rank) {
return rocshmem::rocshmem_get_p2p_ptr(dest, rank, dst_rank);
}
#if !defined(ROCM_DISABLE_MULTIQP)
__device__ inline void shmem_qp_quiet(int idx_qp) {
rocshmem::rocshmem_quiet_dp(idx_qp);
}
......@@ -273,6 +277,20 @@ __device__ inline void shmem_long_atomic_add(
dushmem_long_atomic_add(dest, value, pe);
}
__device__ __forceinline__ uint64_t shmem_get_p2p_ptr(void *dest, int rank, int dst_rank) {
// Local rank, no need for mapping
if (rank == dst_rank)
return reinterpret_cast<uint64_t>(dest);
auto peer_base = __ldg(reinterpret_cast<uint64_t*>(dushmemi_device_state_d.peer_heap_base_p2p) + dst_rank);
// RDMA connected
if (peer_base == 0)
return 0;
// NVLink P2P is enabled
return peer_base + (reinterpret_cast<uint64_t>(dest) - reinterpret_cast<uint64_t>(dushmemi_device_state_d.heap_base));
}
#endif
} // namespace deep_ep::internode
......
try:
__version__ = "1.0.0"
__version_tuple__ = (1, 0, 0)
__hcu_version__ = f'1.0.0+das.opt1.dtk2504'
from .version import __version__, __version_tuple__, __hcu_version__
except Exception as e:
import warnings
warnings.warn(f"Failed to read commit hash:\n + str(e)",
RuntimeWarning, stacklevel=2)
__version__ = "dev"
__version_tuple__ = (0, 0, __version__)
def _prev_minor_version_was(version_str):
'''Check whether a given version matches the previous minor version.
Return True if version_str matches the previous minor version.
For example - return True if the current version is 0.7.4 and the
supplied version_str is '0.6'.
Used for --show-hidden-metrics-for-version.
'''
# Match anything if this is a dev tree
if __version_tuple__[0:2] == (0, 0):
return True
# Note - this won't do the right thing when we release 1.0!
# assert __version_tuple__[0] == 0
assert isinstance(__version_tuple__[1], int)
return version_str == f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
def _prev_minor_version():
'''For the purpose of testing, return a previous minor version number.'''
# In dev tree, this will return "0.-1", but that will work fine"
assert isinstance(__version_tuple__[1], int)
return f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
......@@ -35,10 +35,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha != 'Unknown':
if sha is None:
sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=deepep_root).decode('ascii').strip()
if (major, minor) >= ('2', '5'):
if (major, minor) >= ('2', '4'):
version = 'das.opt1.' + sha[:7] + shmem
else:
if (major, minor) >= ('2', '5'):
if (major, minor) >= ('2', '4'):
version = 'das.opt1'
if os.getenv("ROCM_PATH"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment