Commit cf4514dc authored by lishen's avatar lishen
Browse files

fp8量化细节调整

parent 44ec8bed
......@@ -1330,9 +1330,9 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
auto packed_recv_x_dtype = torch::kBFloat16;
switch (quant_type) {
case 1: packed_recv_x_dtype = torch::kInt8; break;
case 2: packed_recv_x_dtype = torch::kFloat8_e4m3fnuz; break;
case 3: packed_recv_x_dtype = torch::kFloat8_e4m3fnuz; break;
case 4: packed_recv_x_dtype = torch::kFloat8_e5m2fnuz; break;
case 2: packed_recv_x_dtype = torch::kFloat8_e4m3fn; break;
case 3: packed_recv_x_dtype = torch::kFloat8_e4m3fn; break;
case 4: packed_recv_x_dtype = torch::kFloat8_e5m2; break;
}
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(packed_recv_x_dtype));
......
......@@ -152,10 +152,10 @@ __forceinline__ __device__ void pack_quantized_values(
if constexpr (kQuantType == 4) {
// FP8 E5M2
fp8x2_ptr[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E5M2_FNUZ);
fp8x2_ptr[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E5M2);
} else {
// FP8 E4M3 或 UE8M0
fp8x2_ptr[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
fp8x2_ptr[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3);
}
}
}
......@@ -179,9 +179,9 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
enum class QuantType {
None = 0, // 不进行量化
Int8 = 1, // 采用 Int8 量化
FP8_E4M3 = 2, // 采用 FP8 量化 __HIP_E4M3_FNUZ
FP8_E4M3 = 2, // 采用 FP8 量化 __HIP_E4M3
FP8_UE8M0 = 3, // 采用 FP8 量化 DeepseekV3.1的 UE8M0
FP8_E5M2 = 4 // 采用 FP8 量化 __HIP_E5M2_FNUZ
FP8_E5M2 = 4 // 采用 FP8 量化 __HIP_E5M2
};
const auto sm_id = static_cast<int>(blockIdx.x);
......@@ -247,7 +247,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
__shared__ float channel_amaxf[kNumScales];
if constexpr(kUseQuant8Bit && kQuantGroupSize == 0) {
if (thread_id < kNumScales) {
channel_amaxf[thread_id] = kFP8Margin;
channel_amaxf[thread_id] = 0.0;
}
__syncthreads();
}
......@@ -262,7 +262,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
// Calculate local amax
auto bf16_values = reinterpret_cast<hip_bfloat16*>(&int4_value);
float fp32_values[kNumElemsPerRead];
float amax = kFP8Margin, scale, scale_inv;
float amax = 0.0, scale, scale_inv;
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; ++ j) {
fp32_values[j] = static_cast<float>(bf16_values[j]);
......@@ -294,7 +294,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
__syncthreads();
if constexpr(kUseQuant8Bit && kQuantGroupSize == 0) {
float amax_per_token = kFP8Margin;
float amax_per_token = 0.0;
// 并行规约,计算每个token的amax
for (int s = 0; s < kNumScales; s+=kWarpSize) {
int src_idx = s + lane_id;
......@@ -310,7 +310,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
// 根据最大值计算scale
float scale, scale_inv;
calculate_quant8bit_scales<kQuantType>(amax_per_token, scale, scale_inv);
calculate_quant8bit_scales<kQuantType>(amax_per_token, scale, scale_inv, fp8_round_scale);
if (thread_id == 0) {
rdma_x_scales[0] = scale_inv;
}
......@@ -571,9 +571,9 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
/*量化类型枚举
0 -> None 不量化,保持原始精度
1 -> Int8 使用 INT8 对称量化
2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3_FNUZ)
2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3)
3 -> FP8_UE8M0 使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True)
4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2_FNUZ)
4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2)
*/
#define DISPATCH_LAUNCH_CASE(hidden) \
......
......@@ -342,8 +342,7 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
}
// 设置不同的量化方式的最大值与相反数
constexpr float kFP8Margin = 0.0;
constexpr float kFinfoAmaxE4M3 = 240.0f;
constexpr float kFinfoAmaxE4M3 = 448.0f;
constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3;
constexpr float kFinfoAmaxE5M2 = 57344.0f;
constexpr float kFinfoAmaxInvE5M2 = 1.0f / kFinfoAmaxE5M2;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment