Commit d7f41337 authored by lijian6's avatar lijian6
Browse files

Modify nvshmem to dushmem.


Signed-off-by: lijian6's avatarlijian <lijian6@sugon.com>
parent 1a2f45fc
......@@ -48,7 +48,7 @@ for arg in "$@"; do
ROCM_USE_MULTIQP=ON
;;
*)
echo "Usage: ./build.sh rocshmem [ROCM_DISABLE_CTX=ON] [ROCM_USE_MULTIQP=ON] / ./build.sh nvshmem"
echo "Usage: ./build.sh rocshmem [ROCM_DISABLE_CTX=ON] [ROCM_USE_MULTIQP=ON] / ./build.sh dushmem"
exit 1
;;
esac
......@@ -133,8 +133,8 @@ if [ "$USE_NVSHMEM" == "ON" ]; then
# build_dushmem
# SHMEM_INSTALL_PREFIX=$(pwd)/third-party/dushmem_install
SHMEM_INSTALL_PREFIX=${ROCM_PATH}/dushmem
COMPILE_OPTIONS=${COMPILE_OPTIONS:= -fPIC -DFORCE_NVSHMEM_API -DHIP_ENABLE_WARP_SYNC_BUILTINS -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 --offload-arch=gfx938 -std=c++17 -Wno-return-type}
SHMEM_LINK_OPTIONS="-Wl,-rpath,${SHMEM_INSTALL_PREFIX}/lib/ -l:libnvshmem_device.a -lnvshmem_host"
COMPILE_OPTIONS=${COMPILE_OPTIONS:= -fPIC -DFORCE_DUSHMEM_API -DHIP_ENABLE_WARP_SYNC_BUILTINS -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 --offload-arch=gfx938 -std=c++17 -Wno-return-type}
SHMEM_LINK_OPTIONS="-Wl,-rpath,${SHMEM_INSTALL_PREFIX}/lib/ -l:libdushmem_device.a -ldushmem_host"
fi
# -------------------------- duSHMEM END -------------------------- #
......
......@@ -143,7 +143,7 @@ pybind11::bytearray Buffer::get_local_ipc_handle() const {
return {ipc_handles[nvl_rank].reserved, HIP_IPC_HANDLE_SIZE};
}
pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const {
pybind11::bytearray Buffer::get_local_dushmem_unique_id() const {
#ifndef DISABLE_ROCSHMEM
EP_HOST_ASSERT(rdma_rank == 0 and "Only RDMA rank 0 can get ROCSHMEM unique ID");
auto unique_id = internode::get_unique_id();
......@@ -260,9 +260,9 @@ void Buffer::sync(const std::vector<int> &device_
std::vector<uint8_t> root_unique_id(root_unique_id_opt->size());
auto root_unique_id_str = root_unique_id_opt->cast<std::string>();
std::memcpy(root_unique_id.data(), root_unique_id_str.c_str(), root_unique_id_opt->size());
auto nvshmem_rank = low_latency_mode ? rank : rdma_rank;
auto num_nvshmem_ranks = low_latency_mode ? num_ranks : num_rdma_ranks;
EP_HOST_ASSERT(nvshmem_rank == internode::init(root_unique_id, nvshmem_rank, num_nvshmem_ranks, low_latency_mode));
auto dushmem_rank = low_latency_mode ? rank : rdma_rank;
auto num_dushmem_ranks = low_latency_mode ? num_ranks : num_rdma_ranks;
EP_HOST_ASSERT(dushmem_rank == internode::init(root_unique_id, dushmem_rank, num_dushmem_ranks, low_latency_mode));
internode::barrier();
// Allocate
......@@ -1531,7 +1531,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("get_root_rdma_rank", &deep_ep::Buffer::get_root_rdma_rank)
.def("get_local_device_id", &deep_ep::Buffer::get_local_device_id)
.def("get_local_ipc_handle", &deep_ep::Buffer::get_local_ipc_handle)
.def("get_local_nvshmem_unique_id", &deep_ep::Buffer::get_local_nvshmem_unique_id)
.def("get_local_dushmem_unique_id", &deep_ep::Buffer::get_local_dushmem_unique_id)
.def("get_local_buffer_tensor", &deep_ep::Buffer::get_local_buffer_tensor)
.def("get_comm_stream", &deep_ep::Buffer::get_comm_stream)
.def("sync", &deep_ep::Buffer::sync)
......
......@@ -29,7 +29,7 @@ private:
void* nvl_buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
void** nvl_buffer_ptrs_gpu = nullptr;
// NVSHMEM Buffer
// DUSHMEM Buffer
int64_t num_rdma_bytes;
void *rdma_buffer_ptr = nullptr;
......@@ -48,7 +48,7 @@ private:
// Stream for communication
at::hip::HIPStreamMasqueradingAsCUDA comm_stream;
// After IPC/NVSHMEM synchronization, this flag will be true
// After IPC/DUSHMEM synchronization, this flag will be true
bool available = false;
// Whether explicit `destroy()` is required.
......@@ -95,7 +95,7 @@ public:
pybind11::bytearray get_local_ipc_handle() const;
pybind11::bytearray get_local_nvshmem_unique_id() const;
pybind11::bytearray get_local_dushmem_unique_id() const;
torch::Tensor get_local_buffer_tensor(const pybind11::object &dtype, int64_t offset,
bool use_rdma_buffer) const;
......
......@@ -86,7 +86,7 @@ __forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank,
template <bool kLowLatencyMode>
__forceinline__ __device__ void
nvshmem_barrier_with_same_gpu_idx(const shmem_team_t &rdma_team) {
dushmem_barrier_with_same_gpu_idx(const shmem_team_t &rdma_team) {
// NOTE: shmem_device_barrier_all() might be an issue as
// it doesn't follow OpenSHMEM specification on ROCm
kLowLatencyMode ? shmem_barrier(rdma_team) : shmem_device_barrier_all();
......@@ -119,7 +119,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
EP_DEVICE_ASSERT(num_warps > 1);
EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads);
if (thread_id == kWarpSize)
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
__syncthreads();
......@@ -161,7 +161,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
}
__syncthreads();
if (thread_id == 0)
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
__syncthreads();
......@@ -189,7 +189,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;
// Reduce number of tokens per expert into the NVL send buffer
// TODO: may use NVSHMEM reduction
// TODO: may use DUSHMEM reduction
EP_DEVICE_ASSERT(num_rdma_experts <= num_threads);
if (thread_id < num_rdma_experts) {
int sum = 0;
......@@ -257,7 +257,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
// Finally barrier
__syncthreads();
if (thread_id == kWarpSize)
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
} else {
......@@ -399,7 +399,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
kForwarderCoordinator, // 向远端RDMA确认接收
kNVLReceivers // 从nvl缓存写入到recv_x
};
#if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
__shared__ shmem_ctx_t ctx;
shmem_wg_ctx_create(&ctx);
#endif
......@@ -516,7 +516,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
syncwarp();
if (dst_rdma_rank != rdma_rank) {
#if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_int_put_nbi_warp(ctx,
#else
shmemx_int_put_nbi_warp(
......@@ -527,7 +527,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
}
}
#if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_quiet(ctx);
#else
shmem_fence();
......@@ -690,7 +690,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
最后,它更新相关的尾部位置,以便下次循环时可以正确地计算需要发送的令牌数。
kRDMASenderCoordinator使用了同sm内存一致性(ld.acquire.cta.s32),
nvshmem内存一致性(nvshmem_fence)和原子操作(nvshmemx_signal_op),减少硬同步,提升整体效率。
dushmem内存一致性(dushmem_fence)和原子操作(dushmemx_signal_op),减少硬同步,提升整体效率。
*/
if(warp_id > kNumDispatchRDMASenderWarps) {
return;
......@@ -741,7 +741,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
if(dst_rdma_rank != rdma_rank) {
auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);
#if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_schar_put_nbi_warp(ctx,
#else
shmemx_int8_put_nbi_warp(
......@@ -752,7 +752,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
dst_slot_idx * num_bytes_per_rdma_token,
num_bytes_per_rdma_token * num_tokens_to_issue,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
#if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_quiet(ctx);
#else
shmem_fence();
......@@ -768,7 +768,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
last_issued_tail += num_tokens_to_issue;
num_tokens_to_send -= num_tokens_to_issue;
// 更新远端rdma 己方已发送的token数,用于做发送信息同步。用于与kRDMAAndNVLForwarder互相通信
#if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add(ctx,
#else
shmem_signal_op_add(
......@@ -1008,7 +1008,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
// 更新远程头部
if(min_head != std::numeric_limits<int>::max() && min_head >= last_head + num_max_rdma_chunked_send_tokens && lane_id < kNumRDMARanks){
#if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add(ctx,
#else
shmem_signal_op_add(
......@@ -1127,7 +1127,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
}
} // while(num_tokens_to_recv > 0)
}
#if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_wg_ctx_destroy(&ctx);
#endif
}
......@@ -1203,7 +1203,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
if (sm_id == 0) {
// Barrier for RDMA
if (thread_id == 0)
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
__syncthreads();
......@@ -1216,7 +1216,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
// Barrier again
if (thread_id == 0)
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
} else if (sm_id == 1) {
// Barrier for NVL
......@@ -1417,7 +1417,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
kRDMACoordinator,
kNVLCoordinator
};
#if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
__shared__ shmem_ctx_t ctx;
shmem_wg_ctx_create(&ctx);
#endif
......@@ -1744,7 +1744,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
if(sub_warp_id == kNumWarpsPerForwarder - 1) {
if(dst_rdma_rank != rdma_rank) {
auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
#if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_schar_put_nbi_warp(ctx,
#else
shmemx_int8_put_nbi_warp(
......@@ -1755,7 +1755,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
rdma_slot_idx * num_bytes_per_rdma_token,
num_chunked_tokens * num_bytes_per_rdma_token,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
#if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_quiet(ctx);
#else
shmem_fence();
......@@ -1767,7 +1767,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
// Write new RDMA tail
syncwarp();
if(lane_id == 0) {
#if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add(ctx,
#else
shmem_signal_op_add(
......@@ -1900,7 +1900,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
#if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add(ctx,
#else
shmem_signal_op_add(
......@@ -1917,7 +1917,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
}
}
}
#if !defined(FORCE_NVSHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_wg_ctx_destroy(&ctx);
#endif
}
......
......@@ -257,10 +257,10 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
slot_idx * num_bytes_per_msg;
if (dst_rank != rank) {
#if defined(FORCE_NVSHMEM_API)
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
#if defined(FORCE_DUSHMEM_API)
void *peer_base_addr = (void *)__ldg((const long long unsigned *)dushmemi_device_state_d.peer_heap_base_p2p + dst_rank);
if (peer_base_addr) {
char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(nvshmemi_device_state_d.heap_base));
char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(dushmemi_device_state_d.heap_base));
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
......@@ -279,7 +279,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
num_bytes_per_msg, (dst_expert_local_idx + 1) * num_ranks + dst_rank, dst_rank);
#endif
#endif // defined(FORCE_NVSHMEM_API)
#endif // defined(FORCE_DUSHMEM_API)
} else {
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
......@@ -342,11 +342,11 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Wait local sends issued and send expert counts
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
if (dst_rank != rank) {
#if defined(FORCE_NVSHMEM_API)
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
#if defined(FORCE_DUSHMEM_API)
void *peer_base_addr = (void *)__ldg((const long long unsigned *)dushmemi_device_state_d.peer_heap_base_p2p + dst_rank);
if (peer_base_addr) { // P2P enabled
int *rptr_actual = (int *)((char *)(peer_base_addr) +
((char *)(rdma_recv_count + dst_expert_local_idx * num_ranks + rank) - (char *)(nvshmemi_device_state_d.heap_base)));
((char *)(rdma_recv_count + dst_expert_local_idx * num_ranks + rank) - (char *)(dushmemi_device_state_d.heap_base)));
st_na_release(rptr_actual, -num_tokens_sent - 1);
} else {
internode::shmem_long_atomic_add(
......@@ -361,7 +361,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1,
(dst_expert_local_idx + 1) * num_ranks + dst_rank, dst_rank);
#endif
#endif // defined(FORCE_NVSHMEM_API)
#endif // defined(FORCE_DUSHMEM_API)
} else {
st_na_release(reinterpret_cast<int *>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1);
}
......@@ -640,10 +640,10 @@ combine(void* combined_x,
if (not zero_copy)
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
#if defined(FORCE_NVSHMEM_API)
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
#if defined(FORCE_DUSHMEM_API)
void *peer_base_addr = (void *)__ldg((const long long unsigned *)dushmemi_device_state_d.peer_heap_base_p2p + dst_rank);
if (peer_base_addr) {
char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(nvshmemi_device_state_d.heap_base));
char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(dushmemi_device_state_d.heap_base));
const auto dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
} else {
......@@ -661,7 +661,7 @@ combine(void* combined_x,
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr),
hidden * sizeof(hip_bfloat16), (local_expert_idx + 1) * num_ranks + dst_rank, dst_rank);
#endif
#endif // defined(FORCE_NVSHMEM_API)
#endif // defined(FORCE_DUSHMEM_API)
}
}
......@@ -676,11 +676,11 @@ combine(void* combined_x,
if (sub_warp_id == 1 and lane_id == 0) {
while (ld_acquire_global(atomic_clean_flag) == 0);
if (dst_rank != rank) {
#if defined(FORCE_NVSHMEM_API)
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
#if defined(FORCE_DUSHMEM_API)
void *peer_base_addr = (void *)__ldg((const long long unsigned *)dushmemi_device_state_d.peer_heap_base_p2p + dst_rank);
if (peer_base_addr) {
int *req_rptr_actual = (int *)((char *)(peer_base_addr) +
((char *)(rdma_recv_flag + global_expert_idx) - (char *)(nvshmemi_device_state_d.heap_base)));
((char *)(rdma_recv_flag + global_expert_idx) - (char *)(dushmemi_device_state_d.heap_base)));
st_na_release(req_rptr_actual, 1);
} else {
internode::shmem_long_atomic_add(
......@@ -695,7 +695,7 @@ combine(void* combined_x,
rdma_recv_flag + global_expert_idx, 1,
(local_expert_idx + 1) * num_ranks + dst_rank, dst_rank);
#endif
#endif // defined(FORCE_NVSHMEM_API)
#endif // defined(FORCE_DUSHMEM_API)
} else {
st_na_release(reinterpret_cast<int*>(rdma_recv_flag + global_expert_idx), 1);
}
......
......@@ -61,9 +61,9 @@ int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks
&cpu_rdma_team_config, 0, &cpu_rdma_team) == 0);
EP_HOST_ASSERT(cpu_rdma_team != EP_SHMEM_TEAM_INVALID);
#ifdef FORCE_NVSHMEM_API
nvshmemi_device_host_state_t* dev_state_ptr = nullptr;
CUDA_CHECK(hipGetSymbolAddress(reinterpret_cast<void**>(&dev_state_ptr), nvshmemi_device_state_d));
#ifdef FORCE_DUSHMEM_API
dushmemi_device_host_state_t* dev_state_ptr = nullptr;
CUDA_CHECK(hipGetSymbolAddress(reinterpret_cast<void**>(&dev_state_ptr), dushmemi_device_state_d));
bool ibgda_is_initialized = false;
CUDA_CHECK(hipMemcpy(&dev_state_ptr->ibgda_is_initialized, &ibgda_is_initialized, sizeof(bool), hipMemcpyHostToDevice));
#endif
......
#pragma once
/*
* Temporary wrapper for for platform specific NVSHMEM and rocSHMEM functions.
* Temporary wrapper for for platform specific DUSHMEM and rocSHMEM functions.
* Once hipify or hipify-torch fully supports this mapping, this file has to be
* removed and according nvshmem* functions restored.
* removed and according dushmem* functions restored.
*/
#ifndef DISABLE_ROCSHMEM
#include "configs.cuh"
#ifndef FORCE_NVSHMEM_API
#ifndef FORCE_DUSHMEM_API
#include <hip/hip_bfloat16.h>
#include <hip/hip_fp8.h>
#include <hip/hip_runtime.h>
#include <rocshmem/rocshmem.hpp>
#else
#include <device_host_transport/nvshmem_common_ibgda.h>
#include <device_host_transport/dushmem_common_ibgda.h>
#include <infiniband/mlx5dv.h>
#include <nvshmem.h>
#include <nvshmemx.h>
#include <non_abi/device/threadgroup/nvshmemi_common_device_defines.cuh>
#include <dushmem.h>
#include <dushmemx.h>
#include <non_abi/device/threadgroup/dushmemi_common_device_defines.cuh>
#endif
namespace deep_ep::internode {
// rocSHMEM wrapper
#ifndef FORCE_NVSHMEM_API
#ifndef FORCE_DUSHMEM_API
using shmem_team_t = rocshmem::rocshmem_team_t;
using shmem_team_config_t = rocshmem::rocshmem_team_config_t;
const shmem_team_t EP_SHMEM_TEAM_INVALID = rocshmem::ROCSHMEM_TEAM_INVALID;
......@@ -171,106 +171,106 @@ __device__ inline void shmem_ctx_int_put_nbi_warp(
#else
// NVSHMEM wrapper
// DUSHMEM wrapper
#ifndef ROCM_DISABLE_CTX
#define ROCM_DISABLE_CTX
#endif
using shmem_team_t = nvshmem_team_t;
using shmem_team_config_t = nvshmem_team_config_t;
using shmemx_uniqueid_t = nvshmemx_uniqueid_t;
using shmemx_init_attr_t = nvshmemx_init_attr_t;
const shmem_team_t EP_SHMEM_TEAM_INVALID = NVSHMEM_TEAM_INVALID;
const shmem_team_t EP_SHMEM_TEAM_WORLD = NVSHMEM_TEAM_WORLD;
constexpr auto EP_SHMEMX_INIT_WITH_UNIQUEID = NVSHMEMX_INIT_WITH_UNIQUEID;
using shmem_team_t = dushmem_team_t;
using shmem_team_config_t = dushmem_team_config_t;
using shmemx_uniqueid_t = dushmemx_uniqueid_t;
using shmemx_init_attr_t = dushmemx_init_attr_t;
const shmem_team_t EP_SHMEM_TEAM_INVALID = DUSHMEM_TEAM_INVALID;
const shmem_team_t EP_SHMEM_TEAM_WORLD = DUSHMEM_TEAM_WORLD;
constexpr auto EP_SHMEMX_INIT_WITH_UNIQUEID = DUSHMEMX_INIT_WITH_UNIQUEID;
__host__ inline int shmemx_get_uniqueid(shmemx_uniqueid_t *uid) {
return nvshmemx_get_uniqueid(uid);
return dushmemx_get_uniqueid(uid);
}
__host__ inline int shmemx_set_attr_uniqueid_args(int rank, int nranks,
shmemx_uniqueid_t *uid, shmemx_init_attr_t *attr) {
return nvshmemx_set_attr_uniqueid_args(rank, nranks, uid, attr);
return dushmemx_set_attr_uniqueid_args(rank, nranks, uid, attr);
}
__host__ inline int shmemx_init_attr(unsigned int flags, shmemx_init_attr_t *attr) {
return nvshmemx_init_attr(flags, attr);
return dushmemx_init_attr(flags, attr);
}
__host__ inline int shmem_team_split_strided(shmem_team_t parent_team,
int start, int stride, int size,
const shmem_team_config_t *config,
long config_mask, shmem_team_t *new_team) {
return nvshmem_team_split_strided(parent_team, start, stride, size, config, config_mask, new_team);
return dushmem_team_split_strided(parent_team, start, stride, size, config, config_mask, new_team);
}
__host__ inline void shmem_barrier_all() {
nvshmem_barrier_all();
dushmem_barrier_all();
}
__device__ inline void shmem_device_barrier_all() {
nvshmem_barrier_all();
dushmem_barrier_all();
}
__device__ inline void shmem_barrier(shmem_team_t team) {
void(nvshmem_barrier(team));
void(dushmem_barrier(team));
}
__host__ inline int shmem_my_pe(){
return nvshmem_my_pe();
return dushmem_my_pe();
}
__host__ inline void shmem_free(void *ptr){
nvshmem_free(ptr);
dushmem_free(ptr);
}
__host__ inline void* shmem_align(const size_t alignment, const size_t size) {
return nvshmem_align(size, alignment);
return dushmem_align(size, alignment);
}
__host__ inline void shmem_finalize() {
nvshmem_finalize();
dushmem_finalize();
}
__host__ inline void shmem_team_destroy(shmem_team_t team) {
nvshmem_team_destroy(team);
dushmem_team_destroy(team);
}
__device__ inline void shmem_fence() {
nvshmem_fence();
dushmem_fence();
}
__device__ inline void shmem_int_put_nbi(
int *dest, const int *source, size_t nelems, int pe) {
nvshmem_int_put_nbi(dest, source, nelems, pe);
dushmem_int_put_nbi(dest, source, nelems, pe);
}
__device__ inline void shmemx_int_put_nbi_warp(
int *dest, const int *source, size_t nelems, int pe) {
nvshmemx_int_put_nbi_warp(dest, source, nelems, pe);
dushmemx_int_put_nbi_warp(dest, source, nelems, pe);
}
__device__ inline void shmemx_int8_put_nbi_warp(
signed char *dest, const signed char *source, size_t nelems, int pe) {
nvshmemx_int8_put_nbi_warp(dest, source, nelems, pe);
dushmemx_int8_put_nbi_warp(dest, source, nelems, pe);
}
__device__ inline void shmem_signal_op_add(
uint64_t *dest, uint64_t value, int pe) {
nvshmemx_signal_op(dest, value, NVSHMEM_SIGNAL_ADD, pe);
dushmemx_signal_op(dest, value, DUSHMEM_SIGNAL_ADD, pe);
}
__device__ inline void shmem_ulong_atomic_add(
uint64_t *dest, uint64_t value, int pe) {
nvshmem_ulong_atomic_add(dest, value, pe);
dushmem_ulong_atomic_add(dest, value, pe);
}
__device__ inline void shmem_long_atomic_add(
long *dest, long value, int pe) {
// nvshmem_##Name##_atomic_add(dest, value, pe);
nvshmem_long_atomic_add(dest, value, pe);
// dushmem_##Name##_atomic_add(dest, value, pe);
dushmem_long_atomic_add(dest, value, pe);
}
#endif
......
......@@ -96,46 +96,46 @@ class Buffer:
local_ipc_handle = self.runtime.get_local_ipc_handle()
dist.all_gather_object(ipc_handles, local_ipc_handle, group)
# Synchronize NVSHMEM unique IDs
# Synchronize DUSHMEM unique IDs
root_unique_id = None
if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode:
# Enable IBGDA
self._setup_device_hca_mapping()
assert num_qps_per_rank > 0
os.environ["NVSHMEM_DISABLE_P2P"] = "0" if allow_nvlink_for_low_latency_mode else "1"
# os.environ["NVSHMEM_IB_ENABLE_IBGDA"] = "1"
os.environ["NVSHMEM_IB_ENABLE_IBGDA"] = "0" # force_use_ibrc
os.environ["DUSHMEM_DISABLE_P2P"] = "0" if allow_nvlink_for_low_latency_mode else "1"
# os.environ["DUSHMEM_IB_ENABLE_IBGDA"] = "1"
os.environ["DUSHMEM_IB_ENABLE_IBGDA"] = "0" # force_use_ibrc
os.environ["NVSHMEM_IBGDA_NIC_HANDLER"] = "gpu"
os.environ["NVSHMEM_IB_DISABLE_DMABUF"] = "1"
os.environ["NVSHMEM_ENABLE_NIC_PE_MAPPING"] = "1"
os.environ["DUSHMEM_IBGDA_NIC_HANDLER"] = "gpu"
os.environ["DUSHMEM_IB_DISABLE_DMABUF"] = "1"
os.environ["DUSHMEM_ENABLE_NIC_PE_MAPPING"] = "1"
os.environ["NVSHMEM_IBGDA_NUM_RC_PER_PE"] = f"{num_qps_per_rank}"
os.environ["DUSHMEM_IBGDA_NUM_RC_PER_PE"] = f"{num_qps_per_rank}"
# Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check
os.environ["NVSHMEM_QP_DEPTH"] = os.environ.get("NVSHMEM_QP_DEPTH", "1024")
os.environ["DUSHMEM_QP_DEPTH"] = os.environ.get("DUSHMEM_QP_DEPTH", "1024")
# Reduce gpu memory usage
# 6 default teams + 1 extra team
os.environ["NVSHMEM_MAX_TEAMS"] = "7"
os.environ["DUSHMEM_MAX_TEAMS"] = "7"
# Disable NVLink SHArP
os.environ["NVSHMEM_DISABLE_NVLS"] = "1"
# NOTES: NVSHMEM initialization requires at least 256 MiB
os.environ["NVSHMEM_CUMEM_GRANULARITY"] = f"{2 ** 29}"
os.environ["DUSHMEM_DISABLE_NVLS"] = "1"
# NOTES: DUSHMEM initialization requires at least 256 MiB
os.environ["DUSHMEM_CUMEM_GRANULARITY"] = f"{2 ** 29}"
if not allow_mnnvl:
# Disable multi-node NVLink detection
os.environ["NVSHMEM_DISABLE_MNNVL"] = "1"
os.environ["DUSHMEM_DISABLE_MNNVL"] = "1"
# Synchronize using the root ID
nvshmem_unique_ids = [
dushmem_unique_ids = [
None,
] * self.group_size
if (low_latency_mode and self.rank == 0) or (
not low_latency_mode and self.runtime.get_rdma_rank() == 0
):
root_unique_id = self.runtime.get_local_nvshmem_unique_id()
dist.all_gather_object(nvshmem_unique_ids, root_unique_id, group)
root_unique_id = nvshmem_unique_ids[
root_unique_id = self.runtime.get_local_dushmem_unique_id()
dist.all_gather_object(dushmem_unique_ids, root_unique_id, group)
root_unique_id = dushmem_unique_ids[
0 if low_latency_mode else self.runtime.get_root_rdma_rank(True)
]
......@@ -169,9 +169,9 @@ class Buffer:
# assert visible_devices[current_device].isdigit(), f"DEEP_EP_DEVICE_TO_HCA_MAPPING requires CUDA_VISIBLE_DEVICES to contain integer indices"
# current_device = int(visible_devices[current_device])
assert current_device in device_mapping, f"Current CUDA device {current_device} not found in DEEP_EP_DEVICE_TO_HCA_MAPPING"
os.environ['NVSHMEM_ENABLE_PE_MAPPING'] = '1'
os.environ['NVSHMEM_HCA_LIST'] = device_mapping[current_device]
assert current_device in device_mapping, f"Current HIP device {current_device} not found in DEEP_EP_DEVICE_TO_HCA_MAPPING"
os.environ['DUSHMEM_ENABLE_PE_MAPPING'] = '1'
os.environ['DUSHMEM_HCA_LIST'] = device_mapping[current_device]
def destroy(self):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment