Commit 44ec8bed authored by lishen's avatar lishen
Browse files

支持更复杂的量化,包括fp8/int8/ue8m0,且支持per-group/per-channel

parent 81e56124
......@@ -136,7 +136,7 @@ struct LowLatencyLayout {
LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts) {
const int num_scales = hidden / FP8_QUANTIZATION_NUM_PER_CHANNEL;
const int num_scales = hidden / QUANTIZATION_GROUPSIZE;
// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
......
......@@ -1293,7 +1293,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8,
int quant_type, int quant_group_size, bool fp8_round_scale,
bool async, bool return_recv_hook) {
EP_HOST_ASSERT(low_latency_mode);
......@@ -1327,8 +1327,15 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
stream_wait(launch_stream, compute_stream);
// Allocate packed tensors
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
x.options().dtype(use_int8 ? torch::kInt8 : use_fp8 ? torch::kFloat8_e4m3fnuz: torch::kBFloat16));
auto packed_recv_x_dtype = torch::kBFloat16;
switch (quant_type) {
case 1: packed_recv_x_dtype = torch::kInt8; break;
case 2: packed_recv_x_dtype = torch::kFloat8_e4m3fnuz; break;
case 3: packed_recv_x_dtype = torch::kFloat8_e4m3fnuz; break;
case 4: packed_recv_x_dtype = torch::kFloat8_e5m2fnuz; break;
}
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(packed_recv_x_dtype));
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
......@@ -1336,21 +1343,28 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
// Allocate column-majored scales
auto packed_recv_x_scales = std::optional<torch::Tensor>();
void* packed_recv_x_scales_ptr = nullptr;
if (use_fp8) {
if (quant_type > 0) {
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
// TODO: support unaligned cases
EP_HOST_ASSERT(hidden % (FP8_QUANTIZATION_NUM_PER_CHANNEL * 4) == 0);
EP_HOST_ASSERT(!(use_ue8m0 && use_int8));
EP_HOST_ASSERT(hidden % (QUANTIZATION_GROUPSIZE * 4) == 0);
// 计算scale_col的大小
int scales_col_size = 1; // 默认为per-channel
if (quant_group_size > 0) {
if (quant_type == 3) { // FP8_UE8M0比较特殊
scales_col_size = hidden / (QUANTIZATION_GROUPSIZE * 4);
} else {
scales_col_size = hidden / QUANTIZATION_GROUPSIZE;
}
}
if (use_ue8m0) {
EP_HOST_ASSERT(round_scale);
packed_recv_x_scales = torch::empty({num_local_experts, hidden / (FP8_QUANTIZATION_NUM_PER_CHANNEL * 4), num_ranks * num_max_dispatch_tokens_per_rank},
// 设置packed_recv_x_scales
if (quant_type == 3) { // FP8_UE8M0比较特殊,需要单独处理
EP_HOST_ASSERT(fp8_round_scale && quant_group_size == 128);
packed_recv_x_scales = torch::empty({num_local_experts, scales_col_size, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kInt).device(torch::kCUDA));
} else if (use_int8) {
packed_recv_x_scales = torch::empty({num_local_experts, 1, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kFloat32).device(torch::kCUDA));
} else {
packed_recv_x_scales = torch::empty({num_local_experts, hidden / FP8_QUANTIZATION_NUM_PER_CHANNEL, num_ranks * num_max_dispatch_tokens_per_rank},
packed_recv_x_scales = torch::empty({num_local_experts, scales_col_size, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kFloat32).device(torch::kCUDA));
}
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
......@@ -1370,7 +1384,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
next_clean_meta.first, next_clean_meta.second,
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
use_fp8, round_scale, use_ue8m0, use_int8,
quant_type, quant_group_size, fp8_round_scale,
workspace, num_device_sms, launch_stream, phases);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
......
......@@ -177,7 +177,7 @@ public:
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8,
int quant_type, int quant_group_size, bool fp8_round_scale,
bool async, bool return_recv_hook);
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
......
......@@ -147,7 +147,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8,
int quant_type, int group_size, bool fp8_round_scale,
void* workspace, int num_device_sms, hipStream_t stream, int phases);
void combine(void* combined_x,
......
......@@ -23,7 +23,7 @@
#define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_RECV_PHASE 2
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define FP8_QUANTIZATION_NUM_PER_CHANNEL 128
#define QUANTIZATION_GROUPSIZE 128
#define DEFAULT_NUM_CU 20
#define DEFAULT_NUM_MAX_XGMI_CHUNKED_SEND_TOKENS 6
......
......@@ -116,20 +116,74 @@ internode_ll_long_atomic_add(long* dest, const long &value,
#endif // defined(FORCE_DUSHMEM_API)
}
template <bool kUseFP8, bool kUseUE8M0, bool kUseInt8, int kHidden>
/**
* @brief 将 K 个浮点数(BF16/FP32)量化并打包成 INT2(64位)存储
*
* @tparam kQuantType 量化类型 (1: Int8, 2/3: FP8_E4M3/UE8M0, 4: FP8_E5M2)
* @tparam kNumElemsPerRead 每次读取的元素数量 (通常为 2, 4, 8)
* @tparam SrcT 源数据类型 (float 或 __hip_bfloat16)
* @tparam DstT 目标数据类型 (int2 或 int4)
* @param src_values 源数据数组 (长度 >= kNumElemsPerRead)
* @param scale 缩放因子 (将 FP32 值映射到量化范围)
* @param[out] dst_vec 输出的 64 位向量 (int2 或 int4)
*/
template <int kQuantType, int kNumElemsPerRead, typename SrcT, typename DstT>
__forceinline__ __device__ void pack_quantized_values(
const SrcT* src_values, float scale, DstT& dst_vec) {
if constexpr (kQuantType == 1) {
// INT8 量化
auto int8_ptr = reinterpret_cast<int8_t*>(&dst_vec);
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; ++j) {
// 如果源是 bfloat16,先提升为 float
float fp32_value_scaled = static_cast<float>(src_values[j]) * scale;
// 使用 nearbyintf 进行四舍五入
int8_ptr[j] = static_cast<int8_t>(nearbyintf(fp32_value_scaled));
}
} else {
// FP8 量化 (E4M3, UE8M0, E5M2)
// 假设 dst_vec 能容纳 kNumElemsPerRead/2 个 fp8x2 元素
auto fp8x2_ptr = reinterpret_cast<__hip_fp8x2_storage_t*>(&dst_vec);
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; j += 2) {
// 处理两个元素
float2 fp32x2 = {static_cast<float>(src_values[j]) * scale, static_cast<float>(src_values[j + 1]) * scale};
if constexpr (kQuantType == 4) {
// FP8 E5M2
fp8x2_ptr[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E5M2_FNUZ);
} else {
// FP8 E4M3 或 UE8M0
fp8x2_ptr[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
}
}
}
}
template <int kHidden, int kQuantType=0, int kQuantGroupSize=0, int kMaxNumWarps=16>
__global__ __launch_bounds__(16 * kWarpSize, 1) void
dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* global_atomic_counter,
void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
int64_t* next_clean, int num_next_clean_int,
int num_tokens, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
int num_warp_groups, int num_warps_per_group,
bool round_scale, int phases) {
dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* global_atomic_counter,
void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
int64_t* next_clean, int num_next_clean_int,
int num_tokens, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
int num_warp_groups, int num_warps_per_group,
bool fp8_round_scale, int phases) {
// 定义量化类型的枚举
enum class QuantType {
None = 0, // 不进行量化
Int8 = 1, // 采用 Int8 量化
FP8_E4M3 = 2, // 采用 FP8 量化 __HIP_E4M3_FNUZ
FP8_UE8M0 = 3, // 采用 FP8 量化 DeepseekV3.1的 UE8M0
FP8_E5M2 = 4 // 采用 FP8 量化 __HIP_E5M2_FNUZ
};
const auto sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
......@@ -141,20 +195,22 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
// May extract UE8M0 from the scales
constexpr bool kUseQuant8Bit = kQuantType > 0;
constexpr bool kUseUE8M0 = kQuantType == 3; // QuantType::FP8_UE8M0
using scale_t = std::conditional_t<kUseUE8M0, uint8_t, float>;
using packed_t = std::conditional_t<kUseUE8M0, uint32_t, float>;
EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length");
// FP8 staffs
constexpr int kNumPerChannels = FP8_QUANTIZATION_NUM_PER_CHANNEL;
constexpr int kNumPerChannels = QUANTIZATION_GROUPSIZE;
constexpr int kNumScales = kHidden / kNumPerChannels;
const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__hip_fp8_storage_t) : sizeof(hip_bfloat16));
const size_t hidden_bytes = kHidden * (kUseQuant8Bit ? sizeof(__hip_fp8_storage_t) : sizeof(hip_bfloat16));
const size_t hidden_int4 = hidden_bytes / sizeof(int4);
// Message package: hidden data, FP8 scales, index at source
// NOTES: currently we have 3 reserved int fields for future use
using vec_t = typename std::conditional<kUseFP8, int2, int4>::type;
constexpr size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + kNumScales * sizeof(float)) : (kHidden * sizeof(hip_bfloat16)));
using vec_t = typename std::conditional<kUseQuant8Bit, int2, int4>::type;
constexpr size_t num_bytes_per_msg = sizeof(int4) + (kUseQuant8Bit ? (kHidden + kNumScales * sizeof(float)) : (kHidden * sizeof(hip_bfloat16)));
EP_STATIC_ASSERT(num_bytes_per_msg % sizeof(int4) == 0, "Invalid message size");
constexpr size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
......@@ -171,6 +227,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// 2. The last warp for reading `topk_idx` and count for per-expert information
if (warp_id < num_warps) {
constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(hip_bfloat16);
constexpr int kNumThreadPerGroup = QUANTIZATION_GROUPSIZE / kNumElemsPerRead;
// EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize % kNumPerChannels == 0, "Invalid vectorization");
const auto num_threads = (num_warps - 1) * kWarpSize;
......@@ -186,10 +243,11 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
__shared__ float int8_amaxf[kNumScales];
if constexpr(kUseInt8) {
// 用于记录per-channel量化的amax
__shared__ float channel_amaxf[kNumScales];
if constexpr(kUseQuant8Bit && kQuantGroupSize == 0) {
if (thread_id < kNumScales) {
int8_amaxf[thread_id] = kFP8Margin;
channel_amaxf[thread_id] = kFP8Margin;
}
__syncthreads();
}
......@@ -200,7 +258,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Read
auto int4_value = __ldg(x_int4 + i);
if constexpr(kUseFP8) {
if constexpr(kUseQuant8Bit) {
// Calculate local amax
auto bf16_values = reinterpret_cast<hip_bfloat16*>(&int4_value);
float fp32_values[kNumElemsPerRead];
......@@ -212,25 +270,20 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
}
// Reduce amax and scale
EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize / kNumPerChannels == 4, "Invalid vectorization");
amax = warp_reduce_max<16>(amax);
const int scale_offset = i * kNumElemsPerRead / FP8_QUANTIZATION_NUM_PER_CHANNEL;
amax = warp_reduce_max<kNumThreadPerGroup>(amax);
const int scale_offset = i * kNumElemsPerRead / QUANTIZATION_GROUPSIZE;
if constexpr(kUseInt8) {
if constexpr(kQuantGroupSize == 0) {
// 记录每128个数的最大值
int8_amaxf[scale_offset] = fmaxf(amax, int8_amaxf[scale_offset]);
channel_amaxf[scale_offset] = fmaxf(amax, channel_amaxf[scale_offset]);
} else {
calculate_fp8_scales(amax, scale, scale_inv, round_scale);
if (lane_id % 16 == 0)
calculate_quant8bit_scales<kQuantType>(amax, scale, scale_inv, fp8_round_scale);
if (lane_id % kNumThreadPerGroup == 0)
rdma_x_scales[scale_offset] = scale_inv;
// Cast into send buffer
vec_t int2_value;
auto fp8x2_values = reinterpret_cast<__hip_fp8x2_storage_t*>(&int2_value);
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; j += 2) {
float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale};
fp8x2_values[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
}
pack_quantized_values<kQuantType, kNumElemsPerRead>(fp32_values, scale, int2_value);
rdma_x_vec[i] = int2_value;
}
} else {
......@@ -240,24 +293,24 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
}
__syncthreads();
if constexpr(kUseInt8) {
if constexpr(kUseQuant8Bit && kQuantGroupSize == 0) {
float amax_per_token = kFP8Margin;
// 并行规约,计算每个token的amax
for (int s = 0; s < kNumScales; s+=kWarpSize) {
int src_idx = s + lane_id;
float tmp_amaxf = 0;
if(src_idx < kNumScales) {
tmp_amaxf = int8_amaxf[src_idx];
tmp_amaxf = channel_amaxf[src_idx];
}
tmp_amaxf = warp_reduce_max<kWarpSize>(tmp_amaxf);
int8_amaxf[0] = fmaxf(tmp_amaxf, int8_amaxf[0]);
channel_amaxf[0] = fmaxf(tmp_amaxf, channel_amaxf[0]);
__syncthreads();
}
amax_per_token = int8_amaxf[0];
amax_per_token = channel_amaxf[0];
// 根据最大值计算scale
float scale, scale_inv;
calculate_int8_scales(amax_per_token, scale, scale_inv);
calculate_quant8bit_scales<kQuantType>(amax_per_token, scale, scale_inv);
if (thread_id == 0) {
rdma_x_scales[0] = scale_inv;
}
......@@ -269,13 +322,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Cast into send buffer
vec_t int2_value;
auto int8_values = reinterpret_cast<int8_t*>(&int2_value);
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; ++ j) {
auto fp32_value = static_cast<float>(bf16_values[j]);
auto fp32_value_scaled = fp32_value * scale;
int8_values[j] = static_cast<int8_t>(nearbyintf(fp32_value_scaled));
}
pack_quantized_values<kQuantType, kNumElemsPerRead>(bf16_values, scale, int2_value);
rdma_x_vec[i] = int2_value;
}
__syncthreads();
......@@ -392,11 +439,10 @@ LOW_LATENCY_DISPATCH_RECV:
}
// 16 is the max possible number of warps in AMD GPUs
constexpr int kMaxNumWarps = 1024 / kWarpSize;
constexpr int num_sync_large_iteration = kMaxNumWarps ;
__shared__ volatile int sync_large_warp_counters[num_sync_large_iteration];
#pragma unroll
#pragma unroll
for (int i = thread_id; i < num_sync_large_iteration; i += blockDim.x) {
sync_large_warp_counters[i] = 0;
}
......@@ -416,7 +462,7 @@ LOW_LATENCY_DISPATCH_RECV:
const auto num_aligned_scales = ALIGN<int>(kNumScales, sizeof(float) / sizeof(scale_t));
const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank *
(kUseInt8 ? 1 : num_aligned_scales);
(kQuantType == 1 ? 1 : num_aligned_scales);
// Shared between sub-warps in warp groups
__shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups];
......@@ -461,14 +507,14 @@ LOW_LATENCY_DISPATCH_RECV:
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);
// Copy scales
if constexpr(kUseFP8) {
if constexpr(kUseQuant8Bit) {
const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t));
const auto token_idx = recv_token_begin_idx + i;
const auto token_stride = num_elems_per_pack;
const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack;
if constexpr(kUseInt8) {
if constexpr(kQuantType == 1) {
if (lane_id == 0) {
recv_x_scales[token_idx] = ld_nc_global(src_scales);
}
......@@ -500,12 +546,13 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8,
int quant_type, int quant_group_size, bool fp8_round_scale,
void* workspace, int num_device_sms,
hipStream_t stream, int phases) {
constexpr int kMaxNumWarps = 16;
constexpr int kNumMaxTopK = 11;
const int num_warp_groups = ceil_div(num_experts, num_device_sms);
const int num_warps_per_group = 16 / num_warp_groups;
const int num_warps_per_group = kMaxNumWarps / num_warp_groups;
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0);
EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group);
......@@ -518,33 +565,54 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
#define DISPATCH_LAUNCH_CASE(hidden) { \
auto dispatch_func = dispatch<false, false, false, hidden>; \
if (use_fp8 and not use_ue8m0) \
dispatch_func = dispatch<true, false, false, hidden>; \
if (use_fp8 and use_ue8m0) \
dispatch_func = dispatch<true, true, false, hidden>; \
if (use_int8) \
dispatch_func = dispatch<true, false, true, hidden>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \
packed_recv_count, \
global_atomic_counter, \
rdma_recv_x, rdma_recv_count, rdma_x, \
x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, \
num_warp_groups, num_warps_per_group, round_scale, phases); } break
// 限制groupsize的大小
EP_HOST_ASSERT(quant_group_size == 0 || quant_group_size == 128);
/*量化类型枚举
0 -> None 不量化,保持原始精度
1 -> Int8 使用 INT8 对称量化
2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3_FNUZ)
3 -> FP8_UE8M0 使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True)
4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2_FNUZ)
*/
#define DISPATCH_LAUNCH_CASE(hidden) \
{ \
auto dispatch_func = dispatch<hidden, 0, 0, kMaxNumWarps>; \
if (quant_group_size == 0) { \
switch (quant_type) { \
case 1: dispatch_func = dispatch<hidden, 1, 0, kMaxNumWarps>; break; \
case 2: dispatch_func = dispatch<hidden, 2, 0, kMaxNumWarps>; break; \
case 3: dispatch_func = dispatch<hidden, 3, 0, kMaxNumWarps>; break; \
case 4: dispatch_func = dispatch<hidden, 4, 0, kMaxNumWarps>; break; \
} \
} else { \
switch (quant_type) { \
case 1: dispatch_func = dispatch<hidden, 1, 128, kMaxNumWarps>; break; \
case 2: dispatch_func = dispatch<hidden, 2, 128, kMaxNumWarps>; break; \
case 3: dispatch_func = dispatch<hidden, 3, 128, kMaxNumWarps>; break; \
case 4: dispatch_func = dispatch<hidden, 4, 128, kMaxNumWarps>; break; \
} \
} \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, packed_recv_count, \
global_atomic_counter, \
rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, \
num_warp_groups, num_warps_per_group, fp8_round_scale, phases); \
} \
break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream);
SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE
}
template <int kHidden, int kNumMaxTopk>
template <int kHidden, int kNumMaxTopk, int kMaxNumWarps=16>
__global__ __launch_bounds__(16 * kWarpSize, 1) void
combine(void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
......@@ -574,12 +642,11 @@ combine(void* combined_x,
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;
// Message package
EP_STATIC_ASSERT(kHidden % FP8_QUANTIZATION_NUM_PER_CHANNEL == 0, "Invalid hidden");
EP_STATIC_ASSERT(kHidden % QUANTIZATION_GROUPSIZE == 0, "Invalid hidden");
constexpr size_t num_bytes_per_slot = kHidden * sizeof(hip_bfloat16);
EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
// 16 is the max possible number of warps in AMD GPUs
constexpr int kMaxNumWarps = 1024 / kWarpSize;
// 初始化用于细粒度warp间同步的计数器数组
__shared__ volatile int sync_large_warp_counters[kMaxNumWarps];
if (threadIdx.x==0){
#pragma unroll
......@@ -755,9 +822,10 @@ void combine(void* combined_x,
int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, int num_device_sms, hipStream_t stream,
int phases, bool zero_copy) {
constexpr int kMaxNumWarps = 16;
constexpr int kNumMaxTopk = 11;
const int num_warp_groups = ceil_div(num_experts, num_device_sms);
const int num_warps_per_group = 16 / num_warp_groups; // num_warps_per_group>1, "Requires more than one warp per group"
const int num_warps_per_group = kMaxNumWarps / num_warp_groups; // num_warps_per_group>1, "Requires more than one warp per group"
const int num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms);
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm >= 0);
......@@ -770,20 +838,20 @@ void combine(void* combined_x,
EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
#define COMBINE_LAUNCH_CASE(hidden) { \
auto combine_func = combine<hidden, kNumMaxTopk>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
combined_x, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \
global_atomic_counter, \
combine_wait_recv_cost_stats, \
next_clean, num_next_clean_int, \
atomic_clean_flag, \
num_combined_tokens, hidden, num_topk, \
num_max_dispatch_tokens_per_rank, \
num_experts, rank, num_ranks, \
num_warp_groups, num_warps_per_group, phases, zero_copy); } break
#define COMBINE_LAUNCH_CASE(hidden) \
{ \
auto combine_func = combine<hidden, kNumMaxTopk, kMaxNumWarps>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \
global_atomic_counter, combine_wait_recv_cost_stats, \
next_clean, num_next_clean_int, \
atomic_clean_flag, num_combined_tokens, hidden, \
num_topk, num_max_dispatch_tokens_per_rank, \
num_experts, rank, num_ranks, \
num_warp_groups, num_warps_per_group, phases, zero_copy); \
} \
break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream);
SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);
......
......@@ -341,10 +341,14 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
return *reinterpret_cast<dtype_t *>(recv_int_values);
}
constexpr float kFP8Margin = 1e-4;
// 设置不同的量化方式的最大值与相反数
constexpr float kFP8Margin = 0.0;
constexpr float kFinfoAmaxE4M3 = 240.0f;
constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3;
constexpr float kInt8Amax = 127.0f;
constexpr float kFinfoAmaxE5M2 = 57344.0f;
constexpr float kFinfoAmaxInvE5M2 = 1.0f / kFinfoAmaxE5M2;
constexpr float kFinfoAmaxInt8 = 127.0f;
constexpr float kFinfoAmaxInvInt8 = 1.0f / 127.0f;
__forceinline__ __device__ float fast_pow2(int x) {
// We can ensure `-126 <= x and x <= 127`
......@@ -359,22 +363,33 @@ __forceinline__ __device__ int fast_log2_ceil(float x) {
return exp_x - 127 + (man_bits != 0);
}
__forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv, bool round_scale) {
if (round_scale) {
auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3);
scale = fast_pow2(-exp_scale_inv);
scale_inv = fast_pow2(exp_scale_inv);
} else {
scale_inv = amax * kFinfoAmaxInvE4M3;
scale = kFinfoAmaxE4M3 / amax;
template <int kQuantType>
__forceinline__ __device__ void calculate_quant8bit_scales(float amax, float& scale, float& scale_inv, bool round_scale=0) {
amax = fmaxf(amax, 1e-6f);
if constexpr(kQuantType == 1) { // 使用 INT8 对称量化
scale_inv = kFinfoAmaxInvInt8 * amax;
scale = kFinfoAmaxInt8 / amax;
} else if constexpr(kQuantType == 2 || kQuantType == 3) { // 使用 FP8_E4M3 或 FP8_UE8M0 非对称量化
if (round_scale) {
auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3);
scale = fast_pow2(-exp_scale_inv);
scale_inv = fast_pow2(exp_scale_inv);
} else {
scale_inv = amax * kFinfoAmaxInvE4M3;
scale = kFinfoAmaxE4M3 / amax;
}
} else if constexpr(kQuantType == 4) { // 使用 FP8_E5M2 对称量化
if (round_scale) {
auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE5M2);
scale = fast_pow2(-exp_scale_inv);
scale_inv = fast_pow2(exp_scale_inv);
} else {
scale_inv = amax * kFinfoAmaxInvE5M2;
scale = kFinfoAmaxE5M2 / amax;
}
}
}
__forceinline__ __device__ void calculate_int8_scales(float amax, float& scale, float& scale_inv) {
scale = kInt8Amax / amax;
scale_inv = amax / kInt8Amax;
}
template <bool kIsUE8M0, typename out_dtype_t = std::conditional_t<kIsUE8M0, uint8_t, float>>
__forceinline__ __device__ out_dtype_t extract_required_scale_format(float value) {
if constexpr (kIsUE8M0) {
......
......@@ -841,7 +841,7 @@ class Buffer:
# noinspection PyTypeChecker
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
num_max_dispatch_tokens_per_rank: int, num_experts: int,
use_fp8: bool = True, round_scale: bool = False, use_ue8m0: bool = False, use_int8: bool = False,
quant_type: int = 1, quant_group_size: int = 0, fp8_round_scale: bool = False,
async_finish: bool = False, return_recv_hook: bool = False) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
"""
......@@ -858,10 +858,20 @@ class Buffer:
only several top-k shapes are supported. `-1` indices (not selecting any expert) are supported.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_experts: the number of all experts.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
round_scale: whether round the scaling factors into power of 2.
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
use_int8: whether to enable INT8 casting.
量化配置
quant_type: int 量化类型枚举
0 -> None 不量化,保持原始精度
1 -> Int8 使用 INT8 对称量化
2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3_FNUZ)
3 -> FP8_UE8M0 使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True)
4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2_FNUZ)
quant_group_size: int 量化分组大小
0 -> 逐token量化 (per-channel)
128-> 每 128 元素一组 (per-group) 量化
fp8_round_scale: bool 是否将 FP8 缩放因子取整为 2 的幂
true -> 缩放因子 = 2^k,硬件零开销
false -> 缩放因子 = 任意浮点,精度更高
异步配置
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
......@@ -869,15 +879,25 @@ class Buffer:
Returns:
recv_x: a tensor or tuple with received tokens for each expert.
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
The second tensor is the corresponding scales for the first element with shape
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`,
if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`.
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
With `use_fp8=False`, the result would be a tensor shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
- packed_recv_x:
存储接收到的 Token 数据,形状为
`[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden]`。
数据类型取决于 quant_type:
quant_type == 1 -> torch.int8
quant_type == 2 -> torch.float8_e4m3fnuz
quant_type == 3 -> torch.float8_e4m3fnuz (UE8M0 使用 E4M3 格式存储)
quant_type == 4 -> torch.float8_e5m2fnuz
其他 (非量化) -> torch.bfloat16
- packed_recv_x_scales (可选):
仅在 quant_type > 0 时存在,存储量化的 Scale 值。
形状为 `[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, scales_col_size]`。
- 当 quant_type == 3 (UE8M0) 时:
scales_col_size = hidden // 512
数据类型为 torch.int (内部打包存储 4-bit scale)。
*注意:此模式强制要求 fp8_round_scale=True 且 group_size=128。
- 当 quant_type == 1, 2, 4 时:
scales_col_size = hidden // 128 (若使用 group_size) 或 1 (per-channel)。
数据类型为 torch.float32。
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
......@@ -889,14 +909,15 @@ class Buffer:
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \
self.runtime.low_latency_dispatch(x, topk_idx,
num_max_dispatch_tokens_per_rank, num_experts,
use_fp8, round_scale, use_ue8m0, use_int8,
quant_type, quant_group_size, fp8_round_scale,
async_finish, return_recv_hook)
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts)
tensors_to_record = (x, topk_idx,
packed_recv_x, packed_recv_x_scales, packed_recv_count,
packed_recv_src_info, packed_recv_layout_range)
return (packed_recv_x, packed_recv_x_scales) if use_fp8 else packed_recv_x, packed_recv_count, handle, \
EventOverlap(event, tensors_to_record if async_finish else None), hook
recv_x = (packed_recv_x, packed_recv_x_scales) if (quant_type > 0) else packed_recv_x
return recv_x, packed_recv_count, handle, EventOverlap(event, tensors_to_record if async_finish else None), hook
# noinspection PyTypeChecker
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
......
......@@ -58,7 +58,8 @@ def test_main(num_tokens: int,
x_list = [x]
# # NOTES: the last one is for performance testing
# # Most of the values in the perf case is lower than the threshold, casting most channels
# x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
# x_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1
# x_list = [x_rand]
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
......@@ -80,14 +81,19 @@ def test_main(num_tokens: int,
hash_value, num_times = 0, 0
for current_x in x_list:
for return_recv_hook in (False, True):
for dispatch_use_fp8 in (False, True):
for round_scale in (False, True) if dispatch_use_fp8 else (False,):
for use_ue8m0 in (False, True) if round_scale else (False,):
for quant_type in (0, 2, 3, ): # 0: 不量化, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True)
dispatch_use_fp8 = quant_type > 0
for fp8_round_scale in (False, True) if dispatch_use_fp8 else (False, ):
for quant_group_size in (128, ):
# 跳过不支持的情况
if quant_type == 3 and fp8_round_scale == False:
continue
num_times += 1
for _ in range((num_times % 2) + 1):
packed_recv_x, packed_recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0,
quant_type=quant_type, fp8_round_scale=fp8_round_scale, quant_group_size=quant_group_size,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait()
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
......@@ -115,13 +121,13 @@ def test_main(num_tokens: int,
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
if round_scale:
if fp8_round_scale:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
else:
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
for j in range(num_ranks):
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
if not round_scale:
if not fp8_round_scale:
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if dispatch_use_fp8:
......@@ -147,7 +153,7 @@ def test_main(num_tokens: int,
if do_check:
diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
assert torch.isnan(combined_x).sum().item() == 0
# if not round_scale:
# if not fp8_round_scale:
assert diff < (9e-4 if dispatch_use_fp8 else 1e-5), f'Error: diff={diff}, dispatch_use_fp8={dispatch_use_fp8}, zero_copy={zero_copy}'
hash_value ^= hash_tensor(combined_x)
......@@ -162,7 +168,8 @@ def test_main(num_tokens: int,
def test_func(return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook)
quant_type=2, quant_group_size=128,
async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
......
......@@ -54,16 +54,16 @@ def test_main(num_tokens: int,
do_check = True
hash_value, num_times = 0, 0
for current_x in x_list:
for return_recv_hook in (False, ):
for dispatch_use_fp8 in (True, ):
for round_scale in (False, ):
for use_ue8m0 in (False, ):
for return_recv_hook in (False, True):
for quant_type in (1, ):
for fp8_round_scale in (False, ):
for quant_group_size in (0, ):
dispatch_use_fp8 = quant_type > 0
num_times += 1
use_int8 = True
for _ in range(1):
packed_recv_x, packed_recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0, use_int8=use_int8,
quant_type=quant_type, quant_group_size=quant_group_size,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait()
......@@ -97,9 +97,7 @@ def test_main(num_tokens: int,
assert torch.equal(recv_x_amin, recv_x_amax)
if round_scale:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
elif use_int8:
if quant_type == 1:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.01
else:
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
......@@ -131,7 +129,7 @@ def test_main(num_tokens: int,
def test_func(return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=True, round_scale=False, use_ue8m0=False, use_int8=True,
quant_type=1, quant_group_size=0,
async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment