Commit d1bf10d3 authored by lishen's avatar lishen
Browse files

基于rocm的DeepEP,低延迟优化

parent ee3551ab
...@@ -98,11 +98,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -98,11 +98,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
int num_warp_groups, int num_warps_per_group, int num_warp_groups, int num_warps_per_group,
bool round_scale, int phases) { bool round_scale, int phases) {
#if !defined(ROCM_DISABLE_CTX)
__shared__ internode::shmem_ctx_t ctx;
internode::shmem_wg_ctx_create(&ctx);
#endif
const auto sm_id = static_cast<int>(blockIdx.x); const auto sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x); const auto thread_id = static_cast<int>(threadIdx.x);
const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id(); const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
...@@ -131,17 +126,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -131,17 +126,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);
// 16 is the max possible number of warps in AMD GPUs
constexpr int kMaxNumWarps = 1024 / kWarpSize;
constexpr int num_sync_large_iteration = kMaxNumWarps ;
__shared__ volatile int sync_large_warp_counters[num_sync_large_iteration];
#pragma unroll
for (int i = thread_id; i < num_sync_large_iteration; i += blockDim.x) {
sync_large_warp_counters[i] = 0;
}
__syncthreads();
// Expert counts // Expert counts
constexpr int kNumMaxWarpGroups = 1024 / kWarpSize; constexpr int kNumMaxWarpGroups = 1024 / kWarpSize;
__shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups]; __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups];
...@@ -150,6 +134,11 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -150,6 +134,11 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
if ((phases & LOW_LATENCY_SEND_PHASE) == 0) if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
goto LOW_LATENCY_DISPATCH_RECV; goto LOW_LATENCY_DISPATCH_RECV;
#if !defined(ROCM_DISABLE_CTX)
__shared__ internode::shmem_ctx_t ctx;
internode::shmem_wg_ctx_create(&ctx);
#endif
// There are 2 kinds of warps in this part: // There are 2 kinds of warps in this part:
// 1. The first-kind warps for FP8 cast and sending top-k tokens // 1. The first-kind warps for FP8 cast and sending top-k tokens
// 2. The last warp for reading `topk_idx` and count for per-expert information // 2. The last warp for reading `topk_idx` and count for per-expert information
...@@ -220,9 +209,18 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -220,9 +209,18 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
slot_idx * num_bytes_per_msg; slot_idx * num_bytes_per_msg;
if (dst_rank != rank) { if (dst_rank != rank) {
internode::shmemx_int8_put_nbi_warp(reinterpret_cast<signed char*>(dst_ptr), #if !defined(ROCM_DISABLE_CTX)
reinterpret_cast<signed char*>(src_ptr), num_bytes_per_msg, dst_rank); internode::shmem_ctx_schar_put_nbi_warp(ctx,
internode::shmem_fence(); #else
internode::shmemx_int8_put_nbi_warp(
#endif
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
num_bytes_per_msg, dst_rank);
// #if !defined(ROCM_DISABLE_CTX)
// internode::shmem_ctx_quiet(ctx);
// #else
// internode::shmem_fence();
// #endif
} else { } else {
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls // NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr); const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
...@@ -274,8 +272,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -274,8 +272,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
} }
} }
} }
//revert sync_large_warp_counters to 0 for next sync
__syncthreads(); __syncthreads();
// Issue count sends // Issue count sends
...@@ -287,7 +283,12 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -287,7 +283,12 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Wait local sends issued and send expert counts // Wait local sends issued and send expert counts
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2); while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
if (dst_rank != rank) { if (dst_rank != rank) {
internode::shmem_long_atomic_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank); #if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_long_atomic_add(ctx,
#else
internode::shmem_long_atomic_add(
#endif
rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank);
} else { } else {
st_na_release(reinterpret_cast<int *>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1); st_na_release(reinterpret_cast<int *>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1);
} }
...@@ -302,6 +303,10 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -302,6 +303,10 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
} }
syncwarp(); syncwarp();
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_wg_ctx_destroy(&ctx);
#endif
// Receiving phase // Receiving phase
LOW_LATENCY_DISPATCH_RECV: LOW_LATENCY_DISPATCH_RECV:
if ((phases & LOW_LATENCY_RECV_PHASE) == 0) if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
...@@ -312,6 +317,17 @@ LOW_LATENCY_DISPATCH_RECV: ...@@ -312,6 +317,17 @@ LOW_LATENCY_DISPATCH_RECV:
grid_barrier(global_atomic_counter, num_sms); grid_barrier(global_atomic_counter, num_sms);
} }
// 16 is the max possible number of warps in AMD GPUs
constexpr int kMaxNumWarps = 1024 / kWarpSize;
constexpr int num_sync_large_iteration = kMaxNumWarps ;
__shared__ volatile int sync_large_warp_counters[num_sync_large_iteration];
#pragma unroll
for (int i = thread_id; i < num_sync_large_iteration; i += blockDim.x) {
sync_large_warp_counters[i] = 0;
}
__syncthreads();
// Receiving and packing // Receiving and packing
if (responsible_expert_idx < num_experts) { if (responsible_expert_idx < num_experts) {
const auto src_rank = responsible_expert_idx / num_local_experts; const auto src_rank = responsible_expert_idx / num_local_experts;
...@@ -393,10 +409,6 @@ LOW_LATENCY_DISPATCH_RECV: ...@@ -393,10 +409,6 @@ LOW_LATENCY_DISPATCH_RECV:
} }
} }
} }
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_wg_ctx_destroy(&ctx);
#endif
} }
void dispatch(void* packed_recv_x, void* packed_recv_x_scales, void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...@@ -464,11 +476,6 @@ combine(void* combined_x, ...@@ -464,11 +476,6 @@ combine(void* combined_x,
int num_experts, int rank, int num_ranks, int num_experts, int rank, int num_ranks,
int num_warp_groups, int num_warps_per_group, int num_warp_groups, int num_warps_per_group,
int phases, bool zero_copy) { int phases, bool zero_copy) {
#if !defined(ROCM_DISABLE_CTX)
__shared__ internode::shmem_ctx_t ctx;
internode::shmem_wg_ctx_create(&ctx);
#endif
const auto sm_id = static_cast<int>(blockIdx.x); const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_sms = static_cast<int>(gridDim.x); const auto num_sms = static_cast<int>(gridDim.x);
const auto thread_id = static_cast<int>(threadIdx.x); const auto thread_id = static_cast<int>(threadIdx.x);
...@@ -503,6 +510,11 @@ combine(void* combined_x, ...@@ -503,6 +510,11 @@ combine(void* combined_x,
if ((phases & LOW_LATENCY_SEND_PHASE) == 0) if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
goto LOW_LATENCY_COMBINE_RECV; goto LOW_LATENCY_COMBINE_RECV;
#if !defined(ROCM_DISABLE_CTX)
__shared__ internode::shmem_ctx_t ctx;
internode::shmem_wg_ctx_create(&ctx);
#endif
// Clean up next buffer // Clean up next buffer
if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) { if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
#pragma unroll #pragma unroll
...@@ -550,18 +562,13 @@ combine(void* combined_x, ...@@ -550,18 +562,13 @@ combine(void* combined_x,
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global); UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
//nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(hip_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset); //nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(hip_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
#if defined(ROCM_DISABLE_CTX) #if !defined(ROCM_DISABLE_CTX)
internode::shmemx_int8_put_nbi_warp(
#else
internode::shmem_ctx_schar_put_nbi_warp(ctx, internode::shmem_ctx_schar_put_nbi_warp(ctx,
#endif
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr), hidden * sizeof(hip_bfloat16), dst_rank);
#if defined(ROCM_DISABLE_CTX)
internode::shmem_fence();
#else #else
internode::shmem_ctx_quiet(ctx); internode::shmemx_int8_put_nbi_warp(
#endif #endif
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr),
hidden * sizeof(hip_bfloat16), dst_rank);
} }
} }
...@@ -573,27 +580,49 @@ combine(void* combined_x, ...@@ -573,27 +580,49 @@ combine(void* combined_x,
} }
syncwarp(); syncwarp();
while (sync_large_warp_counters[warp_group_id] < num_warps_per_group); while (sync_large_warp_counters[warp_group_id] < num_warps_per_group);
if (sub_warp_id == 1 and lane_id == 0) { if (sub_warp_id == 1 and lane_id == 0) {
while (ld_acquire_global(atomic_clean_flag) == 0); while (ld_acquire_global(atomic_clean_flag) == 0);
if (dst_rank != rank) { if (dst_rank != rank) {
#if defined(ROCM_DISABLE_CTX) #if !defined(ROCM_DISABLE_CTX)
internode::shmem_long_atomic_add(rdma_recv_flag + global_expert_idx, 1, dst_rank); internode::shmem_ctx_long_atomic_add(ctx,
#else #else
internode::shmem_ctx_long_atomic_add(ctx, rdma_recv_flag + global_expert_idx, 1, dst_rank); internode::shmem_long_atomic_add(
#endif #endif
rdma_recv_flag + global_expert_idx, 1, dst_rank);
} else { } else {
st_na_release(reinterpret_cast<int*>(rdma_recv_flag + global_expert_idx), 1); st_na_release(reinterpret_cast<int*>(rdma_recv_flag + global_expert_idx), 1);
} }
atomic_add_release_global(atomic_clean_flag, -1); atomic_add_release_global(atomic_clean_flag, -1);
} }
syncwarp(); syncwarp();
if (num_ranks > 8){
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_quiet(ctx);
#else
internode::shmem_fence();
#endif
}
} }
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_wg_ctx_destroy(&ctx);
#endif
// Receiving phase // Receiving phase
LOW_LATENCY_COMBINE_RECV: LOW_LATENCY_COMBINE_RECV:
if ((phases & LOW_LATENCY_RECV_PHASE) == 0) if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
return; return;
// if (num_ranks > 8){
// #if !defined(ROCM_DISABLE_CTX)
// internode::shmem_ctx_quiet(ctx);
// #else
// internode::shmem_fence();
// #endif
// }
// Wait all ranks to arrive and notify PCIe usage // Wait all ranks to arrive and notify PCIe usage
if (responsible_expert_idx < num_experts) { if (responsible_expert_idx < num_experts) {
EP_DEVICE_ASSERT(num_warps_per_group > 1); EP_DEVICE_ASSERT(num_warps_per_group > 1);
...@@ -641,9 +670,6 @@ combine(void* combined_x, ...@@ -641,9 +670,6 @@ combine(void* combined_x,
(reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4; (reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
} }
} }
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_wg_ctx_destroy(&ctx);
#endif
} }
void combine(void* combined_x, void combine(void* combined_x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment