Commit cf4514dc authored by lishen's avatar lishen
Browse files

fp8量化细节调整

parent 44ec8bed
...@@ -1330,9 +1330,9 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1330,9 +1330,9 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
auto packed_recv_x_dtype = torch::kBFloat16; auto packed_recv_x_dtype = torch::kBFloat16;
switch (quant_type) { switch (quant_type) {
case 1: packed_recv_x_dtype = torch::kInt8; break; case 1: packed_recv_x_dtype = torch::kInt8; break;
case 2: packed_recv_x_dtype = torch::kFloat8_e4m3fnuz; break; case 2: packed_recv_x_dtype = torch::kFloat8_e4m3fn; break;
case 3: packed_recv_x_dtype = torch::kFloat8_e4m3fnuz; break; case 3: packed_recv_x_dtype = torch::kFloat8_e4m3fn; break;
case 4: packed_recv_x_dtype = torch::kFloat8_e5m2fnuz; break; case 4: packed_recv_x_dtype = torch::kFloat8_e5m2; break;
} }
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(packed_recv_x_dtype)); auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(packed_recv_x_dtype));
......
...@@ -152,10 +152,10 @@ __forceinline__ __device__ void pack_quantized_values( ...@@ -152,10 +152,10 @@ __forceinline__ __device__ void pack_quantized_values(
if constexpr (kQuantType == 4) { if constexpr (kQuantType == 4) {
// FP8 E5M2 // FP8 E5M2
fp8x2_ptr[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E5M2_FNUZ); fp8x2_ptr[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E5M2);
} else { } else {
// FP8 E4M3 或 UE8M0 // FP8 E4M3 或 UE8M0
fp8x2_ptr[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3_FNUZ); fp8x2_ptr[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3);
} }
} }
} }
...@@ -179,9 +179,9 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void ...@@ -179,9 +179,9 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
enum class QuantType { enum class QuantType {
None = 0, // 不进行量化 None = 0, // 不进行量化
Int8 = 1, // 采用 Int8 量化 Int8 = 1, // 采用 Int8 量化
FP8_E4M3 = 2, // 采用 FP8 量化 __HIP_E4M3_FNUZ FP8_E4M3 = 2, // 采用 FP8 量化 __HIP_E4M3
FP8_UE8M0 = 3, // 采用 FP8 量化 DeepseekV3.1的 UE8M0 FP8_UE8M0 = 3, // 采用 FP8 量化 DeepseekV3.1的 UE8M0
FP8_E5M2 = 4 // 采用 FP8 量化 __HIP_E5M2_FNUZ FP8_E5M2 = 4 // 采用 FP8 量化 __HIP_E5M2
}; };
const auto sm_id = static_cast<int>(blockIdx.x); const auto sm_id = static_cast<int>(blockIdx.x);
...@@ -247,7 +247,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void ...@@ -247,7 +247,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
__shared__ float channel_amaxf[kNumScales]; __shared__ float channel_amaxf[kNumScales];
if constexpr(kUseQuant8Bit && kQuantGroupSize == 0) { if constexpr(kUseQuant8Bit && kQuantGroupSize == 0) {
if (thread_id < kNumScales) { if (thread_id < kNumScales) {
channel_amaxf[thread_id] = kFP8Margin; channel_amaxf[thread_id] = 0.0;
} }
__syncthreads(); __syncthreads();
} }
...@@ -262,7 +262,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void ...@@ -262,7 +262,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
// Calculate local amax // Calculate local amax
auto bf16_values = reinterpret_cast<hip_bfloat16*>(&int4_value); auto bf16_values = reinterpret_cast<hip_bfloat16*>(&int4_value);
float fp32_values[kNumElemsPerRead]; float fp32_values[kNumElemsPerRead];
float amax = kFP8Margin, scale, scale_inv; float amax = 0.0, scale, scale_inv;
#pragma unroll #pragma unroll
for (int j = 0; j < kNumElemsPerRead; ++ j) { for (int j = 0; j < kNumElemsPerRead; ++ j) {
fp32_values[j] = static_cast<float>(bf16_values[j]); fp32_values[j] = static_cast<float>(bf16_values[j]);
...@@ -294,7 +294,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void ...@@ -294,7 +294,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
__syncthreads(); __syncthreads();
if constexpr(kUseQuant8Bit && kQuantGroupSize == 0) { if constexpr(kUseQuant8Bit && kQuantGroupSize == 0) {
float amax_per_token = kFP8Margin; float amax_per_token = 0.0;
// 并行规约,计算每个token的amax // 并行规约,计算每个token的amax
for (int s = 0; s < kNumScales; s+=kWarpSize) { for (int s = 0; s < kNumScales; s+=kWarpSize) {
int src_idx = s + lane_id; int src_idx = s + lane_id;
...@@ -310,7 +310,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void ...@@ -310,7 +310,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
// 根据最大值计算scale // 根据最大值计算scale
float scale, scale_inv; float scale, scale_inv;
calculate_quant8bit_scales<kQuantType>(amax_per_token, scale, scale_inv); calculate_quant8bit_scales<kQuantType>(amax_per_token, scale, scale_inv, fp8_round_scale);
if (thread_id == 0) { if (thread_id == 0) {
rdma_x_scales[0] = scale_inv; rdma_x_scales[0] = scale_inv;
} }
...@@ -344,8 +344,8 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void ...@@ -344,8 +344,8 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
uint64_t p2p_ptr = internode::shmem_get_p2p_ptr((void*)dst_ptr, rank, dst_rank); uint64_t p2p_ptr = internode::shmem_get_p2p_ptr((void*)dst_ptr, rank, dst_rank);
if (p2p_ptr == 0) { // RDMA if (p2p_ptr == 0) { // RDMA
internode_ll_putmem_nbi((void*)dst_ptr, (void*)src_ptr, internode_ll_putmem_nbi((void*)dst_ptr, (void*)src_ptr,
num_ranks, dst_rank, dst_expert_local_idx, num_ranks, dst_rank, dst_expert_local_idx,
num_bytes_per_msg); num_bytes_per_msg);
} else { // 本地 GPU 和 同一计算节点的 其他 GPU 地址 } else { // 本地 GPU 和 同一计算节点的 其他 GPU 地址
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls // NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr); const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
...@@ -571,9 +571,9 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -571,9 +571,9 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
/*量化类型枚举 /*量化类型枚举
0 -> None 不量化,保持原始精度 0 -> None 不量化,保持原始精度
1 -> Int8 使用 INT8 对称量化 1 -> Int8 使用 INT8 对称量化
2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3_FNUZ) 2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3)
3 -> FP8_UE8M0 使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True) 3 -> FP8_UE8M0 使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True)
4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2_FNUZ) 4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2)
*/ */
#define DISPATCH_LAUNCH_CASE(hidden) \ #define DISPATCH_LAUNCH_CASE(hidden) \
......
...@@ -342,8 +342,7 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) { ...@@ -342,8 +342,7 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
} }
// 设置不同的量化方式的最大值与相反数 // 设置不同的量化方式的最大值与相反数
constexpr float kFP8Margin = 0.0; constexpr float kFinfoAmaxE4M3 = 448.0f;
constexpr float kFinfoAmaxE4M3 = 240.0f;
constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3; constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3;
constexpr float kFinfoAmaxE5M2 = 57344.0f; constexpr float kFinfoAmaxE5M2 = 57344.0f;
constexpr float kFinfoAmaxInvE5M2 = 1.0f / kFinfoAmaxE5M2; constexpr float kFinfoAmaxInvE5M2 = 1.0f / kFinfoAmaxE5M2;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment