Commit 44ec8bed authored by lishen's avatar lishen
Browse files

支持更复杂的量化,包括fp8/int8/ue8m0,且支持per-group/per-channel

parent 81e56124
...@@ -136,7 +136,7 @@ struct LowLatencyLayout { ...@@ -136,7 +136,7 @@ struct LowLatencyLayout {
LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts) { int num_ranks, int num_experts) {
const int num_scales = hidden / FP8_QUANTIZATION_NUM_PER_CHANNEL; const int num_scales = hidden / QUANTIZATION_GROUPSIZE;
// Dispatch and combine layout: // Dispatch and combine layout:
// - 2 symmetric odd/even send buffer // - 2 symmetric odd/even send buffer
......
...@@ -1293,7 +1293,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int ...@@ -1293,7 +1293,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8, int quant_type, int quant_group_size, bool fp8_round_scale,
bool async, bool return_recv_hook) { bool async, bool return_recv_hook) {
EP_HOST_ASSERT(low_latency_mode); EP_HOST_ASSERT(low_latency_mode);
...@@ -1327,8 +1327,15 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1327,8 +1327,15 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
stream_wait(launch_stream, compute_stream); stream_wait(launch_stream, compute_stream);
// Allocate packed tensors // Allocate packed tensors
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, auto packed_recv_x_dtype = torch::kBFloat16;
x.options().dtype(use_int8 ? torch::kInt8 : use_fp8 ? torch::kFloat8_e4m3fnuz: torch::kBFloat16)); switch (quant_type) {
case 1: packed_recv_x_dtype = torch::kInt8; break;
case 2: packed_recv_x_dtype = torch::kFloat8_e4m3fnuz; break;
case 3: packed_recv_x_dtype = torch::kFloat8_e4m3fnuz; break;
case 4: packed_recv_x_dtype = torch::kFloat8_e5m2fnuz; break;
}
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(packed_recv_x_dtype));
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA)); auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
...@@ -1336,21 +1343,28 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1336,21 +1343,28 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
// Allocate column-majored scales // Allocate column-majored scales
auto packed_recv_x_scales = std::optional<torch::Tensor>(); auto packed_recv_x_scales = std::optional<torch::Tensor>();
void* packed_recv_x_scales_ptr = nullptr; void* packed_recv_x_scales_ptr = nullptr;
if (use_fp8) { if (quant_type > 0) {
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4"); EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
// TODO: support unaligned cases // TODO: support unaligned cases
EP_HOST_ASSERT(hidden % (FP8_QUANTIZATION_NUM_PER_CHANNEL * 4) == 0); EP_HOST_ASSERT(hidden % (QUANTIZATION_GROUPSIZE * 4) == 0);
EP_HOST_ASSERT(!(use_ue8m0 && use_int8));
if (use_ue8m0) { // 计算scale_col的大小
EP_HOST_ASSERT(round_scale); int scales_col_size = 1; // 默认为per-channel
packed_recv_x_scales = torch::empty({num_local_experts, hidden / (FP8_QUANTIZATION_NUM_PER_CHANNEL * 4), num_ranks * num_max_dispatch_tokens_per_rank}, if (quant_group_size > 0) {
if (quant_type == 3) { // FP8_UE8M0比较特殊
scales_col_size = hidden / (QUANTIZATION_GROUPSIZE * 4);
} else {
scales_col_size = hidden / QUANTIZATION_GROUPSIZE;
}
}
// 设置packed_recv_x_scales
if (quant_type == 3) { // FP8_UE8M0比较特殊,需要单独处理
EP_HOST_ASSERT(fp8_round_scale && quant_group_size == 128);
packed_recv_x_scales = torch::empty({num_local_experts, scales_col_size, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kInt).device(torch::kCUDA)); torch::dtype(torch::kInt).device(torch::kCUDA));
} else if (use_int8) {
packed_recv_x_scales = torch::empty({num_local_experts, 1, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kFloat32).device(torch::kCUDA));
} else { } else {
packed_recv_x_scales = torch::empty({num_local_experts, hidden / FP8_QUANTIZATION_NUM_PER_CHANNEL, num_ranks * num_max_dispatch_tokens_per_rank}, packed_recv_x_scales = torch::empty({num_local_experts, scales_col_size, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kFloat32).device(torch::kCUDA)); torch::dtype(torch::kFloat32).device(torch::kCUDA));
} }
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
...@@ -1370,7 +1384,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1370,7 +1384,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
next_clean_meta.first, next_clean_meta.second, next_clean_meta.first, next_clean_meta.second,
num_tokens, hidden, num_max_dispatch_tokens_per_rank, num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks, num_topk, num_experts, rank, num_ranks,
use_fp8, round_scale, use_ue8m0, use_int8, quant_type, quant_group_size, fp8_round_scale,
workspace, num_device_sms, launch_stream, phases); workspace, num_device_sms, launch_stream, phases);
}; };
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
......
...@@ -177,7 +177,7 @@ public: ...@@ -177,7 +177,7 @@ public:
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8, int quant_type, int quant_group_size, bool fp8_round_scale,
bool async, bool return_recv_hook); bool async, bool return_recv_hook);
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
......
...@@ -147,7 +147,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -147,7 +147,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int64_t* next_clean, int num_next_clean_int, int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8, int quant_type, int group_size, bool fp8_round_scale,
void* workspace, int num_device_sms, hipStream_t stream, int phases); void* workspace, int num_device_sms, hipStream_t stream, int phases);
void combine(void* combined_x, void combine(void* combined_x,
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#define LOW_LATENCY_SEND_PHASE 1 #define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_RECV_PHASE 2 #define LOW_LATENCY_RECV_PHASE 2
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3 #define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define FP8_QUANTIZATION_NUM_PER_CHANNEL 128 #define QUANTIZATION_GROUPSIZE 128
#define DEFAULT_NUM_CU 20 #define DEFAULT_NUM_CU 20
#define DEFAULT_NUM_MAX_XGMI_CHUNKED_SEND_TOKENS 6 #define DEFAULT_NUM_MAX_XGMI_CHUNKED_SEND_TOKENS 6
......
...@@ -116,9 +116,54 @@ internode_ll_long_atomic_add(long* dest, const long &value, ...@@ -116,9 +116,54 @@ internode_ll_long_atomic_add(long* dest, const long &value,
#endif // defined(FORCE_DUSHMEM_API) #endif // defined(FORCE_DUSHMEM_API)
} }
template <bool kUseFP8, bool kUseUE8M0, bool kUseInt8, int kHidden> /**
* @brief 将 K 个浮点数(BF16/FP32)量化并打包成 INT2(64位)存储
*
* @tparam kQuantType 量化类型 (1: Int8, 2/3: FP8_E4M3/UE8M0, 4: FP8_E5M2)
* @tparam kNumElemsPerRead 每次读取的元素数量 (通常为 2, 4, 8)
* @tparam SrcT 源数据类型 (float 或 __hip_bfloat16)
* @tparam DstT 目标数据类型 (int2 或 int4)
* @param src_values 源数据数组 (长度 >= kNumElemsPerRead)
* @param scale 缩放因子 (将 FP32 值映射到量化范围)
* @param[out] dst_vec 输出的 64 位向量 (int2 或 int4)
*/
template <int kQuantType, int kNumElemsPerRead, typename SrcT, typename DstT>
__forceinline__ __device__ void pack_quantized_values(
const SrcT* src_values, float scale, DstT& dst_vec) {
if constexpr (kQuantType == 1) {
// INT8 量化
auto int8_ptr = reinterpret_cast<int8_t*>(&dst_vec);
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; ++j) {
// 如果源是 bfloat16,先提升为 float
float fp32_value_scaled = static_cast<float>(src_values[j]) * scale;
// 使用 nearbyintf 进行四舍五入
int8_ptr[j] = static_cast<int8_t>(nearbyintf(fp32_value_scaled));
}
} else {
// FP8 量化 (E4M3, UE8M0, E5M2)
// 假设 dst_vec 能容纳 kNumElemsPerRead/2 个 fp8x2 元素
auto fp8x2_ptr = reinterpret_cast<__hip_fp8x2_storage_t*>(&dst_vec);
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; j += 2) {
// 处理两个元素
float2 fp32x2 = {static_cast<float>(src_values[j]) * scale, static_cast<float>(src_values[j + 1]) * scale};
if constexpr (kQuantType == 4) {
// FP8 E5M2
fp8x2_ptr[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E5M2_FNUZ);
} else {
// FP8 E4M3 或 UE8M0
fp8x2_ptr[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
}
}
}
}
template <int kHidden, int kQuantType=0, int kQuantGroupSize=0, int kMaxNumWarps=16>
__global__ __launch_bounds__(16 * kWarpSize, 1) void __global__ __launch_bounds__(16 * kWarpSize, 1) void
dispatch(void* packed_recv_x, void* packed_recv_x_scales, dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count, int* packed_recv_count,
int* global_atomic_counter, int* global_atomic_counter,
...@@ -129,7 +174,16 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -129,7 +174,16 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int num_tokens, int num_max_dispatch_tokens_per_rank, int num_tokens, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
int num_warp_groups, int num_warps_per_group, int num_warp_groups, int num_warps_per_group,
bool round_scale, int phases) { bool fp8_round_scale, int phases) {
// 定义量化类型的枚举
enum class QuantType {
None = 0, // 不进行量化
Int8 = 1, // 采用 Int8 量化
FP8_E4M3 = 2, // 采用 FP8 量化 __HIP_E4M3_FNUZ
FP8_UE8M0 = 3, // 采用 FP8 量化 DeepseekV3.1的 UE8M0
FP8_E5M2 = 4 // 采用 FP8 量化 __HIP_E5M2_FNUZ
};
const auto sm_id = static_cast<int>(blockIdx.x); const auto sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x); const auto thread_id = static_cast<int>(threadIdx.x);
const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id(); const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
...@@ -141,20 +195,22 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -141,20 +195,22 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id; const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
// May extract UE8M0 from the scales // May extract UE8M0 from the scales
constexpr bool kUseQuant8Bit = kQuantType > 0;
constexpr bool kUseUE8M0 = kQuantType == 3; // QuantType::FP8_UE8M0
using scale_t = std::conditional_t<kUseUE8M0, uint8_t, float>; using scale_t = std::conditional_t<kUseUE8M0, uint8_t, float>;
using packed_t = std::conditional_t<kUseUE8M0, uint32_t, float>; using packed_t = std::conditional_t<kUseUE8M0, uint32_t, float>;
EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length"); EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length");
// FP8 staffs // FP8 staffs
constexpr int kNumPerChannels = FP8_QUANTIZATION_NUM_PER_CHANNEL; constexpr int kNumPerChannels = QUANTIZATION_GROUPSIZE;
constexpr int kNumScales = kHidden / kNumPerChannels; constexpr int kNumScales = kHidden / kNumPerChannels;
const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__hip_fp8_storage_t) : sizeof(hip_bfloat16)); const size_t hidden_bytes = kHidden * (kUseQuant8Bit ? sizeof(__hip_fp8_storage_t) : sizeof(hip_bfloat16));
const size_t hidden_int4 = hidden_bytes / sizeof(int4); const size_t hidden_int4 = hidden_bytes / sizeof(int4);
// Message package: hidden data, FP8 scales, index at source // Message package: hidden data, FP8 scales, index at source
// NOTES: currently we have 3 reserved int fields for future use // NOTES: currently we have 3 reserved int fields for future use
using vec_t = typename std::conditional<kUseFP8, int2, int4>::type; using vec_t = typename std::conditional<kUseQuant8Bit, int2, int4>::type;
constexpr size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + kNumScales * sizeof(float)) : (kHidden * sizeof(hip_bfloat16))); constexpr size_t num_bytes_per_msg = sizeof(int4) + (kUseQuant8Bit ? (kHidden + kNumScales * sizeof(float)) : (kHidden * sizeof(hip_bfloat16)));
EP_STATIC_ASSERT(num_bytes_per_msg % sizeof(int4) == 0, "Invalid message size"); EP_STATIC_ASSERT(num_bytes_per_msg % sizeof(int4) == 0, "Invalid message size");
constexpr size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); constexpr size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
...@@ -171,6 +227,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -171,6 +227,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// 2. The last warp for reading `topk_idx` and count for per-expert information // 2. The last warp for reading `topk_idx` and count for per-expert information
if (warp_id < num_warps) { if (warp_id < num_warps) {
constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(hip_bfloat16); constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(hip_bfloat16);
constexpr int kNumThreadPerGroup = QUANTIZATION_GROUPSIZE / kNumElemsPerRead;
// EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0); // EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize % kNumPerChannels == 0, "Invalid vectorization"); EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize % kNumPerChannels == 0, "Invalid vectorization");
const auto num_threads = (num_warps - 1) * kWarpSize; const auto num_threads = (num_warps - 1) * kWarpSize;
...@@ -186,10 +243,11 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -186,10 +243,11 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1; auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0; thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
__shared__ float int8_amaxf[kNumScales]; // 用于记录per-channel量化的amax
if constexpr(kUseInt8) { __shared__ float channel_amaxf[kNumScales];
if constexpr(kUseQuant8Bit && kQuantGroupSize == 0) {
if (thread_id < kNumScales) { if (thread_id < kNumScales) {
int8_amaxf[thread_id] = kFP8Margin; channel_amaxf[thread_id] = kFP8Margin;
} }
__syncthreads(); __syncthreads();
} }
...@@ -200,7 +258,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -200,7 +258,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Read // Read
auto int4_value = __ldg(x_int4 + i); auto int4_value = __ldg(x_int4 + i);
if constexpr(kUseFP8) { if constexpr(kUseQuant8Bit) {
// Calculate local amax // Calculate local amax
auto bf16_values = reinterpret_cast<hip_bfloat16*>(&int4_value); auto bf16_values = reinterpret_cast<hip_bfloat16*>(&int4_value);
float fp32_values[kNumElemsPerRead]; float fp32_values[kNumElemsPerRead];
...@@ -212,25 +270,20 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -212,25 +270,20 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
} }
// Reduce amax and scale // Reduce amax and scale
EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize / kNumPerChannels == 4, "Invalid vectorization"); EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize / kNumPerChannels == 4, "Invalid vectorization");
amax = warp_reduce_max<16>(amax); amax = warp_reduce_max<kNumThreadPerGroup>(amax);
const int scale_offset = i * kNumElemsPerRead / FP8_QUANTIZATION_NUM_PER_CHANNEL; const int scale_offset = i * kNumElemsPerRead / QUANTIZATION_GROUPSIZE;
if constexpr(kUseInt8) { if constexpr(kQuantGroupSize == 0) {
// 记录每128个数的最大值 // 记录每128个数的最大值
int8_amaxf[scale_offset] = fmaxf(amax, int8_amaxf[scale_offset]); channel_amaxf[scale_offset] = fmaxf(amax, channel_amaxf[scale_offset]);
} else { } else {
calculate_fp8_scales(amax, scale, scale_inv, round_scale); calculate_quant8bit_scales<kQuantType>(amax, scale, scale_inv, fp8_round_scale);
if (lane_id % 16 == 0) if (lane_id % kNumThreadPerGroup == 0)
rdma_x_scales[scale_offset] = scale_inv; rdma_x_scales[scale_offset] = scale_inv;
// Cast into send buffer // Cast into send buffer
vec_t int2_value; vec_t int2_value;
auto fp8x2_values = reinterpret_cast<__hip_fp8x2_storage_t*>(&int2_value); pack_quantized_values<kQuantType, kNumElemsPerRead>(fp32_values, scale, int2_value);
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; j += 2) {
float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale};
fp8x2_values[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
}
rdma_x_vec[i] = int2_value; rdma_x_vec[i] = int2_value;
} }
} else { } else {
...@@ -240,24 +293,24 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -240,24 +293,24 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
} }
__syncthreads(); __syncthreads();
if constexpr(kUseInt8) { if constexpr(kUseQuant8Bit && kQuantGroupSize == 0) {
float amax_per_token = kFP8Margin; float amax_per_token = kFP8Margin;
// 并行规约,计算每个token的amax // 并行规约,计算每个token的amax
for (int s = 0; s < kNumScales; s+=kWarpSize) { for (int s = 0; s < kNumScales; s+=kWarpSize) {
int src_idx = s + lane_id; int src_idx = s + lane_id;
float tmp_amaxf = 0; float tmp_amaxf = 0;
if(src_idx < kNumScales) { if(src_idx < kNumScales) {
tmp_amaxf = int8_amaxf[src_idx]; tmp_amaxf = channel_amaxf[src_idx];
} }
tmp_amaxf = warp_reduce_max<kWarpSize>(tmp_amaxf); tmp_amaxf = warp_reduce_max<kWarpSize>(tmp_amaxf);
int8_amaxf[0] = fmaxf(tmp_amaxf, int8_amaxf[0]); channel_amaxf[0] = fmaxf(tmp_amaxf, channel_amaxf[0]);
__syncthreads(); __syncthreads();
} }
amax_per_token = int8_amaxf[0]; amax_per_token = channel_amaxf[0];
// 根据最大值计算scale // 根据最大值计算scale
float scale, scale_inv; float scale, scale_inv;
calculate_int8_scales(amax_per_token, scale, scale_inv); calculate_quant8bit_scales<kQuantType>(amax_per_token, scale, scale_inv);
if (thread_id == 0) { if (thread_id == 0) {
rdma_x_scales[0] = scale_inv; rdma_x_scales[0] = scale_inv;
} }
...@@ -269,13 +322,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -269,13 +322,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Cast into send buffer // Cast into send buffer
vec_t int2_value; vec_t int2_value;
auto int8_values = reinterpret_cast<int8_t*>(&int2_value); pack_quantized_values<kQuantType, kNumElemsPerRead>(bf16_values, scale, int2_value);
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; ++ j) {
auto fp32_value = static_cast<float>(bf16_values[j]);
auto fp32_value_scaled = fp32_value * scale;
int8_values[j] = static_cast<int8_t>(nearbyintf(fp32_value_scaled));
}
rdma_x_vec[i] = int2_value; rdma_x_vec[i] = int2_value;
} }
__syncthreads(); __syncthreads();
...@@ -392,11 +439,10 @@ LOW_LATENCY_DISPATCH_RECV: ...@@ -392,11 +439,10 @@ LOW_LATENCY_DISPATCH_RECV:
} }
// 16 is the max possible number of warps in AMD GPUs // 16 is the max possible number of warps in AMD GPUs
constexpr int kMaxNumWarps = 1024 / kWarpSize;
constexpr int num_sync_large_iteration = kMaxNumWarps ; constexpr int num_sync_large_iteration = kMaxNumWarps ;
__shared__ volatile int sync_large_warp_counters[num_sync_large_iteration]; __shared__ volatile int sync_large_warp_counters[num_sync_large_iteration];
#pragma unroll #pragma unroll
for (int i = thread_id; i < num_sync_large_iteration; i += blockDim.x) { for (int i = thread_id; i < num_sync_large_iteration; i += blockDim.x) {
sync_large_warp_counters[i] = 0; sync_large_warp_counters[i] = 0;
} }
...@@ -416,7 +462,7 @@ LOW_LATENCY_DISPATCH_RECV: ...@@ -416,7 +462,7 @@ LOW_LATENCY_DISPATCH_RECV:
const auto num_aligned_scales = ALIGN<int>(kNumScales, sizeof(float) / sizeof(scale_t)); const auto num_aligned_scales = ALIGN<int>(kNumScales, sizeof(float) / sizeof(scale_t));
const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) + const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank *
(kUseInt8 ? 1 : num_aligned_scales); (kQuantType == 1 ? 1 : num_aligned_scales);
// Shared between sub-warps in warp groups // Shared between sub-warps in warp groups
__shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups]; __shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups];
...@@ -461,14 +507,14 @@ LOW_LATENCY_DISPATCH_RECV: ...@@ -461,14 +507,14 @@ LOW_LATENCY_DISPATCH_RECV:
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global); UNROLLED_WARP_COPY_LL(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);
// Copy scales // Copy scales
if constexpr(kUseFP8) { if constexpr(kUseQuant8Bit) {
const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes); const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t)); const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t));
const auto token_idx = recv_token_begin_idx + i; const auto token_idx = recv_token_begin_idx + i;
const auto token_stride = num_elems_per_pack; const auto token_stride = num_elems_per_pack;
const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack; const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack;
if constexpr(kUseInt8) { if constexpr(kQuantType == 1) {
if (lane_id == 0) { if (lane_id == 0) {
recv_x_scales[token_idx] = ld_nc_global(src_scales); recv_x_scales[token_idx] = ld_nc_global(src_scales);
} }
...@@ -500,12 +546,13 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -500,12 +546,13 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int64_t* next_clean, int num_next_clean_int, int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8, int quant_type, int quant_group_size, bool fp8_round_scale,
void* workspace, int num_device_sms, void* workspace, int num_device_sms,
hipStream_t stream, int phases) { hipStream_t stream, int phases) {
constexpr int kMaxNumWarps = 16;
constexpr int kNumMaxTopK = 11; constexpr int kNumMaxTopK = 11;
const int num_warp_groups = ceil_div(num_experts, num_device_sms); const int num_warp_groups = ceil_div(num_experts, num_device_sms);
const int num_warps_per_group = 16 / num_warp_groups; const int num_warps_per_group = kMaxNumWarps / num_warp_groups;
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0); EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0);
EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group); EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group);
...@@ -518,33 +565,54 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -518,33 +565,54 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts; auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES); EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
#define DISPATCH_LAUNCH_CASE(hidden) { \ // 限制groupsize的大小
auto dispatch_func = dispatch<false, false, false, hidden>; \ EP_HOST_ASSERT(quant_group_size == 0 || quant_group_size == 128);
if (use_fp8 and not use_ue8m0) \
dispatch_func = dispatch<true, false, false, hidden>; \ /*量化类型枚举
if (use_fp8 and use_ue8m0) \ 0 -> None 不量化,保持原始精度
dispatch_func = dispatch<true, true, false, hidden>; \ 1 -> Int8 使用 INT8 对称量化
if (use_int8) \ 2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3_FNUZ)
dispatch_func = dispatch<true, false, true, hidden>; \ 3 -> FP8_UE8M0 使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True)
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \ 4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2_FNUZ)
*/
#define DISPATCH_LAUNCH_CASE(hidden) \
{ \
auto dispatch_func = dispatch<hidden, 0, 0, kMaxNumWarps>; \
if (quant_group_size == 0) { \
switch (quant_type) { \
case 1: dispatch_func = dispatch<hidden, 1, 0, kMaxNumWarps>; break; \
case 2: dispatch_func = dispatch<hidden, 2, 0, kMaxNumWarps>; break; \
case 3: dispatch_func = dispatch<hidden, 3, 0, kMaxNumWarps>; break; \
case 4: dispatch_func = dispatch<hidden, 4, 0, kMaxNumWarps>; break; \
} \
} else { \
switch (quant_type) { \
case 1: dispatch_func = dispatch<hidden, 1, 128, kMaxNumWarps>; break; \
case 2: dispatch_func = dispatch<hidden, 2, 128, kMaxNumWarps>; break; \
case 3: dispatch_func = dispatch<hidden, 3, 128, kMaxNumWarps>; break; \
case 4: dispatch_func = dispatch<hidden, 4, 128, kMaxNumWarps>; break; \
} \
} \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \ packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \ packed_recv_src_info, packed_recv_layout_range, packed_recv_count, \
packed_recv_count, \
global_atomic_counter, \ global_atomic_counter, \
rdma_recv_x, rdma_recv_count, rdma_x, \ rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \
x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \ atomic_counter_per_expert, atomic_finish_counter_per_expert, \
next_clean, num_next_clean_int, \ next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \ num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, \ num_topk, num_experts, rank, num_ranks, \
num_warp_groups, num_warps_per_group, round_scale, phases); } break num_warp_groups, num_warps_per_group, fp8_round_scale, phases); \
} \
break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream); SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream);
SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE); SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE #undef DISPATCH_LAUNCH_CASE
} }
template <int kHidden, int kNumMaxTopk> template <int kHidden, int kNumMaxTopk, int kMaxNumWarps=16>
__global__ __launch_bounds__(16 * kWarpSize, 1) void __global__ __launch_bounds__(16 * kWarpSize, 1) void
combine(void* combined_x, combine(void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
...@@ -574,12 +642,11 @@ combine(void* combined_x, ...@@ -574,12 +642,11 @@ combine(void* combined_x,
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4; const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;
// Message package // Message package
EP_STATIC_ASSERT(kHidden % FP8_QUANTIZATION_NUM_PER_CHANNEL == 0, "Invalid hidden"); EP_STATIC_ASSERT(kHidden % QUANTIZATION_GROUPSIZE == 0, "Invalid hidden");
constexpr size_t num_bytes_per_slot = kHidden * sizeof(hip_bfloat16); constexpr size_t num_bytes_per_slot = kHidden * sizeof(hip_bfloat16);
EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization"); EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
// 16 is the max possible number of warps in AMD GPUs // 初始化用于细粒度warp间同步的计数器数组
constexpr int kMaxNumWarps = 1024 / kWarpSize;
__shared__ volatile int sync_large_warp_counters[kMaxNumWarps]; __shared__ volatile int sync_large_warp_counters[kMaxNumWarps];
if (threadIdx.x==0){ if (threadIdx.x==0){
#pragma unroll #pragma unroll
...@@ -755,9 +822,10 @@ void combine(void* combined_x, ...@@ -755,9 +822,10 @@ void combine(void* combined_x,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, int num_device_sms, hipStream_t stream, void* workspace, int num_device_sms, hipStream_t stream,
int phases, bool zero_copy) { int phases, bool zero_copy) {
constexpr int kMaxNumWarps = 16;
constexpr int kNumMaxTopk = 11; constexpr int kNumMaxTopk = 11;
const int num_warp_groups = ceil_div(num_experts, num_device_sms); const int num_warp_groups = ceil_div(num_experts, num_device_sms);
const int num_warps_per_group = 16 / num_warp_groups; // num_warps_per_group>1, "Requires more than one warp per group" const int num_warps_per_group = kMaxNumWarps / num_warp_groups; // num_warps_per_group>1, "Requires more than one warp per group"
const int num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms); const int num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms);
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm >= 0); EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm >= 0);
...@@ -770,20 +838,20 @@ void combine(void* combined_x, ...@@ -770,20 +838,20 @@ void combine(void* combined_x,
EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES); EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
EP_HOST_ASSERT(num_topk <= kNumMaxTopk); EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
#define COMBINE_LAUNCH_CASE(hidden) { \ #define COMBINE_LAUNCH_CASE(hidden) \
auto combine_func = combine<hidden, kNumMaxTopk>; \ { \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \ auto combine_func = combine<hidden, kNumMaxTopk, kMaxNumWarps>; \
combined_x, \ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \ combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \ x, topk_idx, topk_weights, src_info, layout_range, \
global_atomic_counter, \ global_atomic_counter, combine_wait_recv_cost_stats, \
combine_wait_recv_cost_stats, \
next_clean, num_next_clean_int, \ next_clean, num_next_clean_int, \
atomic_clean_flag, \ atomic_clean_flag, num_combined_tokens, hidden, \
num_combined_tokens, hidden, num_topk, \ num_topk, num_max_dispatch_tokens_per_rank, \
num_max_dispatch_tokens_per_rank, \
num_experts, rank, num_ranks, \ num_experts, rank, num_ranks, \
num_warp_groups, num_warps_per_group, phases, zero_copy); } break num_warp_groups, num_warps_per_group, phases, zero_copy); \
} \
break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream); SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream);
SWITCH_HIDDEN(COMBINE_LAUNCH_CASE); SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);
......
...@@ -341,10 +341,14 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) { ...@@ -341,10 +341,14 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
return *reinterpret_cast<dtype_t *>(recv_int_values); return *reinterpret_cast<dtype_t *>(recv_int_values);
} }
constexpr float kFP8Margin = 1e-4; // 设置不同的量化方式的最大值与相反数
constexpr float kFP8Margin = 0.0;
constexpr float kFinfoAmaxE4M3 = 240.0f; constexpr float kFinfoAmaxE4M3 = 240.0f;
constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3; constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3;
constexpr float kInt8Amax = 127.0f; constexpr float kFinfoAmaxE5M2 = 57344.0f;
constexpr float kFinfoAmaxInvE5M2 = 1.0f / kFinfoAmaxE5M2;
constexpr float kFinfoAmaxInt8 = 127.0f;
constexpr float kFinfoAmaxInvInt8 = 1.0f / 127.0f;
__forceinline__ __device__ float fast_pow2(int x) { __forceinline__ __device__ float fast_pow2(int x) {
// We can ensure `-126 <= x and x <= 127` // We can ensure `-126 <= x and x <= 127`
...@@ -359,7 +363,13 @@ __forceinline__ __device__ int fast_log2_ceil(float x) { ...@@ -359,7 +363,13 @@ __forceinline__ __device__ int fast_log2_ceil(float x) {
return exp_x - 127 + (man_bits != 0); return exp_x - 127 + (man_bits != 0);
} }
__forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv, bool round_scale) { template <int kQuantType>
__forceinline__ __device__ void calculate_quant8bit_scales(float amax, float& scale, float& scale_inv, bool round_scale=0) {
amax = fmaxf(amax, 1e-6f);
if constexpr(kQuantType == 1) { // 使用 INT8 对称量化
scale_inv = kFinfoAmaxInvInt8 * amax;
scale = kFinfoAmaxInt8 / amax;
} else if constexpr(kQuantType == 2 || kQuantType == 3) { // 使用 FP8_E4M3 或 FP8_UE8M0 非对称量化
if (round_scale) { if (round_scale) {
auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3); auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3);
scale = fast_pow2(-exp_scale_inv); scale = fast_pow2(-exp_scale_inv);
...@@ -368,11 +378,16 @@ __forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, f ...@@ -368,11 +378,16 @@ __forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, f
scale_inv = amax * kFinfoAmaxInvE4M3; scale_inv = amax * kFinfoAmaxInvE4M3;
scale = kFinfoAmaxE4M3 / amax; scale = kFinfoAmaxE4M3 / amax;
} }
} } else if constexpr(kQuantType == 4) { // 使用 FP8_E5M2 对称量化
if (round_scale) {
__forceinline__ __device__ void calculate_int8_scales(float amax, float& scale, float& scale_inv) { auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE5M2);
scale = kInt8Amax / amax; scale = fast_pow2(-exp_scale_inv);
scale_inv = amax / kInt8Amax; scale_inv = fast_pow2(exp_scale_inv);
} else {
scale_inv = amax * kFinfoAmaxInvE5M2;
scale = kFinfoAmaxE5M2 / amax;
}
}
} }
template <bool kIsUE8M0, typename out_dtype_t = std::conditional_t<kIsUE8M0, uint8_t, float>> template <bool kIsUE8M0, typename out_dtype_t = std::conditional_t<kIsUE8M0, uint8_t, float>>
......
...@@ -841,7 +841,7 @@ class Buffer: ...@@ -841,7 +841,7 @@ class Buffer:
# noinspection PyTypeChecker # noinspection PyTypeChecker
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor, def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
num_max_dispatch_tokens_per_rank: int, num_experts: int, num_max_dispatch_tokens_per_rank: int, num_experts: int,
use_fp8: bool = True, round_scale: bool = False, use_ue8m0: bool = False, use_int8: bool = False, quant_type: int = 1, quant_group_size: int = 0, fp8_round_scale: bool = False,
async_finish: bool = False, return_recv_hook: bool = False) -> \ async_finish: bool = False, return_recv_hook: bool = False) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]: Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
""" """
...@@ -858,10 +858,20 @@ class Buffer: ...@@ -858,10 +858,20 @@ class Buffer:
only several top-k shapes are supported. `-1` indices (not selecting any expert) are supported. only several top-k shapes are supported. `-1` indices (not selecting any expert) are supported.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value. num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_experts: the number of all experts. num_experts: the number of all experts.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors. 量化配置
round_scale: whether round the scaling factors into power of 2. quant_type: int 量化类型枚举
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`). 0 -> None 不量化,保持原始精度
use_int8: whether to enable INT8 casting. 1 -> Int8 使用 INT8 对称量化
2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3_FNUZ)
3 -> FP8_UE8M0 使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True)
4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2_FNUZ)
quant_group_size: int 量化分组大小
0 -> 逐token量化 (per-channel)
128-> 每 128 元素一组 (per-group) 量化
fp8_round_scale: bool 是否将 FP8 缩放因子取整为 2 的幂
true -> 缩放因子 = 2^k,硬件零开销
false -> 缩放因子 = 任意浮点,精度更高
异步配置
async_finish: the current stream will not wait for the communication kernels to be finished if set. async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
...@@ -869,15 +879,25 @@ class Buffer: ...@@ -869,15 +879,25 @@ class Buffer:
Returns: Returns:
recv_x: a tensor or tuple with received tokens for each expert. recv_x: a tensor or tuple with received tokens for each expert.
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as - packed_recv_x:
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`. 存储接收到的 Token 数据,形状为
The second tensor is the corresponding scales for the first element with shape `[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden]`。
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`, 数据类型取决于 quant_type:
if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as quant_type == 1 -> torch.int8
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`. quant_type == 2 -> torch.float8_e4m3fnuz
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility. quant_type == 3 -> torch.float8_e4m3fnuz (UE8M0 使用 E4M3 格式存储)
With `use_fp8=False`, the result would be a tensor shaped as quant_type == 4 -> torch.float8_e5m2fnuz
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`. 其他 (非量化) -> torch.bfloat16
- packed_recv_x_scales (可选):
仅在 quant_type > 0 时存在,存储量化的 Scale 值。
形状为 `[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, scales_col_size]`。
- 当 quant_type == 3 (UE8M0) 时:
scales_col_size = hidden // 512
数据类型为 torch.int (内部打包存储 4-bit scale)。
*注意:此模式强制要求 fp8_round_scale=True 且 group_size=128。
- 当 quant_type == 1, 2, 4 时:
scales_col_size = hidden // 128 (若使用 group_size) 或 1 (per-channel)。
数据类型为 torch.float32。
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are, Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced). as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
...@@ -889,14 +909,15 @@ class Buffer: ...@@ -889,14 +909,15 @@ class Buffer:
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \ packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \
self.runtime.low_latency_dispatch(x, topk_idx, self.runtime.low_latency_dispatch(x, topk_idx,
num_max_dispatch_tokens_per_rank, num_experts, num_max_dispatch_tokens_per_rank, num_experts,
use_fp8, round_scale, use_ue8m0, use_int8, quant_type, quant_group_size, fp8_round_scale,
async_finish, return_recv_hook) async_finish, return_recv_hook)
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts) handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts)
tensors_to_record = (x, topk_idx, tensors_to_record = (x, topk_idx,
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_x, packed_recv_x_scales, packed_recv_count,
packed_recv_src_info, packed_recv_layout_range) packed_recv_src_info, packed_recv_layout_range)
return (packed_recv_x, packed_recv_x_scales) if use_fp8 else packed_recv_x, packed_recv_count, handle, \
EventOverlap(event, tensors_to_record if async_finish else None), hook recv_x = (packed_recv_x, packed_recv_x_scales) if (quant_type > 0) else packed_recv_x
return recv_x, packed_recv_count, handle, EventOverlap(event, tensors_to_record if async_finish else None), hook
# noinspection PyTypeChecker # noinspection PyTypeChecker
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
......
...@@ -58,7 +58,8 @@ def test_main(num_tokens: int, ...@@ -58,7 +58,8 @@ def test_main(num_tokens: int,
x_list = [x] x_list = [x]
# # NOTES: the last one is for performance testing # # NOTES: the last one is for performance testing
# # Most of the values in the perf case is lower than the threshold, casting most channels # # Most of the values in the perf case is lower than the threshold, casting most channels
# x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1) # x_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1
# x_list = [x_rand]
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
...@@ -80,14 +81,19 @@ def test_main(num_tokens: int, ...@@ -80,14 +81,19 @@ def test_main(num_tokens: int,
hash_value, num_times = 0, 0 hash_value, num_times = 0, 0
for current_x in x_list: for current_x in x_list:
for return_recv_hook in (False, True): for return_recv_hook in (False, True):
for dispatch_use_fp8 in (False, True): for quant_type in (0, 2, 3, ): # 0: 不量化, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True)
for round_scale in (False, True) if dispatch_use_fp8 else (False,): dispatch_use_fp8 = quant_type > 0
for use_ue8m0 in (False, True) if round_scale else (False,): for fp8_round_scale in (False, True) if dispatch_use_fp8 else (False, ):
for quant_group_size in (128, ):
# 跳过不支持的情况
if quant_type == 3 and fp8_round_scale == False:
continue
num_times += 1 num_times += 1
for _ in range((num_times % 2) + 1): for _ in range((num_times % 2) + 1):
packed_recv_x, packed_recv_count, handle, event, hook = \ packed_recv_x, packed_recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts, buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0, quant_type=quant_type, fp8_round_scale=fp8_round_scale, quant_group_size=quant_group_size,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait() hook() if return_recv_hook else event.current_stream_wait()
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
...@@ -115,13 +121,13 @@ def test_main(num_tokens: int, ...@@ -115,13 +121,13 @@ def test_main(num_tokens: int,
recv_x_amin = recv_x[:, :-128].amin(dim=-1) recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens] recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1)) assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
if round_scale: if fp8_round_scale:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007 assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
else: else:
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0 assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
for j in range(num_ranks): for j in range(num_ranks):
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item() begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
if not round_scale: if not fp8_round_scale:
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item() assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0 assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if dispatch_use_fp8: if dispatch_use_fp8:
...@@ -147,7 +153,7 @@ def test_main(num_tokens: int, ...@@ -147,7 +153,7 @@ def test_main(num_tokens: int,
if do_check: if do_check:
diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
assert torch.isnan(combined_x).sum().item() == 0 assert torch.isnan(combined_x).sum().item() == 0
# if not round_scale: # if not fp8_round_scale:
assert diff < (9e-4 if dispatch_use_fp8 else 1e-5), f'Error: diff={diff}, dispatch_use_fp8={dispatch_use_fp8}, zero_copy={zero_copy}' assert diff < (9e-4 if dispatch_use_fp8 else 1e-5), f'Error: diff={diff}, dispatch_use_fp8={dispatch_use_fp8}, zero_copy={zero_copy}'
hash_value ^= hash_tensor(combined_x) hash_value ^= hash_tensor(combined_x)
...@@ -162,7 +168,8 @@ def test_main(num_tokens: int, ...@@ -162,7 +168,8 @@ def test_main(num_tokens: int,
def test_func(return_recv_hook: bool): def test_func(return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = \ recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts, buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook) quant_type=2, quant_group_size=128,
async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None large_gemm_with_hook(hook) if return_recv_hook else None
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx, topk_idx,
......
...@@ -54,16 +54,16 @@ def test_main(num_tokens: int, ...@@ -54,16 +54,16 @@ def test_main(num_tokens: int,
do_check = True do_check = True
hash_value, num_times = 0, 0 hash_value, num_times = 0, 0
for current_x in x_list: for current_x in x_list:
for return_recv_hook in (False, ): for return_recv_hook in (False, True):
for dispatch_use_fp8 in (True, ): for quant_type in (1, ):
for round_scale in (False, ): for fp8_round_scale in (False, ):
for use_ue8m0 in (False, ): for quant_group_size in (0, ):
dispatch_use_fp8 = quant_type > 0
num_times += 1 num_times += 1
use_int8 = True
for _ in range(1): for _ in range(1):
packed_recv_x, packed_recv_count, handle, event, hook = \ packed_recv_x, packed_recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts, buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0, use_int8=use_int8, quant_type=quant_type, quant_group_size=quant_group_size,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait() hook() if return_recv_hook else event.current_stream_wait()
...@@ -97,9 +97,7 @@ def test_main(num_tokens: int, ...@@ -97,9 +97,7 @@ def test_main(num_tokens: int,
assert torch.equal(recv_x_amin, recv_x_amax) assert torch.equal(recv_x_amin, recv_x_amax)
if round_scale: if quant_type == 1:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
elif use_int8:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.01 assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.01
else: else:
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0 assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
...@@ -131,7 +129,7 @@ def test_main(num_tokens: int, ...@@ -131,7 +129,7 @@ def test_main(num_tokens: int,
def test_func(return_recv_hook: bool): def test_func(return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = \ recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts, buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=True, round_scale=False, use_ue8m0=False, use_int8=True, quant_type=1, quant_group_size=0,
async_finish=False, return_recv_hook=return_recv_hook) async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None large_gemm_with_hook(hook) if return_recv_hook else None
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment