Commit 1b497233 authored by lishen's avatar lishen
Browse files

Merge branch 'updates' into 'main'

Updates

See merge request dcutoolkit/deeplearing/DeepEP!12
parents 1b00b9d8 94694314
......@@ -6,7 +6,8 @@ export OMPI_MCA_coll_hcoll_enable=0
export UCX_TLS=rc,rocm
# export ROCSHMEM_UNIQUEID_WITH_MPI=1
export OMPI_MCA_rmaps_base_mapping_policy="slot:numa"
export ROCSHMEM_MAX_NUM_CONTEXTS=32
export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export ROCSHMEM_MAX_NUM_CONTEXTS=48
export UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS=16384
export UCX_NET_DEVICES=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
......@@ -15,5 +16,5 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export ROCSHMEM_HEAP_SIZE=10737418240
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency_new.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency_new.py --pressure-test
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py --test-ll-compatibility
......@@ -6,7 +6,8 @@ export OMPI_MCA_coll_hcoll_enable=0
export UCX_TLS=rc,rocm
# export ROCSHMEM_UNIQUEID_WITH_MPI=1
export OMPI_MCA_rmaps_base_mapping_policy="slot:numa"
export ROCSHMEM_MAX_NUM_CONTEXTS=32
export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export ROCSHMEM_MAX_NUM_CONTEXTS=48
export UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS=16384
export UCX_NET_DEVICES=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
......@@ -15,5 +16,5 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export ROCSHMEM_HEAP_SIZE=10737418240
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency_new.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency_new.py --pressure-test
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py --test-ll-compatibility
......@@ -31,24 +31,33 @@ PYTHON_PLATLIB=$(python3 -c "from sysconfig import get_paths; print(get_paths()[
USE_NVSHMEM=OFF
USE_ROCSHMEM=OFF
ROCM_DISABLE_CTX=OFF
case "$1" in
ROCM_USE_MULTIQP=OFF
# 解析命令行参数
for arg in "$@"; do
case $arg in
rocshmem)
USE_ROCSHMEM=ON
;;
nvshmem|dushmem)
USE_NVSHMEM=ON
;;
ROCM_DISABLE_CTX=ON)
ROCM_DISABLE_CTX=ON
;;
ROCM_USE_MULTIQP=ON)
ROCM_USE_MULTIQP=ON
;;
*)
echo "Usage: ./build.sh rocshmem [ROCM_DISABLE_CTX] / ./build.sh nvshmem"
echo "Usage: ./build.sh rocshmem [ROCM_DISABLE_CTX=ON] [ROCM_USE_MULTIQP=ON] / ./build.sh nvshmem"
exit 1
;;
esac
if [ "${2:-}" = "ROCM_DISABLE_CTX" ]; then
ROCM_DISABLE_CTX=ON
fi
esac
done
echo "USE_NVSHMEM=$USE_NVSHMEM"
echo "USE_ROCSHMEM=$USE_ROCSHMEM"
echo "ROCM_DISABLE_CTX=$ROCM_DISABLE_CTX"
echo "ROCM_USE_MULTIQP=$ROCM_USE_MULTIQP"
# -------------------------- With rocSHMEM -------------------------- #
build_rocshmem()
......@@ -84,6 +93,9 @@ if [ "$USE_ROCSHMEM" == "ON" ]; then
if [ "$ROCM_DISABLE_CTX" == "ON" ]; then
COMPILE_OPTIONS="-DROCM_DISABLE_CTX $COMPILE_OPTIONS"
fi
if [ "$ROCM_USE_MULTIQP" == "ON" ]; then
COMPILE_OPTIONS="-DROCM_USE_MULTIQP $COMPILE_OPTIONS"
fi
SHMEM_LINK_OPTIONS=${SHMEM_LINK_OPTIONS:="-Wl,-rpath,${SHMEM_INSTALL_PREFIX}/lib/ -l:librocshmem.a"}
fi
# -------------------------- rocSHMEM END -------------------------- #
......
......@@ -44,10 +44,10 @@ struct Config {
constexpr int kNumMaxTopK = 128;
constexpr int kNumMaxScales = 128;
EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0);
EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % (2 * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL) == 0);
const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1);
const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
const int num_channels = num_sms / 2;
const int num_channels = num_sms;
size_t num_bytes = 0;
num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int);
......@@ -77,9 +77,9 @@ struct Config {
constexpr int kNumMaxTopK = 128;
constexpr int kNumMaxScales = 128;
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_sms % 2 == 0);
EP_HOST_ASSERT(num_sms % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0);
const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
const int num_channels = num_sms / 2;
const int num_channels = num_sms;
size_t num_bytes = 0;
num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int);
......
......@@ -25,8 +25,6 @@
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define FP8_QUANTIZATION_NUM_PER_CHANNEL 128
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define DEFAULT_NUM_CU 20
#define DEFAULT_NUM_MAX_XGMI_CHUNKED_SEND_TOKENS 6
#define DEFAULT_NUM_MAX_XGMI_CHUNKED_RECV_TOKENS 256
......
......@@ -8,9 +8,6 @@
#include "hip/hip_runtime.h"
// low latency+RocSHMEM has issue with CTX.
#define ROCM_DISABLE_CTX
#include "shmem_wrapper.cuh"
namespace deep_ep {
......@@ -133,11 +130,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
goto LOW_LATENCY_DISPATCH_RECV;
#if !defined(ROCM_DISABLE_CTX)
__shared__ internode::shmem_ctx_t ctx;
internode::shmem_wg_ctx_create(&ctx);
#endif
// There are 2 kinds of warps in this part:
// 1. The first-kind warps for FP8 cast and sending top-k tokens
// 2. The last warp for reading `topk_idx` and count for per-expert information
......@@ -265,24 +257,29 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
slot_idx * num_bytes_per_msg;
if (dst_rank != rank) {
#if defined(FORCE_NVSHMEM_API)
#if defined(FORCE_NVSHMEM_API)
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
if (peer_base_addr) {
char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(nvshmemi_device_state_d.heap_base));
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
} else
#endif
{
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_schar_put_nbi_warp(ctx,
#else
} else {
internode::shmemx_int8_put_nbi_warp(
#endif
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
num_bytes_per_msg, dst_rank);
}
#else
#if !defined(ROCM_USE_MULTIQP)
internode::shmemx_int8_put_nbi_warp(
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
num_bytes_per_msg, dst_rank);
#else
internode::shmemx_int8_put_nbi_warp_dp(
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
num_bytes_per_msg, (dst_expert_local_idx + 1) * num_ranks + dst_rank, dst_rank);
#endif
#endif // defined(FORCE_NVSHMEM_API)
} else {
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
......@@ -345,22 +342,26 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Wait local sends issued and send expert counts
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
if (dst_rank != rank) {
#if defined(FORCE_NVSHMEM_API)
#if defined(FORCE_NVSHMEM_API)
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
if (peer_base_addr) { // P2P enabled
int *rptr_actual = (int *)((char *)(peer_base_addr) +
((char *)(rdma_recv_count + dst_expert_local_idx * num_ranks + rank) - (char *)(nvshmemi_device_state_d.heap_base)));
st_na_release(rptr_actual, -num_tokens_sent - 1);
} else
#endif
{
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_long_atomic_add(ctx,
#else
} else {
internode::shmem_long_atomic_add(
#endif
rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank);
}
#else
#if !defined(ROCM_USE_MULTIQP)
internode::shmem_long_atomic_add(
rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank);
#else
internode::shmem_long_atomic_add_dp(
rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1,
(dst_expert_local_idx + 1) * num_ranks + dst_rank, dst_rank);
#endif
#endif // defined(FORCE_NVSHMEM_API)
} else {
st_na_release(reinterpret_cast<int *>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1);
}
......@@ -375,10 +376,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
}
syncwarp();
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_wg_ctx_destroy(&ctx);
#endif
// Receiving phase
LOW_LATENCY_DISPATCH_RECV:
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
......@@ -433,6 +430,12 @@ LOW_LATENCY_DISPATCH_RECV:
recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
}
#if defined(ROCM_USE_MULTIQP)
if (sub_warp_id == 2 and lane_id == 0) {
internode::shmem_qp_quiet(num_ranks + responsible_expert_idx);
}
#endif
// no needs to reset because there is no iteration
if (lane_id == 0){
volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
......@@ -591,11 +594,6 @@ combine(void* combined_x,
if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
goto LOW_LATENCY_COMBINE_RECV;
#if !defined(ROCM_DISABLE_CTX)
__shared__ internode::shmem_ctx_t ctx;
internode::shmem_wg_ctx_create(&ctx);
#endif
// Clean up next buffer
if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
#pragma unroll
......@@ -642,23 +640,28 @@ combine(void* combined_x,
if (not zero_copy)
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
#if defined(FORCE_NVSHMEM_API)
#if defined(FORCE_NVSHMEM_API)
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
if (peer_base_addr) {
char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(nvshmemi_device_state_d.heap_base));
const auto dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
} else
#endif
{
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_schar_put_nbi_warp(ctx,
#else
} else {
internode::shmemx_int8_put_nbi_warp(
#endif
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr),
hidden * sizeof(hip_bfloat16), dst_rank);
}
#else
#if !defined(ROCM_USE_MULTIQP)
internode::shmemx_int8_put_nbi_warp(
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr),
hidden * sizeof(hip_bfloat16), dst_rank);
#else
internode::shmemx_int8_put_nbi_warp_dp(
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr),
hidden * sizeof(hip_bfloat16), (local_expert_idx + 1) * num_ranks + dst_rank, dst_rank);
#endif
#endif // defined(FORCE_NVSHMEM_API)
}
}
......@@ -673,55 +676,39 @@ combine(void* combined_x,
if (sub_warp_id == 1 and lane_id == 0) {
while (ld_acquire_global(atomic_clean_flag) == 0);
if (dst_rank != rank) {
#if defined(FORCE_NVSHMEM_API)
#if defined(FORCE_NVSHMEM_API)
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
if (peer_base_addr) {
int *req_rptr_actual = (int *)((char *)(peer_base_addr) +
((char *)(rdma_recv_flag + global_expert_idx) - (char *)(nvshmemi_device_state_d.heap_base)));
st_na_release(req_rptr_actual, 1);
} else
#endif
{
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_long_atomic_add(ctx,
#else
} else {
internode::shmem_long_atomic_add(
#endif
rdma_recv_flag + global_expert_idx, 1, dst_rank);
}
#else
#if !defined(ROCM_USE_MULTIQP)
internode::shmem_long_atomic_add(
rdma_recv_flag + global_expert_idx, 1, dst_rank);
#else
internode::shmem_long_atomic_add_dp(
rdma_recv_flag + global_expert_idx, 1,
(local_expert_idx + 1) * num_ranks + dst_rank, dst_rank);
#endif
#endif // defined(FORCE_NVSHMEM_API)
} else {
st_na_release(reinterpret_cast<int*>(rdma_recv_flag + global_expert_idx), 1);
}
atomic_add_release_global(atomic_clean_flag, -1);
}
syncwarp();
if (num_ranks > 8){
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_quiet(ctx);
#else
internode::shmem_fence();
#endif
}
}
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_wg_ctx_destroy(&ctx);
#endif
// Receiving phase
LOW_LATENCY_COMBINE_RECV:
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
return;
// if (num_ranks > 8){
// #if !defined(ROCM_DISABLE_CTX)
// internode::shmem_ctx_quiet(ctx);
// #else
// internode::shmem_fence();
// #endif
// }
// Wait all ranks to arrive and notify PCIe usage
if (responsible_expert_idx < num_experts) {
EP_DEVICE_ASSERT(num_warps_per_group > 1);
......@@ -743,6 +730,11 @@ LOW_LATENCY_COMBINE_RECV:
atomicAdd(reinterpret_cast<unsigned long long*>(combine_wait_recv_cost_stats + src_rank), wait_recv_cost);
}
}
#if defined(ROCM_USE_MULTIQP)
if (sub_warp_id == 2 and lane_id == 0) {
internode::shmem_qp_quiet(num_ranks + responsible_expert_idx);
}
#endif
}
grid_barrier(global_atomic_counter, num_sms);
......
......@@ -116,6 +116,22 @@ __device__ inline void shmem_long_atomic_add(
rocshmem::rocshmem_long_atomic_add(dest, value, pe);
}
#if defined(ROCM_USE_MULTIQP)
__device__ inline void shmem_qp_quiet(int idx_qp) {
rocshmem::rocshmem_quiet_dp(idx_qp);
}
__device__ inline void shmemx_int8_put_nbi_warp_dp(
signed char *dest, const signed char *source, size_t nelems, int qp_idx, int pe) {
rocshmem::rocshmem_schar_put_nbi_wave_dp(dest, source, nelems, qp_idx, pe);
}
__device__ inline void shmem_long_atomic_add_dp(
long *dest, long value, int qp_idx, int pe) {
rocshmem::rocshmem_long_atomic_add_dp(dest, value, qp_idx, pe);
}
#endif
#if !defined(ROCM_DISABLE_CTX)
using shmem_ctx_t = rocshmem::rocshmem_ctx_t;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment