Commit ce671dd4 authored by lishen's avatar lishen
Browse files

低延迟接口支持int8类型通信

parent da13c63a
...@@ -14,4 +14,5 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ...@@ -14,4 +14,5 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240 # export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240
export ROCSHMEM_HEAP_SIZE=2880100992 export ROCSHMEM_HEAP_SIZE=2880100992
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency_new.py
...@@ -14,4 +14,5 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ...@@ -14,4 +14,5 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240 # export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240
export ROCSHMEM_HEAP_SIZE=2880100992 export ROCSHMEM_HEAP_SIZE=2880100992
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency_new.py
...@@ -1293,7 +1293,8 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int ...@@ -1293,7 +1293,8 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool async, bool return_recv_hook) { bool use_fp8, bool round_scale, bool use_ue8m0,
bool async, bool return_recv_hook) {
EP_HOST_ASSERT(low_latency_mode); EP_HOST_ASSERT(low_latency_mode);
// Tensor checks // Tensor checks
...@@ -1306,8 +1307,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1306,8 +1307,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
EP_HOST_ASSERT(num_experts % num_ranks == 0); EP_HOST_ASSERT(num_experts % num_ranks == 0);
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1)); auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
auto num_scales = hidden / 128, num_topk = static_cast<int>(topk_idx.size(1)); auto num_topk = static_cast<int>(topk_idx.size(1));
int num_local_experts = num_experts / num_ranks; auto num_local_experts = num_experts / num_ranks;
// Buffer control // Buffer control
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
...@@ -1339,12 +1340,21 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1339,12 +1340,21 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
// Allocate column-majored scales // Allocate column-majored scales
auto packed_recv_x_scales = std::optional<torch::Tensor>(); auto packed_recv_x_scales = std::optional<torch::Tensor>();
float* packed_recv_x_scales_ptr = nullptr; void* packed_recv_x_scales_ptr = nullptr;
if (use_fp8) { if (use_fp8) {
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4"); EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); // TODO: support unaligned cases
EP_HOST_ASSERT(hidden % (FP8_QUANTIZATION_NUM_PER_CHANNEL * 4) == 0);
if (not use_ue8m0) {
packed_recv_x_scales = torch::empty({num_local_experts, hidden / FP8_QUANTIZATION_NUM_PER_CHANNEL, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kFloat32).device(torch::kCUDA));
} else {
EP_HOST_ASSERT(round_scale);
packed_recv_x_scales = torch::empty({num_local_experts, hidden / (FP8_QUANTIZATION_NUM_PER_CHANNEL * 4), num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kInt).device(torch::kCUDA));
}
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr<float>(); packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
} }
// Kernel launch // Kernel launch
...@@ -1359,8 +1369,9 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1359,8 +1369,9 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
x.data_ptr(), topk_idx.data_ptr<int64_t>(), x.data_ptr(), topk_idx.data_ptr<int64_t>(),
next_clean_meta.first, next_clean_meta.second, next_clean_meta.first, next_clean_meta.second,
num_tokens, hidden, num_max_dispatch_tokens_per_rank, num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks, use_fp8, num_topk, num_experts, rank, num_ranks,
workspace, launch_stream, phases); use_fp8, round_scale, use_ue8m0,
workspace, num_device_sms, launch_stream, phases);
}; };
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
...@@ -1454,7 +1465,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id ...@@ -1454,7 +1465,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
next_clean_meta.first, next_clean_meta.second, next_clean_meta.first, next_clean_meta.second,
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank, num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks, num_topk, num_experts, rank, num_ranks,
workspace, launch_stream, workspace, num_device_sms, launch_stream,
phases, zero_copy); phases, zero_copy);
}; };
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
......
...@@ -177,7 +177,8 @@ public: ...@@ -177,7 +177,8 @@ public:
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool async, bool return_recv_hook); bool use_fp8, bool round_scale, bool use_ue8m0,
bool async, bool return_recv_hook);
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
......
...@@ -138,7 +138,7 @@ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, ...@@ -138,7 +138,7 @@ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
int64_t* clean_1, int num_clean_int_1, int64_t* clean_1, int num_clean_int_1,
hipStream_t stream); hipStream_t stream);
void dispatch(void* packed_recv_x, float* packed_recv_x_scales, void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count, int* packed_recv_count,
int* global_atomic_counter, int* global_atomic_counter,
...@@ -146,8 +146,9 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, ...@@ -146,8 +146,9 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const void* x, const int64_t* topk_idx, const void* x, const int64_t* topk_idx,
int64_t* next_clean, int num_next_clean_int, int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, hipStream_t stream, int phases); bool use_fp8, bool round_scale, bool use_ue8m0,
void* workspace, int num_device_sms, hipStream_t stream, int phases);
void combine(void* combined_x, void combine(void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
...@@ -157,7 +158,7 @@ void combine(void* combined_x, ...@@ -157,7 +158,7 @@ void combine(void* combined_x,
int64_t* next_clean, int num_next_clean_int, int64_t* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, hipStream_t stream, void* workspace, int num_device_sms, hipStream_t stream,
int phases, bool zero_copy); int phases, bool zero_copy);
} // namespace internode_ll } // namespace internode_ll
......
...@@ -440,9 +440,6 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv ...@@ -440,9 +440,6 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
auto warp_role = role_meta.first; auto warp_role = role_meta.first;
auto target_rank = role_meta.second; // Not applicable for RDMA senders auto target_rank = role_meta.second; // Not applicable for RDMA senders
// if(lane_id==0){
// printf("tid=%d, bid=%d, warp_role=%d\n", threadIdx.x, blockIdx.x, warp_role);
// }
// RDMA symmetric layout // RDMA symmetric layout
auto hidden_bytes = hidden_int4 * sizeof(int4); auto hidden_bytes = hidden_int4 * sizeof(int4);
...@@ -1610,8 +1607,6 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1610,8 +1607,6 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
int lds_dst_rdma_rank = dst_rdma_rank + (iter % num_sync_large_iteration) * kNumRDMARanks + mode * rdma_warp_counters; int lds_dst_rdma_rank = dst_rdma_rank + (iter % num_sync_large_iteration) * kNumRDMARanks + mode * rdma_warp_counters;
//reset index in the LDS to avoid race condition due to warp scheduling //reset index in the LDS to avoid race condition due to warp scheduling
int reset_idx = dst_rdma_rank + ((iter + num_sync_large_iteration/2) % num_sync_large_iteration) * kNumRDMARanks + mode * rdma_warp_counters; int reset_idx = dst_rdma_rank + ((iter + num_sync_large_iteration/2) % num_sync_large_iteration) * kNumRDMARanks + mode * rdma_warp_counters;
// // if (lane_id==0)
// // printf("rank %d dst_rdma_rank %d iter %d warp_id %d val %d\n", rank, dst_rdma_rank, iter, warp_id, sync_large_warp_counters[lds_dst_rdma_rank]);
auto start_time = wall_clock64(); auto start_time = wall_clock64();
if (lane_id == 0){ if (lane_id == 0){
volatile int ret = atomicAdd((int*)&sync_large_warp_counters[lds_dst_rdma_rank], 1); volatile int ret = atomicAdd((int*)&sync_large_warp_counters[lds_dst_rdma_rank], 1);
......
...@@ -85,9 +85,9 @@ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, ...@@ -85,9 +85,9 @@ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
clean_0, num_clean_int_0, clean_1, num_clean_int_1); clean_0, num_clean_int_0, clean_1, num_clean_int_1);
} }
template <bool kUseFP8, int kHidden> template <bool kUseFP8, bool kUseUE8M0, int kHidden>
__global__ __launch_bounds__(16 * kWarpSize, 1) void __global__ __launch_bounds__(16 * kWarpSize, 1) void
dispatch(void* packed_recv_x, float* packed_recv_x_scales, dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count, int* packed_recv_count,
int* global_atomic_counter, int* global_atomic_counter,
...@@ -97,7 +97,8 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, ...@@ -97,7 +97,8 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int64_t* next_clean, int num_next_clean_int, int64_t* next_clean, int num_next_clean_int,
int num_tokens, int num_max_dispatch_tokens_per_rank, int num_tokens, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
int num_warp_groups, int num_warps_per_group, int phases) { int num_warp_groups, int num_warps_per_group,
bool round_scale, int phases) {
#if !defined(ROCM_DISABLE_CTX) #if !defined(ROCM_DISABLE_CTX)
__shared__ rocshmem::rocshmem_ctx_t ctx; __shared__ rocshmem::rocshmem_ctx_t ctx;
rocshmem::rocshmem_wg_ctx_create(0, &ctx); rocshmem::rocshmem_wg_ctx_create(0, &ctx);
...@@ -113,6 +114,11 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, ...@@ -113,6 +114,11 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const auto sub_warp_id = warp_id % num_warps_per_group; const auto sub_warp_id = warp_id % num_warps_per_group;
const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id; const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
// May extract UE8M0 from the scales
using scale_t = std::conditional_t<kUseUE8M0, uint8_t, float>;
using packed_t = std::conditional_t<kUseUE8M0, uint32_t, float>;
EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length");
// FP8 staffs // FP8 staffs
constexpr int kNumPerChannels = FP8_QUANTIZATION_NUM_PER_CHANNEL; constexpr int kNumPerChannels = FP8_QUANTIZATION_NUM_PER_CHANNEL;
const int num_scales = kHidden / kNumPerChannels; const int num_scales = kHidden / kNumPerChannels;
...@@ -184,9 +190,9 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, ...@@ -184,9 +190,9 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
// Reduce amax and scale // Reduce amax and scale
EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize / kNumPerChannels == 4, "Invalid vectorization"); EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize / kNumPerChannels == 4, "Invalid vectorization");
amax = warp_reduce_max<16>(amax); amax = warp_reduce_max<16>(amax);
calculate_fp8_scales</*round_scale*/false>(amax, scale, scale_inv); calculate_fp8_scales(amax, scale, scale_inv, round_scale);
if (lane_id % 16 == 0) if (lane_id % 16 == 0)
rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv; rdma_x_scales[i * kNumElemsPerRead / FP8_QUANTIZATION_NUM_PER_CHANNEL] = scale_inv;
// Cast into send buffer // Cast into send buffer
vec_t int2_value; vec_t int2_value;
...@@ -316,9 +322,11 @@ LOW_LATENCY_DISPATCH_RECV: ...@@ -316,9 +322,11 @@ LOW_LATENCY_DISPATCH_RECV:
src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg; src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) + const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4; local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
const auto recv_x_scales = packed_recv_x_scales + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_scales;
const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks; const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
const auto num_aligned_scales = ALIGN<int>(num_scales, sizeof(float) / sizeof(scale_t));
const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales;
// Shared between sub-warps in warp groups // Shared between sub-warps in warp groups
__shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups]; __shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups];
...@@ -366,12 +374,23 @@ LOW_LATENCY_DISPATCH_RECV: ...@@ -366,12 +374,23 @@ LOW_LATENCY_DISPATCH_RECV:
// Copy scales // Copy scales
if (kUseFP8) { if (kUseFP8) {
const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes); const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
const auto dst_scales = reinterpret_cast<float*>(recv_x_scales + recv_token_begin_idx + i); const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t));
const auto scale_stride = num_ranks * num_max_dispatch_tokens_per_rank; const auto token_idx = recv_token_begin_idx + i;
auto scale_0 = lane_id < num_scales ? ld_nc_global(src_scales + lane_id) : 0; const auto token_stride = num_elems_per_pack;
auto scale_1 = (lane_id + kWarpSize) < num_scales ? ld_nc_global(src_scales + lane_id + kWarpSize) : 0; const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack;
lane_id < num_scales ? dst_scales[lane_id * scale_stride] = scale_0 : 0.0f;
(lane_id + kWarpSize) < num_scales ? dst_scales[(lane_id + kWarpSize) * scale_stride] = scale_1 : 0.0f; if (lane_id < num_scales) {
const auto pack_idx = lane_id / num_elems_per_pack;
const auto elem_idx = lane_id % num_elems_per_pack;
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id));
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
}
if (lane_id + kWarpSize < num_scales) {
const auto pack_idx = (lane_id + kWarpSize) / num_elems_per_pack;
const auto elem_idx = (lane_id + kWarpSize) % num_elems_per_pack;
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id + kWarpSize));
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
}
} }
} }
} }
...@@ -381,7 +400,7 @@ LOW_LATENCY_DISPATCH_RECV: ...@@ -381,7 +400,7 @@ LOW_LATENCY_DISPATCH_RECV:
#endif #endif
} }
void dispatch(void* packed_recv_x, float* packed_recv_x_scales, void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count, int* packed_recv_count,
int* global_atomic_counter, int* global_atomic_counter,
...@@ -389,10 +408,12 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, ...@@ -389,10 +408,12 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const void* x, const int64_t* topk_idx, const void* x, const int64_t* topk_idx,
int64_t* next_clean, int num_next_clean_int, int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, hipStream_t stream, int phases) { bool use_fp8, bool round_scale, bool use_ue8m0,
void* workspace, int num_device_sms,
hipStream_t stream, int phases) {
constexpr int kNumMaxTopK = 11; constexpr int kNumMaxTopK = 11;
const int num_warp_groups = ceil_div(num_experts, /*num_device_sms*/80); const int num_warp_groups = ceil_div(num_experts, num_device_sms);
const int num_warps_per_group = 16 / num_warp_groups; const int num_warps_per_group = 16 / num_warp_groups;
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0); EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0);
EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group); EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group);
...@@ -407,8 +428,11 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, ...@@ -407,8 +428,11 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES); EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
#define DISPATCH_LAUNCH_CASE(hidden) { \ #define DISPATCH_LAUNCH_CASE(hidden) { \
auto dispatch_func = use_fp8 ? dispatch<true, hidden> : \ auto dispatch_func = dispatch<false, false, hidden>; \
dispatch<false, hidden>; \ if (use_fp8 and not use_ue8m0) \
dispatch_func = dispatch<true, false, hidden>; \
if (use_fp8 and use_ue8m0) \
dispatch_func = dispatch<true, true, hidden>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \ packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \ packed_recv_src_info, packed_recv_layout_range, \
...@@ -420,15 +444,15 @@ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \ ...@@ -420,15 +444,15 @@ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
next_clean, num_next_clean_int, \ next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \ num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, \ num_topk, num_experts, rank, num_ranks, \
num_warp_groups, num_warps_per_group, phases); } break num_warp_groups, num_warps_per_group, round_scale, phases); } break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream); SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream);
SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE); SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE #undef DISPATCH_LAUNCH_CASE
} }
template <int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden, int kNumMaxTopk> template <int kHidden, int kNumMaxTopk>
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * kWarpSize, 1) void __global__ __launch_bounds__(16 * kWarpSize, 1) void
combine(void* combined_x, combine(void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights, const void* x, const int64_t* topk_idx, const float* topk_weights,
...@@ -439,6 +463,7 @@ combine(void* combined_x, ...@@ -439,6 +463,7 @@ combine(void* combined_x,
int num_combined_tokens, int hidden, int num_topk, int num_combined_tokens, int hidden, int num_topk,
int num_max_dispatch_tokens_per_rank, int num_max_dispatch_tokens_per_rank,
int num_experts, int rank, int num_ranks, int num_experts, int rank, int num_ranks,
int num_warp_groups, int num_warps_per_group,
int phases, bool zero_copy) { int phases, bool zero_copy) {
#if !defined(ROCM_DISABLE_CTX) #if !defined(ROCM_DISABLE_CTX)
...@@ -451,19 +476,21 @@ combine(void* combined_x, ...@@ -451,19 +476,21 @@ combine(void* combined_x,
const auto num_threads = static_cast<int>(blockDim.x); const auto num_threads = static_cast<int>(blockDim.x);
const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id(); const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
const auto num_local_experts = num_experts / num_ranks; const auto num_local_experts = num_experts / num_ranks;
const auto warp_group_id = warp_id / kNumWarpsPerGroup; const auto warp_group_id = warp_id / num_warps_per_group;
const auto sub_warp_id = warp_id % kNumWarpsPerGroup; const auto sub_warp_id = warp_id % num_warps_per_group;
const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id; const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
// Data type staffs // Data type staffs
constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(hip_bfloat16); constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(hip_bfloat16);
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4; const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;
// Message package // Message package
// BF16 mode: always use BF16 for hidden data (ignoring the extra flag slot) EP_STATIC_ASSERT(kHidden % FP8_QUANTIZATION_NUM_PER_CHANNEL == 0, "Invalid hidden");
constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(hip_bfloat16); constexpr int kNumDivisions = kHidden / FP8_QUANTIZATION_NUM_PER_CHANNEL;
constexpr int kNumMetaBytes = kNumDivisions * sizeof(float);
constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(hip_bfloat16) + kNumMetaBytes;
EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization"); EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
__syncthreads();
// 16 is the max possible number of warps in AMD GPUs // 16 is the max possible number of warps in AMD GPUs
constexpr int kMaxNumWarps = 1024 / kWarpSize; constexpr int kMaxNumWarps = 1024 / kWarpSize;
__shared__ volatile int sync_large_warp_counters[kMaxNumWarps]; __shared__ volatile int sync_large_warp_counters[kMaxNumWarps];
...@@ -508,13 +535,13 @@ combine(void* combined_x, ...@@ -508,13 +535,13 @@ combine(void* combined_x,
unpack2(layout, num_tokens_to_send, offset); unpack2(layout, num_tokens_to_send, offset);
// Issue IBGDA send // Issue IBGDA send
for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += kNumWarpsPerGroup) { for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += num_warps_per_group) {
const auto x_int4 = local_x + token_idx * hidden_bf16_int4; const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot); const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row + 4); const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row + 4);
// Copy directly to local rank, or copy to buffer and issue RDMA // Copy directly to local rank, or copy to buffer and issue RDMA
auto src_idx = __ldg(local_src_info + token_idx); const auto src_idx = shfl_sync(__ldg(local_src_info + token_idx), 0);
const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row); const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot + sizeof(int4); const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot + sizeof(int4);
if (dst_rank == rank) { if (dst_rank == rank) {
...@@ -542,13 +569,13 @@ combine(void* combined_x, ...@@ -542,13 +569,13 @@ combine(void* combined_x,
} }
// Put finishing flag // Put finishing flag
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group"); EP_DEVICE_ASSERT(num_warps_per_group > 1);
if (lane_id == 0){ if (lane_id == 0){
// volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1,__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP); // volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1,__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
volatile int ret = atomicAdd((int*)&sync_large_warp_counters[warp_group_id], 1); volatile int ret = atomicAdd((int*)&sync_large_warp_counters[warp_group_id], 1);
} }
syncwarp(); syncwarp();
while (sync_large_warp_counters[warp_group_id] < (kNumWarpsPerGroup)); while (sync_large_warp_counters[warp_group_id] < num_warps_per_group);
if (sub_warp_id == 1 and lane_id == 0) { if (sub_warp_id == 1 and lane_id == 0) {
while (ld_acquire_global(atomic_clean_flag) == 0); while (ld_acquire_global(atomic_clean_flag) == 0);
if (dst_rank != rank) { if (dst_rank != rank) {
...@@ -572,7 +599,7 @@ combine(void* combined_x, ...@@ -572,7 +599,7 @@ combine(void* combined_x,
// Wait all ranks to arrive and notify PCIe usage // Wait all ranks to arrive and notify PCIe usage
if (responsible_expert_idx < num_experts) { if (responsible_expert_idx < num_experts) {
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group"); EP_DEVICE_ASSERT(num_warps_per_group > 1);
if (sub_warp_id == 0 and lane_id == 0){ if (sub_warp_id == 0 and lane_id == 0){
while (ld_acquire_global(reinterpret_cast<int*>(rdma_recv_flag + responsible_expert_idx)) == 0); while (ld_acquire_global(reinterpret_cast<int*>(rdma_recv_flag + responsible_expert_idx)) == 0);
} }
...@@ -630,14 +657,17 @@ void combine(void* combined_x, ...@@ -630,14 +657,17 @@ void combine(void* combined_x,
int64_t* next_clean, int num_next_clean_int, int64_t* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, hipStream_t stream, void* workspace, int num_device_sms, hipStream_t stream,
int phases, bool zero_copy) { int phases, bool zero_copy) {
constexpr int kNumWarpsPerGroup = 4; constexpr int kNumMaxTopk = 11;
constexpr int kNumWarpGroups = 4; const int num_warp_groups = ceil_div(num_experts, num_device_sms);
constexpr int kNumMaxTopk = 9; const int num_warps_per_group = 16 / num_warp_groups; // num_warps_per_group>1, "Requires more than one warp per group"
const int num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms);
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm >= 0);
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; const auto num_warps = num_warp_groups * num_warps_per_group;
const auto num_sms = ceil_div(num_experts, kNumWarpGroups); const auto num_sms =
max(ceil_div(num_experts, num_warp_groups), num_recv_per_sm == 0 ? 1 : ceil_div(num_combined_tokens, num_recv_per_sm));
// Check workspace // Check workspace
auto atomic_clean_flag = reinterpret_cast<int*>(workspace); auto atomic_clean_flag = reinterpret_cast<int*>(workspace);
...@@ -645,7 +675,7 @@ void combine(void* combined_x, ...@@ -645,7 +675,7 @@ void combine(void* combined_x,
EP_HOST_ASSERT(num_topk <= kNumMaxTopk); EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
#define COMBINE_LAUNCH_CASE(hidden) { \ #define COMBINE_LAUNCH_CASE(hidden) { \
auto combine_func = combine<kNumWarpGroups, kNumWarpsPerGroup, hidden, kNumMaxTopk>; \ auto combine_func = combine<hidden, kNumMaxTopk>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
combined_x, \ combined_x, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \ rdma_recv_x, rdma_recv_flag, rdma_send_x, \
...@@ -656,7 +686,7 @@ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \ ...@@ -656,7 +686,7 @@ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
num_combined_tokens, hidden, num_topk, \ num_combined_tokens, hidden, num_topk, \
num_max_dispatch_tokens_per_rank, \ num_max_dispatch_tokens_per_rank, \
num_experts, rank, num_ranks, \ num_experts, rank, num_ranks, \
phases, zero_copy); } break num_warp_groups, num_warps_per_group, phases, zero_copy); } break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream); SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream);
SWITCH_HIDDEN(COMBINE_LAUNCH_CASE); SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);
......
...@@ -365,9 +365,8 @@ __forceinline__ __device__ int fast_log2_ceil(float x) { ...@@ -365,9 +365,8 @@ __forceinline__ __device__ int fast_log2_ceil(float x) {
return exp_x - 127 + (man_bits != 0); return exp_x - 127 + (man_bits != 0);
} }
template <bool kRoundScale> __forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv, bool round_scale) {
__forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv) { if (round_scale) {
if constexpr(kRoundScale) {
auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3); auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3);
scale = fast_pow2(-exp_scale_inv); scale = fast_pow2(-exp_scale_inv);
scale_inv = fast_pow2(exp_scale_inv); scale_inv = fast_pow2(exp_scale_inv);
......
...@@ -804,42 +804,46 @@ class Buffer: ...@@ -804,42 +804,46 @@ class Buffer:
# noinspection PyTypeChecker # noinspection PyTypeChecker
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor, def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
num_max_dispatch_tokens_per_rank: int, num_experts: int, num_max_dispatch_tokens_per_rank: int, num_experts: int,
use_fp8: bool = True, async_finish: bool = False, return_recv_hook: bool = False) -> \ use_fp8: bool = True, round_scale: bool = False, use_ue8m0: bool = False,
async_finish: bool = False, return_recv_hook: bool = False) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]: Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
""" """
A low-latency implementation for dispatching with IBGDA. A low-latency implementation for dispatching with IBGDA.
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
(specifically, IBGDA must be enabled). (specifically, IBGDA must be enabled).
Even for ranks in the same node, NVLink are fully disabled for simplicity. Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2 low-latency kernels' result tensors at a single moment.
low-latency kernels' result tensor at a single moment.
Arguments: Arguments:
x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are
supported. The number of tokens to be dispatched must be less than `num_max_dispatch_tokens_per_rank`. supported. The number of tokens to be dispatched must be less than `num_max_dispatch_tokens_per_rank`.
topk_idx: `torch.Tensor` with `torch.int64`, shaped as `[num_tokens, num_topk]`, only several top-k shapes topk_idx: `torch.Tensor` with `deep_ep.topk_idx_t` (typically `torch.int64`), shaped as `[num_tokens, num_topk]`,
are supported. `-1` indices (not selecting any expert) are supported. only several top-k shapes are supported. `-1` indices (not selecting any expert) are supported.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value. num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_experts: the number of all experts. num_experts: the number of all experts.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors. use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
round_scale: whether round the scaling factors into power of 2.
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
async_finish: the current stream will not wait for the communication kernels to be finished if set. async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival. If you do not set this flag, the kernel will ensure the data's arrival.
Returns: Returns:
recv_x: a tensor or tuple with received tokens for each expert. recv_x: a tensor or tuple with received tokens for each expert.
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`. `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
The second tensor is the corresponding scales for the first element with shape The second tensor is the corresponding scales for the first element with shape
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`. `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`,
if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`.
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility. Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
With `use_fp8=False`, the result would be a tensor shaped as With `use_fp8=False`, the result would be a tensor shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`. `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are, Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced). as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
expert receive. As mentioned before, all not tokens are valid in `recv_x`. expert receives. As mentioned before, not all tokens are valid in `recv_x`.
handle: the communication handle to be used in the `low_latency_combine` function. handle: the communication handle to be used in the `low_latency_combine` function.
event: the event after executing the kernel (valid only if `async_finish` is set). event: the event after executing the kernel (valid only if `async_finish` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set). hook: the receiving hook function (valid only if `return_recv_hook` is set).
...@@ -847,7 +851,8 @@ class Buffer: ...@@ -847,7 +851,8 @@ class Buffer:
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \ packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \
self.runtime.low_latency_dispatch(x, topk_idx, self.runtime.low_latency_dispatch(x, topk_idx,
num_max_dispatch_tokens_per_rank, num_experts, num_max_dispatch_tokens_per_rank, num_experts,
use_fp8, async_finish, return_recv_hook) use_fp8, round_scale, use_ue8m0,
async_finish, return_recv_hook)
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts) handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts)
tensors_to_record = (x, topk_idx, tensors_to_record = (x, topk_idx,
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_x, packed_recv_x_scales, packed_recv_count,
......
import argparse
import random
import torch
import torch.distributed as dist
from functools import partial
from typing import Literal, Set
import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back
def simulate_failure_and_skip(rank: int, api: Literal["dispatch", "combine", "clean"], expected_masked_ranks: Set[int]):
# Simulates rank failure when the rank first calls the corresponding communication API
failed_api_ranks = {
# API -> rank to fail (rank fails when it first calls the corresponding communication API)
'dispatch': 1,
'combine': 3,
'clean': 5
}
if rank in expected_masked_ranks:
# Rank already failed
return True
if api in failed_api_ranks.keys():
expected_masked_ranks.add(failed_api_ranks[api])
if failed_api_ranks[api] == rank:
print(f"Rank {rank} failed when first calling {api} communication API, exit...", flush=True)
return True
return False
def query_mask_buffer_and_check(api: Literal["dispatch", "combine", "clean"], buffer: deep_ep.Buffer, mask_status: torch.Tensor,
expected_masked_ranks: Set[int]):
buffer.low_latency_query_mask_buffer(mask_status)
assert set(mask_status.nonzero().squeeze(-1).tolist()) == expected_masked_ranks
def test_main(num_tokens: int,
hidden: int,
num_experts: int,
num_topk: int,
rank: int,
num_ranks: int,
group: dist.ProcessGroup,
buffer: deep_ep.Buffer,
seed: int = 0):
torch.manual_seed(seed + rank)
random.seed(seed + rank)
assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks
# NOTES: the integers greater than 256 exceed the BF16 precision limit
rank_offset = 128
assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)'
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset)
x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1)
x_list = [x]
# # NOTES: the last one is for performance testing
# # Most of the values in the perf case is lower than the threshold, casting most channels
# x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()
# Randomly mask some positions
for _ in range(10):
topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
# For failure simulation and shrink testing
mask_status = torch.zeros((num_ranks,), dtype=torch.int, device='cuda')
expected_masked_ranks = set()
# Check dispatch correctness
do_check = True
hash_value, num_times = 0, 0
for current_x in x_list:
for return_recv_hook in (False, True):
for dispatch_use_fp8 in (False, True):
for round_scale in (False, True) if dispatch_use_fp8 else (False,):
for use_ue8m0 in (False, True) if round_scale else (False,):
num_times += 1
for _ in range((num_times % 2) + 1):
packed_recv_x, packed_recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait()
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \
if dispatch_use_fp8 else packed_recv_x.clone()
for i in range(num_local_experts if do_check else 0):
expert_id = rank * num_local_experts + i
recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i]
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
# Check expert indices
int_mask = (2 ** 32) - 1
num_valid_tokens = recv_count.item()
assert num_valid_tokens == (
recv_layout_range
& int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
assert num_valid_tokens == (all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status == 0].sum().item(
), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status == 0].sum().item()}'
if num_valid_tokens == 0:
continue
# Check received data
if current_x is x:
recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
if round_scale:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
else:
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
for j in range(num_ranks):
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
if not round_scale:
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if dispatch_use_fp8:
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
else:
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
# Check combine correctness
for zero_copy in (False, True):
if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
topk_weights,
handle,
async_finish=not return_recv_hook,
# zero_copy=zero_copy,
return_recv_hook=return_recv_hook,
out=out)
hook() if return_recv_hook else event.current_stream_wait()
if do_check:
diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
assert torch.isnan(combined_x).sum().item() == 0
# if not round_scale:
assert diff < (9e-4 if dispatch_use_fp8 else 1e-5), f'Error: diff={diff}, dispatch_use_fp8={dispatch_use_fp8}, zero_copy={zero_copy}'
hash_value ^= hash_tensor(combined_x)
# noinspection PyShadowingNames
def large_gemm_with_hook(hook):
mat_0 = torch.randn((8192, 8192), dtype=torch.float)
mat_1 = torch.randn((8192, 8192), dtype=torch.float)
mat_0 @ mat_1
hook()
# noinspection PyShadowingNames
def test_func(return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
topk_weights,
handle,
return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
# Calculate bandwidth
num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2
num_logfmt10_bytes = hidden * 10 / 8 + hidden / 128 * 4
num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
for i in range(num_tokens):
num_selections = (topk_idx[i] != -1).sum().item()
num_dispatch_comm_bytes += num_fp8_bytes * num_selections
num_combine_comm_bytes += num_bf16_bytes * num_selections
# Dispatch + combine testing
avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False))
print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, '
f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us',
flush=True)
# Separate profiling
for return_recv_hook in (False, True):
group.barrier()
dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook),
kernel_names=('dispatch', 'combine'),
barrier_comm_profiling=True,
suppress_kineto_output=True,
num_kernels_per_period=2 if return_recv_hook else 1)
if not return_recv_hook:
print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us',
flush=True)
else:
print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | '
f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us',
flush=True)
return hash_value
# noinspection PyUnboundLocalVariable,PyShadowingNames
def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
num_tokens, hidden = args.num_tokens, args.hidden
num_topk, num_experts = args.num_topk, args.num_experts
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
if local_rank == 0:
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
buffer = deep_ep.Buffer(group,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_experts // num_ranks,
allow_nvlink_for_low_latency_mode=not args.disable_nvlink,
explicitly_destroy=True,
allow_mnnvl=args.allow_mnnvl)
test_main(num_tokens,
hidden,
num_experts,
num_topk,
rank,
num_ranks,
group,
buffer,
seed=1)
do_pressure_test = args.pressure_test
for seed in range(int(1e9) if do_pressure_test else 0):
if local_rank == 0:
print(f'Testing with seed {seed} ...', flush=True)
ref_hash = test_main(num_tokens,
hidden,
num_experts,
num_topk,
rank,
num_ranks,
group,
buffer,
seed=seed)
for _ in range(20):
assert test_main(num_tokens,
hidden,
num_experts,
num_topk,
rank,
num_ranks,
group,
buffer,
seed=seed) == ref_hash, f'Error: seed={seed}'
# Destroy the buffer runtime and communication group
buffer.destroy()
dist.barrier()
dist.destroy_process_group()
if __name__ == '__main__':
# TODO: you may modify NUMA binding for less CPU overhead
# TODO: buggy with `num_tokens=512`
parser = argparse.ArgumentParser(description='Test low-latency EP kernels')
parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)')
parser.add_argument('--num-tokens', type=int, default=128, help='Number of tokens (default: 128)')
parser.add_argument('--hidden', type=int, default=7168, help='Hidden dimension size (default: 7168)')
parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)')
parser.add_argument('--num-experts', type=int, default=288, help='Number of experts (default: 288)')
parser.add_argument('--allow-mnnvl', action="store_true", help='Allow MNNVL for communication')
parser.add_argument('--disable-nvlink', action='store_true', help='Whether to disable NVLink for testing')
parser.add_argument("--pressure-test", action='store_true', help='Whether to do pressure test')
parser.add_argument("--shrink-test", action='store_true', help='Whether to simulate failure and test shrink mode')
parser.add_argument('--use-logfmt', action='store_true', help='Whether to test LogFMT combine')
args = parser.parse_args()
num_processes = args.num_processes
torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment