Commit d0fcf024 authored by lishen's avatar lishen
Browse files

Merge branch 'quant_master' into 'main'

And quant.

See merge request dcutoolkit/deeplearing/DeepEP!19
parents 81e56124 ace6e18e
......@@ -136,7 +136,7 @@ struct LowLatencyLayout {
LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts) {
const int num_scales = hidden / FP8_QUANTIZATION_NUM_PER_CHANNEL;
const int num_scales = hidden / QUANTIZATION_GROUPSIZE;
// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
......
......@@ -1293,7 +1293,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8,
int quant_type, int quant_group_size, bool fp8_round_scale,
bool async, bool return_recv_hook) {
EP_HOST_ASSERT(low_latency_mode);
......@@ -1327,8 +1327,15 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
stream_wait(launch_stream, compute_stream);
// Allocate packed tensors
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
x.options().dtype(use_int8 ? torch::kInt8 : use_fp8 ? torch::kFloat8_e4m3fnuz: torch::kBFloat16));
auto packed_recv_x_dtype = torch::kBFloat16;
switch (quant_type) {
case 1: packed_recv_x_dtype = torch::kInt8; break;
case 2: packed_recv_x_dtype = torch::kFloat8_e4m3fn; break;
case 3: packed_recv_x_dtype = torch::kFloat8_e4m3fn; break;
case 4: packed_recv_x_dtype = torch::kFloat8_e5m2; break;
}
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(packed_recv_x_dtype));
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
......@@ -1336,21 +1343,28 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
// Allocate column-majored scales
auto packed_recv_x_scales = std::optional<torch::Tensor>();
void* packed_recv_x_scales_ptr = nullptr;
if (use_fp8) {
if (quant_type > 0) {
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
// TODO: support unaligned cases
EP_HOST_ASSERT(hidden % (FP8_QUANTIZATION_NUM_PER_CHANNEL * 4) == 0);
EP_HOST_ASSERT(!(use_ue8m0 && use_int8));
EP_HOST_ASSERT(hidden % (QUANTIZATION_GROUPSIZE * 4) == 0);
// 计算scale_col的大小
int scales_col_size = 1; // 默认为per-channel
if (quant_group_size > 0) {
if (quant_type == 3) { // FP8_UE8M0比较特殊
scales_col_size = hidden / (QUANTIZATION_GROUPSIZE * 4);
} else {
scales_col_size = hidden / QUANTIZATION_GROUPSIZE;
}
}
if (use_ue8m0) {
EP_HOST_ASSERT(round_scale);
packed_recv_x_scales = torch::empty({num_local_experts, hidden / (FP8_QUANTIZATION_NUM_PER_CHANNEL * 4), num_ranks * num_max_dispatch_tokens_per_rank},
// 设置packed_recv_x_scales
if (quant_type == 3) { // FP8_UE8M0比较特殊,需要单独处理
EP_HOST_ASSERT(fp8_round_scale && quant_group_size == 128);
packed_recv_x_scales = torch::empty({num_local_experts, scales_col_size, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kInt).device(torch::kCUDA));
} else if (use_int8) {
packed_recv_x_scales = torch::empty({num_local_experts, 1, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kFloat32).device(torch::kCUDA));
} else {
packed_recv_x_scales = torch::empty({num_local_experts, hidden / FP8_QUANTIZATION_NUM_PER_CHANNEL, num_ranks * num_max_dispatch_tokens_per_rank},
packed_recv_x_scales = torch::empty({num_local_experts, scales_col_size, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kFloat32).device(torch::kCUDA));
}
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
......@@ -1370,7 +1384,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
next_clean_meta.first, next_clean_meta.second,
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
use_fp8, round_scale, use_ue8m0, use_int8,
quant_type, quant_group_size, fp8_round_scale,
workspace, num_device_sms, launch_stream, phases);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
......
......@@ -177,7 +177,7 @@ public:
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8,
int quant_type, int quant_group_size, bool fp8_round_scale,
bool async, bool return_recv_hook);
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
......
......@@ -147,7 +147,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8,
int quant_type, int group_size, bool fp8_round_scale,
void* workspace, int num_device_sms, hipStream_t stream, int phases);
void combine(void* combined_x,
......
......@@ -23,7 +23,7 @@
#define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_RECV_PHASE 2
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define FP8_QUANTIZATION_NUM_PER_CHANNEL 128
#define QUANTIZATION_GROUPSIZE 128
#define DEFAULT_NUM_CU 20
#define DEFAULT_NUM_MAX_XGMI_CHUNKED_SEND_TOKENS 6
......
......@@ -116,20 +116,74 @@ internode_ll_long_atomic_add(long* dest, const long &value,
#endif // defined(FORCE_DUSHMEM_API)
}
template <bool kUseFP8, bool kUseUE8M0, bool kUseInt8, int kHidden>
/**
* @brief 将 K 个浮点数(BF16/FP32)量化并打包成 INT2(64位)存储
*
* @tparam kQuantType 量化类型 (1: Int8, 2/3: FP8_E4M3/UE8M0, 4: FP8_E5M2)
* @tparam kNumElemsPerRead 每次读取的元素数量 (通常为 2, 4, 8)
* @tparam SrcT 源数据类型 (float 或 __hip_bfloat16)
* @tparam DstT 目标数据类型 (int2 或 int4)
* @param src_values 源数据数组 (长度 >= kNumElemsPerRead)
* @param scale 缩放因子 (将 FP32 值映射到量化范围)
* @param[out] dst_vec 输出的 64 位向量 (int2 或 int4)
*/
template <int kQuantType, int kNumElemsPerRead, typename SrcT, typename DstT>
__forceinline__ __device__ void pack_quantized_values(
const SrcT* src_values, float scale, DstT& dst_vec) {
if constexpr (kQuantType == 1) {
// INT8 量化
auto int8_ptr = reinterpret_cast<int8_t*>(&dst_vec);
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; ++j) {
// 如果源是 bfloat16,先提升为 float
float fp32_value_scaled = static_cast<float>(src_values[j]) * scale;
// 使用 nearbyintf 进行四舍五入
int8_ptr[j] = static_cast<int8_t>(nearbyintf(fp32_value_scaled));
}
} else {
// FP8 量化 (E4M3, UE8M0, E5M2)
// 假设 dst_vec 能容纳 kNumElemsPerRead/2 个 fp8x2 元素
auto fp8x2_ptr = reinterpret_cast<__hip_fp8x2_storage_t*>(&dst_vec);
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; j += 2) {
// 处理两个元素
float2 fp32x2 = {static_cast<float>(src_values[j]) * scale, static_cast<float>(src_values[j + 1]) * scale};
if constexpr (kQuantType == 4) {
// FP8 E5M2
fp8x2_ptr[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E5M2);
} else {
// FP8 E4M3 或 UE8M0
fp8x2_ptr[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3);
}
}
}
}
template <int kHidden, int kQuantType=0, int kQuantGroupSize=0, int kMaxNumWarps=16>
__global__ __launch_bounds__(16 * kWarpSize, 1) void
dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* global_atomic_counter,
void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
int64_t* next_clean, int num_next_clean_int,
int num_tokens, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
int num_warp_groups, int num_warps_per_group,
bool round_scale, int phases) {
dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* global_atomic_counter,
void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
int64_t* next_clean, int num_next_clean_int,
int num_tokens, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
int num_warp_groups, int num_warps_per_group,
bool fp8_round_scale, int phases) {
// 定义量化类型的枚举
enum class QuantType {
None = 0, // 不进行量化
Int8 = 1, // 采用 Int8 量化
FP8_E4M3 = 2, // 采用 FP8 量化 __HIP_E4M3
FP8_UE8M0 = 3, // 采用 FP8 量化 DeepseekV3.1的 UE8M0
FP8_E5M2 = 4 // 采用 FP8 量化 __HIP_E5M2
};
const auto sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
......@@ -141,20 +195,22 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
// May extract UE8M0 from the scales
constexpr bool kUseQuant8Bit = kQuantType > 0;
constexpr bool kUseUE8M0 = kQuantType == 3; // QuantType::FP8_UE8M0
using scale_t = std::conditional_t<kUseUE8M0, uint8_t, float>;
using packed_t = std::conditional_t<kUseUE8M0, uint32_t, float>;
EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length");
// FP8 staffs
constexpr int kNumPerChannels = FP8_QUANTIZATION_NUM_PER_CHANNEL;
constexpr int kNumPerChannels = QUANTIZATION_GROUPSIZE;
constexpr int kNumScales = kHidden / kNumPerChannels;
const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__hip_fp8_storage_t) : sizeof(hip_bfloat16));
const size_t hidden_bytes = kHidden * (kUseQuant8Bit ? sizeof(__hip_fp8_storage_t) : sizeof(hip_bfloat16));
const size_t hidden_int4 = hidden_bytes / sizeof(int4);
// Message package: hidden data, FP8 scales, index at source
// NOTES: currently we have 3 reserved int fields for future use
using vec_t = typename std::conditional<kUseFP8, int2, int4>::type;
constexpr size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + kNumScales * sizeof(float)) : (kHidden * sizeof(hip_bfloat16)));
using vec_t = typename std::conditional<kUseQuant8Bit, int2, int4>::type;
constexpr size_t num_bytes_per_msg = sizeof(int4) + (kUseQuant8Bit ? (kHidden + kNumScales * sizeof(float)) : (kHidden * sizeof(hip_bfloat16)));
EP_STATIC_ASSERT(num_bytes_per_msg % sizeof(int4) == 0, "Invalid message size");
constexpr size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
......@@ -171,6 +227,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// 2. The last warp for reading `topk_idx` and count for per-expert information
if (warp_id < num_warps) {
constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(hip_bfloat16);
constexpr int kNumThreadPerGroup = QUANTIZATION_GROUPSIZE / kNumElemsPerRead;
// EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize % kNumPerChannels == 0, "Invalid vectorization");
const auto num_threads = (num_warps - 1) * kWarpSize;
......@@ -186,10 +243,11 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
__shared__ float int8_amaxf[kNumScales];
if constexpr(kUseInt8) {
// 用于记录per-channel量化的amax
__shared__ float channel_amaxf[kNumScales];
if constexpr(kUseQuant8Bit && kQuantGroupSize == 0) {
if (thread_id < kNumScales) {
int8_amaxf[thread_id] = kFP8Margin;
channel_amaxf[thread_id] = 0.0;
}
__syncthreads();
}
......@@ -200,11 +258,11 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Read
auto int4_value = __ldg(x_int4 + i);
if constexpr(kUseFP8) {
if constexpr(kUseQuant8Bit) {
// Calculate local amax
auto bf16_values = reinterpret_cast<hip_bfloat16*>(&int4_value);
float fp32_values[kNumElemsPerRead];
float amax = kFP8Margin, scale, scale_inv;
float amax = 0.0, scale, scale_inv;
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; ++ j) {
fp32_values[j] = static_cast<float>(bf16_values[j]);
......@@ -212,25 +270,20 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
}
// Reduce amax and scale
EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize / kNumPerChannels == 4, "Invalid vectorization");
amax = warp_reduce_max<16>(amax);
const int scale_offset = i * kNumElemsPerRead / FP8_QUANTIZATION_NUM_PER_CHANNEL;
amax = warp_reduce_max<kNumThreadPerGroup>(amax);
const int scale_offset = i * kNumElemsPerRead / QUANTIZATION_GROUPSIZE;
if constexpr(kUseInt8) {
if constexpr(kQuantGroupSize == 0) {
// 记录每128个数的最大值
int8_amaxf[scale_offset] = fmaxf(amax, int8_amaxf[scale_offset]);
channel_amaxf[scale_offset] = fmaxf(amax, channel_amaxf[scale_offset]);
} else {
calculate_fp8_scales(amax, scale, scale_inv, round_scale);
if (lane_id % 16 == 0)
calculate_quant8bit_scales<kQuantType>(amax, scale, scale_inv, fp8_round_scale);
if (lane_id % kNumThreadPerGroup == 0)
rdma_x_scales[scale_offset] = scale_inv;
// Cast into send buffer
vec_t int2_value;
auto fp8x2_values = reinterpret_cast<__hip_fp8x2_storage_t*>(&int2_value);
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; j += 2) {
float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale};
fp8x2_values[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
}
pack_quantized_values<kQuantType, kNumElemsPerRead>(fp32_values, scale, int2_value);
rdma_x_vec[i] = int2_value;
}
} else {
......@@ -240,24 +293,24 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
}
__syncthreads();
if constexpr(kUseInt8) {
float amax_per_token = kFP8Margin;
if constexpr(kUseQuant8Bit && kQuantGroupSize == 0) {
float amax_per_token = 0.0;
// 并行规约,计算每个token的amax
for (int s = 0; s < kNumScales; s+=kWarpSize) {
int src_idx = s + lane_id;
float tmp_amaxf = 0;
if(src_idx < kNumScales) {
tmp_amaxf = int8_amaxf[src_idx];
tmp_amaxf = channel_amaxf[src_idx];
}
tmp_amaxf = warp_reduce_max<kWarpSize>(tmp_amaxf);
int8_amaxf[0] = fmaxf(tmp_amaxf, int8_amaxf[0]);
channel_amaxf[0] = fmaxf(tmp_amaxf, channel_amaxf[0]);
__syncthreads();
}
amax_per_token = int8_amaxf[0];
amax_per_token = channel_amaxf[0];
// 根据最大值计算scale
float scale, scale_inv;
calculate_int8_scales(amax_per_token, scale, scale_inv);
calculate_quant8bit_scales<kQuantType>(amax_per_token, scale, scale_inv, fp8_round_scale);
if (thread_id == 0) {
rdma_x_scales[0] = scale_inv;
}
......@@ -269,13 +322,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Cast into send buffer
vec_t int2_value;
auto int8_values = reinterpret_cast<int8_t*>(&int2_value);
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; ++ j) {
auto fp32_value = static_cast<float>(bf16_values[j]);
auto fp32_value_scaled = fp32_value * scale;
int8_values[j] = static_cast<int8_t>(nearbyintf(fp32_value_scaled));
}
pack_quantized_values<kQuantType, kNumElemsPerRead>(bf16_values, scale, int2_value);
rdma_x_vec[i] = int2_value;
}
__syncthreads();
......@@ -297,8 +344,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
uint64_t p2p_ptr = internode::shmem_get_p2p_ptr((void*)dst_ptr, rank, dst_rank);
if (p2p_ptr == 0) { // RDMA
internode_ll_putmem_nbi((void*)dst_ptr, (void*)src_ptr,
num_ranks, dst_rank, dst_expert_local_idx,
num_bytes_per_msg);
num_ranks, dst_rank, dst_expert_local_idx,
num_bytes_per_msg);
} else { // 本地 GPU 和 同一计算节点的 其他 GPU 地址
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
......@@ -392,11 +439,10 @@ LOW_LATENCY_DISPATCH_RECV:
}
// 16 is the max possible number of warps in AMD GPUs
constexpr int kMaxNumWarps = 1024 / kWarpSize;
constexpr int num_sync_large_iteration = kMaxNumWarps ;
__shared__ volatile int sync_large_warp_counters[num_sync_large_iteration];
#pragma unroll
#pragma unroll
for (int i = thread_id; i < num_sync_large_iteration; i += blockDim.x) {
sync_large_warp_counters[i] = 0;
}
......@@ -416,7 +462,7 @@ LOW_LATENCY_DISPATCH_RECV:
const auto num_aligned_scales = ALIGN<int>(kNumScales, sizeof(float) / sizeof(scale_t));
const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank *
(kUseInt8 ? 1 : num_aligned_scales);
(kQuantGroupSize == 0 ? 1 : num_aligned_scales);
// Shared between sub-warps in warp groups
__shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups];
......@@ -461,14 +507,14 @@ LOW_LATENCY_DISPATCH_RECV:
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);
// Copy scales
if constexpr(kUseFP8) {
if constexpr(kUseQuant8Bit) {
const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t));
const auto token_idx = recv_token_begin_idx + i;
const auto token_stride = num_elems_per_pack;
const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack;
if constexpr(kUseInt8) {
if constexpr(kQuantGroupSize == 0) {
if (lane_id == 0) {
recv_x_scales[token_idx] = ld_nc_global(src_scales);
}
......@@ -500,12 +546,13 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8,
int quant_type, int quant_group_size, bool fp8_round_scale,
void* workspace, int num_device_sms,
hipStream_t stream, int phases) {
constexpr int kMaxNumWarps = 16;
constexpr int kNumMaxTopK = 11;
const int num_warp_groups = ceil_div(num_experts, num_device_sms);
const int num_warps_per_group = 16 / num_warp_groups;
const int num_warps_per_group = kMaxNumWarps / num_warp_groups;
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0);
EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group);
......@@ -518,33 +565,54 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
#define DISPATCH_LAUNCH_CASE(hidden) { \
auto dispatch_func = dispatch<false, false, false, hidden>; \
if (use_fp8 and not use_ue8m0) \
dispatch_func = dispatch<true, false, false, hidden>; \
if (use_fp8 and use_ue8m0) \
dispatch_func = dispatch<true, true, false, hidden>; \
if (use_int8) \
dispatch_func = dispatch<true, false, true, hidden>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \
packed_recv_count, \
global_atomic_counter, \
rdma_recv_x, rdma_recv_count, rdma_x, \
x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, \
num_warp_groups, num_warps_per_group, round_scale, phases); } break
// 限制groupsize的大小
EP_HOST_ASSERT(quant_group_size == 0 || quant_group_size == 128);
/*量化类型枚举
0 -> None 不量化,保持原始精度
1 -> Int8 使用 INT8 对称量化
2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3)
3 -> FP8_UE8M0 使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True)
4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2)
*/
#define DISPATCH_LAUNCH_CASE(hidden) \
{ \
auto dispatch_func = dispatch<hidden, 0, 0, kMaxNumWarps>; \
if (quant_group_size == 0) { \
switch (quant_type) { \
case 1: dispatch_func = dispatch<hidden, 1, 0, kMaxNumWarps>; break; \
case 2: dispatch_func = dispatch<hidden, 2, 0, kMaxNumWarps>; break; \
case 3: dispatch_func = dispatch<hidden, 3, 0, kMaxNumWarps>; break; \
case 4: dispatch_func = dispatch<hidden, 4, 0, kMaxNumWarps>; break; \
} \
} else { \
switch (quant_type) { \
case 1: dispatch_func = dispatch<hidden, 1, 128, kMaxNumWarps>; break; \
case 2: dispatch_func = dispatch<hidden, 2, 128, kMaxNumWarps>; break; \
case 3: dispatch_func = dispatch<hidden, 3, 128, kMaxNumWarps>; break; \
case 4: dispatch_func = dispatch<hidden, 4, 128, kMaxNumWarps>; break; \
} \
} \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, packed_recv_count, \
global_atomic_counter, \
rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, \
num_warp_groups, num_warps_per_group, fp8_round_scale, phases); \
} \
break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream);
SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE
}
template <int kHidden, int kNumMaxTopk>
template <int kHidden, int kNumMaxTopk, int kMaxNumWarps=16>
__global__ __launch_bounds__(16 * kWarpSize, 1) void
combine(void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
......@@ -574,12 +642,11 @@ combine(void* combined_x,
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;
// Message package
EP_STATIC_ASSERT(kHidden % FP8_QUANTIZATION_NUM_PER_CHANNEL == 0, "Invalid hidden");
EP_STATIC_ASSERT(kHidden % QUANTIZATION_GROUPSIZE == 0, "Invalid hidden");
constexpr size_t num_bytes_per_slot = kHidden * sizeof(hip_bfloat16);
EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
// 16 is the max possible number of warps in AMD GPUs
constexpr int kMaxNumWarps = 1024 / kWarpSize;
// 初始化用于细粒度warp间同步的计数器数组
__shared__ volatile int sync_large_warp_counters[kMaxNumWarps];
if (threadIdx.x==0){
#pragma unroll
......@@ -755,9 +822,10 @@ void combine(void* combined_x,
int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, int num_device_sms, hipStream_t stream,
int phases, bool zero_copy) {
constexpr int kMaxNumWarps = 16;
constexpr int kNumMaxTopk = 11;
const int num_warp_groups = ceil_div(num_experts, num_device_sms);
const int num_warps_per_group = 16 / num_warp_groups; // num_warps_per_group>1, "Requires more than one warp per group"
const int num_warps_per_group = kMaxNumWarps / num_warp_groups; // num_warps_per_group>1, "Requires more than one warp per group"
const int num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms);
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm >= 0);
......@@ -770,20 +838,20 @@ void combine(void* combined_x,
EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
#define COMBINE_LAUNCH_CASE(hidden) { \
auto combine_func = combine<hidden, kNumMaxTopk>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
combined_x, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \
global_atomic_counter, \
combine_wait_recv_cost_stats, \
next_clean, num_next_clean_int, \
atomic_clean_flag, \
num_combined_tokens, hidden, num_topk, \
num_max_dispatch_tokens_per_rank, \
num_experts, rank, num_ranks, \
num_warp_groups, num_warps_per_group, phases, zero_copy); } break
#define COMBINE_LAUNCH_CASE(hidden) \
{ \
auto combine_func = combine<hidden, kNumMaxTopk, kMaxNumWarps>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \
global_atomic_counter, combine_wait_recv_cost_stats, \
next_clean, num_next_clean_int, \
atomic_clean_flag, num_combined_tokens, hidden, \
num_topk, num_max_dispatch_tokens_per_rank, \
num_experts, rank, num_ranks, \
num_warp_groups, num_warps_per_group, phases, zero_copy); \
} \
break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream);
SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);
......
......@@ -62,7 +62,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
case 8: \
case_macro(8); \
default: \
EP_HOST_ASSERT(false and "Unsupported ranks"); \
EP_HOST_ASSERT(false and "Unsupported ranks"); \
} \
while (false)
......@@ -83,7 +83,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
case 20: \
case_macro(20); \
default: \
EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \
EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \
} \
while (false)
......@@ -96,7 +96,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
case 8: \
case_macro(dtype, 8); \
default: \
EP_HOST_ASSERT(false and "Unsupported ranks"); \
EP_HOST_ASSERT(false and "Unsupported ranks"); \
} \
while (false)
......@@ -107,7 +107,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
case HIP_R_32F: \
case_macro(float); \
default: \
EP_HOST_ASSERT(false and "Unsupported type"); \
EP_HOST_ASSERT(false and "Unsupported type"); \
} \
while (false)
......@@ -121,7 +121,9 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
case_macro(4096); \
case 7168: \
case_macro(7168); \
case 8192: \
case_macro(8192); \
default: \
EP_HOST_ASSERT(false and "Unsupported hidden"); \
EP_HOST_ASSERT(false and "Unsupported hidden"); \
} \
while (false)
......@@ -341,10 +341,13 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
return *reinterpret_cast<dtype_t *>(recv_int_values);
}
constexpr float kFP8Margin = 1e-4;
constexpr float kFinfoAmaxE4M3 = 240.0f;
// 设置不同的量化方式的最大值与相反数
constexpr float kFinfoAmaxE4M3 = 448.0f;
constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3;
constexpr float kInt8Amax = 127.0f;
constexpr float kFinfoAmaxE5M2 = 57344.0f;
constexpr float kFinfoAmaxInvE5M2 = 1.0f / kFinfoAmaxE5M2;
constexpr float kFinfoAmaxInt8 = 127.0f;
constexpr float kFinfoAmaxInvInt8 = 1.0f / 127.0f;
__forceinline__ __device__ float fast_pow2(int x) {
// We can ensure `-126 <= x and x <= 127`
......@@ -359,22 +362,33 @@ __forceinline__ __device__ int fast_log2_ceil(float x) {
return exp_x - 127 + (man_bits != 0);
}
__forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv, bool round_scale) {
if (round_scale) {
auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3);
scale = fast_pow2(-exp_scale_inv);
scale_inv = fast_pow2(exp_scale_inv);
} else {
scale_inv = amax * kFinfoAmaxInvE4M3;
scale = kFinfoAmaxE4M3 / amax;
template <int kQuantType>
__forceinline__ __device__ void calculate_quant8bit_scales(float amax, float& scale, float& scale_inv, bool round_scale=0) {
amax = fmaxf(amax, 1e-6f);
if constexpr(kQuantType == 1) { // 使用 INT8 对称量化
scale_inv = kFinfoAmaxInvInt8 * amax;
scale = kFinfoAmaxInt8 / amax;
} else if constexpr(kQuantType == 2 || kQuantType == 3) { // 使用 FP8_E4M3 或 FP8_UE8M0 非对称量化
if (round_scale) {
auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3);
scale = fast_pow2(-exp_scale_inv);
scale_inv = fast_pow2(exp_scale_inv);
} else {
scale_inv = amax * kFinfoAmaxInvE4M3;
scale = kFinfoAmaxE4M3 / amax;
}
} else if constexpr(kQuantType == 4) { // 使用 FP8_E5M2 对称量化
if (round_scale) {
auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE5M2);
scale = fast_pow2(-exp_scale_inv);
scale_inv = fast_pow2(exp_scale_inv);
} else {
scale_inv = amax * kFinfoAmaxInvE5M2;
scale = kFinfoAmaxE5M2 / amax;
}
}
}
__forceinline__ __device__ void calculate_int8_scales(float amax, float& scale, float& scale_inv) {
scale = kInt8Amax / amax;
scale_inv = amax / kInt8Amax;
}
template <bool kIsUE8M0, typename out_dtype_t = std::conditional_t<kIsUE8M0, uint8_t, float>>
__forceinline__ __device__ out_dtype_t extract_required_scale_format(float value) {
if constexpr (kIsUE8M0) {
......
......@@ -841,7 +841,7 @@ class Buffer:
# noinspection PyTypeChecker
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
num_max_dispatch_tokens_per_rank: int, num_experts: int,
use_fp8: bool = True, round_scale: bool = False, use_ue8m0: bool = False, use_int8: bool = False,
quant_type: int = 1, quant_group_size: int = 0, fp8_round_scale: bool = False,
async_finish: bool = False, return_recv_hook: bool = False) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
"""
......@@ -858,10 +858,20 @@ class Buffer:
only several top-k shapes are supported. `-1` indices (not selecting any expert) are supported.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_experts: the number of all experts.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
round_scale: whether round the scaling factors into power of 2.
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
use_int8: whether to enable INT8 casting.
量化配置
quant_type: int 量化类型枚举
0 -> None 不量化,保持原始精度
1 -> Int8 使用 INT8 对称量化
2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3_FNUZ)
3 -> FP8_UE8M0 使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True)
4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2_FNUZ)
quant_group_size: int 量化分组大小
0 -> 逐token量化 (per-channel)
128-> 每 128 元素一组 (per-group) 量化
fp8_round_scale: bool 是否将 FP8 缩放因子取整为 2 的幂
true -> 缩放因子 = 2^k,硬件零开销
false -> 缩放因子 = 任意浮点,精度更高
异步配置
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
......@@ -869,15 +879,25 @@ class Buffer:
Returns:
recv_x: a tensor or tuple with received tokens for each expert.
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
The second tensor is the corresponding scales for the first element with shape
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`,
if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`.
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
With `use_fp8=False`, the result would be a tensor shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
- packed_recv_x:
存储接收到的 Token 数据,形状为
`[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden]`。
数据类型取决于 quant_type:
quant_type == 1 -> torch.int8
quant_type == 2 -> torch.float8_e4m3fnuz
quant_type == 3 -> torch.float8_e4m3fnuz (UE8M0 使用 E4M3 格式存储)
quant_type == 4 -> torch.float8_e5m2fnuz
其他 (非量化) -> torch.bfloat16
- packed_recv_x_scales (可选):
仅在 quant_type > 0 时存在,存储量化的 Scale 值。
形状为 `[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, scales_col_size]`。
- 当 quant_type == 3 (UE8M0) 时:
scales_col_size = hidden // 512
数据类型为 torch.int (内部打包存储 4-bit scale)。
*注意:此模式强制要求 fp8_round_scale=True 且 group_size=128。
- 当 quant_type == 1, 2, 4 时:
scales_col_size = hidden // 128 (若使用 group_size) 或 1 (per-channel)。
数据类型为 torch.float32。
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
......@@ -889,14 +909,15 @@ class Buffer:
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \
self.runtime.low_latency_dispatch(x, topk_idx,
num_max_dispatch_tokens_per_rank, num_experts,
use_fp8, round_scale, use_ue8m0, use_int8,
quant_type, quant_group_size, fp8_round_scale,
async_finish, return_recv_hook)
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts)
tensors_to_record = (x, topk_idx,
packed_recv_x, packed_recv_x_scales, packed_recv_count,
packed_recv_src_info, packed_recv_layout_range)
return (packed_recv_x, packed_recv_x_scales) if use_fp8 else packed_recv_x, packed_recv_count, handle, \
EventOverlap(event, tensors_to_record if async_finish else None), hook
recv_x = (packed_recv_x, packed_recv_x_scales) if (quant_type > 0) else packed_recv_x
return recv_x, packed_recv_count, handle, EventOverlap(event, tensors_to_record if async_finish else None), hook
# noinspection PyTypeChecker
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
......
......@@ -6,7 +6,7 @@ from functools import partial
from typing import Literal, Set
import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_pg_back, per_token_cast_pc_back
def simulate_failure_and_skip(rank: int, api: Literal["dispatch", "combine", "clean"], expected_masked_ranks: Set[int]):
......@@ -58,7 +58,8 @@ def test_main(num_tokens: int,
x_list = [x]
# # NOTES: the last one is for performance testing
# # Most of the values in the perf case is lower than the threshold, casting most channels
# x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
# x_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1
# x_list = [x_rand]
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
......@@ -80,22 +81,35 @@ def test_main(num_tokens: int,
hash_value, num_times = 0, 0
for current_x in x_list:
for return_recv_hook in (False, True):
for dispatch_use_fp8 in (False, True):
for round_scale in (False, True) if dispatch_use_fp8 else (False,):
for use_ue8m0 in (False, True) if round_scale else (False,):
for quant_type in (0, 1, 2, 3, ): # 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
dispatch_use_quant = quant_type > 0
for fp8_round_scale in (False, True) if quant_type != 3 else (True, ):
for quant_group_size in (0, 128,) if quant_type >= 2 else (0, ):
if quant_type == 3 and (fp8_round_scale == False or quant_group_size == 0):
continue
num_times += 1
for _ in range((num_times % 2) + 1):
packed_recv_x, packed_recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0,
quant_type=quant_type, fp8_round_scale=fp8_round_scale, quant_group_size=quant_group_size,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait()
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \
if dispatch_use_fp8 else packed_recv_x.clone()
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_quant else packed_recv_x
if not dispatch_use_quant:
simulated_gemm_x = packed_recv_x.clone()
elif quant_group_size == 0:
simulated_gemm_x = per_token_cast_pc_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].reshape(-1)).view(packed_recv_x[0].shape)
elif quant_group_size == 128:
simulated_gemm_x = per_token_cast_pg_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape)
for i in range(num_local_experts if do_check else 0):
expert_id = rank * num_local_experts + i
recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i]
if not dispatch_use_quant:
recv_x = packed_recv_x[i]
elif quant_group_size == 0:
recv_x = per_token_cast_pc_back(packed_recv_x[0][i], packed_recv_x[1][i])
elif quant_group_size == 128:
recv_x = per_token_cast_pg_back(packed_recv_x[0][i], packed_recv_x[1][i])
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
# Check expert indices
......@@ -113,18 +127,25 @@ def test_main(num_tokens: int,
if current_x is x:
recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_x_amax = recv_x[:, :-128].amax(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
if round_scale:
assert torch.equal(recv_x_amin, recv_x_amax)
if dispatch_use_quant:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
else:
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
for j in range(num_ranks):
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
if not round_scale:
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if dispatch_use_fp8:
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
if quant_group_size != 0:
if fp8_round_scale:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
else:
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
for j in range(num_ranks):
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
if not fp8_round_scale:
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if dispatch_use_quant:
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
else:
......@@ -147,8 +168,8 @@ def test_main(num_tokens: int,
if do_check:
diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
assert torch.isnan(combined_x).sum().item() == 0
# if not round_scale:
assert diff < (9e-4 if dispatch_use_fp8 else 1e-5), f'Error: diff={diff}, dispatch_use_fp8={dispatch_use_fp8}, zero_copy={zero_copy}'
# if not fp8_round_scale:
assert diff < (9e-4 if dispatch_use_quant else 1e-5), f'Error: diff={diff}, dispatch_use_quant={dispatch_use_quant}, zero_copy={zero_copy}'
hash_value ^= hash_tensor(combined_x)
# noinspection PyShadowingNames
......@@ -162,7 +183,8 @@ def test_main(num_tokens: int,
def test_func(return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook)
quant_type=2, quant_group_size=0,
async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
......
import argparse
import random
import os
import torch
import torch.distributed as dist
from functools import partial
from typing import Literal, Set
import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back_int8
def test_main(num_tokens: int,
hidden: int,
num_experts: int,
num_topk: int,
rank: int,
num_ranks: int,
group: dist.ProcessGroup,
buffer: deep_ep.Buffer,
seed: int = 0):
torch.manual_seed(seed + rank)
random.seed(seed + rank)
assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks
# NOTES: the integers greater than 256 exceed the BF16 precision limit
rank_offset = 128
assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)'
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset)
x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1)
x_list = [x]
# # NOTES: the last one is for performance testing
# # Most of the values in the perf case is lower than the threshold, casting most channels
# x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()
# Randomly mask some positions
for _ in range(10):
topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
# For failure simulation and shrink testing
mask_status = torch.zeros((num_ranks,), dtype=torch.int, device='cuda')
# Check dispatch correctness
do_check = True
hash_value, num_times = 0, 0
for current_x in x_list:
for return_recv_hook in (False, ):
for dispatch_use_fp8 in (True, ):
for round_scale in (False, ):
for use_ue8m0 in (False, ):
num_times += 1
use_int8 = True
for _ in range(1):
packed_recv_x, packed_recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0, use_int8=use_int8,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait()
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
simulated_gemm_x = per_token_cast_back_int8(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, 1)).view(packed_recv_x[0].shape)
for i in range(num_local_experts if do_check else 0):
expert_id = rank * num_local_experts + i
recv_x = per_token_cast_back_int8(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i]
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
# Check expert indices
int_mask = (2 ** 32) - 1
num_valid_tokens = recv_count.item()
assert num_valid_tokens == (
recv_layout_range
& int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
assert num_valid_tokens == (all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status == 0].sum().item(
), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status == 0].sum().item()}'
if num_valid_tokens == 0:
continue
# Check received data
if current_x is x:
recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_x_amax = recv_x[:, :-128].amax(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x_amax)
if round_scale:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
elif use_int8:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.01
else:
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
# for j in range(num_ranks):
# if (not round_scale):
# check_tmp1 = (recv_x_amin == j - rank_offset).sum().item()
# check_tmp2 = (all_topk_idx[j] == expert_id).sum().item()
# print(f'rank: {rank}, j: {j}, check_tmp1: {check_tmp1}, check_tmp2: {check_tmp2}, diff: {abs(check_tmp1 - check_tmp2)}')
# assert abs(check_tmp1 - check_tmp2) < 3
# assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if dispatch_use_fp8:
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
else:
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
print("dispatch int 8 pass")
# noinspection PyShadowingNames
def large_gemm_with_hook(hook):
mat_0 = torch.randn((8192, 8192), dtype=torch.float)
mat_1 = torch.randn((8192, 8192), dtype=torch.float)
mat_0 @ mat_1
hook()
# noinspection PyShadowingNames
def test_func(return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=True, round_scale=False, use_ue8m0=False, use_int8=True,
async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
topk_weights,
handle,
return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
# Calculate bandwidth
scale_size = 1 # hidden / 128
num_fp8_bytes, num_bf16_bytes = (hidden + scale_size * 4 + 16), hidden * 2
num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
for i in range(num_tokens):
num_selections = (topk_idx[i] != -1).sum().item()
num_dispatch_comm_bytes += num_fp8_bytes * num_selections
num_combine_comm_bytes += num_bf16_bytes * num_selections
# Separate profiling
for return_recv_hook in (True, False):
group.barrier()
dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook),
kernel_names=('dispatch', 'combine'),
barrier_comm_profiling=True,
suppress_kineto_output=True,
num_kernels_per_period=2 if return_recv_hook else 1)
if not return_recv_hook:
print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us',
flush=True)
else:
print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | '
f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us',
flush=True)
return hash_value
# noinspection PyUnboundLocalVariable,PyShadowingNames
def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
num_tokens, hidden = args.num_tokens, args.hidden
num_topk, num_experts = args.num_topk, args.num_experts
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
if local_rank == 0:
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
buffer = deep_ep.Buffer(group,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_experts // num_ranks,
allow_nvlink_for_low_latency_mode=not args.disable_nvlink,
explicitly_destroy=True,
allow_mnnvl=args.allow_mnnvl)
test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=1)
# Destroy the buffer runtime and communication group
buffer.destroy()
dist.barrier()
dist.destroy_process_group()
if __name__ == '__main__':
# TODO: you may modify NUMA binding for less CPU overhead
# TODO: buggy with `num_tokens=512`
parser = argparse.ArgumentParser(description='Test low-latency EP kernels')
parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)')
parser.add_argument('--num-tokens', type=int, default=128, help='Number of tokens (default: 128)')
parser.add_argument('--hidden', type=int, default=7168, help='Hidden dimension size (default: 7168)')
parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)')
parser.add_argument('--num-experts', type=int, default=256, help='Number of experts (default: 288)')
parser.add_argument('--allow-mnnvl', action="store_true", help='Allow MNNVL for communication')
parser.add_argument('--disable-nvlink', action='store_true', help='Whether to disable NVLink for testing')
parser.add_argument("--pressure-test", action='store_true', help='Whether to do pressure test')
parser.add_argument("--shrink-test", action='store_true', help='Whether to simulate failure and test shrink mode')
parser.add_argument('--use-logfmt', action='store_true', help='Whether to test LogFMT combine')
args = parser.parse_args()
num_processes = args.num_processes
torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)
......@@ -57,28 +57,22 @@ def per_token_cast_to_fp8(x: torch.Tensor):
return (x_padded_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, aligned_n)[:, :n].contiguous(), (x_amax / 448.0).view(m, -1)
def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
if x_fp8.numel() == 0:
return x_fp8.to(torch.bfloat16)
def per_token_cast_pg_back(x: torch.Tensor, x_scales: torch.Tensor):
if x.numel() == 0:
return x.to(torch.bfloat16)
assert x_fp8.dim() == 2
m, n = x_fp8.shape
assert x.dim() == 2
m, n = x.shape
aligned_n = align_up(n, 128)
x_fp8_padded = torch.nn.functional.pad(x_fp8, (0, aligned_n - n), mode='constant', value=0)
x_padded = torch.nn.functional.pad(x, (0, aligned_n - n), mode='constant', value=0)
if x_scales.dtype == torch.int:
x_scales = x_scales.view(dtype=torch.uint8).to(torch.int) << 23
x_scales = x_scales.view(dtype=torch.float)
x_fp32_padded = x_fp8_padded.to(torch.float32).view(x_fp8.size(0), -1, 128)
x_scales = x_scales.view(x_fp8.size(0), -1, 1)
return (x_fp32_padded * x_scales).view(x_fp8_padded.shape).to(torch.bfloat16)[:,:n].contiguous()
x_fp32_padded = x_padded.to(torch.float32).view(x.size(0), -1, 128)
x_scales = x_scales.view(x.size(0), -1, 1)
return (x_fp32_padded * x_scales).view(x_padded.shape).to(torch.bfloat16)[:,:n].contiguous()
def per_token_cast_back_int8(x_int8: torch.Tensor, x_scales: torch.Tensor):
"""
x_int8: [m, n] int8 tensor
x_scales: [m, n] 或 [m, 1] 或 [m, n/128] 量化 scale float
return: [m, n] bf16 tensor
"""
def per_token_cast_pc_back(x_int8: torch.Tensor, x_scales: torch.Tensor):
if x_int8.numel() == 0:
return x_int8.to(torch.bfloat16)
......@@ -86,12 +80,9 @@ def per_token_cast_back_int8(x_int8: torch.Tensor, x_scales: torch.Tensor):
m, n = x_int8.shape
aligned_n = align_up(n, 128)
x_int8_padded = torch.nn.functional.pad(
x_int8, (0, aligned_n - n), mode='constant', value=0
)
x_int8_padded = torch.nn.functional.pad(x_int8, (0, aligned_n - n), mode='constant', value=0)
x_fp32_padded = x_int8_padded.to(torch.float32).view(m, -1, 1)
x_scales = x_scales.view(m, -1, 1).to(torch.float32)
# print(f'x_int8.shape: {x_int8.shape}, x_fp32_padded: {x_fp32_padded.shape}, x_scales: {x_scales.shape}')
x_deq = (x_fp32_padded * x_scales).view(m, aligned_n)
return x_deq[:, :n].to(torch.bfloat16).contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment