Commit d1bf10d3 authored by lishen's avatar lishen
Browse files

基于rocm的DeepEP,低延迟优化

parent ee3551ab
......@@ -36,7 +36,7 @@ __device__ void grid_barrier(int* global_counter, int num_blocks) {
}
__syncthreads();
if (threadIdx.x == 0) {
while (__hip_atomic_load(global_counter, __ATOMIC_RELAXED,__HIP_MEMORY_SCOPE_AGENT) != num_blocks);
while (__hip_atomic_load(global_counter, __ATOMIC_RELAXED,__HIP_MEMORY_SCOPE_AGENT) != num_blocks);
}
__syncthreads();
}
......@@ -69,7 +69,7 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
for (int i = thread_id; i < num_clean_int_1; i += kNumThreads)
clean_1[i] = 0;
// Barrier after cleaning (make sure low-latency mode work
// Barrier after cleaning (make sure low-latency mode work
if (threadIdx.x == 0)
internode::shmem_device_barrier_all();
}
......@@ -96,13 +96,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int64_t* next_clean, int num_next_clean_int,
int num_tokens, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
int num_warp_groups, int num_warps_per_group,
int num_warp_groups, int num_warps_per_group,
bool round_scale, int phases) {
#if !defined(ROCM_DISABLE_CTX)
__shared__ internode::shmem_ctx_t ctx;
internode::shmem_wg_ctx_create(&ctx);
#endif
const auto sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
......@@ -131,17 +126,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);
// 16 is the max possible number of warps in AMD GPUs
constexpr int kMaxNumWarps = 1024 / kWarpSize;
constexpr int num_sync_large_iteration = kMaxNumWarps ;
__shared__ volatile int sync_large_warp_counters[num_sync_large_iteration];
#pragma unroll
for (int i = thread_id; i < num_sync_large_iteration; i += blockDim.x) {
sync_large_warp_counters[i] = 0;
}
__syncthreads();
// Expert counts
constexpr int kNumMaxWarpGroups = 1024 / kWarpSize;
__shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups];
......@@ -150,6 +134,11 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
goto LOW_LATENCY_DISPATCH_RECV;
#if !defined(ROCM_DISABLE_CTX)
__shared__ internode::shmem_ctx_t ctx;
internode::shmem_wg_ctx_create(&ctx);
#endif
// There are 2 kinds of warps in this part:
// 1. The first-kind warps for FP8 cast and sending top-k tokens
// 2. The last warp for reading `topk_idx` and count for per-expert information
......@@ -220,9 +209,18 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
slot_idx * num_bytes_per_msg;
if (dst_rank != rank) {
internode::shmemx_int8_put_nbi_warp(reinterpret_cast<signed char*>(dst_ptr),
reinterpret_cast<signed char*>(src_ptr), num_bytes_per_msg, dst_rank);
internode::shmem_fence();
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_schar_put_nbi_warp(ctx,
#else
internode::shmemx_int8_put_nbi_warp(
#endif
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
num_bytes_per_msg, dst_rank);
// #if !defined(ROCM_DISABLE_CTX)
// internode::shmem_ctx_quiet(ctx);
// #else
// internode::shmem_fence();
// #endif
} else {
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
......@@ -274,8 +272,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
}
}
}
//revert sync_large_warp_counters to 0 for next sync
__syncthreads();
// Issue count sends
......@@ -287,7 +283,12 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Wait local sends issued and send expert counts
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
if (dst_rank != rank) {
internode::shmem_long_atomic_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank);
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_long_atomic_add(ctx,
#else
internode::shmem_long_atomic_add(
#endif
rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank);
} else {
st_na_release(reinterpret_cast<int *>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1);
}
......@@ -302,6 +303,10 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
}
syncwarp();
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_wg_ctx_destroy(&ctx);
#endif
// Receiving phase
LOW_LATENCY_DISPATCH_RECV:
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
......@@ -312,20 +317,31 @@ LOW_LATENCY_DISPATCH_RECV:
grid_barrier(global_atomic_counter, num_sms);
}
// 16 is the max possible number of warps in AMD GPUs
constexpr int kMaxNumWarps = 1024 / kWarpSize;
constexpr int num_sync_large_iteration = kMaxNumWarps ;
__shared__ volatile int sync_large_warp_counters[num_sync_large_iteration];
#pragma unroll
for (int i = thread_id; i < num_sync_large_iteration; i += blockDim.x) {
sync_large_warp_counters[i] = 0;
}
__syncthreads();
// Receiving and packing
if (responsible_expert_idx < num_experts) {
const auto src_rank = responsible_expert_idx / num_local_experts;
const auto local_expert_idx = responsible_expert_idx % num_local_experts;
const auto rdma_recv_x_uint8 = reinterpret_cast<uint8_t*>(rdma_recv_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
const auto num_aligned_scales = ALIGN<int>(num_scales, sizeof(float) / sizeof(scale_t));
const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales;
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales;
// Shared between sub-warps in warp groups
__shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups];
......@@ -393,10 +409,6 @@ LOW_LATENCY_DISPATCH_RECV:
}
}
}
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_wg_ctx_destroy(&ctx);
#endif
}
void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
......@@ -407,9 +419,9 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
const void* x, const int64_t* topk_idx,
int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0,
void* workspace, int num_device_sms,
void* workspace, int num_device_sms,
hipStream_t stream, int phases) {
constexpr int kNumMaxTopK = 11;
const int num_warp_groups = ceil_div(num_experts, num_device_sms);
......@@ -464,11 +476,6 @@ combine(void* combined_x,
int num_experts, int rank, int num_ranks,
int num_warp_groups, int num_warps_per_group,
int phases, bool zero_copy) {
#if !defined(ROCM_DISABLE_CTX)
__shared__ internode::shmem_ctx_t ctx;
internode::shmem_wg_ctx_create(&ctx);
#endif
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_sms = static_cast<int>(gridDim.x);
const auto thread_id = static_cast<int>(threadIdx.x);
......@@ -488,7 +495,7 @@ combine(void* combined_x,
constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(hip_bfloat16);
EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
// 16 is the max possible number of warps in AMD GPUs
// 16 is the max possible number of warps in AMD GPUs
constexpr int kMaxNumWarps = 1024 / kWarpSize;
__shared__ volatile int sync_large_warp_counters[kMaxNumWarps];
if (threadIdx.x==0){
......@@ -503,6 +510,11 @@ combine(void* combined_x,
if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
goto LOW_LATENCY_COMBINE_RECV;
#if !defined(ROCM_DISABLE_CTX)
__shared__ internode::shmem_ctx_t ctx;
internode::shmem_wg_ctx_create(&ctx);
#endif
// Clean up next buffer
if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
#pragma unroll
......@@ -522,10 +534,10 @@ combine(void* combined_x,
const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
const auto local_x = reinterpret_cast<const int4*>(x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
// Unpack layout
int offset, num_tokens_to_send;
......@@ -548,21 +560,16 @@ combine(void* combined_x,
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
if (not zero_copy)
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
//nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(hip_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
#if defined(ROCM_DISABLE_CTX)
internode::shmemx_int8_put_nbi_warp(
#else
internode::shmem_ctx_schar_put_nbi_warp(ctx,
#endif
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr), hidden * sizeof(hip_bfloat16), dst_rank);
#if defined(ROCM_DISABLE_CTX)
internode::shmem_fence();
//nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(hip_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_schar_put_nbi_warp(ctx,
#else
internode::shmem_ctx_quiet(ctx);
internode::shmemx_int8_put_nbi_warp(
#endif
}
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr),
hidden * sizeof(hip_bfloat16), dst_rank);
}
}
// Put finishing flag
......@@ -573,27 +580,49 @@ combine(void* combined_x,
}
syncwarp();
while (sync_large_warp_counters[warp_group_id] < num_warps_per_group);
if (sub_warp_id == 1 and lane_id == 0) {
while (ld_acquire_global(atomic_clean_flag) == 0);
if (dst_rank != rank) {
#if defined(ROCM_DISABLE_CTX)
internode::shmem_long_atomic_add(rdma_recv_flag + global_expert_idx, 1, dst_rank);
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_long_atomic_add(ctx,
#else
internode::shmem_ctx_long_atomic_add(ctx, rdma_recv_flag + global_expert_idx, 1, dst_rank);
internode::shmem_long_atomic_add(
#endif
rdma_recv_flag + global_expert_idx, 1, dst_rank);
} else {
st_na_release(reinterpret_cast<int*>(rdma_recv_flag + global_expert_idx), 1);
}
atomic_add_release_global(atomic_clean_flag, -1);
}
syncwarp();
if (num_ranks > 8){
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_quiet(ctx);
#else
internode::shmem_fence();
#endif
}
}
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_wg_ctx_destroy(&ctx);
#endif
// Receiving phase
LOW_LATENCY_COMBINE_RECV:
LOW_LATENCY_COMBINE_RECV:
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
return;
// if (num_ranks > 8){
// #if !defined(ROCM_DISABLE_CTX)
// internode::shmem_ctx_quiet(ctx);
// #else
// internode::shmem_fence();
// #endif
// }
// Wait all ranks to arrive and notify PCIe usage
if (responsible_expert_idx < num_experts) {
EP_DEVICE_ASSERT(num_warps_per_group > 1);
......@@ -641,9 +670,6 @@ combine(void* combined_x,
(reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
}
}
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_wg_ctx_destroy(&ctx);
#endif
}
void combine(void* combined_x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment