Commit d1bf10d3 authored by lishen's avatar lishen
Browse files

基于rocm的DeepEP,低延迟优化

parent ee3551ab
...@@ -36,7 +36,7 @@ __device__ void grid_barrier(int* global_counter, int num_blocks) { ...@@ -36,7 +36,7 @@ __device__ void grid_barrier(int* global_counter, int num_blocks) {
} }
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
while (__hip_atomic_load(global_counter, __ATOMIC_RELAXED,__HIP_MEMORY_SCOPE_AGENT) != num_blocks); while (__hip_atomic_load(global_counter, __ATOMIC_RELAXED,__HIP_MEMORY_SCOPE_AGENT) != num_blocks);
} }
__syncthreads(); __syncthreads();
} }
...@@ -69,7 +69,7 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, ...@@ -69,7 +69,7 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
for (int i = thread_id; i < num_clean_int_1; i += kNumThreads) for (int i = thread_id; i < num_clean_int_1; i += kNumThreads)
clean_1[i] = 0; clean_1[i] = 0;
// Barrier after cleaning (make sure low-latency mode work // Barrier after cleaning (make sure low-latency mode work
if (threadIdx.x == 0) if (threadIdx.x == 0)
internode::shmem_device_barrier_all(); internode::shmem_device_barrier_all();
} }
...@@ -96,13 +96,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -96,13 +96,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int64_t* next_clean, int num_next_clean_int, int64_t* next_clean, int num_next_clean_int,
int num_tokens, int num_max_dispatch_tokens_per_rank, int num_tokens, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
int num_warp_groups, int num_warps_per_group, int num_warp_groups, int num_warps_per_group,
bool round_scale, int phases) { bool round_scale, int phases) {
#if !defined(ROCM_DISABLE_CTX)
__shared__ internode::shmem_ctx_t ctx;
internode::shmem_wg_ctx_create(&ctx);
#endif
const auto sm_id = static_cast<int>(blockIdx.x); const auto sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x); const auto thread_id = static_cast<int>(threadIdx.x);
const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id(); const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
...@@ -131,17 +126,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -131,17 +126,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);
// 16 is the max possible number of warps in AMD GPUs
constexpr int kMaxNumWarps = 1024 / kWarpSize;
constexpr int num_sync_large_iteration = kMaxNumWarps ;
__shared__ volatile int sync_large_warp_counters[num_sync_large_iteration];
#pragma unroll
for (int i = thread_id; i < num_sync_large_iteration; i += blockDim.x) {
sync_large_warp_counters[i] = 0;
}
__syncthreads();
// Expert counts // Expert counts
constexpr int kNumMaxWarpGroups = 1024 / kWarpSize; constexpr int kNumMaxWarpGroups = 1024 / kWarpSize;
__shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups]; __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups];
...@@ -150,6 +134,11 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -150,6 +134,11 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
if ((phases & LOW_LATENCY_SEND_PHASE) == 0) if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
goto LOW_LATENCY_DISPATCH_RECV; goto LOW_LATENCY_DISPATCH_RECV;
#if !defined(ROCM_DISABLE_CTX)
__shared__ internode::shmem_ctx_t ctx;
internode::shmem_wg_ctx_create(&ctx);
#endif
// There are 2 kinds of warps in this part: // There are 2 kinds of warps in this part:
// 1. The first-kind warps for FP8 cast and sending top-k tokens // 1. The first-kind warps for FP8 cast and sending top-k tokens
// 2. The last warp for reading `topk_idx` and count for per-expert information // 2. The last warp for reading `topk_idx` and count for per-expert information
...@@ -220,9 +209,18 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -220,9 +209,18 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
slot_idx * num_bytes_per_msg; slot_idx * num_bytes_per_msg;
if (dst_rank != rank) { if (dst_rank != rank) {
internode::shmemx_int8_put_nbi_warp(reinterpret_cast<signed char*>(dst_ptr), #if !defined(ROCM_DISABLE_CTX)
reinterpret_cast<signed char*>(src_ptr), num_bytes_per_msg, dst_rank); internode::shmem_ctx_schar_put_nbi_warp(ctx,
internode::shmem_fence(); #else
internode::shmemx_int8_put_nbi_warp(
#endif
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
num_bytes_per_msg, dst_rank);
// #if !defined(ROCM_DISABLE_CTX)
// internode::shmem_ctx_quiet(ctx);
// #else
// internode::shmem_fence();
// #endif
} else { } else {
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls // NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr); const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
...@@ -274,8 +272,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -274,8 +272,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
} }
} }
} }
//revert sync_large_warp_counters to 0 for next sync
__syncthreads(); __syncthreads();
// Issue count sends // Issue count sends
...@@ -287,7 +283,12 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -287,7 +283,12 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Wait local sends issued and send expert counts // Wait local sends issued and send expert counts
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2); while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
if (dst_rank != rank) { if (dst_rank != rank) {
internode::shmem_long_atomic_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank); #if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_long_atomic_add(ctx,
#else
internode::shmem_long_atomic_add(
#endif
rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank);
} else { } else {
st_na_release(reinterpret_cast<int *>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1); st_na_release(reinterpret_cast<int *>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1);
} }
...@@ -302,6 +303,10 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -302,6 +303,10 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
} }
syncwarp(); syncwarp();
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_wg_ctx_destroy(&ctx);
#endif
// Receiving phase // Receiving phase
LOW_LATENCY_DISPATCH_RECV: LOW_LATENCY_DISPATCH_RECV:
if ((phases & LOW_LATENCY_RECV_PHASE) == 0) if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
...@@ -312,20 +317,31 @@ LOW_LATENCY_DISPATCH_RECV: ...@@ -312,20 +317,31 @@ LOW_LATENCY_DISPATCH_RECV:
grid_barrier(global_atomic_counter, num_sms); grid_barrier(global_atomic_counter, num_sms);
} }
// 16 is the max possible number of warps in AMD GPUs
constexpr int kMaxNumWarps = 1024 / kWarpSize;
constexpr int num_sync_large_iteration = kMaxNumWarps ;
__shared__ volatile int sync_large_warp_counters[num_sync_large_iteration];
#pragma unroll
for (int i = thread_id; i < num_sync_large_iteration; i += blockDim.x) {
sync_large_warp_counters[i] = 0;
}
__syncthreads();
// Receiving and packing // Receiving and packing
if (responsible_expert_idx < num_experts) { if (responsible_expert_idx < num_experts) {
const auto src_rank = responsible_expert_idx / num_local_experts; const auto src_rank = responsible_expert_idx / num_local_experts;
const auto local_expert_idx = responsible_expert_idx % num_local_experts; const auto local_expert_idx = responsible_expert_idx % num_local_experts;
const auto rdma_recv_x_uint8 = reinterpret_cast<uint8_t*>(rdma_recv_x) + const auto rdma_recv_x_uint8 = reinterpret_cast<uint8_t*>(rdma_recv_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg; src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) + const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4; local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks; const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
const auto num_aligned_scales = ALIGN<int>(num_scales, sizeof(float) / sizeof(scale_t)); const auto num_aligned_scales = ALIGN<int>(num_scales, sizeof(float) / sizeof(scale_t));
const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) + const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales; local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales;
// Shared between sub-warps in warp groups // Shared between sub-warps in warp groups
__shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups]; __shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups];
...@@ -393,10 +409,6 @@ LOW_LATENCY_DISPATCH_RECV: ...@@ -393,10 +409,6 @@ LOW_LATENCY_DISPATCH_RECV:
} }
} }
} }
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_wg_ctx_destroy(&ctx);
#endif
} }
void dispatch(void* packed_recv_x, void* packed_recv_x_scales, void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...@@ -407,9 +419,9 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -407,9 +419,9 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
const void* x, const int64_t* topk_idx, const void* x, const int64_t* topk_idx,
int64_t* next_clean, int num_next_clean_int, int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_fp8, bool round_scale, bool use_ue8m0,
void* workspace, int num_device_sms, void* workspace, int num_device_sms,
hipStream_t stream, int phases) { hipStream_t stream, int phases) {
constexpr int kNumMaxTopK = 11; constexpr int kNumMaxTopK = 11;
const int num_warp_groups = ceil_div(num_experts, num_device_sms); const int num_warp_groups = ceil_div(num_experts, num_device_sms);
...@@ -464,11 +476,6 @@ combine(void* combined_x, ...@@ -464,11 +476,6 @@ combine(void* combined_x,
int num_experts, int rank, int num_ranks, int num_experts, int rank, int num_ranks,
int num_warp_groups, int num_warps_per_group, int num_warp_groups, int num_warps_per_group,
int phases, bool zero_copy) { int phases, bool zero_copy) {
#if !defined(ROCM_DISABLE_CTX)
__shared__ internode::shmem_ctx_t ctx;
internode::shmem_wg_ctx_create(&ctx);
#endif
const auto sm_id = static_cast<int>(blockIdx.x); const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_sms = static_cast<int>(gridDim.x); const auto num_sms = static_cast<int>(gridDim.x);
const auto thread_id = static_cast<int>(threadIdx.x); const auto thread_id = static_cast<int>(threadIdx.x);
...@@ -488,7 +495,7 @@ combine(void* combined_x, ...@@ -488,7 +495,7 @@ combine(void* combined_x,
constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(hip_bfloat16); constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(hip_bfloat16);
EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization"); EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
// 16 is the max possible number of warps in AMD GPUs // 16 is the max possible number of warps in AMD GPUs
constexpr int kMaxNumWarps = 1024 / kWarpSize; constexpr int kMaxNumWarps = 1024 / kWarpSize;
__shared__ volatile int sync_large_warp_counters[kMaxNumWarps]; __shared__ volatile int sync_large_warp_counters[kMaxNumWarps];
if (threadIdx.x==0){ if (threadIdx.x==0){
...@@ -503,6 +510,11 @@ combine(void* combined_x, ...@@ -503,6 +510,11 @@ combine(void* combined_x,
if ((phases & LOW_LATENCY_SEND_PHASE) == 0) if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
goto LOW_LATENCY_COMBINE_RECV; goto LOW_LATENCY_COMBINE_RECV;
#if !defined(ROCM_DISABLE_CTX)
__shared__ internode::shmem_ctx_t ctx;
internode::shmem_wg_ctx_create(&ctx);
#endif
// Clean up next buffer // Clean up next buffer
if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) { if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
#pragma unroll #pragma unroll
...@@ -522,10 +534,10 @@ combine(void* combined_x, ...@@ -522,10 +534,10 @@ combine(void* combined_x,
const auto global_expert_idx = rank * num_local_experts + local_expert_idx; const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank); const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
const auto local_x = reinterpret_cast<const int4*>(x) + const auto local_x = reinterpret_cast<const int4*>(x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4; local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) + const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot; local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
// Unpack layout // Unpack layout
int offset, num_tokens_to_send; int offset, num_tokens_to_send;
...@@ -548,21 +560,16 @@ combine(void* combined_x, ...@@ -548,21 +560,16 @@ combine(void* combined_x,
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr); const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
if (not zero_copy) if (not zero_copy)
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global); UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
//nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(hip_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
#if defined(ROCM_DISABLE_CTX)
internode::shmemx_int8_put_nbi_warp(
#else
internode::shmem_ctx_schar_put_nbi_warp(ctx,
#endif
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr), hidden * sizeof(hip_bfloat16), dst_rank);
#if defined(ROCM_DISABLE_CTX) //nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(hip_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
internode::shmem_fence(); #if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_schar_put_nbi_warp(ctx,
#else #else
internode::shmem_ctx_quiet(ctx); internode::shmemx_int8_put_nbi_warp(
#endif #endif
} reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr),
hidden * sizeof(hip_bfloat16), dst_rank);
}
} }
// Put finishing flag // Put finishing flag
...@@ -573,27 +580,49 @@ combine(void* combined_x, ...@@ -573,27 +580,49 @@ combine(void* combined_x,
} }
syncwarp(); syncwarp();
while (sync_large_warp_counters[warp_group_id] < num_warps_per_group); while (sync_large_warp_counters[warp_group_id] < num_warps_per_group);
if (sub_warp_id == 1 and lane_id == 0) { if (sub_warp_id == 1 and lane_id == 0) {
while (ld_acquire_global(atomic_clean_flag) == 0); while (ld_acquire_global(atomic_clean_flag) == 0);
if (dst_rank != rank) { if (dst_rank != rank) {
#if defined(ROCM_DISABLE_CTX) #if !defined(ROCM_DISABLE_CTX)
internode::shmem_long_atomic_add(rdma_recv_flag + global_expert_idx, 1, dst_rank); internode::shmem_ctx_long_atomic_add(ctx,
#else #else
internode::shmem_ctx_long_atomic_add(ctx, rdma_recv_flag + global_expert_idx, 1, dst_rank); internode::shmem_long_atomic_add(
#endif #endif
rdma_recv_flag + global_expert_idx, 1, dst_rank);
} else { } else {
st_na_release(reinterpret_cast<int*>(rdma_recv_flag + global_expert_idx), 1); st_na_release(reinterpret_cast<int*>(rdma_recv_flag + global_expert_idx), 1);
} }
atomic_add_release_global(atomic_clean_flag, -1); atomic_add_release_global(atomic_clean_flag, -1);
} }
syncwarp(); syncwarp();
if (num_ranks > 8){
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_quiet(ctx);
#else
internode::shmem_fence();
#endif
}
} }
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_wg_ctx_destroy(&ctx);
#endif
// Receiving phase // Receiving phase
LOW_LATENCY_COMBINE_RECV: LOW_LATENCY_COMBINE_RECV:
if ((phases & LOW_LATENCY_RECV_PHASE) == 0) if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
return; return;
// if (num_ranks > 8){
// #if !defined(ROCM_DISABLE_CTX)
// internode::shmem_ctx_quiet(ctx);
// #else
// internode::shmem_fence();
// #endif
// }
// Wait all ranks to arrive and notify PCIe usage // Wait all ranks to arrive and notify PCIe usage
if (responsible_expert_idx < num_experts) { if (responsible_expert_idx < num_experts) {
EP_DEVICE_ASSERT(num_warps_per_group > 1); EP_DEVICE_ASSERT(num_warps_per_group > 1);
...@@ -641,9 +670,6 @@ combine(void* combined_x, ...@@ -641,9 +670,6 @@ combine(void* combined_x,
(reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4; (reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
} }
} }
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_wg_ctx_destroy(&ctx);
#endif
} }
void combine(void* combined_x, void combine(void* combined_x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment